特征提取模型训练方法、装置、电子设备及存储介质与流程

文档序号:36381995发布日期:2023-12-14 15:42阅读:29来源:国知局
特征提取模型训练方法、装置、电子设备及存储介质与流程

本技术涉及计算机,特别涉及一种特征提取模型训练方法、装置、电子设备及存储介质。


背景技术:

1、对象识别(如人脸识别)通常基于对象的特征表示实现,因此特征表示的质量将直接影响对象识别任务的效果。

2、相关技术中,为了得到高质量的特征表示需要大规模的预训练对象数据(如人脸图像)来训练特征提取模型,而在对象信息(如人脸信息)作为隐私信息受到保护时,预训练对象数据的规模很难达到相关技术中的数据规模要求,使得这种情况下基于相关技术中的预训练方式训练的特征提取模型的特征提取性能较差,无法提供满足下游对象识别任务的高质量特征表示。


技术实现思路

1、为了解决现有技术的问题,本技术实施例提供了一种特征提取模型训练方法、装置、电子设备及存储介质。所述技术方案如下:

2、一方面,提供了一种特征提取模型训练方法,所述方法包括:

3、获取样本图像对应的样本遮蔽图像;所述样本图像包含目标对象,所述样本遮蔽图像是对所述样本图像中的部分区域进行遮蔽处理得到的图像;

4、基于待训练特征提取网络对所述样本遮蔽图像中的未遮蔽区域进行特征提取,得到样本特征图;所述样本特征图包括所述未遮蔽区域的特征;

5、基于待训练特征风格编码网络对所述样本特征图进行特征风格编码,得到多个风格隐向量;所述多个风格隐向量表征所述目标对象的不同特征层级的特征;

6、基于所述多个风格隐向量进行图像重建,得到重建图像;基于所述样本图像和所述重建图像,分别确定图像重建损失和特征重建损失;

7、基于所述图像重建损失和特征重建损失进行网络参数的调整,直至满足预设训练结束条件,将训练结束时的所述待训练特征提取网络作为目标特征提取模型;所述网络参数包括所述待训练特征提取网络和所述待训练特征风格编码网络的网络参数。

8、另一方面,提供了一种特征提取模型训练装置,所述装置包括:

9、遮蔽图像获取模块,用于获取样本图像对应的样本遮蔽图像;所述样本图像包含目标对象,所述样本遮蔽图像是对所述样本图像中的部分区域进行遮蔽处理得到的图像;

10、特征提取模块,用于基于待训练特征提取网络对所述样本遮蔽图像中的未遮蔽区域进行特征提取,得到样本特征图;所述样本特征图包括所述未遮蔽区域的特征;

11、特征风格编码模块,用于基于待训练特征风格编码网络对所述样本特征图进行特征风格编码,得到多个风格隐向量;所述多个风格隐向量表征所述目标对象的不同特征层级的特征;

12、重建及损失确定模块,用于基于所述多个风格隐向量进行图像重建,得到重建图像;基于所述样本图像和所述重建图像,分别确定图像重建损失和特征重建损失;

13、参数调整模块,用于基于所述图像重建损失和特征重建损失进行网络参数的调整,直至满足预设训练结束条件,将训练结束时的所述待训练特征提取网络作为目标特征提取模型;所述网络参数包括所述待训练特征提取网络和所述待训练特征风格编码网络的网络参数。

14、在一些示例性的实施方式中,所述重建及损失确定模块,包括:

15、图像重建损失确定模块,用于基于所述样本图像中对应遮蔽区域的图像块与所述重建图像中对应所述遮蔽区域的图像块之间的差异,确定图像重建损失;

16、视觉特征提取模块,用于基于自注意力机制分别提取所述样本图像的视觉特征和所述重建图像的视觉特征;

17、特征重建损失确定模块,用于基于所述样本图像的视觉特征与所述重建图像的视觉特征之间的距离,确定特征重建损失。

18、在一些示例性的实施方式中,所述视觉特征提取模块,包括:

19、第一视觉特征提取子模块,用于将所述样本图像输入至自注意力网络,通过所述自注意力网络的多个自注意力层依次对所述样本图像进行特征提取,将所述多个自注意力层中至少一个预设自注意力层的输出作为所述样本图像的视觉特征;

