基于Transformer的跨域双分支对抗域适应图像分类方法

文档序号:35421939发布日期:2023-09-13 09:14阅读:101来源:国知局
基于Transformer的跨域双分支对抗域适应图像分类方法

本发明涉及图像分类技术,特别涉及基于transformer的跨域双分支对抗域适应图像分类方法。


背景技术:

1、近年来,随着计算机算力提升和深度学习算法的发展,计算机在图像分类、目标检测、语义分割等多种视觉任务上都取得了巨大的提升,在一些任务中甚至超越了人类。目前深度学习算法已经被普遍运用在在线商务、深度翻译、语音识别、自动驾驶以及计算机辅助诊断等多个领域。

2、尽管深度学习在众多应用中取得了成功,但其出色的性能在很大程度上取决于大量的标记数据,然而在许多现实应用中,收集足够的带标记训练数据通常需要耗费大量的人力、物力和时间成本。除此之外,深度学习算法需要满足训练数据和测试数据独立同分布的假设,即真实采集的数据和训练用的训练数据从分布上一致。然而在深度学习被使用的诸多领域中,数据独立同分布的假设往往并不成立,受到分辨率、光照、视点、背景、天气条件等因素的影响等外部因素的影响,真实的测试数据往往无法满足一致性形成域偏移。域偏移的现象在日常应用中非常普遍,例如在人体姿态识别场景下,在室内采集的图像和室外采集的图像从数据分布上差异巨大,导致用室内标注数据训练的模型,对室外场景下人体姿态的识别能力大幅度下降。数据分布存在差异导致传统的深度学习算法训练得到的模型往往不能在相似的新领域中取得预期结果,这限制了深度学习模型的泛化能力和知识复用能力。

3、领域适应任务常用卷积神经网络(convolutional neural network,cnn)作为特征提取器,其分层设计可以提取丰富的抽象语义信息,且具有平移不变性与局部敏感性等优点。然而,受限于感受野范围,cnn无法充分利用上下文信息,缺乏长距离关系建模能力。由于transformer模型具有强大的上下文关系建模能力,将图像变为序列再使用该模型处理,可有效解决cnn中存在的长距离依赖问题。但全局具有二次计算复杂度,给硬件带来很大的计算压力,图像训练和推理效率较低。

4、近年来,transformer已用于域适应图像分类的研究。可迁移视觉transformer(transferable vision transformer,tvt),在视觉transformer(vit)的基础上将其最后一层改进为可迁移适配模块,使注意力集中在可迁移和可区分特征上,首次研究了vit跨域知识迁移能力。交叉域transformer(cross-domain transformer,cdtrans),将vit改为三分支结构分布学习源域、目标域域内和域间特征。双向交叉注意力transformer(bidirectional cross-attention transformer,bcat)使用四分支结构,在cdtrans的基础上使用两个对齐分支取得了更好的效果。

5、上述方法均基于transformer模型,并取得了很好的适应效果,但相关仍有待改进之处。然而上述方法算法中设计的分支过多或计算复杂度过高,都给硬件带来了巨大的计算代价,导致处理高分辨率图像时效率较低。对齐特征时使用单一域查询的交叉注意力难以准确在域间进行关系建模,不利于域间知识迁移。


技术实现思路

1、发明目的:针对以上问题,本发明目的是提供一种基于transformer的跨域双分支对抗域适应图像分类方法,将特征提取器精简为并行可交互的双分支结构,减少计算量,提升了图像训练和推理的效率;针对单一域查询的交叉注意力在域差异大的任务中表现不佳的问题,设计了跨域融合模块,将源域和目标域的查询融合为跨域统一查询量,平滑域间分布差异,融合两个域的查询信息,用于交叉注意力在域间进行关系建模,有利于迁移知识。

2、技术方案:本发明的一种基于transformer的跨域双分支对抗域适应图像分类方法,包括以下步骤:

3、步骤1,分别读取源域和目标域的图像,并对图像进行预处理;

4、步骤2,构建基于transformer的跨域双分支对抗域适应网络模型;其中跨域双分支对抗域适应模型包括双分支特征提取器、分类器和域判别器,双分支特征提取器包括域内特征提取模块和跨域融合模块;

5、步骤3,利用预处理后的源域和目标域图像对跨域双分支对抗域适应网络模型进行训练,将训练后的跨域双分支对抗域适应网络模型作为图像分类模型;

6、步骤4,将源域图像和目标域图像配对后分别输入至图像分类模型的双分支特征提取器中提取特征,将提取的特征输入至图像分类模型的分类器中,利用分类器预测目标域中图像的标签类别。

7、进一步,利用预处理后的源域和目标域图像对跨域双分支对抗域适应网络模型进行训练,具体包括以下子步骤:

8、步骤31,将预处理后的源域图像和目标域图像分别输入至域内特征提取模块中,利用双分支结构并行处理源域图像和目标域图像,分别获得不同层次的多尺度特征;

9、步骤32,将源域的多尺度特征和目标域的多尺度特征输入至跨域融合模块,利用跨域融合模块做分支间交互对齐域间特征,输出源域特征向量和目标域特征向量;

10、步骤33,利用有标签的源域特征向量采用标准的监督式交叉熵损失训练分类器;

11、步骤34,使用域标签训练判别器,反向传播过程经过梯度反转层,与双分支特征提取器进行最小最大博弈,直至分类损失收敛,结束训练。

12、进一步,步骤32具体包括:

13、将源域注意力模块的查询量与目标域注意力模块的查询量加权融合,获得域间共享查询量qf,表达式为:

14、qf=αqs+(1-α)qt

15、式中,α表示融合系数,qs表示源域的查询量,qt表示目标域的查询量;

16、将域间共享查询量作为两并行计算分支注意力的统一查询量,计算该域间统一查询量分别在源域和目标域键向量上的分布,建立源域和目标域之间的相关关系,表达式为:

17、

18、式中,attns表示源域注意力,attnt表示目标域注意力,ks和vs分别表示源域的键和值,kt和vt分别表示目标域的键和值,d表示查询和键向量的维度;

19、使用cdf在两分支间做数据交互,学习源域与目标域统一查询量和域不变特征,表达式为:

20、

21、

22、

23、

24、式中,和分别表示源域、目标域基于窗口的多头自注意力模块的输出,和分别表示源域、目标域l层的输出,cdf为跨域融合注意力机制,mlp为多层感知机,ln为层归一化。

25、进一步,步骤34具体包括:

26、优化判别器参数最小化判别损失,优化双分支特征提取器最大化判别损失,目标函数如下:

27、

28、式中,θf、θy、θd分别表示双分支特征提取器gf、分类器gy和域判别器gd的参数,lcls表示分类器损失,ladv表示域判别器损失;权重系数λ∈[0,1),迭代更新方式如下:

29、

30、其中,λ随训练过程逐渐增大,u表示当前迭代次数相对总迭代次数占比。

31、进一步,预处理包括分别对源域图像和目标域图像进行随机剪裁、随机翻转、随机遮挡以及亮度增强。

32、有益效果:本发明与现有技术相比,其显著优点是:本发明提出基于transformer的跨域双分支对抗域适应图像分类方法,用于解决无监督域适应图像分类问题,提出了双分支特征提取器,精简了计算分支,建立并行可交互的双分支结构,并引入局部注意力取代基于vit的域适应方法,将计算复杂度由二次降低至线性,大幅缓解了硬件的计算代价,提升了高分辨率场景图像的训练和推理效率;本发明相较主流域适应方法具有更高的精度,领域差异大的目标域有较强的适应性;通过训练时间对比试验,可以看出本发明方法的训练效率更高,以更短的训练时间获取了更好的适应效果。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1