一种迁移学习模型训练方法、装置及电子设备与流程

文档序号:31026692发布日期:2022-08-06 00:56阅读:157来源:国知局
一种迁移学习模型训练方法、装置及电子设备与流程

1.本技术涉及人工智能(artificial intelligence,ai)技术领域,尤其涉及一种迁移学习模型训练方法、装置及电子设备。


背景技术:

2.时间序列数据存在于大量实际应用中,如人类活动识别中的可穿戴医疗设备数据、金融预测中的价格波动数据以及自动化系统中的各类传感器数据等。近年来,基于时间序列数据的深度学习模型已广泛应用于时间序列分类和预测中,如变分递归神经网络、长短期记忆网络和一维卷积网络等。深度学习模型通常需要大量的标记数据才能达到良好的效果。然而,相比于标注其他种类的数据,由于时间序列数据特有的连续性,标注时间序列数据往往需要更多的专家并投入大量时间精力。


技术实现要素:

3.本技术提供了一种迁移学习模型训练方法、装置、电子设备、计算机存储介质、计算机产品及芯片,能够降低时间序列数据的标注难度,以及提升对时间序列数据预测的准确度。
4.第一方面,本技术提供一种迁移学习模型训练方法,该方法包括:
5.获取第一训练样本,第一训练样本为时间序列数据,第一训练样本包括源域的样本数据和目标域的样本数据;
6.将第一训练样本输入至待训练的特征提取模型,对第一训练样本进行特征提取,得到第一特征和第二特征,第一特征为源域和目标域共同所拥有的特征,第二特征为源域和目标域各自所特有的特征;
7.将源域的样本数据对应的第一特征输入至待训练的任务分类模型,得到第一预测结果,以及,根据第一预测结果和源域的样本数据的标签,确定第一预测损失;
8.将源域和目标域各自对应的第一特征输入至待训练的域鉴别模型,得到第二预测结果,以及,根据第二预测结果和源域和目标域各自对应的第一特征的来源,确定第二预测损失,其中,在域鉴别模型中采用对抗训练对源域和目标域各自对应的第一特征进行处理;
9.将源域和目标域各自对应的第二特征输入至待训练的域分类模型,得到第三预测结果,以及,根据第三预测结果和源域和所述目标域各自对应的第二特征的来源,确定第三预测损失;
10.根据第一预测损失、第二预测损失和第三预测损失,确定总预测损失;
11.以最小化总预测损失为目标,训练特征提取模型、任务分类模型、域鉴别模型和域分类模型。
12.这样,由于在训练网络模型过程中,是基于第一预测损失、第二预测损失和第三预测损失对各个模型进行训练,而域鉴别模型是采用对抗训练以混淆各个第一特征,使得模型不知道第一特征的来源,域分类模型是区分各个第二特征,使得模型知道第二特征的来
源,任务分类模型是完成对第一特征的分类,因此,最终训练得到的网络模型既可以完成分类任务,又可以不知道其所需分类的数据的来源,从而使得训练得到的网络模型能够准确对目标数据进行预测,提升了网络模型的预测准确度。可以通过特征提取模型可以分别提取出第一特征和第二特征,从而可以对样本数据进行解耦,进而可以基于目标域和源域共用的第一特征对各个模型进行训练,提升训练得到的网络模型对目标域数据上的预测准确率。
13.在一种可能的实现方式中,特征提取模型包括第一子模型和第二子模型;
14.将第一训练样本输入至待训练的特征提取模型,对第一训练样本进行特征提取,得到第一特征和第二特征,具体包括:
15.将第一训练样本输入至第一子模型,对第一训练样本进行特征提取,得到第一特征,以及,将第一训练样本输入至第二子模型,对第一训练样本进行特征提取,得到第二特征,其中,第一子模型和第二子模型间至少存在一次数据交互。由此以便于将第一特征和第二特征解耦。
16.在一种可能的实现方式中,将第一训练样本输入至第二子模型,具体包括:
17.基于第一训练样本,增加第一训练样本中的数据量,以得到第二训练样本;
18.将第二训练样本输入至第二子模型。由此以提高第二特征对域自身信息的刻画能力,以提升对第二特征的捕获能力。
19.在一种可能的实现方式中,第一子模型的网络深度大于第二子模型的网络深度。
20.在一种可能的实现方式中,将源域的样本数据对应的第一特征输入至待训练的任务分类模型,得到第一预测结果,具体包括:
21.对源域的样本数据对应的第一特征进行超球约束,以及将超球约束后的第一特征输入至待训练的任务分类模型,得到第一预测结果。
22.在一种可能的实现方式中,在总预测损失达到最小后,方法还包括:
23.将待预测的样本数据输入至特征提取模型,得到第三特征,待预测的样本数据为时间序列数据;
24.将第三特征输入至任务分类模型,得到第四预测结果。示例性的,第三特征可以为域不变特征。
25.第二方面,本技术提供一种迁移学习模型训练装置,该装置包括:
26.获取单元,用于获取第一训练样本,第一训练样本为时间序列数据,第一训练样本包括源域的样本数据和目标域的样本数据;
27.处理单元,用于将第一训练样本输入至待训练的特征提取模型,对第一训练样本进行特征提取,得到第一特征和第二特征,第一特征为源域和目标域共同所拥有的特征,第二特征为源域和目标域各自所特有的特征;
28.处理单元,还用于将源域的样本数据对应的第一特征输入至待训练的任务分类模型,得到第一预测结果,以及,根据第一预测结果和源域的样本数据的标签,确定第一预测损失;
29.处理单元,还用于将源域和目标域各自对应的第一特征输入至待训练的域鉴别模型,得到第二预测结果,以及,根据第二预测结果和源域和目标域各自对应的第一特征的来源,确定第二预测损失,其中,在域鉴别模型中采用对抗训练对源域和目标域各自对应的第
一特征进行处理;
30.处理单元,还用于将源域和目标域各自对应的第二特征输入至待训练的域分类模型,得到第三预测结果,以及,根据第三预测结果和源域和目标域各自对应的第二特征的来源,确定第三预测损失;
31.处理单元,还用于根据第一预测损失、第二预测损失和第三预测损失,确定总预测损失;
32.训练单元,用于以最小化总预测损失为目标,训练特征提取模型、任务分类模型、域鉴别模型和域分类模型。
33.在一种可能的实现方式中,特征提取模型包括第一子模型和第二子模型;
34.处理单元在将第一训练样本输入至待训练的特征提取模型,对第一训练样本进行特征提取,得到第一特征和第二特征时,具体用于:
35.将第一训练样本输入至第一子模型,对第一训练样本进行特征提取,得到第一特征,以及,将第一训练样本输入至第二子模型,对第一训练样本进行特征提取,得到第二特征,其中,第一子模型和第二子模型间至少存在一次数据交互。
36.在一种可能的实现方式中,处理单元在将第一训练样本输入至第二子模型时,具体用于:
37.基于第一训练样本,增加第一训练样本中的数据量,以得到第二训练样本;
38.将第二训练样本输入至第二子模型。
39.在一种可能的实现方式中,第一子模型的网络深度大于第二子模型的网络深度。
40.在一种可能的实现方式中,处理单元在将源域的样本数据对应的第一特征输入至待训练的任务分类模型,得到第一预测结果时,具体用于:
41.对源域的样本数据对应的第一特征进行超球约束,以及将超球约束后的第一特征输入至待训练的任务分类模型,得到第一预测结果。
42.在一种可能的实现方式中,装置还包括:
43.应用单元,用于在总预测损失达到最小后,将待预测的样本数据输入至特征提取模型,得到第三特征,待预测的样本数据为时间序列数据;
44.应用单元,还用于将第三特征输入至任务分类模型,得到第四预测结果。
45.第三方面,本技术提供一种电子设备,包括:至少一个存储器,用于存储程序;至少一个处理器,用于执行存储器存储的程序;其中,当存储器存储的程序被执行时,处理器用于执行第一方面或第一方面的任一种可能的实现方式所描述的方法。
46.第四方面,本技术提供一种计算机可读存储介质,计算机可读存储介质存储有计算机程序,当计算机程序在处理器上运行时,使得处理器执行第一方面或第一方面的任一种可能的实现方式所描述的方法。
47.第五方面,本技术提供一种计算机程序产品,其特征在于,当计算机程序产品在处理器上运行时,使得处理器执行第一方面或第一方面的任一种可能的实现方式所描述的方法。
48.第六方面,本技术提供一种芯片,其特征在于,包括至少一个处理器和接口;至少一个处理器通过接口获取程序指令或者数据;至少一个处理器用于执行程序行指令,以实现第一方面或第一方面的任一种可能的实现方式所描述的方法。
49.可以理解的是,上述第二方面至第六方面的有益效果可以参见上述第一方面中的相关描述,在此不再赘述。
附图说明
50.图1是本技术实施例提供的一种域自适应学习方法的过程示意图;
51.图2是本技术实施例提供的一种域自适应学习方法的架构示意图;
52.图3是本技术实施例提供的一种图2中所示的耦合交互网络210的网络结构示意图;
53.图4是本技术实施例提供的一种在图2所示的架构下域自适应学习方法的过程示意图;
54.图5是本技术实施例提供的一种进行超球约束的示意图;
55.图6是本技术实施例提供的一种迁移学习模型训练方法的流程示意图;
56.图7是本技术实施例提供的一种迁移学习模型训练装置的结构示意图;
57.图8是本技术实施例提供的一种芯片的结构示意图。
具体实施方式
58.本文中术语“和/或”,是一种描述关联对象的关联关系,表示可以存在三种关系,例如,a和/或b,可以表示:单独存在a,同时存在a和b,单独存在b这三种情况。本文中符号“/”表示关联对象是或者的关系,例如a/b表示a或者b。
59.本文中的说明书和权利要求书中的术语“第一”和“第二”等是用于区别不同的对象,而不是用于描述对象的特定顺序。例如,第一响应消息和第二响应消息等是用于区别不同的响应消息,而不是用于描述响应消息的特定顺序。
60.在本技术实施例中,“示例性的”或者“例如”等词用于表示作例子、例证或说明。本技术实施例中被描述为“示例性的”或者“例如”的任何实施例或设计方案不应被解释为比其它实施例或设计方案更优选或更具优势。确切而言,使用“示例性的”或者“例如”等词旨在以具体方式呈现相关概念。
61.在本技术实施例的描述中,除非另有说明,“多个”的含义是指两个或者两个以上,例如,多个处理单元是指两个或者两个以上的处理单元等;多个元件是指两个或者两个以上的元件等。
62.首先,对本技术中涉及的技术术语进行介绍。
63.(1)域自适应(domain adaptation,da)
64.域自适应:一种迁移学习的问题设定,特指从一个源域到一个目标域的迁移。其中,源域包含了样本数据和对应的标签目标域包含了样本数据和对应的标签域自适应研究如何结合源域和目标域的数据,使得在目标域上可以得到一个准确率更高的模型。
65.(2)无监督域自适应(unsupervised domain adaptation,uda)
66.无监督域自适应:是域自适应的一种更具体的问题设定,特指目标域标签不可获得情况下的情景。无监督领域自适应任务的目标是通过利用有标记
的源域数据来提高无标记的目标域数据上预测的效果。
67.(3)耦合交互网络
68.耦合交互网络:一种特征提取网络,包含了两个有交互的子网络,域不变网络(domain invariant network,din)和域特异网络(domain specific network,dsn)。其中,域不变网络可以用于提取域不变特征,域特异网络可以用于提取域特异特征。
69.(4)域不变特征
70.域不变特征:一种在源域和目标域上的标签具有判别力,但对源域和目标域不具备判别力的特征。也即是说,域不变特征是源域和目标域共同所拥有的特征。
71.(5)域特异特征
72.域特异特征:与域不变特征相对应,一种在源域和目标域上的标签不具有判别力,但对源域和目标域具备判别力的特征。也即是说,域特异特征是源域特有的特征,和/或,目标域特有的特征。
73.(6)域判别器
74.域判别器:一个用于区分源域和目标域的分类器。本技术实施例中涉及两个域判别器,域分类器(domain classifier)和域鉴别器(domain discriminator)。
75.其中,域分类器(domain classifier)主要用于区分域特异特征的来源,比如其是来自源域还是来自目标域。域分类器的输入的数据可以为:源域上的特异特征和目标域上的域特异特征。输入的数据的标签可以是数据的来源。
76.域鉴别器(domain discriminator)主要用于区分域不变特征的来源,比如其是来自源域还是来自目标域。域鉴别器的输入的数据可以为:源域上的域不变特征和目标域上的域不变特征。输入的数据的标签可以是数据的来源。在域鉴别器中可以使用对抗训练,以使其域不变特征不具备对源域和目标域的区分能力。
77.(7)对抗训练(或对抗学习)
78.对抗训练(或对抗学习):一种深度学习中的训练范式,通过让生成器生成的数据无法被判别器准确分类实现生成模型训练。在本技术实施例中,对抗训练可以被用于域鉴别器的训练,使其域不变特征不具备对源域和目标域的区分能力。
79.(8)时间序列数据
80.时间序列数据:一种带有时间先后顺序信息的数据,例如某地一月份的气温数据就是一个时间序列数据。时间序列数据可以是一维的(如每个时间点仅有气温信息),也可以是多维的(如每个时间点包含气温和湿度信息等)。在本技术实施例中,域自适应中的样本数据和都是时间序列数据。在一些实施例中,时间序列数据可以但不限于包括传感器检测到的数据(比如电压数据、电流数据、速度数据、位姿数据等)、用户的行为数据(比如购物数据、用电数据等)、金融数据(比如价格波动数据等)等等。
81.接着,对本技术涉及的技术方案进行介绍。
82.一般的,考虑到时序数据(即前述的时间序列数据)的序列结构,目前主要是利用递归神经网络作为特征提取器,以对时序数据进行特征提取。例如,codats提出利用一维卷积神经网络conv1d作为特征提取器,算法的整体架构如图1所示。继续参阅图1,这一方法包含三个部分:一个特征提取器(feature extractor),一个域判别器(domain classifier)和一个任务分类器(task classifier)。其中,特征提取器负责从时序数据中提取特征,域
判别器用于区分特征属于哪个域,任务分类器则需要预测源域的数据的标签。该方法采用对抗训练,特征抽取器需要学习到使域判别器无法正确区分源域和目标域的域不变特征。但codats这类算法的问题在于仅使用了单网络捕捉域不变特征,而没有考虑源域和目标域特有的特征,这导致该方法提取出的特征中仍存在域特异特征,致使目标域数据上的预测准确性下降。因此,如何捕获源域和目标域中域特有的特征及不随领域变化的特征是至关重要的。同时,该方法使用传统的对抗学习的思路,存在任务分类器的分类边界过近,模型泛化能力较差的问题。
83.有鉴于此,本技术实施例提出了一种域自适应学习方法,可以通过两个有交互的网络(即前述的耦合交互网络)分别提取出域不变和域特异特征,以对时序数据进行解耦,从而得到域不变特征及域特异特征,进而可以基于域不变特征和源域的标签训练得到所需的分类器,提升目标域数据上的预测准确率。另外,在学习域特异特征时,可以对时序数据进行裁剪拼接、随机打乱等操作进行数据扩增,从而提高对域自身信息的捕获能力。此外,在源域上可以将不同类别的样本驱赶到不同的超球内,实现类内样本的聚拢和类间样本的分隔,以提升模型的泛化能力。
84.示例性的,图2示出了本技术实施例提供的一种域自适应学习方法的架构示意图。如图2所示,该架构下主要包括:耦合交互网络210、任务分类器220、域鉴别器230和域分类器240。其中,任务分类器220、域鉴别器230和域分类器240均可以根据耦合交互网络210提取到的域不变特征和/或域特异特征训练得到。示例性的,任务分类器220、域鉴别器230和域分类器240均可以但不限于是基于多层感知机(multilayer perceptron,mlp)得到的神经网络。
85.域不变网络211可以用于提取域不变特征,域特异网络212可以用于提取域特异特征,两者在提取特征过程中可以进行交互。耦合交互网络210可以包括域不变网络211和域特异网络212。通过域不变网络211可以提取到源域的域不变特征和目标域的域不变特征;通过域特异网络212可以提取到源域的域特异特征和目标域的域特异特征。
86.在一些实施例中,由于域不变特征倾向于捕捉通用的、深层的特性,这些特征很难直接被发现并从数据中提取,而域特异特征的特性往往是一些局部的、浅层的信息,更容易被网络所捕捉。因此,可以使用浅层神经网络来捕获浅层的域特异特征,使用深层神经网络捕获深层的域不变特征。另外,理解每个领域的特性可以更好地提取域不变特征和域特异特征,因此,通过在两个网络中添加不同层之间的交互,可以使两个网络相互通信以共享其信息,以便更好的提取到域不变特征和域特异特征。
87.示例性的,图3示出了图2中所示的耦合交互网络210的一种网络结构示意图。如图3所示,耦合交互网络210中的域不变网络211是由4层卷积组成的深度较深的网络,其包括卷积层2111、2112、2113和2114,以及全连接层2115;域特异网络212是由2层卷积组成的深度较浅的网络,其包括卷积层2121和2122,以及全连接层2123。
88.在域不变网络211和域特异网络212间存在交互操作。其中,域不变网络211中的卷积层2112输出的中间变量可以作为卷积层2113的输入,也可以作为域特异网络212中卷积层2122的输入;域特异网络212中的卷积层2121输出的中间变量可以作为卷积层2122的输入,也可以作为域不变网络211中卷积层2113的输入,由此以实现两个网络间的交互操作。应理解的是,图3中对域不变网络211和域特异网络212的结构仅是示意性说明,并不构成对
本技术中技术方案的限定,其他深度的域不变网络211和域特异网络212均在本技术的保护范围之内。
89.在一些实施例中,域不变网络211和域特异网络212间的交互操作可表示为其中,表示域特异网络212的第k层输出,表示域不变网络211的第j层输出,f
k,j
是转换函数,f
k,j
可以通过一个神经网络或其他形式实现,运算符表示两个特征的简单拼接(或者按元素相加)其中,可以根据分类任务的复杂程度进行调整,此处不做限定。两个网络的深度和交互次数可以根据分类任务的复杂程度进行调整,此处不做限定。
90.在一些实施例中,继续参阅图2和图3,在将域不变网络211和域特异网络212设计为可以进行交互操作后,在训练过程中梯度信息会在两个网络之间传导。这种相互作用将导致域不变表示和域特异表示的混淆。为了解决这个问题,可以在交互层和导出信息的网络输出层之间添加一个残差连接(即图3中的虚线部分)。通过这种方式,残差连接中所涉及的目标函数的梯度可以直接影响交互层和交互之前的底层网络,从而使得交互前的特征倾向于保留更多的自身特征。
91.在一些实施例中,继续参阅图2和图3,可以将源域和目标域的时序数据(即源域数据和目标域数据)输入到域不变网络211中,以通过域不变网络211提取到源域的域不变特征和目标域的域不变特征。同样的,可以将源域数据和目标域数据输入到域特异网络212中,以通过域特异网络212提取到源域的域特异特征和目标域的域特异特征。在一些实施例中,为了提高域特异特征对域自身信息的刻画能力,以提升对域特异特征的捕获能力,可以先分别对源域数据和目标域数据进行数据扩增,再将扩增后的数据输入到域特异网络212中。示例性的,在对源域数据进行扩增时可以但不限于采用以下一项或多项方式进行扩增:a、对源域数据进行随机裁剪,并将裁剪出的子序列拼接成为一个新的时间序列,以实现对源域数据的数据扩增;b、随机打乱源域数据中一部分子序列的顺序;c、在源域数据中的时间序列上的每个值添加高斯噪声;d、随机翻转源域数据中时间序列中一部分子序列。在对目标域数据进行扩增时,可以参考对源域数据扩增的方式,此处不再赘述。
92.在域不变网络211提取到源域的域不变特征和目标域的域不变特征后,可以采用对抗训练的方式,并根据源域的域不变特征和目标域的域不变特征对域鉴别器230进行训练,从而得到所需的域鉴别器。示例性的,可以根据源域和目标域各自的域不变特征和各个域不变特征的来源,对域鉴别器230进行训练。
93.另外,也可以根据源域的域不变特征,对任务分类器220进行训练,以得到所需的任务分类器。示例性的,可以根据源域的域不变特征和源域中样本数据的标签,对域鉴别器230进行训练。在一些实施例中,为了实现类内样本的聚拢和类间样本的分隔,以提升模型的泛化能力,以及提升域不变特征的区分度,可以对训练任务分类器220所需的数据进行超球约束。其中,在进行超球约束时,可以将同一标签对应的各个样本数据归入到一个超球内,且不同的标签对应的超球之间分离,由此实现了对同一标签下的样本数据的聚拢和不同标签下的域不变特征间的分离,进而使得训练得到的任务分类器220可以准确地预测目标域上的标签。
94.在域特异网络212提取到源域的域特异特征和目标域的域特异特征后,可以根据源域的域特异特征和目标域的域特异特征对域分类器240进行训练,从而得到所需的域分
类器。示例性的,可以根据源域和目标域各自的域特异特征和各个域特异特征的来源,对域分类器240进行训练。
95.示例性的,图4示出了一种在图2所示的架构下域自适应学习方法的过程示意图。如图4所示,先将源域数据(即源域的时序数据)xs和目标域数据(即目标域的时序数据)x
t
输入到耦合交互网络210(即图中所示的coupled interactive network)。
96.输入到域不变网络211(即图中所示的domain-invariant network)的源域数据和目标域数据,经过域不变网络211的特征提取后,可以得到源域的域不变特征和目标域的域不变特征将和各个特征的来源作为域鉴别器230(即图中所示的domain discriminator)的输入,并采用对抗训练的方式对域鉴别器230进行训练,即可以得到所需的域鉴别器。将和源域数据中的样本标签作为任务分类器220(即图中所示的task classifier)的输入,对任务分类器220进行训练,即可以得到所需的任务分类器。其中,在将作为任务分类器220的输入时,可以先对进行超球约束(即图中所示的class-wise hypersphere),以将不同类别的样本数据(比如不同标签对应的域不变特征等)驱赶到不同的超球内,最后可以将进行超球约束后的数据作为任务分类器220的输入。在进行超球约束时,所采用的目标函数可以表示为:
[0097][0098]
其中,表示按(xs,ys)服从分布求期望,hv(xs)为域不变特征,表示类别为ys的圆心。
[0099]
输入到域特异网络212(即图中所示的domain-specific network)的源域数据和目标域数据,可以先进行数据扩增(即图中所示的date augmentation),然后再将扩增后的数据输入到域特异网络212。经过域特异网络212的特征提取后,可以得到源域的域特异特征和目标域的域特异特征将和各个特征的来源作为域分类器240(即图中所示的domain classifier)的输入,并对域分类器240进行训练,即可以得到所需的域分类器。
[0100]
可以理解的是,上述在对任务分类器220、域鉴别器230和域分类器240的训练过程中,同时也可以对耦合交互网络210进行训练,从而使得训练得到的模型可以对识别到目标域数据的标签,进而完成由源域向目标域的迁移学习。
[0101]
在一些实施例中,继续参阅4,在训练过程中,在训练任务分类器220(即图中所示的task classifier)时,通过任务分类器220中的神经网络对源域的域不变特征进行预测,以得到预测标签(即预测结果),以及利用一个损失函数对预测标签和源域的样本数据的标签(即输入到任务分类器220中的标签)进行处理,得到一个预测损失(以下简称“预测损失一”)。
[0102]
在训练域鉴别器230(即图中所示的domain discriminator)时,通过域鉴别器230中的神经网络对源域和目标域的域不变特征进行对抗训练,并预测,可以预测得到各个域不变特征的来源,以及利用一个损失函数对预测得到各个域不变特征的来源和各个域不变特征的真实来源(即输入到域鉴别器230中的各个域不变特征的来源)进行处理,得到一个预测损失(以下简称“预测损失一”)。
[0103]
在训练域分类器240(即图中所示的domain classifier)时,通过域分类器240中
的神经网络对源域和目标域的域特异特征进行预测,可以预测得到各个域特异特征的来源,以及利用一个损失函数对预测得到各个域特异特征的来源和各个域特异特征的真实来源(即输入到域分类器240中的各个域特异特征的来源)进行处理,得到一个预测损失(以下简称“预测损失三”)。
[0104]
然后,可以根据预测损失一、预测损失二和预测损失三,确定出一个总预测损失。比如,将预测损失一、预测损失二和预测损失三相加之和作为总预测损失等。
[0105]
最后,可以以最小化总预测损失为目标,对域不变网络211、域特异网络212、任务分类器220、域鉴别器230和域分类器240进行训练,以完成得到所需的网络模型。其中,训练得到的网络模型中至少包括:域不变网络211和任务分类器230。另外,由于在训练网络模型过程中,是基于预测损失一、预测损失二和预测损失三对域不变网络211和任务分类器230进行训练,而域鉴别器230是混淆各个域不变特征,使得模型不知道域不变特征的来源,域分类器240是区分各个域特异特征,使得模型知道域特异特征的来源,任务分类器230是完成对域不变特征的分类,因此,最终训练得到的网络模型既可以完成分类任务,又可以不知道其所需分类的数据的来源,从而使得网络模型能够准确对目标数据进行预测,提升了网络模型的预测准确度。
[0106]
在一些实施例中,为了实现域不变特征和域特异特征解耦的目标,在本技术实施例中添加了域鉴别器230和域分类器240。域鉴别器230的目的是让目标域的域不变特征与源域的域不变特征无法被区分。相反,域分类器240的目的是区分域特异网络生成的特征(即域特异特征)属于哪一个域。真正有效的域不变特征可以预测每个域中数据的标签,因此,可以利用来自源域的标签训练任务分类器220,并进一步为域鉴别器230增加了类别感知的超球模型(即基于任务分类器220对应的损失结果对域鉴别器230进行训练),以克服域鉴别器230带来的分类效果下降问题。具体地,如图5的(a)所示,可以在域不变特征的空间中,要求来自于同一类的源域样本被映射到同一超球(簇)中,从而使得对抗学习不能缩小不同类之间的距离,这样就保证了模型有较大的分类边界。这一约束只针对有标记的源域数据。对于目标域数据,如图5的(b)所示,一旦源域数据被映射到几个超球面中,每个目标域数据也会被放入其中一个超球中(如图5的(c)所示),以便混淆域鉴别器230。最终,所有源数据和目标数据都会被投影到超球中,同时保证了较大的分类边界(如图5的(d)所示)。
[0107]
接下来,基于上文所描述的内容,对本技术实施例提供的一种迁移学习模型训练方法进行介绍。可以理解的是,该方法是基于上文所描述的内容提出,该方法中的部分或全部内容可以参见上文中的描述。
[0108]
请参阅图6,图6是本技术实施例提供的一种迁移学习模型训练方法的流程示意图。可以理解,该方法可以通过任何具有计算、处理能力的装置、设备、平台、设备集群来执行。
[0109]
如图6所示,该迁移学习模型训练方法包括:
[0110]
s601、获取第一训练样本,第一训练样本为时间序列数据,第一训练样本包括源域的样本数据和目标域的样本数据。
[0111]
s602、将第一训练样本输入至待训练的特征提取模型,对第一训练样本进行特征提取,得到第一特征和第二特征,第一特征为源域和目标域共同所拥有的特征,第二特征为源域和目标域各自所特有的特征。示例性的,特征提取模型可以为上文所描述的耦合交互
网络210,第一特征可以为上文所描述的域不变特征,第二特征可以为上文所描述的域特异特征。
[0112]
作为一种可能的实现方式,特征提取模型可以包括第一子模型和第二子模型。此时,可以将第一训练样本输入至第一子模型,对第一训练样本进行特征提取,得到第一特征,以及,将第一训练样本输入至第二子模型,对第一训练样本进行特征提取,得到第二特征,其中,第一子模型和第二子模型间至少存在一次数据交互。示例性的,第一子模型可以为上文所描述的域不变网络211,第二子模型可以为上文所描述的域特异网络212。在一些实施例中,在将第一训练样本输入第二子模型时,可以向基于第一训练样本,增加第一训练样本中的数据量,以得到第二训练样本,然后再将第二训练样本输入至第二子模型。由此以提高第二特征对域自身信息的刻画能力,以提升对第二特征的捕获能力。
[0113]
s603、将源域的样本数据对应的第一特征输入至待训练的任务分类模型,得到第一预测结果,以及,根据第一预测结果和源域的样本数据的标签,确定第一预测损失。示例性的,源域的样本数据对应的第一特征可以为上文所描述的源域的域不变特征,任务分类模型可以为上文所描述的任务分类器220,第一预测损失可以为上文所描述的预测损失一。
[0114]
作为一种可能的实现方式,可以先对源域的样本数据对应的第一特征进行超球约束,然后在将超球约束后的第一特征输入至待训练的任务分类模型,得到第一预测结果。
[0115]
s604、将源域和目标域各自对应的第一特征输入至待训练的域鉴别模型,得到第二预测结果,以及,根据第二预测结果和源域和目标域各自对应的第一特征的来源,确定第二预测损失,其中,在域鉴别模型中采用对抗训练对源域和目标域各自对应的第一特征进行处理。示例性的,域鉴别模型可以为上文所描述的域鉴别器230,源域和目标域各自对应的第一特征的来源可以是指各个第一特征是来自源域还是来自目标域,第一预测损失可以为上文所描述的预测损失二。
[0116]
s605、将源域和目标域各自对应的第二特征输入至待训练的域分类模型,得到第三预测结果,以及,根据第三预测结果和源域和目标域各自对应的第二特征的来源,确定第三预测损失。示例性的,域分类模型可以为上文所描述的域分类器240,源域和目标域各自对应的第二特征的来源可以是指各个第二特征是来自源域还是来自目标域,第三预测损失可以为上文所描述的预测损失三。
[0117]
s606、根据第一预测损失、第二预测损失和第三预测损失,确定总预测损失。示例性的,可以但不限于将第一预测损失、第二预测损失和第三预测损失之和作为总预测损失。
[0118]
s607、以最小化总预测损失为目标,训练特征提取模型、任务分类模型、域鉴别模型和域分类模型。
[0119]
这样,由于在训练网络模型过程中,是基于第一预测损失、第二预测损失和第三预测损失对各个模型进行训练,而域鉴别模型是采用对抗训练以混淆各个第一特征,使得模型不知道第一特征的来源,域分类模型是区分各个第二特征,使得模型知道第二特征的来源,任务分类模型是完成对第一特征的分类,因此,最终训练得到的网络模型既可以完成分类任务,又可以不知道其所需分类的数据的来源,从而使得网络模型能够准确对目标数据进行预测,提升了网络模型的预测准确度。可以通过特征提取模型可以分别提取出第一特征和第二特征,从而可以对样本数据进行解耦,进而可以基于目标域和源域共用的第一特征对各个模型进行训练,提升训练得到的模型对目标域数据上的预测准确率。
[0120]
在一些实施例中,在前述的总预测损失达到最小后(即完成模型训练后),可以将待预测的样本数据输入至特征提取模型,得到第三特征,待预测的样本数据为时间序列数据;然后,在将第三特征输入至任务分类模型,得到第四预测结果。由此以完成对待预测的样本数据的预测。示例性的,第三特征可以理解为是上文所描述的域不变特征。
[0121]
基于上述实施例中的方法,本技术实施例提供了一种迁移学习模型训练装置。请参阅图7,图7是本技术实施例提供的一种迁移学习模型训练装置的结构示意图。如图7所示,该迁移学习模型训练装置700包括:获取单元710、处理单元720和训练单元730。
[0122]
其中,获取单元710可以用于获取第一训练样本,第一训练样本为时间序列数据,第一训练样本包括源域的样本数据和目标域的样本数据。
[0123]
处理单元720可以用于将第一训练样本输入至待训练的特征提取模型,对第一训练样本进行特征提取,得到第一特征和第二特征,第一特征为源域和目标域共同所拥有的特征,第二特征为源域和目标域各自所特有的特征。
[0124]
处理单元720还可以用于将源域的样本数据对应的第一特征输入至待训练的任务分类模型,得到第一预测结果,以及,根据第一预测结果和源域的样本数据的标签,确定第一预测损失。
[0125]
处理单元720还可以用于将源域和目标域各自对应的第一特征输入至待训练的域鉴别模型,得到第二预测结果,以及,根据第二预测结果和源域和目标域各自对应的第一特征的来源,确定第二预测损失,其中,在域鉴别模型中采用对抗训练对源域和目标域各自对应的第一特征进行处理;
[0126]
处理单元720还可以用于将源域和目标域各自对应的第二特征输入至待训练的域分类模型,得到第三预测结果,以及,根据第三预测结果和源域和目标域各自对应的第二特征的来源,确定第三预测损失;
[0127]
处理单元720还可以用于根据第一预测损失、第二预测损失和第三预测损失,确定总预测损失;
[0128]
训练单元730可以用于以最小化总预测损失为目标,训练特征提取模型、任务分类模型、域鉴别模型和域分类模型。
[0129]
在一些实施例中,特征提取模型包括第一子模型和第二子模型。处理单元720在将第一训练样本输入至待训练的特征提取模型,对第一训练样本进行特征提取,得到第一特征和第二特征时,具体可以用于:将第一训练样本输入至第一子模型,对第一训练样本进行特征提取,得到第一特征,以及,将第一训练样本输入至第二子模型,对第一训练样本进行特征提取,得到第二特征,其中,第一子模型和第二子模型间至少存在一次数据交互。
[0130]
在一些实施例中,处理单元720在将第一训练样本输入至第二子模型时,具体可以用于:基于第一训练样本,增加第一训练样本中的数据量,以得到第二训练样本;将第二训练样本输入至第二子模型。
[0131]
在一些实施例中,第一子模型的网络深度大于第二子模型的网络深度。
[0132]
在一些实施例中,处理单元720在将源域的样本数据对应的第一特征输入至待训练的任务分类模型,得到第一预测结果时,具体可以用于:对源域的样本数据对应的第一特征进行超球约束,以及将超球约束后的第一特征输入至待训练的任务分类模型,得到第一预测结果。
[0133]
在一些实施例中,该迁移学习模型训练装置700还可以包括:应用单元(图中未示出),该应用单元可以用于在总预测损失达到最小后,将待预测的样本数据输入至特征提取模型,得到第三特征,待预测的样本数据为时间序列数据;以及将第三特征输入至任务分类模型,得到第四预测结果。
[0134]
应当理解的是,上述装置用于执行上述实施例中的方法,装置中相应的程序模块,其实现原理和技术效果与上述方法中的描述类似,该装置的工作过程可参考上述方法中的对应过程,此处不再赘述。
[0135]
基于上述实施例中的方法,本技术实施例提供了一种电子设备。该电子设备可以包括:至少一个存储器,用于存储程序;至少一个处理器,用于执行存储器存储的程序;其中,当存储器存储的程序被执行时,处理器用于执行上述实施例中的方法。
[0136]
基于上述实施例中的方法,本技术实施例提供了一种计算机可读存储介质,计算机可读存储介质存储有计算机程序,当计算机程序在处理器上运行时,使得处理器执行上述实施例中的方法。
[0137]
基于上述实施例中的方法,本技术实施例提供了一种计算机程序产品,其特征在于,当计算机程序产品在处理器上运行时,使得处理器执行上述实施例中的方法。
[0138]
基于上述实施例中的方法,本技术实施例还提供了一种芯片。请参阅图8,图8为本技术实施例提供的一种芯片的结构示意图。如图8所示,芯片800包括一个或多个处理器801以及接口电路802。可选的,芯片800还可以包含总线803。其中:
[0139]
处理器801可能是一种集成电路芯片,具有信号的处理能力。在实现过程中,上述方法的各步骤可以通过处理器801中的硬件的集成逻辑电路或者软件形式的指令完成。上述的处理器801可以是通用处理器、数字通信器(dsp)、专用集成电路(asic)、现场可编程门阵列(fpga)或者其它可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。可以实现或者执行本技术实施例中的公开的各方法、步骤。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
[0140]
接口电路802可以用于数据、指令或者信息的发送或者接收,处理器801可以利用接口电路802接收的数据、指令或者其它信息,进行加工,可以将加工完成信息通过接口电路802发送出去。
[0141]
可选的,芯片800还包括存储器,存储器可以包括只读存储器和随机存取存储器,并向处理器提供操作指令和数据。存储器的一部分还可以包括非易失性随机存取存储器(nvram)。
[0142]
可选的,存储器存储了可执行软件模块或者数据结构,处理器可以通过调用存储器存储的操作指令(该操作指令可存储在操作系统中),执行相应的操作。
[0143]
可选的,接口电路802可用于输出处理器801的执行结果。
[0144]
需要说明的,处理器801、接口电路802各自对应的功能既可以通过硬件设计实现,也可以通过软件设计来实现,还可以通过软硬件结合的方式来实现,这里不作限制。
[0145]
应理解,上述方法实施例的各步骤可以通过处理器中的硬件形式的逻辑电路或者软件形式的指令完成。
[0146]
可以理解的是,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本技术实施例的实施过程构成任
何限定。此外,在一些可能的实现方式中,上述实施例中的各步骤可以根据实际情况选择性执行,可以部分执行,也可以全部执行,此处不做限定。
[0147]
可以理解的是,本技术的实施例中的处理器可以是中央处理单元(central processing unit,cpu),还可以是其他通用处理器、数字信号处理器(digital signal processor,dsp)、专用集成电路(application specific integrated circuit,asic)、现场可编程门阵列(field programmable gate array,fpga)或者其他可编程逻辑器件、晶体管逻辑器件,硬件部件或者其任意组合。通用处理器可以是微处理器,也可以是任何常规的处理器。
[0148]
本技术的实施例中的方法步骤可以通过硬件的方式来实现,也可以由处理器执行软件指令的方式来实现。软件指令可以由相应的软件模块组成,软件模块可以被存放于随机存取存储器(random access memory,ram)、闪存、只读存储器(read-only memory,rom)、可编程只读存储器(programmable rom,prom)、可擦除可编程只读存储器(erasable prom,eprom)、电可擦除可编程只读存储器(electrically eprom,eeprom)、寄存器、硬盘、移动硬盘、cd-rom或者本领域熟知的任何其它形式的存储介质中。一种示例性的存储介质耦合至处理器,从而使处理器能够从该存储介质读取信息,且可向该存储介质写入信息。当然,存储介质也可以是处理器的组成部分。处理器和存储介质可以位于asic中。
[0149]
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本技术实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者通过所述计算机可读存储介质进行传输。所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(dsl))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,dvd)、或者半导体介质(例如固态硬盘(solid state disk,ssd))等。
[0150]
可以理解的是,在本技术的实施例中涉及的各种数字编号仅为描述方便进行的区分,并不用来限制本技术的实施例的范围。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1