一种基于反馈调节提高生成对抗网络稳定性的方法与流程

文档序号:15559739发布日期:2018-09-29 01:58阅读:668来源:国知局

本发明属于深度学习领域,涉及一种基于反馈调节提高生成对抗网络稳定性的方法。



背景技术:

生成对抗网络(gan)是属于生成学习中的一种网络结构。生成对抗网络是由判别网络(d网络)和生成网络(g网络)这两个网络组成。g网络和d网络通过对抗训练的方式更新权值。gan需要交替训练g网络和d网络,两个网络会相互影响,这导致了gan的稳定性不足。

针对gan稳定性不足的现有改进方案多为对网络结构的改变,而对网络的训练方式没有任何改变。传统的生成对抗网络的训练方式是由人工经验调节在每一批数据迭代中g网络和d网络的权值更新次数。该方式不带反馈,无法达到自适应的平衡,无法保持训练的稳定性。在训练的不同阶段根据两个网络的当前优化程度,动态调节g网络和d网络的权值更新次数,作为一种新的思路,能够提高gan网络的稳定性。并且该方法能够与传统方法相结合。若与传统方法相结合,能更好提高网络的稳定性。



技术实现要素:

本发明针对现有技术的不足,提供一种基于反馈调节的生成对抗网络参数更新方法,该方法从一个新的角度提高生成对抗网络的稳定性,并且该方法可以与传统方法结合使用。

步骤(1)、构建一对生成对抗网络。

1.1生成对抗网络的构建

生成对抗网络包含一个生成网络(g网络),一个判别网络(d网络),其中生成网络输入噪声数据,输出生成数据;判别网络对输入数据输出判别结果,设置当输入数据为真实数据时判别结果为1,当输入数据为为生成数据时判别结果为0。所述的真实数据由所供给的数据集供给,所述的噪声数据由高斯噪声发生器产生。

步骤(2)、裁剪数据集中的数据到同一维度,并将真实数据和噪声数据进行分批。

2.1真实数据的裁剪

获取的数据集中每一份数据的维度可能不同,通过裁剪的方式,将每一份数据统一到相同的维度上,裁剪后的数据为所述的真实数据。

2.2对噪声数据和真实数据进行分批

由供给数据集裁剪得到的真实数据为n组,将n组数据平均分为k批。使每批数据中含有适当量的数据。(其中适当量的判别属于本领域技术人员的基础技能,一般适当量大于50小于200)。对应真实数据的分批方式,将噪声数据也分为k批,每批数据中的含有的数据组数量与真实数据每批中含有的数据组数量相同。

步骤(3)、将每次迭代中d网络和g网络的更新次数kd和kg初始值都设置为1。

步骤(4)、随机抽取1批真实数据和1批噪声数据,并开始迭代。

4.1随机抽取数据并迭代

从k批噪声数据中随机抽取1批噪声数据,并将噪声数据输入g网络中,输出1批生成数据。从k批真实数据中随机抽取1批真实数据,将抽取到的该批真实数据和生成的生成数据输入d网络。

步骤(5)、开始1次迭代并计算g网络和d网络的损失函数值errg和errd,以及其商e。

5.1生成对抗网络的损失函数

生成对抗网络中d网络和g网络的损失函数值为errd和errg,具体表示为:

其中m为一批数据中的数据组数,i为在该批数据中的第i组数据,zi为输入的第i组噪声数据,g(zi)为由生成网络对第i组噪声数据输出的生成数据,d(xi)为对输入的真实数据在当前d网络权值下的判别结果,d(g(zi))为输入的生成数据在当前d网络参数下的判别结果。errd越小,说明在当前d网络的权值下,d网络的分类能力越强,越能区分生成数据和真实数据。errg越小,说明在当前g网络权值下,g网络的生成能力越强,生成的数据越相似于真实数据,导致当前d网络无法区分。

5.2生成对抗网络损失函数的商e

损失函数值的大小可以反映当前网络权值下,网络性能的好坏,由于d网络性能过好,会导致g网络性能变坏。其中d网络性能过好的极端情况为:对输入的生成数据全部输出为1,对输入的生成数据全部输出为0。此时生成网络生成的数据全部不能被g网络输出1,为g网络的最坏情况。网络在训练过程中的相互影响,为了使其平衡训练,故将两个损失值的商e作为新的变量来表述g网络和d网络的相对网络性能好坏。e的具体表示为:

步骤(6)、将e作为反馈信号,调节kd和kg。

6.1参数的反馈更新方式

传统的不带反馈的方式为:在每一次迭代中d网络更新kd次,g网络更新kg次,其中kd和kg为常数,一经确定,不可更改。通过经验调节kd和kg的大小,得到相对稳定的更新方式。

