基于多阶段特征融合的知识蒸馏方法、设备及介质

文档序号:37194366发布日期:2024-03-01 13:07阅读:17来源:国知局
基于多阶段特征融合的知识蒸馏方法、设备及介质

本发明属于计算机,尤其涉及基于多阶段特征融合的知识蒸馏方法、设备及介质。


背景技术:

1、近年来,深度学习领域中的卷积神经网络极大地促进了计算机视觉的发展,并在分类识别、目标检测和目标分割等方面得到了广泛的应用。然而,受限于边缘计算设备的算力约束和显存限制,卷积网络大模型难以部署应用。如何平衡计算开销和模型性能仍是一个十分具有挑战性的问题,而知识蒸馏是一种有效解决方法。知识蒸馏通过教师教授学生的形式将隐含知识从教师网络大模型传递到学生网络小模型上,进而大幅提升学生模型性能。该方法简单有效,被广泛应用于卷积网络和视觉任务上。

2、常见的知识蒸馏方法一般分为两类,一类是基于软标签分类知识,该类方法使用不同的温度软化教师网络和学生网络输出的分类标签,由此通过缩减两者软化标签间的最终分类知识差异性,来提升学生网络的识别精度。另一类是基于中间层特征,一般而言,教师网络与学生网络的结构和学习过程有一定的相似性,学生网络可以学习教师网络在中间特征层中隐含的知识,取得更优的学习过程,实现自身的精度提升效果。

3、教师网络与学生网络在相同阶段之间的特征分布往往存在较大的区别,而且同一网络在不同阶段之间的特征分布也各有侧重点,深层特征注重概念信息,浅层特征注重纹理信息,这带来了特征知识分布差异性问题,导致学生网络难以直接学习教师网络的特征隐含知识。


技术实现思路

1、本发明所解决的技术问题在于提供一种基于多阶段特征融合的知识蒸馏方法、设备及介质,以解决由于教师网络与学生网络之间存在特征分布差异性,导致学生网络难以充分学习教师网络中间层特征隐含知识的问题。

2、本发明提供的基础方案:基于多阶段特征融合的知识蒸馏方法,包括:

3、s1:获取原始数据集,并对原始数据集进行预处理;

4、s2:采用原始数据集训练教师网络模型,获取训练完成的教师网络模型;

5、s3:冻结训练完成的教师网络模型的预训练权重,使用多阶段特征融合框架、跨阶段特征融合注意力模块和相同阶段融合特征对比损失函数训练学生网络模型,生成训练完成的学生网络模型;

6、s4:运行训练完成的学生网络模型,并在推理阶段只保留学生网络架构。

7、进一步,所述s3包括:

8、s3-1:冻结训练完成的教师网络模型的预训练权重;

9、s3-2:将教师网络模型和学生网络模型均使用多阶段特征融合框架,并构建学生网络特征融合和教师网络特征融合;

10、s3-3:通过跨阶段特征融合注意力模块对学生网络模型进行训练;

11、s3-4:构建相同阶段融合特征对比损失函数对s3-3训练完成的学生网络模型进行损失验证。

12、进一步,所述s3-2具体为:

13、定义教师网络模型为t,学生网络模型为s,教师网络模型t和学生网络模型s均包括n个特征输出阶段和n个对应的融合模块ffai,i∈n,其对应的第i层特征分别为ti和si;

14、设置首个融合模块具有一个输入入口,其余均有两个输入入口,且最后一个融合模块有一个输出出口,其余均有两个输出出口,记第i层融合输出特征为和

15、构建学生网络特征融合和教师网络特征融合,所述学生网络特征融合计算公式为:

16、

17、所述教师网络特征融合的计算公式为:

18、

19、进一步,所述s3-3具体为:

20、跨阶段特征融合注意力模块中,包括两个不同阶段特征i1和i2,且i1和i2的尺寸和通道数均不同;

21、通过卷积和归一化处理将输入特征i1的尺寸和通道数调整为与输入特征i2相一致,并相加得到初步融合特征i;

22、通过并联的通道注意力机制ac和空间注意力机制as进行处理,将并联结果相加得到融合特征f;

23、再通过卷积和归一化处理,分别生成尺寸和通道数一般不同的两个输出特征f1和f2,并且融合特征模块ffa1输入特征只有i1,融合特征模块ffan输出特征只有f1。

24、进一步,所述跨阶段特征融合注意力模块的表达公式为:

25、(f1,f2)=as(s(i1)+i2)+ac(s(i1)+i2)。

26、进一步,所述s3-4具体为:

27、构建相同阶段融合特征对比损失函数;

28、通过相同阶段融合特征对比损失函数分别将教师网络模型和学生网络模型的第i阶段融合特征对应的tfi和sfi按照预设的处理做lmse相似度匹配;

29、结合真实标签与学生分类结果的交叉熵损失函数,以及权重调节超参数,构建完整损失函数,对学生网络模型进行损失验证;

30、所述按照预设的处理具体为:

31、不做处理,保留tfi和sfi;

32、不改变特征空间尺寸,在通道上进行压缩处理,得到tfi1和sfi1;

33、不改变通道数、在空间上进行压缩,得到tfi2和sfi2,结合权重调节超参数λ,构成n个阶段融合特征对比函数。

34、进一步,所述相同阶段融合特征对比损失函数的计算公式为:

35、lscm=lmse(tfi,sfi)+λlmse(tfi1,sfi1)+λlmse(tfi2,sfi2)

36、其中,lscm表示相同阶段融合特征对比损失函数,λ表示权重调节超参数;

37、所述完整损失函数的计算公式为:

38、ltotal=lce+αlscm

39、其中,ltotal表示完整损失函数,lce表示交叉熵损失函数,α表示完整损失函数对应的权重调节超参数。

40、进一步,所述s4中在推理阶段只保留学生网络架构具体为:在学生网络模型推理阶段,剪去教师网络模型和多阶段特征融合框架,只保留学生网络架构部分。

41、一种电子设备,包括处理器和存储器,所述存储器中存储程序或指令,所述处理器通过调用所述存储器存储的程序或指令,执行如上所述的基于多阶段特征融合的知识蒸馏方法。

42、一种计算机可读存储介质,所述计算机可读存储介质存储程序或指令,所述程序或指令使计算机执行如上所述的基于多阶段特征融合的知识蒸馏方法。

43、本发明的原理及优点在于:在本申请中,通过多阶段特征融合框架,分别在教师网络和学生网络实现特征知识从浅层到深层的跨阶段知识传递,进而可以让学生网络的单一阶段从教师的不同阶段学习特征隐含知识,增强学生模型的泛化性和学习能力。通过跨阶段特征融合注意力模块可以实现相邻阶段特征之间的有机融合和有益知识增强,再搭配上相同阶段融合特征之间的空间和通道对比损失函数,可以让学生网络从通道和空间两个角度来学习教师网络的特征和对比两者之间的特征差异性,实现学生模型的进一步效果提升,并增强其模型泛化性。



技术特征:

1.基于多阶段特征融合的知识蒸馏方法,其特征在于:包括:

2.根据权利要求1所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述s3包括:

3.根据权利要求2所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述s3-2具体为:

4.根据权利要求3所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述s3-3具体为:

5.根据权利要求4所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述跨阶段特征融合注意力模块的表达公式为:

6.根据权利要求5所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述s3-4具体为:

7.根据权利要求6所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述相同阶段融合特征对比损失函数的计算公式为:

8.根据权利要求7所述的基于多阶段特征融合的知识蒸馏方法,其特征在于:所述s4中在推理阶段只保留学生网络架构具体为:在学生网络模型推理阶段,剪去教师网络模型和多阶段特征融合框架,只保留学生网络架构部分。

9.一种电子设备,其特征在于:包括处理器和存储器,所述存储器中存储程序或指令,所述处理器通过调用所述存储器存储的程序或指令,执行如上权利要求1-7任一项所述的基于多阶段特征融合的知识蒸馏方法。

10.一种计算机可读存储介质,其特征在于:所述计算机可读存储介质存储程序或指令,所述程序或指令使计算机执行如上权利要求1-7任一项所述的基于多阶段特征融合的知识蒸馏方法。


技术总结
本发明属于计算机技术领域,尤其涉及基于多阶段特征融合的知识蒸馏方法、设备及介质,首先获取原始数据集,并对原始数据集进行预处理;然后采用原始数据集训练教师网络模型,获取训练完成的教师网络模型;接着冻结训练完成的教师网络模型的预训练权重,使用多阶段特征融合框架、跨阶段特征融合注意力模块和相同阶段融合特征对比损失函数训练学生网络模型,生成训练完成的学生网络模型;最后运行训练完成的学生网络模型,并在推理阶段只保留学生网络架构。本发明能够解决由于教师网络与学生网络之间存在特征分布差异性,导致学生网络难以充分学习教师网络中间层特征隐含知识的问题。

技术研发人员:李刚,王坤,徐传运,何攀,阮子涵,吕鹏飞,蒋建忠
受保护的技术使用者:重庆理工大学
技术研发日:
技术公布日:2024/2/29
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1