本发明涉及图像生成,尤其涉及一种带辅助分类器的置信度引导条件生成对抗网络的训练方法。
背景技术:
1、目前,生成对抗网络(generative adversarial network, gan)是一种流行的高保真图像生成模型,近年来得到了广泛的研究。尽管其它生成模型,如扩散模型,最近也因其在生成高质量图像方面的有效性而引起了很多关注,但生成对抗网络gan在实际应用中仍然具有几个显著的优势,例如其较低的训练和推理计算复杂度。gan的核心思想是利用对抗博弈的方法同时训练生成器和判别器:生成器利用随机噪声产生假数据以欺骗判别器;同时,判别器试图区分真实和虚假的数据。原始生成对抗网络的目标损失函数为:
2、
3、一般来说,生成对抗网络可以分为非条件生成对抗网络和条件生成对抗网络(conditional generative adversarial network,cgans)。非条件生成对抗网络接受无监督(无标签)的真实数据以及低维隐变量(通常为高斯或均匀分布随机向量)作为输入,生成与训练集中真实数据分布一致的图像;而条件生成对抗网络通过引入有监督学习,可以将类别标签或者图像的某种性质作为条件输入,生成指定类别或性质的图像;同时相比于非条件生成对抗网络往往具有更好的生成质量。尽管基于对抗学习的生成对抗网络在图像生成领域获得了巨大的成功,其本身还存在一些有待解决的问题:例如训练稳定性问题,模式坍塌等。ac-gan(auxiliary classifier generative adversarial network,带辅助分类器的条件生成对抗网络)作为一个具有代表性的带分类器的生成对抗网络,使用一个辅助分类器来学习条件标签分布,以指导生成器生成特定类的图像。虽然ac-gan可以实现较好的生成质量,但最近的研究表明,在实践中使用ac-gan经常会遇到两个问题:(1) 生成器的性能在早期训练阶段突然下降,即早期训练崩溃;(2) 生成器往往生成低多样性的数据。
技术实现思路
1、本发明的目的是提供一种带辅助分类器的置信度引导条件生成对抗网络的训练方法,该方法通过设计新的分类损失函数,避免特征表示大的特征范数,解决早期训练崩溃和过度自信问题。
2、本发明的目的是通过以下技术方案实现的:
3、一种带辅助分类器的置信度引导条件生成对抗网络的训练方法,所述方法包括:
4、步骤1、为条件生成对抗网络cgans设计新的损失函数,并引入一个超参数解决训练收敛问题;
5、步骤2、当带辅助分类器的置信度引导条件生成对抗网络,即cg-gan的分类器对于生成数据的置信度超过所引入的超参数时,所述损失函数将会抑制分类器对生成数据的置信度,通过抑制分类器对生成数据的置信度来隐式影响对真实数据的置信度;
6、步骤3、自定义一个先验标签分布,所述先验标签分布基于之前所引入的超参数,通过在cg-gan上增加一项反向或正向的kl散度作为正则化项来使优化生成数据输出的分布学习所述先验标签分布,提高cg-gan的分类能力。
7、由上述本发明提供的技术方案可以看出,上述方法通过设计新的分类损失函数,避免特征表示大的特征范数,解决早期训练崩溃和过度自信问题,在提高条件生成对抗网络训练稳定性的同时,进一步提高了条件生成性能。
1.一种带辅助分类器的置信度引导条件生成对抗网络的训练方法,其特征在于,所述方法包括:
2.根据权利要求1所述带辅助分类器的置信度引导条件生成对抗网络的训练方法,其特征在于,在步骤1中,
3.根据权利要求1所述带辅助分类器的置信度引导条件生成对抗网络的训练方法,其特征在于,在步骤2中,对于损失函数,采用蒙特卡洛采样将其写成经验loss形式,对于一个真实数据集内第个样本,同样有第个生成样本,首先依据损失函数定义:
4.根据权利要求1所述带辅助分类器的置信度引导条件生成对抗网络的训练方法,其特征在于,在步骤3中,在抑制分类器对生成数据的置信度时,需要提高分类器对生成数据的分类性能,具体来说: