一种基于联合训练生成对抗网络的半监督图像分类方法与流程

文档序号:23654924发布日期:2021-01-15 13:51阅读:91来源:国知局
一种基于联合训练生成对抗网络的半监督图像分类方法与流程

本发明属于图像处理技术领域,具体涉及一种基于联合训练生成对抗网络的半监督图像分类方法。



背景技术:

作为计算机视觉领域最常见的任务之一,图像分类通过提取原始图像的特征并根据特征进行分类。传统的特征提取主要是通过对图像的颜色、纹理、局部特征等几个方面进行分析处理实现的,例如尺度不变特征变换法,方向梯度法以及局部二值法等。但是这些特征都是人为设计的特征,很大程度上靠人类对识别目标的先验知识进行设计,具有一定的局限性。随着大数据时代的到来,基于深度学习的图像分类方法具有对大量复杂数据进行处理和表征的能力,能够有效学习目标的特征信息,从而大大提高图像分类的精度。

深度学习以数据驱动方式进行训练学习,对标签数据依赖性强,而实际应用中往往难以获取大量的标签数据。当样本数量不足时,深度网络模型容易过拟合,导致分类性能较差。生成对抗网络,也称gan网络,是由goodfellow等在2014年提出的,由一个生成器和一个判别器构成。生成器根据输入数据分布来生成尽可能逼真的伪数据,判别器用于判断输入数据是真实数据还是生成器生成的伪数据。在训练期间,生成器不断尝试通过产生越来越好的假图片来超越判别器,与此同时判别器逐渐更好的检测并正确分类真假图片,生成器和判别器经过博弈对抗达到纳什均衡,此时生成的数据能够拟合真实的数据分布。gan网络在训练时既能够生成样本,又能够提高特征提取能力,可以用来解决数据样本少的问题。但gan网络还存在稳定性差和依赖标签数据的问题,不能直接应用于分类任务中。

针对gan网络稳定性差的问题,目前已经有多种方法通过改进gan网络结构或优化算法来解决。但是目前针对依赖标签数据的问题,并没有有效的分类方法,因此亟需一种在一定程度上减小网络对标签数据的依赖、且能提高网络分类准确率的改进的gan网络。



技术实现要素:

本发明所要解决的技术问题在于针对上述现有技术中的不足,提供一种基于联合训练生成对抗网络的半监督图像分类方法,其结构简单、设计合理,采用判别器d1和判别器d2联合训练,以减小单个判别器误差对生成对抗网络的影响;利用大量无标签数据和少量标签数据进行联合训练,能够学习到泛化能力较强的模型,在一定程度上减小生成对抗网络对标签数据的依赖,利用无标签数据在训练时扩充标签数据集,加快网络收敛,提高生成对抗网络的分类准确率。

为解决上述技术问题,本发明采用的技术方案是:一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:包括以下步骤:

步骤一、设置生成对抗网络,包括生成器g、判别器d1和判别器d2,设置生成对抗网络的训练初始参数;

步骤二、获取训练集和测试集,训练集包括标签数据集l和无标签数据集u,将标签数据集l打乱随机分为标签子样本集l1和标签子样本集l2,其中,标签子样本集l1、l2包括k类标签数据;将无标签数据集u打乱随机分为无标签子样本集u1和无标签子样本集u2,其中,无标签子样本集u1包括g个无标签数据,无标签子样本集u2包括r个无标签数据;

步骤三、训练生成器g:

步骤301、将随机高斯噪声z输入生成器g生成伪数据g(z);

步骤302、将伪数据g(z)输入到判别器d1,判别器d1对伪数据g(z)进行判别得到d1(g(z));

步骤303、将伪数据g(z)输入到判别器d2,判别器d2对伪数据g(z)进行判别得到d2(g(z));

步骤304、计算生成器g的损失minlg;

步骤305、更新生成器g的训练参数;

步骤四、训练判别器d1和判别器d2:

步骤401、将标签子样本集l1输入到判别器d1,判别器d1输出k+1维分类预测概率{l11,...l1i,...l1k,l1(k+1)},其中l11至l1k表示标签子样本集l1中k类标签数据的置信度,l1(k+1)表示伪数据g(z)由判别器d1判定为“伪”的置信度;

