本公开涉及人工智能,具体而言,涉及一种网络训练方法、装置、电子设备及存储介质。
背景技术:
1、知识蒸馏也称为师生学习,是一种有效的模型压缩和模型精度提升的技术,通过知识蒸馏可以将知识从容量更高的教师模型转移到可部署性更强、容量较小的学生模型,进而来提升学生模型的性能。
2、经研究发现,针对密集视觉检测任务,由于密集视觉检测任务对于图像的定位信息更加敏感,目前的知识蒸馏方法主要基于对教师特征图的模仿。然而,该基于特征图的知识蒸馏,通常将完整的图像输入学生网络,然后进行逐像素一对一的空间模仿,该模仿过程相对简单,导致学生模型的学习能力得不到较好的挖掘,学生模型的性能较差。
技术实现思路
1、本公开实施例至少提供一种网络训练方法、装置、电子设备及存储介质,能够提升学生网络的性能。
2、公开实施例提供了一种网络训练方法,包括:
3、获取训练样本图像以及与所述训练样本图像对应的掩码样本图像;其中,所述掩码样本图像由所述训练样本图像以及目标掩码图像生成;
4、将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图;
5、将所述样本图像输入至第二网络,并基于所述第二网络的第二特征提取网络对所述样本图像进行特征提取,得到与所述样本图像对应的第二多层级特征图;所述第二网络的规模大于所述第一网络的规模;
6、基于解码器以及与所述第一多层级特征图中每个特征图分别对应的掩码图像,对所述第一多层级特征图进行特征恢复处理,得到恢复处理后的第一多层级特征图;其中,所述第一多层级特征图中每个特征图对应的掩码图像由所述目标掩码图像分别进行缩放处理后得到;
7、基于所述恢复处理后的第一多层级特征图以及所述第二多层级特征图对所述第一网络进行训练,并重复上述步骤,直到所述第一网络的训练结果符合预设要求,得到训练好的第一网络。
8、本公开实施例中,第一网络也称学生网络,第二网络也称教师网络,在基于特征进行知识蒸馏的过程中,通过对训练样本图像进行掩码处理,并通过模仿第二网络的输出的第二层级特征图恢复被掩码区域对应的特征,进而可以增加特征模仿的难度,也即,在不改变第一网络的网络结构的前提下,通过单独的解码器对蒸馏过程进行增强,进而可以提升第一网络的学习能力,如此,即使在输入的待检测图像存在部分被遮盖情况下,通过训练好的第一网络也能够进行检测识别,从而提升了第一网络的检测性能。
9、在一种可能的实施方式中,所述第一特征提取网络包括呈金字塔结构的特征提取模块,所述将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图,包括:
10、将所述掩码样本图像输入至第一网络,并基于所述特征提取模块对所述掩码样本图像进行特征提取,得到中间多层级特征图,并将所述中间多层级特征图作为所述第一多层级特征图;其中,所述中间多层级特征图包括多张尺寸不同的中间特征图。
11、本公开实施例中,由于所述特征提取模块呈金字塔结构设计,导致所述提取出来的第一多层级特征图包括多张尺寸不同的中间特征图,进而可以适用于各类密集视觉检测任务,如目标检测、实例分割和语义分割等。
12、在一种可能的实施方式中,所述第一特征提取网络包括呈金塔结构的特征提取模块以及多个掩码卷积模块;所述将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图,包括:
13、将所述掩码样本图像以及所述目标掩码图像输入至第一网络,并基于所述特征提取模块对所述掩码样本图像进行特征提取,得到中间多层级特征图,其中,所述中间多层级特征图包括多张尺寸不同的中间特征图;
14、基于所述多个掩码卷积模块以及所述目标掩码图像,对所述中间多层级特征图中的每个中间特征图进行掩码处理,并将掩码处理后的中间多层级特征图,作为所述第一多层级特征图。
15、本实施方式中,除了能够提升第一网络的适用性之外,通过对每个中间特征图进行掩码处理后,可以避免掩盖区域和可见区域之间的混淆的特征交互,也即,由于第一网络中的用于特征提取的骨干网络中使用了掩码卷积,进而可以防止卷积过程中被掩盖的图像块受到其他可见图像块的影响,有助于进一步提升第一网络的性能表现。
16、在一种可能的实施方式中,所述基于所述多个掩码卷积模块以及所述目标掩码图像,对所述中间多层级特征图中的每个中间特征图进行掩码处理,包括:
17、根据所述中间多层级特征图中的每个中间特征图的尺寸大小,分别对所述目标掩码图像进行缩放处理,得到与所述中间多层级特征图中每个中间特征图分别对应的掩码图像;
18、针对每个中间特征图,基于所述掩码卷积模块对所述中间特征图以及与所述中间特征图对应的掩码图像进行点乘处理,得到掩码处理的中间特征图,并基于每个掩码处理的中间特征图得到所述掩码处理后的中间多层级特征图。
19、本实施方式中,在所述掩码卷积模块中对所述中间特征图以及与所述中间特征图对应的掩码图像进行点乘处理,即可得到掩码处理的中间特征图,进而可以提升对中间特征图进行掩码处理的效率。
20、在一种可能的实施方式中,所述解码器包括空间对齐模块、解码模块以及空间恢复模块,所述基于解码器以及与所述第一多层级特征图中每个特征图分别对应的掩码图像,对所述第一多层级特征图进行特征恢复处理,包括:
21、基于所述空间对齐模块将所述第一多层级特征图中的各个尺寸不同的特征图对齐到相同的空间分辨率,以使得所述第一多层级特征图中的各个特征图的尺寸大小对齐,得到空间对齐的多层级特征图;
22、基于所述空间对齐的多层级特征图分别对应的掩码图像,采用掩码标记替换所述空间对齐的多层级特征图中的掩码区域,得到带有掩码标记的多层级特征图,并基于所述解码模块对所述带有掩码标记的多层级特征图进行特征预测处理,得到特征预测处理的多层级特征图;
23、基于所述空间恢复模块将相同的空间分辨率的所述特征预测处理的多层级特征图恢复成原始尺寸的多层级特征图,得到所述恢复处理后的第一多层级特征图。
24、本实施方式中,通过对第一多层级特征图进行空间对齐后进行特征恢复操作,然后再将空间对齐后的特征图恢复至原始尺寸,不仅能够实现特征的恢复,还可以保证第一多层级特征图的尺寸,有助于确定后续的特征恢复损失。
25、在一种可能的实施方式中,所述基于所述空间对齐模块将所述第一多层级特征图中的各个尺寸不同的特征图对齐到相同的空间分辨率,包括:
26、针对所述第一多层级特征图中的每个特征图,将所述特征图与目标图像进行比较;
27、当所述特征图的尺寸大于所述目标图像的尺寸的情况下,对所述特征图进行降维处理,使得所述特征图的尺寸与所述目标图像的尺寸一致;或者,
28、当所述特征图的尺寸小于所述目标图像的尺寸的情况下,采用最近邻插值对所述特征图进行上采样处理,使得所述特征图的尺寸与所述目标图像的尺寸一致。
29、本公开实施例中,将所述第一多层级特征图中的每个特征图都对到目标图像尺寸,并在实现过程中对大于目标图像的特征图以及小于目标图像的特征图分别采用不同的方法,如此,不仅可以保证各个特征图之间的对齐还可以提升空间对齐的效率。
30、在一种可能的实施方式中,所述基于所述空间对齐模块将所述第一多层级特征图中的各个尺寸不同的特征图对齐到相同的空间分辨率之前,所述方法还包括:
31、将所述第一多层级特征图的通道数与所述第二多层级特征图的通道数对齐;和/或,对所述第一多层级特征图以及所述第二多层级特征图进行层归一化处理。
32、本公开实施例中,在将所述第一多层级特征图中的各个特征图对齐之前,还将第一多层级特征图的通道数与所述第二多层级特征图的通道数对齐;和/或,对所述第一多层级特征图以及所述第二多层级特征图进行层归一化处理,有助于提升空间对齐的精度以及效率。
33、在一种可能的实施方式中,所述基于所述空间对齐的多层级特征图分别对应的掩码图像,采用掩码标记替换所述空间对齐的多层级特征图中的掩码区域,得到带有掩码标记的多层级特征图,包括:
34、针对每个空间对齐的特征图,对所述空间对齐的特征图进行展开处理,得到一维的展开特征图,并基于与所述空间对齐的特征图对应的掩码图像,确定所述展开特征图中需要替换的掩码区域;
35、采用所述掩码标记对所述掩码区域进行替换,得到带有所述掩码标记的展开特征图,并基于各个带有所述掩码标记的展开特征图,得到所述带有掩码标记的多层级特征图。
36、本公开实施例中,通过将所述空间对齐的特征图进行展开处理,有助于对掩码区域进行掩码标记替换,进而有助于后续对掩码区域所对应的特征的预测。
37、在一种可能的实施方式中,所述得到带有所述掩码标记的展开特征图之后,所述方法还包括:
38、在所述带有所述掩码标记的展开特征图中加上余弦绝对位置编码,并基于预设绝对尺度通过插值的方式对所述带有所述掩码标记的展开特征图进行自适应调整,得到调整后的展开特征图;
39、所述基于各个带有所述掩码标记的展开特征图,得到所述带有掩码标记的多层级特征图,包括:
40、基于各个调整后的展开特征图,得到所述带有掩码标记的多层级特征图。
41、本公开实施例中,通过在所述带有所述掩码标记的展开特征图中加上余弦绝对位置编码,一方面可以便于对掩码区域的确定以及替换,另一方面还可以便于后续在特征预测后基于展开图得到未展开之前的特征图。
42、在一种可能的实施方式中,所述基于所述恢复处理后的第一多层级特征图以及所述第二多层级特征图对所述第一网络进行训练,包括:
43、确定所述恢复处理后的第一多层级特征图以及所述第二多层级特征图之间的特征恢复损失,并基于所述特征恢复损失对所述第一网络的参数进行调整。
44、本公开实施例中,通过第一多层级特征图以及所述第二多层级特征图之间的特征恢复损失,可以指导第一网络进一步模仿第二多层级特征图,进而有利于提升训练的效率。
45、在一种可能的实施方式中,所述方法还包括:
46、基于所述恢复处理后的第一多层级特征图,确定所述第一网络的任务损失;和/或,
47、通过全局上下文模块确定所述恢复处理后的第一多层级特征图以及所述第二多层级特征图分别对应的第一全局关系以及第二全局关系,并确定第一全局关系与所述第二全局关系之间的全局损失;
48、所述基于所述特征恢复损失对所述第一网络的参数进行调整,包括:
49、基于所述特征恢复损失、所述任务损失和/或所述全局损失,对所述第一网络的参数进行调整。
50、本公开实施例中,除了特征恢复损失之外,还根据所述任务损失和/或所述全局损失来对第一网络的参数进行调整,如此,基于多个损失可以提升对第一网络参数调整的精度,可以进一步提升第一网络的训练效率以及性能表现。
51、在一种可能的实施方式中,所述获取训练样本图像以及与所述训练样本图像对应的掩码样本图像,包括:
52、获取所述训练样本图像,并将所述训练样本图像进行划分,得到非重叠的多个图像块;
53、基于预设掩码比例,以随机采样的方式获取所述目标掩码图像,所述目标掩码图像包括用于指示对应的图像块被掩盖的掩码指示标识;
54、基于所述目标掩码图像中的掩码指示标识,对所述训练样本图像进行掩码处理,得到与所述训练样本图像对应的掩码样本图像。
55、本公开实施例中,在对训练样本图像进行掩码处理的过程中,以随机采样的方式获取目标掩码图像,可以避免在特生提取时对掩码本身特征的学习,有助于提升特征提取的精度。
56、在一种可能的实施方式中,所述方法还包括:
57、获取待检测图像,并基于所述训练好的第一网络对所述待检测图像执行图像检测任务;所述图像检测任务包括目标检测任务、语义分割任务或者实例分割任务。
58、本公开实施例中,基于所述训练好的第一网络对所述待检测图像执行图像检测任务,可以实现各类不同密集视觉任务的检测。
59、本公开实施例提供了一种网络训练装置,包括:
60、图像获取模块,用于获取训练样本图像以及与所述训练样本图像对应的掩码样本图像;其中,所述掩码样本图像由所述训练样本图像以及目标掩码图像生成;
61、第一提取模块,用于将所述掩码样本图像输入至第一网络,并基于所述第一网络的第一特征提取网络对所述掩码样本图像进行特征提取,得到与所述掩码样本图像对应的第一多层级特征图;
62、第二提取模块,用于将所述样本图像输入至第二网络,并基于所述第二网络的第二特征提取网络对所述样本图像进行特征提取,得到与所述样本图像对应的第二多层级特征图;所述第二网络的规模大于所述第一网络的规模;
63、特征预测模块,用于基于解码器以及与所述第一多层级特征图中每个特征图分别对应的掩码图像,对所述第一多层级特征图进行特征恢复处理,得到恢复处理后的第一多层级特征图;其中,所述第一多层级特征图中每个特征图对应的掩码图像由所述目标掩码图像分别进行缩放处理后得到;
64、网络训练模块,用于基于所述恢复处理后的第一多层级特征图以及所述第二多层级特征图对所述第一网络进行训练,并重复上述步骤,直到所述第一网络的训练结果符合预设要求,得到训练好的第一网络。
65、在一种可能的实施方式中,所述第一提取模块具体用于:
66、将所述掩码样本图像输入至第一网络,并基于所述特征提取模块对所述掩码样本图像进行特征提取,得到中间多层级特征图,并将所述中间多层级特征图作为所述第一多层级特征图;其中,所述中间多层级特征图包括多张尺寸不同的中间特征图。
67、在一种可能的实施方式中,所述第一提取模块具体用于:
68、将所述掩码样本图像以及所述目标掩码图像输入至第一网络,并基于所述特征提取模块对所述掩码样本图像进行特征提取,得到中间多层级特征图,其中,所述中间多层级特征图包括多张尺寸不同的中间特征图;
69、基于所述多个掩码卷积模块以及所述目标掩码图像,对所述中间多层级特征图中的每个中间特征图进行掩码处理,并将掩码处理后的中间多层级特征图,作为所述第一多层级特征图。
70、在一种可能的实施方式中,所述第一提取模块具体用于:
71、根据所述中间多层级特征图中的每个中间特征图的尺寸大小,分别对所述目标掩码图像进行缩放处理,得到与所述中间多层级特征图中每个中间特征图分别对应的掩码图像;
72、针对每个中间特征图,基于所述掩码卷积模块对所述中间特征图以及与所述中间特征图对应的掩码图像进行点乘处理,得到掩码处理的中间特征图,并基于每个掩码处理的中间特征图得到所述掩码处理后的中间多层级特征图。
73、在一种可能的实施方式中,所述解码器包括空间对齐模块、解码模块以及空间恢复模块,所述特征预测模块具体用于:
74、基于所述空间对齐模块将所述第一多层级特征图中的各个尺寸不同的特征图对齐到相同的空间分辨率,以使得所述第一多层级特征图中的各个特征图的尺寸大小对齐,得到空间对齐的多层级特征图;
75、基于所述空间对齐的多层级特征图分别对应的掩码图像,采用掩码标记替换所述空间对齐的多层级特征图中的掩码区域,得到带有掩码标记的多层级特征图,并基于所述解码模块对所述带有掩码标记的多层级特征图进行特征预测处理,得到特征预测处理的多层级特征图;
76、基于所述空间恢复模块将相同的空间分辨率的所述特征预测处理的多层级特征图恢复成原始尺寸的多层级特征图,得到所述恢复处理后的第一多层级特征图。
77、在一种可能的实施方式中,所述特征预测模块具体用于:
78、针对所述第一多层级特征图中的每个特征图,将所述特征图与目标图像进行比较;
79、当所述特征图的尺寸大于所述目标图像的尺寸的情况下,对所述特征图进行降维处理,使得所述特征图的尺寸与所述目标图像的尺寸一致;或者,
80、当所述特征图的尺寸小于所述目标图像的尺寸的情况下,采用最近邻插值对所述特征图进行上采样处理,使得所述特征图的尺寸与所述目标图像的尺寸一致。
81、在一种可能的实施方式中,所述特征预测模块具体还用于:
82、将所述第一多层级特征图的通道数与所述第二多层级特征图的通道数对齐;和/或,对所述第一多层级特征图以及所述第二多层级特征图进行层归一化处理。
83、在一种可能的实施方式中,所述特征预测模块具体用于:
84、针对每个空间对齐的特征图,对所述空间对齐的特征图进行展开处理,得到一维的展开特征图,并基于与所述空间对齐的特征图对应的掩码图像,确定所述展开特征图中需要替换的掩码区域;
85、采用所述掩码标记对所述掩码区域进行替换,得到带有所述掩码标记的展开特征图,并基于各个带有所述掩码标记的展开特征图,得到所述带有掩码标记的多层级特征图。
86、在一种可能的实施方式中,所述特征预测模块具体还用于:
87、在所述带有所述掩码标记的展开特征图中加上余弦绝对位置编码,并基于预设绝对尺度通过插值的方式对所述带有所述掩码标记的展开特征图进行自适应调整,得到调整后的展开特征图;
88、基于各个调整后的展开特征图,得到所述带有掩码标记的多层级特征图。
89、在一种可能的实施方式中,所述网络训练模块具体用于:
90、确定所述恢复处理后的第一多层级特征图以及所述第二多层级特征图之间的特征恢复损失,并基于所述特征恢复损失对所述第一网络的参数进行调整。
91、在一种可能的实施方式中,所述网络训练模块具体还用于:
92、基于所述恢复处理后的第一多层级特征图,确定所述第一网络的任务损失;和/或,通过全局上下文模块确定所述恢复处理后的第一多层级特征图以及所述第二多层级特征图分别对应的第一全局关系以及第二全局关系,并确定第一全局关系与所述第二全局关系之间的全局损失;
93、基于所述特征恢复损失、所述任务损失和/或所述全局损失,对所述第一网络的参数进行调整。
94、在一种可能的实施方式中,所述图像获取模块具体用于:
95、获取所述训练样本图像,并将所述训练样本图像进行划分,得到非重叠的多个图像块;
96、基于预设掩码比例,以随机采样的方式获取所述目标掩码图像,所述目标掩码图像包括用于指示对应的图像块被掩盖的掩码指示标识;
97、基于所述目标掩码图像中的掩码指示标识,对所述训练样本图像进行掩码处理,得到与所述训练样本图像对应的掩码样本图像。
98、在一种可能的实施方式中,所述装置还包括:
99、任务检测模块,用于获取待检测图像,并基于所述训练好的第一网络对所述待检测图像执行图像检测任务;所述图像检测任务包括目标检测任务、语义分割任务或者实例分割任务。
100、本公开实施例提供了一种电子设备,包括:处理器、存储器和总线,所述存储器存储有所述处理器可执行的机器可读指令,当电子设备运行时,所述处理器与所述存储器之间通过总线通信,所述机器可读指令被所述处理器执行时执行如上述任一可能的实施方式中所述的网络训练方法的步骤。
101、本公开实施例提供了一种计算机可读存储介质,该计算机可读存储介质上存储有计算机程序,该计算机程序被处理器运行时执行如上述任一可能的实施方式中所述的网络训练方法的步骤。
102、为使本公开的上述目的、特征和优点能更明显易懂,下文特举较佳实施例,并配合所附附图,作详细说明如下。