而带反馈的网络权值更新方式为:通过反馈e的值,动态改变kd和kg的大小。其中kd和kg具体表示为:

其中[e]为e的整数部分,[1/e]为1/e的整数部分。

步骤(7)、重复步骤4、步骤5、步骤6,直到本次网络训练结束,确保其动态稳定性。

本发明的有益效果如下:

本发明的关键在于自适应的调整d网络和g网络在每一次迭代中的更新次数。与在迭代中,d网络和g网络更新次数固定的传统方法相比更具有稳定性。本发明由于将稳定性着眼于更新次数,方法简单易于实现,对使用传统方法的工程无需重新构造,向下兼容,能够节省大量人力。并且可以与其它提高稳定性的方法相结合,进一步提高gan网络稳定性。

附图说明

图1为本发明的流程图。

图2为本发明反馈结构图。

具体实施方式

下面结合具体实施例对本发明做进一步的分析。

本实验将一组采集的人脸图片作为训练的样本数据集。在带反馈的生成对抗网络训练过程中具体包括以下步骤,如图1所示:

步骤(1)、构建一对生成对抗网络。

1.1生成对抗网络的构建

生成对抗网络包含一个生成网络(g网络),一个判别网络(d网络),生成网络输入噪声数据输出生成图片。判别网络输入可能为真实图片可能为生成图片,输出判别结果。其中对真实图片的判别结果为1,对生成图片的判别结果为0。所述的真实图片由人脸数据集供给,所述的噪声数据由高斯噪声发生器产生。

步骤(2)、裁剪人脸数据集中的图片,并将真实图片和噪声数据分批。

2.1真实图片的裁剪

获取的人脸数据集中图片大小各不相同,通过裁剪的方式,将每一张人脸图片统一到相同尺寸,裁剪后的人脸图片为所述的真实图片。

2.2对噪声数据和真实数据进行分批

由人脸数据集裁剪得到的真实图片为1000张图片,将这1000张图片平均分为200批,每批50张图片。产生1000份噪声数据,同样分为200批,每批50份噪声数据。

步骤(3)、将每次迭代中d网络和g网络的更新次数kd和kg初始值都设置为1。

步骤(4)、随机抽取1批真实图片和1批噪声数据,并开始迭代。

4.1随机抽取数据并迭代

从200批噪声数据中随机抽取1批噪声数据,并将噪声数据输入g网络中,输出1批生成图片。从200批真实图片中随机抽取1批真实图片,将抽取到的该批真实图片和生成图片输入d网络。

步骤(5)、开始一次迭代并计算g网络和d网络的损失函数值errg和errd,以及其商e。

5.1生成对抗网络的损失函数

生成对抗网络中d网络和g网络的损失函数值为errd和errg,具体表示为:

其中m为1批图片中的图片张数此处m=50,i为在该批图片中的第i张图片,zi为输入的第i组噪声数据,g(zi)为由生成网络对第i组噪声数据输出的生成数据,d(xi)为对输入的真实图片在当前d网络权值下的判别结果,d(g(zi))为输入的生成图片在当前d网络参数下的判别结果。errd越小,说明在当前d网络的权值下,d网络的分类能力越强,越能区分生成图片和真实图片。errg越小,说明在当前g网络权值下,g网络的生成能力越强,生成的图片越相似于真实图片,导致当前d网络无法区分。

5.2生成对抗网络损失函数的商e

损失函数值的大小可以反映当前网络权值下,网络性能的好坏,由于d网络性能过好,会导致g网络性能变坏。其中d网络性能过好的极端情况为:对输入的生成数据全部输出为1,对输入的生成数据全部输出为0。此时生成网络生成的数据全部不能被g网络输出1,为g网络的最坏情况。网络在训练过程中的相互影响,为了使其平衡训练,故将两个损失值的商e作为新的变量来表述g网络和d网络的相对网络性能好坏。e的具体表示为:

步骤(6)、将e作为反馈信号,调节kd和kg,如图2所示。

6.1参数的反馈更新方式

传统的不带反馈的方式为:在每一次迭代中d网络更新kd次,g网络更新kg次,其中kd和kg为常数,一经确定,不可更改。通过经验调节kd和kg的大小,得到相对稳定的更新方式。

而带反馈的网络权值更新方式为:通过反馈e的值,动态改变kd和kg的大小。其中kd和kg具体表示为:

其中[e]为e的整数部分,[1/e]为[1/e]的整数部分。

步骤(7)、重复步骤4、步骤5、步骤6,直到本次网络训练结束。确保其动态稳定性。

上述实施例并非是对于本发明的限制,本发明并非仅限于上述实施例,只要符合本发明要求,均属于本发明的保护范围。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1