步骤402、将无标签子样本集u1中的第n个无标签数据输入到判别器d1,判别器d1针对第n个无标签数据输出k+1维分类预测概率{h11-n,...h1i-n,...h1k-n,h1(k+1)-n},若max{h11-n,...h1j-n,...h1g-n}>η,则将无标签子样本集u1中第n个无标签数据加入标签子样本集l2中max{h11-n,...h1j-n,...h1g-n}所对应的标签类别,η表示置信度阈值,1≤n≤g;

步骤403、将标签子样本集l2输入到判别器d2,判别器d2输出k+1维分类预测概率{l21,...l2i,...l2k,l2(k+1)},其中l21至l2k表示标签子样本集l2中k类标签数据的置信度,l2(k+1)表示伪数据g(z)由判别器d2判定为“伪”的置信度;

步骤404、将无标签子样本集u2中的第m个无标签数据输入到判别器d2,判别器d2针对第m个无标签数据输出k+1维分类预测概率{h21-m,...h2i-m,...h2k-m,h2(k+1)-m},若max{h21-m,...h2j-m,...h2g-m}>η,则将无标签子样本集u2中的第m个无标签数据加入标签子样本集l1中max{h21-m,...h2j-m,...h2g-m}所对应的标签类别,η表示置信度阈值,1≤m≤r;

步骤405、计算判别器总损失maxld;

步骤406、更新判别器d1和判别器d2的训练参数;

步骤五、迭代更新:

步骤501、若判别器损失maxld收敛,结束迭代,得到训练好的生成对抗网络,否则进入步骤502;

步骤502、迭代执行步骤二到步骤五,每次迭代后,迭代次数加1,直到迭代次数等于最大迭代次数,迭代结束。

步骤六、利用测试集对生成对抗网络进行测试,生成对抗网络输出对测试集的分类结果,获得生成对抗网络的分类精度。

上述的一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:步骤304中生成器g的损失minlg的计算公式为:其中fu(·)表示判别器du中间层的特征值,u=1、2。

上述的一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:步骤403中判别器总损失的计算公式为其中表示判别器监督损失,其中yi表示标签数据集l中第i维数据的标签,du(xi)表示判别器du判别标签数据的标签为第i维的概率,maxlunsupd表示判别器无监督损失,y′i表示判别器前一次迭代时判别无标签数据的类别为第i维。

上述的一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:所述生成器g的网络结构依次为:输入层→全连接层→上采样层→卷积层conv1→上采样层→卷积层conv2→卷积层conv3。

上述的一种基于联合训练生成对抗网络的半监督图像分类方法,其特征在于:所述判别器d1和判别器d2的网络结构相同,判别器d1的网络结构依次为:输入层→卷积层conv1→卷积层conv2→卷积层conv3。

本发明与现有技术相比具有以下优点:

1、本发明的结构简单、设计合理,实现及使用操作方便。

2、本发明的基于联合训练生成对抗网络中,采用判别器d1和判别器d2进行联合训练,判别器的总损失为判别器d1损失和判别器d2损失的均值,以消除单个判别器存在的分布误差,从而以减小单个判别器误差对生成对抗网络的影响,提高判别器训练的稳定性。

3、本发明设置置信度阈值η,对每次迭代得到的无标签样本集的分类结果进行置信度判断,如果大于该置信度阈值,则将该标签数据加入到标签样本集中继续迭代训练,利用无标签样本集扩充标签样本集,从而加快生成对抗网络收敛,提高图像分类效率。

综上所述,本发明结构简单、设计合理,采用判别器d1和判别器d2联合训练,以减小单个判别器误差对生成对抗网络的影响;利用大量无标签数据和少量标签数据进行联合训练,能够学习到泛化能力较强的模型,在一定程度上减小生成对抗网络对标签数据的依赖,利用无标签数据在训练时扩充标签数据集,加快网络收敛,提高生成对抗网络的分类准确率。

下面通过附图和实施例,对本发明的技术方案做进一步的详细描述。

附图说明

图1为本发明的方法流程图。

图2为本发明生成器的结构示意图。

图3为本发明判别器的结构示意图。

具体实施方式

下面结合附图及本发明的实施例对本发明的方法作进一步详细的说明。

需要说明的是,在不冲突的情况下,本申请中的实施例及实施例中的特征可以相互组合。下面将参考附图并结合实施例来详细说明本发明。

需要注意的是,这里所使用的术语仅是为了描述具体实施方式,而非意图限制根据本申请的示例性实施方式。如在这里所使用的,除非上下文另外明确指出,否则单数形式也意图包括复数形式,此外,还应当理解的是,当在本说明书中使用术语“包含”和/或“包括”时,其指明存在特征、步骤、操作、器件、组件和/或它们的组合。

