本发明涉及计算机视觉领域,特别地,涉及一种基于任务自适应知识蒸馏的目标检测方法。
背景技术:
1、目标检测作为一种从给定图像中定位所有目标并给出目标的类别的技术,是计算机视觉领域最基础和最重要的任务之一。目标检测技术在实际生活中有着广泛的应用,例如智能安防、自动驾驶、医疗诊断、工业检测等领域。
2、知识蒸馏作为一种模型压缩技术,旨在通过大模型向小模型传递知识,以提升小模型的性能。现有的目标检测知识蒸馏方法多关注特征上的知识蒸馏,这种方式没有考虑到不同架构检测器在特征上本身存在的差异,因此在实际场景特别是异构检测器知识蒸馏中效果不佳。
3、基于响应的目标检测知识蒸馏方案在结果层面上进行知识传递,可以不受检测器架构的限制。但由于主流方案在应用基于响应的目标检测知识蒸馏方案时没有考虑到目标检测本身的任务特性,导致了其性能较差,应用不广泛。
技术实现思路
1、针对以上问题,本发明提供了一种基于任务自适应知识蒸馏的目标检测方法。本发明具体采用的技术方案如下:
2、一种基于任务自适应知识蒸馏的目标检测方法,其包括以下步骤:
3、s1、获取用于训练目标检测器的训练数据集,所述训练数据集中的图像样本预先标注有图像中目标物体的位置及分类;
4、s2、使用所述训练数据集训练一个教师目标检测器,用于生成辅助学生目标检测器的位置和分类软标签;
5、s3、利用使用所述训练数据集训练学生目标检测器,且在训练过程中输入教师目标检测器提供的位置和分类软标签,在学生目标检测器的分类头上计算基于二元交叉熵的分类蒸馏损失函数,在学生目标检测器的定位头上计算基于交并比的定位蒸馏损失函数,将分类蒸馏损失函数和定位蒸馏损失函数加入学生目标检测器自身的目标损失函数上作为总损失函数进行优化,从而通过教师目标检测器提供的软标签信息增强学生目标检测器在目标检测中的性能表现;
6、s4、利用训练完成后的学生目标检测器对待检测图像进行目标检测。
7、作为优选,所述s1中的训练数据集包括多张图像样本,且每张图像样本预先标注有目标物体的位置及分类标签{oi,ci}。
8、作为优选,所述s2中,教师目标检测器的训练包括以下子步骤:
9、s21、将图像样本i输入教师目标检测器中,教师目标检测器由特征提取器f(·)和检测头h(·)组成,检测器首先从图像样本i中提取特征f=f(i),然后将特征输入检测头生成最终预测p=h(f),最终预测p中包含图像中目标的位置及分类预测值;
10、s22、根据最终预测p和输入图像样本i对应的标签{oi,ci},计算损失函数值,并利用反向传播算法更新教师目标检测器的特征提取器f(·)和检测头h(·);
11、s23、不断重复s21和s22,直到达到训练轮次上限或检测器的损失函数已经收敛后停止训练,得到训练后的教师目标检测器。
12、作为优选,所述s3中,基于二元交叉熵的分类蒸馏损失函数的计算方式如下:
13、s311、针对训练过程中的每一个图像样本x,分别将其输入学生目标检测器和训练后的教师目标检测器中,各自经过特征提取器f(·)和检测头h(·)生成分类谱分类谱l中包含n个特征图位置上的k维分类逻辑值,其中n=h×w,h、w和k分别表示特征提取器提取的特征图高度、特征图宽度和分类的类别数;
14、s312、对教师目标检测器和学生目标检测器各自得到的分类谱分别执行归一化操作,将分类谱中每个位置的k维分类逻辑值转换为k维分类分数,从而将两个分类谱转换为分类分数矩阵:
15、pt′=rrotsig(lt)
16、ps′=protsig(ls)
17、式中:protsig()表示sigmoid归一化操作,lt和ls分别为教师目标检测器和学生目标检测器各自得到的分类谱l;
18、s313、依据s312中得到的分类分数,计算分类谱中每个位置(i,j)处的分类蒸馏损失,计算式为:
19、
20、式中:分别表示分类分数矩阵pt′和ps′中位置(i,j)处的值;
21、s314、根据教师目标检测器和学生目标检测器的分类分数差值矩阵w,加权计算各点位分类蒸馏的值,从而得到基于二元交叉熵的分类蒸馏损失,分类蒸馏损失函数形式为:
22、w=|pt′-ps′|,
23、
24、式中:wi,j表示分类分数差值矩阵w中位置(i,j)处的值,||表示取绝对值操作。作为优选,所述s3中,基于交并比的定位蒸馏损失函数的计算方式如下:
25、s321、针对训练过程中的每一个图像样本x,分别将其输入学生目标检测器和训练后的教师目标检测器中,各自经过特征提取器f(·)和检测头h(·)生成定位谱定位谱o中含有含n个特征图位置上的定位预测值,其维度为4;
26、s322、将教师目标检测器的定位预测值通过其自身的解码器解码为回归框:
27、
28、其中ai表示第i个锚点,表示教师目标检测器的第i个定位预测值,表示教师目标检测器的第i个回归框;
29、s323、将学生目标检测器的定位预测值解码为回归框:
30、
31、其中ai表示第i个锚点,表示学生目标检测器的第i个定位预测值,表示教师目标检测器的第i个回归框;
32、s324、计算和之间的交并比:
33、
34、s325、依据s324中计算得到的两个检测器各自定位回归框的交并比,计算基于交并比的定位蒸馏损失,分类蒸馏损失函数形式为:
35、
36、其中max(w.,j)是分类分数差值矩阵w的第j列w.,j中的最大值。
37、作为优选,所述s3中,学生目标检测器训练时的总损失函数形式如下:
38、
39、其中:和分别表示学生目标检测器自身的分类损失函数和定位损失函数,α1和α2是两个超参数,分别表示分类蒸馏损失和定位蒸馏损失的损失权重。
40、作为优选,权重值α1=1.0,权重值α2=4.0。
41、本发明提出了更适应目标检测任务的知识蒸馏方案,对分类头和定位头各自进行了针对性的蒸馏损失函数设计,能够给目标检测器带来更好的知识蒸馏性能增益。相比于传统目标检测知识蒸馏方案,本发明具有如下有益效果:
42、首先,本发明提出了一种基于任务自适应知识蒸馏的目标检测的可行方案。
43、其次,本发明充分考虑了目标检测任务特性,并针对性地设计了在分类头和回归头上的适应任务的知识蒸馏损失函数。
44、最后,本发明提出的目标检测知识蒸馏方案通用性广,能够简单地扩展到异构目标检测器到知识蒸馏方案上,并且能获得良好的效果。
1.一种基于任务白适应知识蒸馏的目标检测方法,其特征在于包括以下步骤:
2.根据权利要求1所述的一种基于任务白适应知识蒸馏的目标检测方法,其特征在于所述s1中的训练数据集包括多张图像样本,且每张图像样本预先标注有目标物体的位置及分类标签{oi,ci}。
3.根据权利要求1所述的一种基于任务白适应知识蒸馏的目标检测方法,其特征在于所述s2中,教师目标检测器的训练包括以下子步骤:
4.根据权利要求1所述的一种基于任务白适应知识蒸馏的目标检测方法,其特征在于所述s3中,基于二元交叉熵的分类蒸馏损失函数的计算方式如下:
5.根据权利要求4所述的一种基于任务白适应知识蒸馏的目标检测方法,其特征在于所述s3中,基于交并比的定位蒸馏损失函数的计算方式如下:
6.根据权利要求1所述的一种基于任务白适应知识蒸馏的目标检测方法,其特征在于所述s3中,学生目标检测器训练时的总损失函数形式如下:
7.根据权利要求6所述的一种基于任务白适应知识蒸馏的目标检测方法,其特征在于,权重值α1=1.0,权重值α2=4.0。