20、第二视觉特征提取子模块,用于将所述重建图像输入至所述自注意力网络,通过所述自注意力网络的多个自注意力层依次对所述重建图像进行特征提取,将所述至少一个预设自注意力层的输出作为所述重建图像的视觉特征;

21、所述特征重建损失确定模块,包括:

22、第一确定模块,用于基于每个所述预设自注意力层针对所述样本图像的输出与针对所述重建图像的输出之间的距离,确定每个所述预设自注意力层对应的损失;

23、第二确定模块,用于基于各所述预设注意力层对应的损失之和,得到所述特征重建损失。

24、在一些示例性的实施方式中,所述重建及损失确定模块,包括:

25、重建模块,用于将所述多个风格隐向量输入至已训练图像生成网络,通过所述已训练图像生成网络中与每个所述特征层级对应的图像生成子网络对每个所述特征层级对应的风格隐向量进行图像重建,得到每个所述特征层级的重建子图像;

26、融合模块,用于将各所述特征层级的重建子图像进行融合,得到重建图像;

27、其中,所述已训练图像生成网络为基于所述目标对象的训练图像集对基于风格的生成对抗网络进行训练得到。

28、在一些示例性的实施方式中,所述特征风格编码模块,包括:

29、中间特征提取模块,用于基于待训练特征风格编码网络中的多个卷积神经网络分别对所述样本特征图进行所述不同特征层级的特征提取,得到每个所述特征层级的中间特征图;

30、编码处理模块,用于对于每个所述特征层级的中间特征图,基于所述特征层级对应的多个编码网络分别对所述中间特征图进行编码处理,得到所述特征层级对应的多个隐向量;其中,所述多个编码网络与所述已训练图像生成网络中对应所述特征层级的图像生成子网络中的多个卷积层一一对应;

31、风格隐向量确定模块,用于基于所述已训练图像生成网络中,对应每个特征层级的图像生成子网络中的多个卷积层的逐级连接顺序,对每个所述特征层级对应的多个隐向量进行拼接,得到每个所述特征层级的风格隐向量。

32、在一些示例性的实施方式中,所述特征提取模块,包括:

33、未遮蔽区域提取模块,用于从所述样本遮蔽图像中提取未遮蔽区域,得到多个未遮蔽图像块;

34、块特征提取模块,用于基于待训练特征提取网络对所述多个未遮蔽图像块分别进行特征提取,得到每个未遮蔽图像块对应的一维特征向量;

35、预设向量获取模块,用于获取表征所述样本遮蔽图像中遮蔽区域的一维预设向量;

36、二维图像重塑模块,用于基于所述一维预设向量,以及每个所述未遮蔽图像块对应的一维特征向量和在所述样本遮蔽图像中的位置进行二维图像重塑,得到样本特征图。

37、在一些示例性的实施方式中,所述遮蔽图像获取模块,包括:

38、图像获取模块,用于获取所述样本图像以及预设掩码图;

39、遮蔽处理模块,用于对所述样本图像和所述预设掩码图进行点乘处理,得到样本遮蔽图像。

40、另一方面,提供了一种电子设备,包括处理器和存储器,所述存储器中存储有至少一条指令或者至少一段程序,所述至少一条指令或者所述至少一段程序由所述处理器加载并执行以实现上述任一方面的特征提取模型训练方法。

41、另一方面,提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有至少一条指令或者至少一段程序,所述至少一条指令或者所述至少一段程序由处理器加载并执行以实现如上述任一方面的特征提取模型训练方法。

42、另一方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。电子设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该电子设备执行上述任一方面的特征提取模型训练方法。

43、本技术实施例通过基于待训练特征提取网络对样本遮蔽图像中的未遮蔽区域进行特征提取得到包括未遮蔽区域特征的样本特征图,并基于待训练特征风格编码网络对样本特征图进行特征风格编码得到多个风格隐向量,该多个风格隐向量指示目标对象的不同特征层级的特征,进而基于多个风格隐向量进行图像重建得到重建图像,基于原始的样本图像与重建图像分别确定图像重建损失和特征重建损失,并基于图像重建损失和特征重建损失对待训练特征提取网络和待训练特征风格编码网络的网络参数进行调整,并将训练结束时的待训练特征提取网络作为目标特征提取模型,从而即使训练数据有限也能得到高特征提取性能的特征提取模型,为下游对象识别任务提供高质量的对象特征表示。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1