需要说明的是,本申请的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请的实施方式例如能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。

为了便于描述,在这里可以使用空间相对术语,如“在……之上”、“在……上方”、“在……上表面”、“上面的”等,用来描述如在图中所示的一个器件或特征与其他器件或特征的空间位置关系。应当理解的是,空间相对术语旨在包含除了器件在图中所描述的方位之外的在使用或操作中的不同方位。例如,如果附图中的器件被倒置,则描述为“在其他器件或构造上方”或“在其他器件或构造之上”的器件之后将被定位为“在其他器件或构造下方”或“在其他器件或构造之下”。因而,示例性术语“在……上方”可以包括“在……上方”和“在……下方”两种方位。该器件也可以其他不同方式定位(旋转90度或处于其他方位),并且对这里所使用的空间相对描述作出相应解释。

如图1所示,本发明的一种基于联合训练生成对抗网络的半监督图像分类方法,包括以下步骤:

步骤一、设置生成对抗网络,包括生成器g、判别器d1和判别器d2,设置生成对抗网络的训练初始参数。

在本申请基于联合训练生成对抗网络中,采用了判别器d1和判别器d2进行联合训练,以减小单个判别器误差对生成对抗网络的影响。判别器d1和判别器d2共享同一个生成器g,同时判别器d1和判别器d2的网络结构和训练初始参数设为相同。

步骤二、获取训练集和测试集,训练集包括标签数据集l和无标签数据集u,将标签数据集l打乱随机分为标签子样本集l1和标签子样本集l2,其中,标签子样本集l1、l2包括k类标签数据;将无标签数据集u打乱随机分为无标签子样本集u1和无标签子样本集u2,其中,无标签子样本集u1包括g个无标签数据,无标签子样本集u2包括r个无标签数据。

需要说明的是,将标签数据集l和无标签数据集u的顺序打乱随机分为两个子集,然后分别输入到判别器d1和判别器d2中,可以保证训练过程中,判别器d1和判别器d2是动态变化的。

步骤三、训练生成器g:

步骤301、将随机高斯噪声z输入生成器g生成伪数据g(z)。基于联合训练生成对抗网络的生成器g框架如图2所示,需要说明的是,所述生成器g的网络结构依次为:输入层→全连接层→上采样层→卷积层conv1→上采样层→卷积层conv2→卷积层conv3。

具体实施时,生成器g的输入为(128,100)的随机噪声,首先通过(100,8192)的全连接层得到(128,8192)的张量,经过维度转换得到维度为(128,128,8,8)的图像,经过两次上采样操作和三次步长为1的3×3卷积核的卷积操作后得到维度为(128,3,32,32)的图像,其中每次完成卷积操作后都是用归一化操作加入relu激活函数,最后一层通过tanh激活函数输出伪数据g(z)。

步骤302、将伪数据g(z)输入到判别器d1,判别器d1对伪数据g(z)进行判别得到d1(g(z));

步骤303、将伪数据g(z)输入到判别器d2,判别器d2对伪数据g(z)进行判别得到d2(g(z));

步骤304、计算生成器g的损失minlg:原始生成对抗网络中生成器的损失表示为为了让生成器生成的数据分布更接近真实数据的统计分布,采用特征匹配的方法对生成器的损失进行约束,定义特征匹配损失为:其中fu(·)表示判别器du中间层的特征值,u=1、2。因此生成器g的损失minlg的计算公式为:

步骤305、更新生成器g的训练参数。

步骤四、训练判别器d1和判别器d2:

步骤401、将标签子样本集l1输入到判别器d1,判别器d1输出k+1维分类结果{l11,...l1i,...l1k,l1(k+1)},其中l11至l1k表示标签子样本集l1中k类标签数据的置信度,l1(k+1)表示伪数据g(z)由判别器d1判定为“伪”的置信度;

步骤402、将无标签子样本集u1中的第n个无标签数据输入到判别器d1,判别器d1针对第n个无标签数据输出k+1维分类预测概率{h11-n,...h1i-n,...h1k-n,h1(k+1)-n},若max{h11-n,...h1j-n,...h1g-n}>η,则将无标签子样本集u1中第n个无标签数据加入标签子样本集l2中max{h11-n,...h1j-n,...h1g-n}所对应的标签类别,η表示置信度阈值,1≤n≤g。

