基于图卷积网络的无监督域自适应的分类方法

文档序号:30580820发布日期:2022-06-29 12:05阅读:226来源:国知局
基于图卷积网络的无监督域自适应的分类方法

1.本技术涉及深度学习技术领域,特别是涉及一种基于图卷积网络的无监督域自适应的分类方法。


背景技术:

2.无监督域自适应任务,即利用源域的信息来辅助完成目标域上的任务,其中源域的样本是标记的或者是部分标记的,目标域的样本是无标记的。无监督域自适应的主要挑战是如何对齐源域和目标域的数据分布。
3.对于无监督域适应任务,一般深度学习方法通常会将源域和目标域中的样本转换到同一公共空间中。例如曾等人为了减少公共空间中源域和目标域之间的分布差异在网络中共享参数层上设计了最大平均差异(mmd,maximum mean discrepancy)损失。ganin等人设计了一个域鉴别器来区分每个样本来自哪个域,并提出了一个梯度反转层(grl,gradient reversal layer)来最大化域分类损失以减少域之间的分布差异。丁等人提出了一种自适应探索(ae,adaptive exploration)方法,通过最大化所有行人图像之间的距离并最小化相似行人图像之间的距离来解决行人重识别的域转移问题。尽管深度学习方法在减少域差异方面取得了一些进展,但源域中的标签率还是会影响无监督域适应任务的预测结果。源域中的标签率越低,目标域的预测结果越差。
4.随着图神经网络的引入,kipf等人提出的图卷积网络(gcn,graph convolutional network)在半监督分类任务中取得了理想的结果。在领域自适应任务中,给定少量标记的源数据,图卷积网络通常能够通过在源网络中传播样本信息来构建一个性能良好的分类器。例如戴等人结合图卷积网络和对抗域自适应模型来减少分布差异并进行准确的标签预测。
5.现有的基于图卷积网络的无监督域自适应的分类方法关注于两个域之间的公共信息,而没有去对域的特定信息加以利用。而且没有进一步关注类别级别的分布对齐问题,这可能会导致跨域的同一类样本的分布负对齐,并且可能不利于目标域的任务,从而导致训练的基于图卷积的无监督域自适应分类模型的性能低。


技术实现要素:

6.基于此,有必要针对上述技术问题,提供一种能够提高训练的基于图卷积的无监督域自适应分类模型性能的基于图卷积网络的无监督域自适应的分类方法。
7.一种基于图卷积网络的无监督域自适应的分类方法,所述方法包括:
8.获取源域中的样本数据和目标域中样本数据作为训练数据;
9.根据所述源域和所述目标域中样本数据间的相似性分别更新两个域中样本的图连接关系;
10.将所述源域和所述目标域中的样本数据输入到域自适应网络中进行训练,所述域自适应网络是基于图卷积网络的无监督域自适应网络,所述域自适应网络包括:跨域特征
提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型;
11.训练所述域自适应网络不断更新迭代所述域自适应网络中的参数,当所述域自适应网络达到收敛条件时,获得域自适应分类模型;
12.输入待分类数据至所述域自适应分类模型进行分类,获得所述待分类数据的分类结果。
13.在其中一个实施例中,所述跨域特征提取模型提取所述源域和所述目标域公共的样本特征,所述源域特征提取模型提取所述源域特定的样本特征,所述分类模型计算分类损失值,所述域对抗鉴别模型计算域特征对齐损失值,所述类对齐模型计算类特征对齐损失值。
14.在其中一个实施例中,所述总损失值为特征差异性损失值、分类损失值、域特征对齐损失值和类特征对齐损失值的和,其中,所述特征差异性损失值是所述源域的样本数据输入跨域特征提取模型和所述源域特征提取模型得到的特征差异,分类损失值是基于源域的样本数据输入分类模型的。
15.在其中一个实施例中,所述域自适应网络的构建方式包括:
16.源域的样本数据和目标域的样本数据输入到跨域特征提取模型中得到源域和目标域的公共嵌入特征表示;
17.源域的样本数据输入到源域特征提取模型中得到所述源域的特定嵌入特征表示;
18.计算所述源域的公共嵌入特征表示和所述特定嵌入特征表示的差异性构建特征差异性损失函数;
19.将目标域的样本数据输入到源域特征提取模型中得到带源域风格的目标域嵌入特征表示,再通过注意力机制与所述目标域的公共嵌入特征表示结合为目标域的嵌入特征表示;同时通过注意力机制将所述源域的公共嵌入特征表示与源域的特定嵌入特征表示结合为源域的嵌入特征表示;
20.将得到的所述源域和所述目标域的嵌入特征表示输入到分类模型中,所述源域的嵌入特征表示中有类别标签的部分构建分类损失函数,源域的其余无类别标签的部分和目标域的嵌入特征表示则生成其特征表示所对应的伪类别标签;
21.将得到的所述源域和所述目标域的公共嵌入特征表示输入到域对抗鉴别模型中,构建域特征对齐损失函数;
22.将所述源域的样本数据和所述目标域的样本数据按照类别标签和伪类别标签中的类别进行分组,同时将不同分组的样本的嵌入特征表示输入到类对齐模型中,构建类特征对齐损失函数。
23.在其中一个实施例中,所述跨域特征提取模型是由两层的图卷积神经网络的共享网络组成的,所述源域的样本数据与目标域的样本数据都输入到共享网络中得到其公共嵌入特征表示;
24.所述源域特征提取模型是由两层的图卷积神经网络模型组成的。
25.在其中一个实施例中,所述特征差异性损失函数为:
[0026][0027]
式中,表示源域的公共嵌入特征表示,表示源域的特定嵌入特征表示,lm表
示特征差异性损失函数,t表示转置运算。
[0028]
在其中一个实施例中,所述分类损失函数为:
[0029][0030]
式中,表示源域的有类别标签样本的嵌入特征表示,为分类模型测得的分类结果,为源域属于第k类的类别标签,k∈[1,c],c为样本的总类别数,n
sl
为源域中有类别标签的样本个数,ls表示分类损失函数。
[0031]
在其中一个实施例中,所述域特征对齐损失函数为:
[0032][0033]
式中,z
ci
表示第i个公共嵌入特征表示,gd(zi)为域对抗鉴别模型测得的结果,为输入的公共特征表示属于源域的域标签还是目标域的域标签,ns为源域的样本总个数,n
t
为目标域的样本总个数,ld表示域特征对齐损失函数。
[0034]
在其中一个实施例中,所述类特征对齐损失函数为:
[0035][0036]
式中,lc表示类特征对齐损失函数,表示类别标签或者伪类别标签为第k类样本的源域嵌入特征表示,表示伪类别标签为第k类样本的目标域嵌入特征表示,和分别为和的概率分布,c为样本的总类别数。
[0037]
在其中一个实施例中,所述不断更新迭代所述域自适应网络中的参数的表达式为:
[0038]
min(ls+λl
m-βld+γlc)
[0039]
式中,lm表示特征差异性损失函数,ls表示分类损失函数,ld表示域特征对齐损失函数,lc表示类特征对齐损失函数,λ,β和γ分别为对应损失函数之间的平衡因子。
[0040]
上述基于图卷积网络的无监督域自适应的分类方法,通过获取源域中的样本数据和目标域中样本数据作为训练数据;根据源域和目标域中样本数据间的相似性分别更新两个域中样本的图连接关系;将源域和目标域中的样本数据输入到域自适应网络中进行训练,域自适应网络是基于图卷积网络的无监督域自适应网络,域自适应网络包括:跨域特征提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型;训练域自适应网络不断更新迭代域自适应网络中的参数,当域自适应网络达到收敛条件时,获得域自适应分类模型;输入待分类数据至域自适应分类模型进行分类,获得待分类数据的分类结果。提高了基于图卷积的无监督域自适应分类模型性能,进一步提高对数据分类的准确性。
附图说明
[0041]
图1为一个实施例中基于图卷积网络的无监督域自适应的分类方法的流程示意图;
[0042]
图2为一个实施例中域自适应网络的构建方式的流程示意图。
具体实施方式
[0043]
为了使本技术的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本技术进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本技术,并不用于限定本技术。
[0044]
本技术提供的基于图卷积网络的无监督域自适应的分类方法,可以应用于终端或服务器。其中,终端可以但不限于是各种个人计算机、笔记本电脑、智能手机、平板电脑和便携式可穿戴设备,服务器可以用独立的服务器或者是多个服务器组成的服务器集群来实现。
[0045]
在一个实施例中,如图1所示,提供了一种基于图卷积网络的无监督域自适应的分类方法,以该方法应用于终端为例进行说明,包括以下步骤:
[0046]
步骤s220,获取源域中的样本数据xs和目标域中样本数据x
t
作为训练数据。
[0047]
其中,源域中的样本数据和目标域中样本数据的类别以及类别数相同。源域中的部分样本数据有类别标签,目标域中样本数据没有类别标签。类别标签是指用于标记样本属于哪一类别的标签。样本数据的类型可以是文本数据,也可以是图片数据,还可以是音频数据,根据分类任务的需要,确定样本数据的类型。如:需要训练用于分类论文属于哪个学科的分类模型时,可以将标记有属于哪个学科的论文作为源域中的样本数据,将没有标记属于哪个学科的论文作为目标域中样本数据,执行步骤s240至步骤s280,获得用于对论文属于哪个学科进行分类的域自适应分类模型。
[0048]
步骤s240,根据源域和目标域中样本数据间的相似性分别去更新两个域中样本的图连接关系as和a
t

[0049]
其中,使用正点互信息(ppmi,positive pointwise mutual information)来计算样本数据之间的相似性。ppmi的计算公式如下:
[0050][0051]
式中,其中n为一个域内的样本个数,a
ij
为样本i和样本j连接的权重系数,ppmi
ij
为样本i和样本j的样本相似度,ppmi
ij
的值越大表明相似度越高。
[0052]
步骤s260,将源域和目标域中的样本数据输入到域自适应网络中进行训练,域自适应网络是基于图卷积网络的无监督域自适应网络,域自适应网络包括:跨域特征提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型。
[0053]
步骤s280,训练域自适应网络不断更新迭代域自适应网络中的参数,当域自适应网络达到收敛条件时,获得域自适应分类模型。
[0054]
步骤s300,输入待分类数据至域自适应分类模型进行分类,获得待分类数据的分
类结果。
[0055]
其中,待分类数据是需要进行分类的数据,待分类数据可以是很多个数据,也可以是一个数据,如:需要对某一篇论文属于哪个学科进行分类,将该论文输入域自适应分类模型,输出该论文所属于的学科。
[0056]
上述基于图卷积网络的无监督域自适应的分类方法,通过获取源域中的样本数据和目标域中样本数据作为训练数据;根据源域和目标域中样本数据间的相似性分别更新两个域中样本的图连接关系;将源域和目标域中的样本数据输入到域自适应网络中进行训练,域自适应网络是基于图卷积网络的无监督域自适应网络,域自适应网络包括:跨域特征提取模型、源域特征提取模型、分类模型、域对抗鉴别模型、类对齐模型;训练域自适应网络不断更新迭代域自适应网络中的参数,当域自适应网络达到收敛条件时,获得域自适应分类模型;输入待分类数据至域自适应分类模型进行分类,获得待分类数据的分类结果。提高了基于图卷积的无监督域自适应分类模型性能,进一步提高对数据分类的准确性。
[0057]
在一个实施例中,跨域特征提取模型提取源域和目标域公共的样本特征,源域特征提取模型提取源域特定的样本特征,分类模型计算分类损失值,域对抗鉴别模型计算域特征对齐损失值,类对齐模型计算类特征对齐损失值。
[0058]
在一个实施例中,总损失值为特征差异性损失值、分类损失值、域特征对齐损失值和类特征对齐损失值的和,其中,特征差异性损失值是源域的样本数据输入跨域特征提取模型和源域特征提取模型得到的特征差异,分类损失值是基于源域的样本数据输入分类模型的。
[0059]
如图2所示,在一个实施例中,域自适应网络的构建方式包括:源域的样本数据和目标域的样本数据输入到跨域特征提取模型中得到源域和目标域的公共嵌入特征表示;源域的样本数据输入到源域特征提取模型中得到源域的特定嵌入特征表示;计算源域的公共嵌入特征表示和特定嵌入特征表示的差异性构建特征差异性损失函数;将目标域的样本数据输入到源域特征提取模型中得到带源域风格的目标域嵌入特征表示,再通过注意力机制与目标域的公共嵌入特征表示结合为目标域的嵌入特征表示;同时通过注意力机制将源域的公共嵌入特征表示与源域的特定嵌入特征表示结合为源域的嵌入特征表示;将得到的源域和目标域的嵌入特征表示输入到分类模型中,源域的嵌入特征表示中有类别标签的部分构建分类损失函数,源域的其余无类别标签的部分和目标域的嵌入特征表示则生成其特征表示所对应的伪类别标签;将得到的源域和目标域的公共嵌入特征表示输入到域对抗鉴别模型中,构建域特征对齐损失函数;将源域的样本数据和目标域的样本数据按照类别标签和伪类别标签中的类别进行分组,同时将不同分组的样本的嵌入特征表示输入到类对齐模型中,构建类特征对齐损失函数。
[0060]
其中,利用跨域特征提取模型得到源域和目标域的公共嵌入特征表示,再分别通过源域特征提取模型得到源域的特定嵌入特征表示和带源域风格的目标域的特定嵌入特征表示,通过注意力机制分别融合公共嵌入特征表示和特定嵌入特征表示,得到源域的嵌入特征表示和目标域的嵌入特征表示,从而混淆源域和目标域以缩小两个域的分布差异。将有源域中有类别标签的嵌入特征表示去训练分类模型,同时给两个域的无类别标签的样本生成伪类别标签。通过特征差异性损失函数使得公共嵌入特征表示和特定嵌入特征表示互斥,还分别设计了域特征对齐损失函数和类特征对齐损失函数消除域分布差异和相同类
的分布差异,在仅依赖源域少量标签样本的情况下提高了域自适应任务的准确性。
[0061]
在一个实施例中,跨域特征提取模型是由两层的图卷积神经网络(gcn)的共享网络组成的,源域的样本数据与目标域的样本数据都输入到共享网络中得到其公共嵌入特征表示;源域特征提取模型是由两层的图卷积神经网络模型组成的。
[0062]
其中,图卷积神经网络提取模型不同域样本的嵌入特征表示,挖掘样本之间的连接关系,促进了样本之间的信息传递。其源域的样本数据与目标域的样本数据的公共嵌入特征表示计算公式如下:
[0063][0064][0065]
其中,as为源域中样本之间的图连接关系,xs为源域中的样本数据,θ0为图卷积神经网络第一层的网络参数,θ1为图卷积神经网络第二层的网络参数,a
t
为目标域中样本的图连接关系,x
t
为目标域中样本数据,表示目标域的公共嵌入特征表示,表示源域的公共嵌入特征表示。
[0066]
源域特征提取模型是由两层的图卷积神经网络(gcn)构成的,其源域样本输入到该模型中得到源域特定嵌入特征表示目标域输入到该模型中得到带源域风格的目标域特定嵌入特征表示
[0067]
在一个实施例中,特征差异性损失函数为:
[0068][0069]
式中,表示源域的公共嵌入特征表示,表示源域的特定嵌入特征表示,lm表示特征差异性损失函数,t表示转置运算。
[0070]
其中,源域的嵌入特征表示zs是由注意力机制将源域的特定嵌入特征表示和源域的公共嵌入特征表示结合起来,目标域的嵌入特征表示z
t
是由注意力机制将带源域风格的目标域特定嵌入特征表示和目标域公共嵌入特征表示结合起来。其中,注意力机制的计算方法如下:
[0071][0072][0073]
式中,w1和w2为列向量,且w1+w2=1。
[0074]
在一个实施例中,分类损失函数为:
[0075][0076]
式中,表示源域的有类别标签样本的嵌入特征表示,为分类模型测得的分类结果,为源域属于第k类的类别标签,k∈[1,c],c为样本的总类别数,n
sl
为源域中有类别标签的样本个数,ls表示分类损失函数。
[0077]
在一个实施例中,域特征对齐损失函数为:
[0078][0079]
式中,z
ci
表示第i个公共嵌入特征表示,gd(zi)为域对抗鉴别模型测得的结果,为输入的公共特征表示属于源域的域标签还是目标域的域标签,ns为源域的样本总个数,n
t
为目标域的样本总个数,ld表示域特征对齐损失函数。
[0080]
其中,域标签是用于标识公共特征表示属于哪一个域的标识。
[0081]
在一个实施例中,类特征对齐损失函数为:
[0082][0083]
式中,lc表示类特征对齐损失函数,表示类别标签或者伪类别标签为第k类样本的源域嵌入特征表示,表示伪类别标签为第k类样本的目标域嵌入特征表示,和分别为和的概率分布,c为样本的总类别数。
[0084]
在一个实施例中,不断更新迭代域自适应网络中的参数的表达式为:
[0085]
min(ls+λl
m-βld+γlc)
[0086]
式中,lm表示特征差异性损失函数,ls表示分类损失函数,ld表示域特征对齐损失函数,lc表示类特征对齐损失函数,λ,β和γ分别为对应损失函数之间的平衡因子。
[0087]
上述基于图卷积网络的无监督域自适应的分类方法,使用图卷积神经网络提取不同域样本的嵌入特征表示,挖掘样本之间的连接关系,促进了样本之间的信息传递。其次将目标域经过源域特征提取模型得到带源域风格的特定嵌入特征表示,再利用差异性损失使得公共嵌入特征表示与特定的嵌入特征表示不相关性。本发明中通过对抗机制设立域对抗鉴别模型最大化域分类损失消除了域之间公共嵌入特征的分布差异。通过注意力机制融合公共和特定的嵌入特征表示为源域嵌入特征表示和目标域嵌入特征表示,同时设立分类模型对有类别标签样本进行分类计算分类损失和给无类别标签样本打上伪类别标签,其中分类损失确保分类模型的有效性。最后在发明中设立类对齐模型消除同类样本不同域之间的分布差异,在类别级别上对齐两个域的样本分布。进一步有效的提升了基于图卷积的无监督域自适应分类模型的性能。
[0088]
应该理解的是,虽然图1的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行。而且,图1中的至少一部分步骤可以包括多个子步骤或者多个阶段,这些子步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些子步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤的子步骤或者阶段的至少一部分轮流或者交替地执行。
[0089]
以上实施例的各技术特征可以进行任意的组合,为使描述简洁,未对上述实施例中的各个技术特征所有可能的组合都进行描述,然而,只要这些技术特征的组合不存在矛盾,都应当认为是本说明书记载的范围。
[0090]
以上所述实施例仅表达了本技术的几种实施方式,其描述较为具体和详细,但并不能因此而理解为对发明专利范围的限制。应当指出的是,对于本领域的普通技术人员来
说,在不脱离本技术构思的前提下,还可以做出若干变形和改进,这些都属于本技术的保护范围。因此,本技术专利的保护范围应以所附权利要求为准。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1