模型训练方法、设备和存储介质与流程

文档序号:33753352发布日期:2023-04-18 13:56阅读:31来源:国知局
模型训练方法、设备和存储介质与流程

本公开涉及计算机,具体涉及深度学习、机器学习等人工智能,尤其涉及模型训练方法、设备和存储介质。


背景技术:

1、目前,通常采用成式对抗网络(generative adversarial network,gan)模型来生成图像。其中,通常训练gan模型通常需要依赖巨大的训练数据,然而,很多实际任务中通常只有非常有限的样本,例如,罕见的物体、特殊风格的图像等。

2、相关技术中,在目标任务所对应的样本量较少的情况下,通常采用基于大数据训练所得到的源gan模型的网络参数对目标任务所对应的目标gan模型进行初始化,并基于目标任务对应的样本图像数据对目标gan模型进行训练。然而,上述方式训练得到的目标gan模型的泛化能力较差,容易出现过拟合现象。


技术实现思路

1、本公开提供了一种用于模型训练方法、设备和存储介质。

2、根据本公开的一方面,提供了一种模型训练方法,包括:获取与预训练好的源生成式对抗网络gan模型相同的目标gan模型;获取目标任务对应的样本图像集合;针对所述样本图像集合中的各个样本图像,确定所述源gan模型的生成器在生成所述样本图像时所使用的目标噪声变量;确定所述目标噪声变量所服从的数据分布;根据所述数据分布和所述样本图像集合对所述目标gan模型进行训练。

3、根据本公开的另一方面,提供了一种模型训练装置,包括:第一获取模块,用于获取与预训练好的源生成式对抗网络gan模型相同的目标gan模型;第二获取模块,用于获取目标任务对应的样本图像集合;第一确定模块,用于针对所述样本图像集合中的各个样本图像,确定所述源gan模型的生成器在生成所述样本图像时所使用的目标噪声变量;第二确定模块,用于确定所述目标噪声变量所服从的数据分布;训练模块,用于根据所述数据分布和所述样本图像集合对所述目标gan模型进行训练。

4、根据本公开的另一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开的模型训练方法。

5、根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述计算机执行本公开实施例公开的模型训练方法。

6、根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现本公开的模型训练方法。

7、应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。



技术特征:

1.一种模型训练方法,包括:

2.根据权利要求1所述的方法,其中,所述针对所述样本图像集合中的各个样本图像,确定所述源gan模型的生成器在生成所述样本图像时所使用的目标噪声变量,包括:

3.根据权利要求1所述的方法,其中,在所述根据所述数据分布和所述样本图像集合对所述目标gan模型进行训练之前,所述方法还包括:

4.根据权利要求1所述的方法,其中,所述根据所述数据分布和所述样本图像集合对所述目标gan模型进行训练,包括:

5.根据权利要求4所述的方法,其中,所述根据所述第一分类结果、所述特征向量和所述第二分类结果,对所述目标gan模型的生成器和判别器进行交替训练,直至满足训练结束条件,包括:

6.根据权利要求4所述的方法,其中,所述根据所述第一分类结果、所述特征向量和所述第二分类结果,确定所述目标gan模型的总损失值,包括:

7.根据权利要求6所述的方法,其中,所述根据所述特征矩阵,确定所述目标gan模型的第二损失值,包括:

8.根据权利要求1-7中任一项所述的方法,其中,在对所述目标gan进行t轮训练时,针对第t轮训练,在根据所述数据分布和所述样本图像集合对所述目标gan模型进行训练之前,所述方法还包括:

9.一种模型训练装置,包括:

10.根据权利要求9所述的装置,其中,所述第一确定模块,具体用于:

11.根据权利要求9所述的装置,其中,所述装置还包括:

12.根据权利要求9所述的装置,其中,所述训练模块,包括:

13.根据权利要求12所述的装置,其中,所述训练子模块,包括:

14.根据权利要求12所述的装置,其中,所述确定单元,包括:

15.根据权利要求14所述的装置,其中,所述第三确定子单元,具体用于:

16.根据权利要求9-15中任一项所述的装置,其中,在对所述目标gan进行t轮训练时,针对第t轮训练,所述装置还包括:

17.一种电子设备,包括:

18.一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述计算机执行权利要求1-8中任一项所述的方法。

19.一种计算机程序产品,包括计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1-8中任一项所述方法的步骤。


技术总结
本公开提供了一种模型训练方法、设备和存储介质,涉及深度学习、机器学习等人工智能技术领域。具体实现方案为:获取与预训练好的源生成式对抗网络GAN模型相同的目标GAN模型;获取目标任务对应的样本图像集合;针对样本图像集合中的各个样本图像,确定源GAN模型的生成器在生成样本图像时所使用的目标噪声变量;确定目标噪声变量所服从的数据分布;根据数据分布和样本图像集合对目标GAN模型进行训练,由此,基于样本图像集合在源GAN模型中所学习到的数据分布来对目标GAN模型进行训练,可更好地利用源GAN模型的信息,实现对源GAN模型的信息的继承以及目GAN标模型的自适应调整,避免目标GAN模型出现过拟合,提高目标GAN模型的泛化能力。

技术研发人员:李兴建,张泽人,窦德景
受保护的技术使用者:北京百度网讯科技有限公司
技术研发日:
技术公布日:2024/1/13
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1