具体实施时,如图3所示,判别器d1和判别器d2的网络结构相同,判别器d1的网络结构依次为:输入层→卷积层conv1→卷积层conv2→卷积层conv3→全连接层→softmax分类器。

判别器d1的输入为大小为32×32的3通道rgb彩色图像,其维度为(128,3,32,32),经过四次步长为2的3×3的卷积核的卷积操作,最终输出图像维度为(128,128,2,2),其中每次完成卷积操作后都加入leakyrelu激活函数和dropout操作以防止过拟合,而除了首次卷积不使用归一化外,其余卷积操作后都是用归一化。

设置置信度阈值η,对每次迭代得到的无标签子样本集u1的分类结果进行置信度判断,如果大于该置信度阈值η,则将该标签数据加入到标签子样本集l2中继续迭代训练,利用无标签子样本集u1扩充标签子样本集l2,从而加快生成对抗网络收敛。

步骤403、将标签子样本集l2输入到判别器d2,判别器d2输出k+1维分类预测概率{l21,...l2i,...l2k,l2(k+1)},其中l21至l2k表示标签子样本集l2中k类标签数据的置信度,l2(k+1)表示伪数据g(z)由判别器d2判定为“伪”的置信度;

步骤404、将无标签子样本集u2中的第m个无标签数据输入到判别器d2,判别器d2针对第m个无标签数据输出k+1维分类预测概率{h21-m,...h2i-m,...h2k-m,h2(k+1)-m},若max{h21-m,...h2j-m,...h2g-m}>η,则将无标签子样本集u2中的第m个无标签数据加入标签子样本集l1中max{h21-m,...h2j-m,...h2g-m}所对应的标签类别,η表示置信度阈值,1≤m≤r;

同理,设置置信度阈值η,步骤404中的置信度阈值η与步骤402中的置信度阈值η相同。对每次迭代得到的无标签子样本集u2的分类结果进行置信度判断,如果大于该置信度阈值η,则将该标签数据加入到标签子样本集l1中继续迭代训练,利用无标签子样本集u2扩充标签子样本集l1,从而加快生成对抗网络收敛。

步骤405、判别器总损失的计算公式为其中表示判别器监督损失,对于判别器的监督损失,需要加入标签信息,因此监督损失以交叉熵的形式定义为,其中yi表示标签数据集l中第i维数据的标签,du(xi)表示判别器du判别标签数据的标签为第i维的概率。表示判别器无监督损失,基于联合训练的生成对抗网络需要判别无标签数据的类别标签,以此判别器的无监督损失既判断真伪,也判断类别概率,所以无监督损失由两部分组成,考虑到两个判别器联合训练的情况,无监督损失定义为:y′i表示判别器前一次迭代时判别无标签数据的类别为第i维。

需要说明的是,判别器的总损失maxld为判别器d1损失和判别器d2损失的均值,以消除单个判别器存在的分布误差。

步骤406、更新判别器d1和判别器d2的训练参数。需要说明的是,所述判别器d1和判别器d2的初始训练参数相同及网络结构相同,在训练过程中动态变化,判别器d1和判别器d2参数共享。

本申请通过判别器d1和判别器d2的联合训练,一方面可以消除单个判别器存在的分布误差,提高判别器训练的稳定性;另一方面,利用无标签数据在训练时扩充标签数据集l,能够加快网络收敛。因此,本申请基于联合训练的生成对抗网络模型能够充分利用少量标签数据的标签信息和大量无标签数据的分布信息来获取整个样本的特征分布,迭代更新扩充了标签子样本集,从而进一步提高在小样本条件下网络图像分类的精度。

步骤五、迭代更新:

步骤501、若判别器损失maxld收敛,结束迭代,得到训练好的生成对抗网络,否则进入步骤502;

步骤502、迭代执行步骤二到步骤五,每次迭代后,迭代次数加1,直到迭代次数等于最大迭代次数,迭代结束,得到训练好的生成对抗网络。

步骤六、利用测试集对训练好的生成对抗网络进行测试,生成对抗网络输出对测试集的分类结果,获得生成对抗网络的分类精度。

以上所述,仅是本发明的实施例,并非对本发明作任何限制,凡是根据本发明技术实质对以上实施例所作的任何简单修改、变更以及等效结构变化,均仍属于本发明技术方案的保护范围内。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1