一种图像分类方法、图像分类模型的训练方法及设备与流程

文档序号:26920902发布日期:2021-10-09 16:39阅读:205来源:国知局
一种图像分类方法、图像分类模型的训练方法及设备与流程

1.本发明实施例涉及图像分类技术领域,特别涉及一种图像分类方法、图像分类模型的训练方法及设备。


背景技术:

2.在图像识别与分类领域中,机器学习范围内的深度学习是一种有效的方法,产生了很多优秀的算法和网络,包括常见的卷积神经网络(convolutional neural network,简称“cnn”)、循环神经网络(recurrent neural network)、生成对抗网络(generative adversarial networks)、深度强化学习(reinforcement learning)四大主流网络结构。
3.但由于某些应用图像领域(例如医学图像领域)样本数据繁杂、需要专业人员才能进行标注,导致样本数据标注代价巨大,无法轻易获取大量的标注数据。图像数据集中数据集标签少、特殊分类任务效果差。


技术实现要素:

4.本发明提供一种图像分类方法、图像分类模型的训练方法及设备,以解决现有技术中存在的上述问题。
5.第一方面,本发明实施例提供了一种图像分类方法,该方法包括:
6.s10:将待分类图像切分成多个patch,生成每个patch对应的patch向量;通过线性层对每patch向量进行降维,将多个降维后的patch向量进行拼接,得到第一序列向量;在所述第一序列向量的首部嵌入一个可变向量,得到第二序列向量,其中,所述可变向量与每个降维后的patch向量尺寸相同,且所述可变向量对应所述多个patch中最能代表所述待分类图像的特征的patch;
7.s20:初始化所述第二序列向量的位置编码向量,其中,所述位置编码向量中包含所述多个patch在所述待分类图像中的位置信息;将所述初始化后的位置编码向量嵌入到所述第二序列向量中,得到输入向量;
8.s30:将所述输入向量输入到transformer模型的编码器,得到编码向量;取所述编码向量的首部的可变向量作为所述待分类图像的特征向量;将所述特征向量输入到所述transformer模型的分类器,得到所述待分类图像的预测类别概率。
9.在一实施例中,s10包括:
10.s110:将尺寸为h
×
w
×
c的所述待分类图像切分成m个尺寸为p
×
p
×
c的patch,其中,h和w分别表示所述待分类图像的高度和宽度,c表示所述待分类图像的通道数,p表示每个patch的宽度;
11.s120:将每个patch展开成一个patch向量,通过所述线性层将每个patch向量降至d维,生成所述第一序列向量x1=[x1;x2;

;x
m
],其中,x
i
表示第i个patch的patch向量,i=1、2

m,表示维度为d的向量域;
[0012]
s130:在x1的首部嵌入所述可变向量x
class
,得到所述第二序列向量x2=[x
class
;x1;
x2;

;x
m
],其中,
[0013]
在一实施例中,s20包括:
[0014]
s21:初始化x
class
的位置编码向量p0,初始化x
i
的位置编码向量p
i
,其中,所述第二序列向量的位置编码向量p=[p0;p1;p2;

;p
m
],j=0、1、2

m,p
j
中包含p
j
对应的patch在所述待分类图像中的位置信息;
[0015]
s22:将p嵌入到x2中,得到所述输入向量x[x
class
+p0;x1+p1;x2+p2;

;x
m
+p
m
]。
[0016]
在一实施例中,所述transformer模型包含所述编码器和所述分类器,不包含解码器,其中,
[0017]
所述编码器包括串行排列的的多头自注意力(multiheaded self

attention,msa)和第一多层感知器(multilayer perceptron,mlp),所述msa的输出为所述第一mlp的输入;所述msa与所述第一mlp的内部均采用残差连接方式;所述msa和所述第一mlp之前均连接有一个归一化层(layernorm,ln),待处理信号经过一个ln后再输入所述msa或所述第一mlp进行处理;
[0018]
所述分类器包括第二mlp。
[0019]
第二方面,本发明实施例还提供了一种图像分类模型的训练方法。该方法包括:
[0020]
s01:获取一个训练数据集d,其中,所述训练数据集中包括有标签数据集d
l
和无标签数据集d
u
,每个训练数据为一幅训练图像,每个有标签数据d
l
的标签为d
l
的真实类别y
l

[0021]
s02:对每个有标签数据d
l
进行一次随机数据增强,得到增强后的有标签数据集对每个无标签数据d
u
进行k次随机数据增强,得到k个增强后的无标签数据集k=1,...,k,将所有d
u
的k个的并集记为将每个无标签数据d
u
的k个分别输入本发明实施例所述的图像分类方法对应的图像分类模型,得到k个预测类别,对所述k个预测类别取平均,将得到的平均值作为d
u
的伪标签;
[0022]
s03:将输入所述图像分类模型,得到中的每个数据的预测类别概率;利用中所有数据的预测类别概率和真实类别,计算交叉熵损失;
[0023]
s04:将输入所述图像分类模型,得到中的每个数据的预测类别概率,将所述预测类别概率中的最大概率值对应的类别作为的预测类别;利用中的所有数据的预测类别和伪标签,计算一致性损失;
[0024]
s05:将所述交叉熵损失和所述一致性损失的加权和作为本轮训练的总损失,对所述图像分类模型中的网络参数进行训练,其中,所述网络参数包括:所述线性层的参数、所述编码器的参数和所述分类器的参数;
[0025]
s06:返回s01,直到满足设定的终止条件,保存训练过程中总损失最小时的网络参数,将对应的图像分类模型作为训练好的图像分类模型。
[0026]
在一实施例中,所述随机数据增强包括图像位移、改变图像的亮度、改变图像的对比度和改变图像的饱和度中的至少一种方式的随机组合,其中,图像的位移、图像的亮度、图像的对比度和图像的饱和度的改变值均为预设范围内的随机数。
[0027]
在一实施例中,s03中,利用中所有数据的预测类别概率和真实类别,计算交叉熵损失,包括:
[0028]
根据公式(1),利用中所有数据的预测类别概率和真实类别,计算所述交叉熵损失loss
l

[0029][0030]
其中,n表示中的数据的个数,表示的真实类别,p
l,a
表示所述图像分类模型预测得到的的类别为的概率。
[0031]
在一实施例中,s04中,利用中的所有数据的预测类别和伪标签,计算一致性损失,包括:
[0032]
根据公式(2),利用中的所有数据的预测类别和伪标签,计算一致性损失loss
u

[0033][0034]
其中,m表示中的数据的个数,ω(
·
)表示坡度函数,t表示全局迭代次数,y
u,k,b
表示的预测类别,表示的伪标签。
[0035]
在一实施例中,在s01之前,所述训练方法还包括:
[0036]
s011:对所述图像分类模型进行初始化,利用大数据集对初始化的模型进行预训练,得到源模型;
[0037]
s012:复制所述源模型的中的transformer模型的编码器的参数,并初始化所述transformer的分类器的参数,得到中间模型;
[0038]
在s02中,将每个无标签数据d
u
的k个分别输入本发明实施例所述的图像分类方法对应的图像分类模型,包括:
[0039]
将所述k个分别输入所述中间模型。
[0040]
第三方面,本发明实施例还提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述实施例所述的图像分类方法,或实现所述实施例所述的图像分类模型的训练方法。
[0041]
本发明提出了一种基于transformer的半监督网络的图像分类方法与图像分类模型的训练方法。本发明具有如下有益效果:
[0042]
1.本发明针对图像分类领域的特殊性,利用注意力机制思想,将transformer模型引入到图像分类任务中,解决了传统深度学习模型提取图像的全局信息困难的问题,有效地关注图像的全局信息,同时更注重图像内容的连续性,从而提高了在图像分类的分类效果;
[0043]
2.本发明通过伪标签预测和consistency regularization的方式,解决了图像分类领域的有标记数据获取困难的问题,仅仅用少量的有标签数据即可完成深度学习训练过
程,实现了半监督网络学习,并且有良好学习效果;
[0044]
3.本发明设计了适用于图像数据的数据结构,在transformer模型的基础上增加了图像分块处理、可变(可学习)特征向量嵌入及图像位置信息编码操作,实现了transformer模型及自注意力机制进行图像分类中的应用;
[0045]
4.本发明采用基于图像的transformer的模型多次识别无标签数据,预测无标签数据的伪标签,并将现有的预测类别与伪标签进行对比,通过保证二者的一致性来约束网络模型,实现了从大量无标签数据中学习有益的信息;
[0046]
5.本发明将交叉熵损失与一致性损失联合起来对网络模型进行训练,通过交叉熵损失来实现有标签数据对网络模型的约束,通过一致性损失从无标签数据中提取有益的信息,实现了对训练数据的充分利用,在更全面的信息下也提高了网络的收敛速度和图像分类的准确性。
附图说明
[0047]
图1是本发明实施例提供的一种图像分类方法的流程图。
[0048]
图2是本发明实施例提供的一种图像分类模型的训练方法的流程图。
[0049]
图3是本发明实施例提供的一种图像分类模型的整个训练过程的流程图。
[0050]
图4是本发明实施例提供的另一种图像分类方法的流程图。
[0051]
图5为本发明实施例提供的一种计算机设备的结构示意图。
具体实施方式
[0052]
下面结合附图与实施例对本发明做进一步说明。在不冲突的情况下,本发明中的实施例及实施例中的特征可以相互组合。
[0053]
应该指出,以下详细说明都是示例性的,旨在对本发明提供进一步的说明。除非另有指明,本文使用的所有技术和科学术语具有与本发明所属技术领域的普通技术人员通常理解的相同含义。
[0054]
需要注意的是,这里所使用的术语仅是为了描述具体实施方式,而非意图限制根据本发明的示例性实施方式。如在这里所使用的,除非上下文另外明确指出,否则单数形式也意图包括复数形式,此外,还应当理解的是,术语“包括”和“具有”以及它们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
[0055]
本发明实施例提出一种基于transformer的半监督网络的图像分类方法与相应的模型训练方法。
[0056]
由于某些应用图像领域(例如医学图像领域)样本数据繁杂、需要专业人员才能进行标注,导致样本数据标注代价巨大,无法轻易获取大量的标注数据。图像数据集中数据集标签少、特殊分类任务效果差。基于现有的少量标记样本数据和大量未标记的样本数据进行深度学习的半监督学习算法(semi

supervised learning,简称为“ssl”)可以利用仅仅一小部分的有标注数据就可以完成网络模型的训练。
[0057]
目前计算机视觉领域的transformer模型已经比肩甚至超越了传统的卷积神经网
络,达到了sota(state of the art,指在该项研究任务中,目前最好/最先进的技术)的水平。transformer模型具有cnn模型不具有的捕捉长期数据之间的依赖信息,容易获得全局图像的有效信息,是一种性能优异提取特征的模型。
[0058]
实施例一
[0059]
本实施例提出一种图像分类方法。图1是本发明实施例提供的一种图像分类方法的流程图。如图1所示,该方法包括s10

s30。
[0060]
s10:将待分类图像切分成多个patch,生成每个patch对应的patch向量;通过线性层对每patch向量进行降维,将多个降维后的patch向量进行拼接,得到第一序列向量;在所述第一序列向量的首部嵌入一个可变向量,得到第二序列向量,其中,所述可变向量与每个降维后的patch向量尺寸相同,且所述可变向量对应所述多个patch中最能代表所述待分类图像的特征的patch。
[0061]
可选地,将待分类图像切分成多个patch,并且铺平成一个序列,接着通过学习好的线性投影进行降维操作,减少维度,得到一个序列向量。最后在序列向量的首部嵌入一个同patch大小的向量,这个向量初始化是随机的,在预测阶段中是可变的,最终得到的patch向量是所有patch中最具有分类代表性的一个patch对应的向量。这个向量是专门用来做解码功能的,它与transformer的编码器对应。
[0062]
s20:初始化所述第二序列向量的位置编码向量,其中,所述位置编码向量中包含所述多个patch在所述待分类图像中的位置信息;将所述初始化后的位置编码向量嵌入到所述第二序列向量中,得到输入向量。
[0063]
s30:将所述输入向量输入到transformer模型的编码器,得到编码向量;取所述编码向量的首部的可变向量作为所述待分类图像的特征向量;将所述特征向量输入到所述transformer模型的分类器,得到所述待分类图像的预测类别概率。
[0064]
可选地,将处理后的图像数据向量输入transformer模型的编码器内得到编码后的结果,其中,编码后的结果同样为向量形式,同输入向量的维度一致。再取位置零处的patch向量作为整个图片的特征向量,并且输入到transformer模型的分类器,最后得到预测类别概率。“位置零处”的patch向量是指第二序列向量的首部位置的嵌入向量。
[0065]
在一实施例中,s10包括s110

s130。
[0066]
s110:将尺寸为h
×
w
×
c的所述待分类图像切分成m个尺寸为p
×
p
×
c的patch,其中,h和w分别表示所述待分类图像的高度和宽度,c表示所述待分类图像的通道数,p表示每个patch的宽度。
[0067]
s120:将每个patch展开成一个patch向量,通过所述线性层将每个patch向量降至d维,生成所述第一序列向量x1=[x1;x2;

;x
m
],其中,x
i
表示第i个patch的patch向量,i=1、2

m,表示维度为d的向量域。
[0068]
s130:在x1的首部嵌入所述可变向量x
class
,得到所述第二序列向量x2=[x
class
;x1;x2;

;x
m
],其中,
[0069]
由于transformer模块需要连续化输入,因此需要把输入图像进行切分。可选地,将图像切分成等大的正方形patch,使得一张原始图片大小从h
×
w
×
c切分成m个p
×
p
×
c大小的patch。其中h、w分别是图像的高度和宽度,c表示图像通道数,p表示正方形patch的宽
度,则m=wh/p2。可以根据具实际情况w和h的大小来选择p。可选地,将p选取为2的整数次幂。
[0070]
接着用一个线性层进行数据降至d维,减少无用特征输入。最后在输入向量的起始位置嵌入一个可变的特征向量,用于输出进行分类预测的依据x,即x[x
class
;x1;x2;

;x
n
],
[0071]
需要说明的是,这里的“连续化输入”是指transformer的输入需要满足连续化要求。例如,输入一句话时,其中的每个单词都具有关系性、连续性。在本发明中,将transformer模型用于图像任务,因此需要把一幅图像切分成多个patch后排列成一个序列,也就如同“输入一句话”一样,多个patch之间具有关联性和连续性。
[0072]
在一实施例中,s20包括:s21

s22。
[0073]
s21:初始化x
class
的位置编码向量p0,初始化x
i
的位置编码向量p
i
,其中,所述第二序列向量的位置编码向量p=[p0;p1;p2;

;p
m
],j=0、1、2

m,p
j
中包含p
j
对应的patch在所述待分类图像中的位置信息。
[0074]
s22:将p嵌入到x2中,得到所述输入向量x[x
class
+p0;x1+p1;x2+p2;

;x
m
+p
m
]。
[0075]
经过s10的分块操作,必然会丢失图像原本的位置信息,因此需要在输入向量中增加可学习的位置编码向量。弥补了丢失的位置信息,同时可学习的设定保证了获得价值最高的位置信息。patch的嵌入向量(即patch向量)和patch的位置编码向量一同作为transformer模型的编码器的输入,此过程可表示为x=[x
class
+p0;x1+p1;x2+p2;

;x
n
+p
n
]。
[0076]
需要说明的是,这里的“可学习的设定”在序列首部嵌入的可变向量在网络学习训练过程中会一直变化,根据注意力机制着重哪个patch,那么这个向量就会更新为这个patch对应的特征向量。
[0077]
在一实施例中,所述transformer模型包含所述编码器和所述分类器,不包含解码器。
[0078]
所述编码器包括串行排列的的msa和第一mlp,所述msa的输出为所述第一mlp的输入。所述msa与所述第一mlp的内部均采用残差连接方式。所述msa和所述第一mlp之前均连接有一个ln,待处理信号经过一个ln后再输入所述msa或所述第一mlp进行处理。所述分类器包括第二mlp。
[0079]
可选地,将s20得到的结果向量输入transformer模型中获得类别概率。transformer模型包括两部分:一是编码器,另一个是分类器,编码器与分类器串联。编码器包括msa和第一mlp,msa和第一mlp均是残差连接方式,msa的输出为第一mlp的输入。msa和第一mlp都经过ln进行图像通道归一化操作。分类器包括第二mlp。
[0080]
可选地,编码器部分负责提取图像的全局信息以及对任务有帮助的图像区域,而分类器负责根据图像特征进行分类获得其对应每类的概率值。整个transformer模型并没有设计解码器组件,由于s10中提前嵌入了一个可变向量,由此向量充当解码器组件,作用是选一个最有效分类的patch。这一设置使得整个模型更加简单、有效。
[0081]
本发明提出了一种基于transformer的半监督网络的图像分类方法。本发明具有如下有益效果:
[0082]
1.本发明实施例针对图像分类领域的特殊性,利用注意力机制思想,将
transformer模型引入到图像分类任务中,解决了传统深度学习模型提取图像的全局信息困难的问题,有效地关注图像的全局信息,同时更注重图像内容的连续性,从而提高了在图像分类的分类效果;
[0083]
2.本发明实施例利用训练好的模型进行图像分类,在模型的训练过程中,通过伪标签预测和consistency regularization的方式,解决了图像分类领域的有标记数据获取困难的问题,仅仅用少量的有标签数据即可完成深度学习训练过程,实现了半监督网络学习,并且有良好学习效果;
[0084]
3.本发明实施例设计了适用于图像数据的数据结构,在transformer模型的基础上增加了图像分块处理、可变(可学习)特征向量嵌入及图像位置信息编码操作,实现了transformer模型及自注意力机制进行图像分类中的应用;
[0085]
4.本发明实施例利用训练好的模型进行图像分类,在模型的训练过程中,采用基于图像的transformer的模型多次识别无标签数据,预测无标签数据的伪标签,并将现有的预测类别与伪标签进行对比,通过保证二者的一致性来约束网络模型,实现了从大量无标签数据中学习有益的信息;
[0086]
5.本发明实施例利用训练好的模型进行图像分类,在模型的训练过程中,将交叉熵损失与一致性损失联合起来对网络模型进行训练,通过交叉熵损失来实现有标签数据对网络模型的约束,通过一致性损失从无标签数据中提取有益的信息,实现了对训练数据的充分利用,在更全面的信息下也提高了网络的收敛速度和图像分类的准确性。
[0087]
实施例二
[0088]
本实施例提供一种图像分类模型的训练方法,用于对实施例一所述的图像分类方法所构成的图像分类模型进行训练。图2是本发明实施例提供的一种图像分类模型的训练方法的流程图。如图2所示,该方法包括步骤s01

s06。
[0089]
s01:获取一个训练数据集d,其中,所述训练数据集中包括有标签数据集d
l
和无标签数据集d
u
,每个训练数据为一幅训练图像,每个有标签数据d
l
的标签为d
l
的真实类别y
l

[0090]
s02:对每个有标签数据d
l
进行一次随机数据增强,得到增强后的有标签数据集对每个无标签数据d
u
进行k次随机数据增强,得到k个增强后的无标签数据集k=1,...,k,将所有d
u
的k个的并集记为将每个无标签数据d
u
的k个分别输入实施例一所述的图像分类方法对应的图像分类模型,最终得到k个预测类别,对所述k个预测类别取平均,将得到的平均值作为d
u
的伪标签。
[0091]
可选地,将所有的无标签数据进行随机数据增强,重复k次,然后把增强后的无标签数据输入模型中进行预测,得到k个预测类别,最后进行取平均操作作为无标签数据的伪标签。需要说明的是,由于在代码实现过程中已将类别数字化,因此取类别的平均值可以预测出无标签数据所述的类别。
[0092]
可选地,首先对原始数据集做数据增强处理。有标签数据集d
l
={d1,d2,

,d
n1
}(其中,n1表示有标签数据的数量)。无标签数据集为d
u
={d
n1+1
,d
n1+2
,

,d
n2
}(其中,n2

n1表示无标签数据的数量)。将数据集d
l
作一次随机数据增强操作,得到集合将数据集x
u
做k次随机数据增强操作,得到k个集合
k∈(1,...,k)。然后将输入所述图像分类模型的初始化网络中进行伪标签预测,得到k∈(1,...,k)。最后利用k次预测结果进行取平均得到最终的伪标签,即的伪标签,即需要说明的是,这里的“初始化网络”是指先利用大数据集对网络模型进行预训练后再用来作具体的分类任务。关于整个网络模型的训练阶段,将在后面进行详细描述。
[0093]
在一实施例中,所述随机数据增强包括图像位移、改变图像的亮度、改变图像的对比度和改变图像的饱和度中的至少一种方式的随机组合,其中,图像的位移、图像的亮度、图像的对比度和图像的饱和度的改变值均为预设范围内的随机数。
[0094]
对于基于伪标签和预测标签一致性实现半监督学习算法来说,随机数据增强的好坏很大程度上决定了算法的好坏。本发明针对图像领域数据集特点设计了合理的数据增强方法。
[0095]
s03:将输入所述图像分类模型,得到中的每个数据的预测类别概率;利用中所有数据的预测类别概率和真实类别,计算交叉熵损失。
[0096]
可选地,利用全部有标签数据的预测类别概率,计算其概率最大值所对应类别为预测类别,利用预测类别与真实标签类别进行交叉熵损失计算。交叉熵损失函数可以约束网络模型对有标记数据类别的预测与真实样本类别,使得网络模型输出更加逼近真实样本数据分布。
[0097]
在一实施例中,s03中,利用中所有数据的预测类别概率和真实类别,计算交叉熵损失,包括:根据公式(1),利用中所有数据的预测类别概率和真实类别,计算所述交叉熵损失loss
l

[0098][0099]
其中,n表示中的数据的个数,表示的真实类别,p
l,a
表示所述图像分类模型预测得到的的类别为的概率。
[0100]
s04:将输入所述图像分类模型,得到中的每个数据的预测类别概率,将所述预测类别概率中的概率最大值所对应的类别作为的预测类别;利用中的所有数据的预测类别和伪标签,计算一致性损失。
[0101]
可选地,利用全部无标签数据的预测类别概率,计算其概率最大值所对应类别为预测类别。将现在输出的预测结果(预测类别)与历史输出的预测结果(伪标签)做一致性损失计算;一致性损失函数可以约束网络模型对无标签数据类别的预测与历史输出的预测结果,使得它们尽量保持一致。由于同一数据的预测结果不变性,它们应当保持一致。基于此原理,可以挖掘无标签数据的有益信息,并且不需要已知标签信息。
[0102]
在一实施例中,s04中,利用中的所有数据的预测类别和伪标签,计算一致性
损失,包括:根据公式(2),利用中的所有数据的预测类别和伪标签,计算一致性损失loss
u

[0103][0104]
其中,m表示中的数据的个数,m=(n2

n1)
×
k,ω(
·
)表示坡度函数,t表示全局迭代次数,y
u,k,b
表示的预测类别,表示的伪标签。
[0105]
需要说明的是,交叉熵损失只能用有标签数据计算,因为它需要用到数据的真实标签信息。如果使用伪标签信息,那么会造成强噪声干扰,不利于模型训练。而一致性损失只用到了无标签数据结果,因为有标签数据的价值信息已经被交叉熵损失利用了,而无标签数据的伪标签信息还没有被利用。
[0106]
s05:将所述交叉熵损失和所述一致性损失的加权和作为本轮训练的总损失,对所述图像分类模型中的网络参数进行训练,其中,所述网络参数包括:所述线性层的参数、所述编码器的参数和所述分类器的参数。
[0107]
可选地,将交叉熵损失和一致性损失加权和做为总损失,不断进行训练,直到训练轮次达到设定值,保存其最小损失值时得网络模型。两种损失函数结合起来,可以同时使用有标签数据和无标签数据学习训练,同时得到一个批次内的有标签数据和无标签数据的有益信息,更正模型参数,为下一轮训练做准备。
[0108]
可选地,将交叉熵损失loss
l
和一致性损失loss
u
加权和做为总损失loss=loss
l
+λloss
u
(其中λ是超参数),不断进行训练,使得loss呈现下降趋势,直到训练轮次达到设定值或者loss呈现平稳趋势。
[0109]
s06:返回s01,直到满足设定的终止条件,保存训练过程中总损失最小时的网络参数,将对应的图像分类模型作为训练好的图像分类模型。
[0110]
在一实施例中,在s01之前,所述训练方法还包括:s011

s012。
[0111]
s011:对所述图像分类模型进行初始化,利用大数据集对初始化的模型进行预训练,得到源模型。
[0112]
s012:复制所述源模型的中的transformer模型的编码器的参数,并初始化所述transformer的分类器的参数,得到中间模型。
[0113]
这时,在s02中,将每个无标签数据d
u
的k个分别输入实施例一所述的图像分类方法对应的图像分类模型,包括:将所述k个分别输入所述中间模型。
[0114]
图3是本发明实施例提供的一种图像分类模型的整个训练过程的流程图。下面将结合图3,对整个图像分类模型的完整训练过程进行说明。模型的完整的训练过程需要经过初始化、预训练、复制、微调四个环节。
[0115]
首先对模型的中间层以及输出层的参数进行初始化,然后用大数据集进行模型的预训练,训练完成后获得源模型以及参数。接着复制源模型的中间层参数并且初始化输出层组成中间模型。最后用任务数据集对中间模型进行训练,微调中间层参数,学习目标输出层的参数,获得鲁棒性能优良的目标模型。
[0116]
在本发明实施例中,模型的中间层包括:线性层和整个transformer模型的编码
器,输出层包括整个transformer的分类器。
[0117]
在预测无标签数据的伪标签的过程中,用到的是整个图像分类模型的中间模型。
[0118]
整个图像分类模型的流程可概括如下:(1)用户输入待测试图像数据进入分类系统,(2)分类系统内部自动进行图像分块处理、获取类别概率和确定预测类别三个过程,(3)输出预测类别与用户进行交互。
[0119]
本发明提出了一种基于transformer的半监督网络的图像分类模型的训练方法。本发明具有如下有益效果:
[0120]
1.本发明实施例针对图像分类领域的特殊性,利用注意力机制思想,将transformer模型引入到图像分类任务中,解决了传统深度学习模型提取图像的全局信息困难的问题,有效地关注图像的全局信息,同时更注重图像内容的连续性,从而提高了在图像分类的分类效果;
[0121]
2.本发明实施例通过伪标签预测和consistency regularization的方式,解决了图像分类领域的有标记数据获取困难的问题,仅仅用少量的有标签数据即可完成深度学习训练过程,实现了半监督网络学习,并且有良好学习效果;
[0122]
3.本发明实施例设计了适用于图像数据的数据结构,在transformer模型的基础上增加了图像分块处理、可变(可学习)特征向量嵌入及图像位置信息编码操作,实现了transformer模型及自注意力机制进行图像分类中的应用;
[0123]
4.本发明实施例采用基于图像的transformer的模型多次识别无标签数据,预测无标签数据的伪标签,并将现有的预测类别与伪标签进行对比,通过保证二者的一致性来约束网络模型,实现了从大量无标签数据中学习有益的信息;
[0124]
5.本发明实施例将交叉熵损失与一致性损失联合起来对网络模型进行训练,通过交叉熵损失来实现有标签数据对网络模型的约束,通过一致性损失从无标签数据中提取有益的信息,实现了对训练数据的充分利用,在更全面的信息下也提高了网络的收敛速度和图像分类的准确性。
[0125]
实施例三
[0126]
图4是本发明实施例提供的另一种图像分类方法的流程图。该方法基于transformer的半监督算法实现图像分类的网络学习过程,包括训练阶段和预测阶段。如图4所示,该方法包括s1

s8。
[0127]
s1:预测伪标签。首先将所有的无标签数据进行随机数据增强,重复k次,然后把增强后的无标签数据输入模型中进行预测,得到k个伪标签,最后进行取平均操作作为无标签数据的伪标签。
[0128]
s2:图像分块处理。将输入的图像切分成多个patch,并且铺平成一个序列,接着通过可学习的线性投影进行降维操作。最后在所有patch对应的序列向量的首部嵌入一个同patch大小的向量(简称为“patch嵌入向量”)。这个向量初始化是随机的,在训练过程中可学习(即是可变的)。这个向量是专门用来做解码功能,它与编码器对应,学习得到的patch嵌入向量是所有patch中最具有分类代表性的一个。
[0129]
s3:嵌入位置编码。初始化位置编码向量加入到图像分块处理操作后的序列向量,一同作为输入向量。
[0130]
s4:获取类别概率:将处理后的图像数据向量输入transformer模型的编码器内得
到编码后的结果,其中,编码后的结果同样为向量形式,同输入向量的维度一致。再取位置零处的patch嵌入向量作为整个图片的特征向量,并且输入到transformer模型的分类器,最后得到预测类别概率。
[0131]
s5:计算交叉熵损失。利用全部有标签数据的预测类别概率,其概率最大值所对应类别为预测类别。利用预测类别与真实标签类别进行交叉熵损失计算。交叉熵损失函数可以约束网络模型对有标记数据类别的预测与真实样本类别,使得网络模型输出更加逼近真实样本数据分布。
[0132]
s6:计算一致性损失。利用全部无标签数据的预测类别概率,其概率最大值所对应类别为预测类别。将现在输出的预测结果与历史输出的预测结果(伪标签)做一致性损失计算。一致性损失函数可以约束网络模型对无标签数据类别的预测与历史输出的预测结果,使得它们尽量保持一致。由于同一数据的预测结果不变性,它们应当保持一致。基于此原理,可以挖掘无标签数据的有益信息,并且不需要已知标签信息。
[0133]
s7:联合训练。将交叉熵损失和一致性损失加权和做为总损失,不断进行训练,直到训练轮次达到设定值。保存其最小损失值时得网络模型。两种损失函数结合起来,可以同时使用有标签数据和无标签数据学习训练,同时得到一个批次内的有标签数据和无标签数据的有益信息,更正模型参数,为下一轮训练做准备。
[0134]
s8:预测类别。利用训练好得网络模型对输入的图像数据进行预测,得到预测类别概率,将最大概率值对应的类别确定为预测结果。
[0135]
在上述方法中,s1和s7属于训练阶段,s8属于预测阶段。在预测阶段,将图像输入到训练好的网络模型后,在网络模型中只执行s2

s4。
[0136]
在一实施例中,在s1:预测伪标签的步骤中,首先,对原始数据集做数据增强处理。有标签数据集d
l
={d1,d2,

,d
n1
}(其中,n1表示有标签数据的数量)。无标签数据集为d
u
={d
n1+1
,d
n1+2
,

,d
n2
}(其中,n2

n1表示无标签数据的数量)。将数据集d
l
作一次随机数据增强操作,得到集合操作,得到集合将数据集x
u
做k次随机数据增强操作,得到k个集合做k次随机数据增强操作,得到k个集合k∈(1,...,k)。然后将输入所述图像分类模型的初始化网络中进行伪标签预测,得到k∈(1,...,k)。最后利用k次预测结果进行取平均得到最终的伪标签,即其中,“初始化网络”是指图3中的中间模型,即先用大数据集预训练后再用来处理具体的分类任务。
[0137]
对于基于伪标签和预测标签一致性实现半监督学习算法来说,随机数据增强的好坏很大程度上决定了算法的好坏。本发明实施例针对图像领域数据集的特点设计了合理的数据增强方法。
[0138]
随机数据增强包括图像的位移、改变图像的亮度、改变图像的对比度、改变图像的饱和度四种方式中的至少一种随机组合。其中,图像的位移、图像的亮度、图像的对比度、图像的饱和度的改变值全部采用一定范围内的随机数。
[0139]
在s2:图像分块处理的步骤中,由于transformer模型需要连续化输入,因此需要把输入图像进行切分,切分成等大的正方形patch,使得一张原始图片大小从h
×
w
×
c切分成m个p
×
p
×
c大小的patch。其中,h、w分别表示图像的高度和宽度,c表示图像的通道数,p
表示正方形patch的宽度,则m=wh/p2。可以根据具实际情况中的w和h的大小来选择p,一般p是2的整数次幂。接着用一个线性层将数据降至d维,减少无用特征输入。最后在输入向量的起始位置嵌入一个可学习(即可变的)的特征向量,用于输出进行分类预测的依据x,即x=[x
class
;x1;x2;

;x
n
],其中,“连续化输入”是指transformer用需要输入满足连续化要求,比如输入一句话,其中的每个单词都具有关系性、连续性。在本发明中,将transformer模型用于图像任务,因此需要把一幅图像切分成多个patch后排列成一个序列,也就如同“输入一句话”一样,多个patch之间具有关联性和连续性。
[0140]
在s2中,输入图像是指全部数据集,也就是包括了数据增强后的有标签数据和无标签数据
[0141]
在s3:嵌入位置编码的步骤中,经过s2的分块操作,必然会丢失图像原本的位置信息,因此需要在输入向量中增加可学习的位置编码向量。弥补了丢失的位置信息,同时,可学习的设定保证了获得价值最高的位置信息。patch的嵌入向量(即patch向量)和patch的位置编码向量一同作为transformer模型的编码器的输入,此过程可表示为x=[x
class
+p0;x1+p1;x2+p2;

;x
n
+p
n
]。“可学习的设定”是指这个向量在网络学习训练的过程中会一直变化,根据注意力机制着重哪个patch,那么这个向量就会更新为这个patch对应的特征向量。
[0142]
在s4:获取类别概率的步骤中,将s3的结果向量输入transformer模型中获得类别概率。transformer模型包括两部分:一是编码器,另一个是分类器,编码器与分类器串联。编码器包括msa和第一mlp,msa和第一mlp均是残差连接方式,msa的输出为第一mlp的输入。msa和第一mlp都经过ln进行图像通道归一化操作。分类器包括第二mlp。
[0143]
编码器部分负责提取图像的全局信息以及对任务有帮助的图像区域,而分类器负责根据图像特征进行分类获得其对应每类的概率值。整个transformer模型并没有设计解码器组件,由于s2中提前嵌入了一个可变向量,由此向量充当解码器组件,作用是选一个最有效分类的patch。这一设置使得整个模型更加简单、有效。
[0144]
在s5:计算交叉熵损失的步骤中,利用全部随机数据增强后的有标签数据(中的数据)的预测概率p
l
与真实标签类别计算交叉熵损失loss
l

[0145][0146]
其中,n表示中的数据的个数,表示的真实类别,p
l,a
表示所述图像分类模型预测得到的的类别为的概率。
[0147]
在s6:计算一致性损失的步骤中,利用全部随机数据增强后无标签数据(中的数据)的预测类别y
u,k
,与s1步骤的伪标签预测结果做一致性损失计算:
[0148][0149]
其中,m表示中的数据的个数,m=(n2

n1)
×
k,ω(
·
)表示坡度函数,t
表示全局迭代次数,y
u,k,b
表示的预测类别,表示的伪标签。
[0150]
交叉熵损失只能用有标签数据计算,因为它需要用到数据的真实标签信息。如果使用伪标签信息,那么会造成强噪声干扰,不利于模型训练。而一致性损失只用到了无标签数据结果,因为有标签数据的价值信息已经被交叉熵损失利用了,而无标签数据的伪标签信息还没有被利用。
[0151]
在s7:联合训练的步骤中,将交叉熵损失loss
l
和一致性损失loss
u
加权和做为总损失loss=loss
l
+λloss
u
(其中λ是超参数),不断进行训练,使得loss呈现下降趋势,直到训练轮次达到设定值或者loss呈现平稳趋势。保存其最小损失值时得网络模型。
[0152]
在s8:预测类别的步骤中,将待分类的图像数据输入已训练好的网络模型中进行预测,得到类别概率,将概率值最大的类别作为预测结果。
[0153]
需要说明的是,在利用训练好的模型进行预测时,patch向量是可变的,通过tranformer模型最终更新为待分类的图像中最具有分类代表性的一个patch对应的特征向量;同时,在对模型进行训练时,patch向量也是可学习的,在整个训练过程中不断更新,这一更新不但包括在tranformer模型中的更新,还包括在训练过程中因网络梯度下降而获得的更新。
[0154]
本发明提出了一种基于transformer的半监督网络的图像分类模型的训练方法。本发明具有如下有益效果:
[0155]
1.本发明实施例针对图像分类领域的特殊性,利用注意力机制思想,将transformer模型引入到图像分类任务中,解决了传统深度学习模型提取图像的全局信息困难的问题,有效地关注图像的全局信息,同时更注重图像内容的连续性,从而提高了在图像分类的分类效果;
[0156]
2.本发明实施例通过伪标签预测和consistency regularization的方式,解决了图像分类领域的有标记数据获取困难的问题,仅仅用少量的有标签数据即可完成深度学习训练过程,实现了半监督网络学习,并且有良好学习效果;
[0157]
3.本发明实施例设计了适用于图像数据的数据结构,在transformer模型的基础上增加了图像分块处理、可学习特征向量嵌入及图像位置信息编码操作,实现了transformer模型及自注意力机制进行图像分类中的应用;
[0158]
4.本发明实施例采用基于图像的transformer的模型多次识别无标签数据,预测无标签数据的伪标签,并将现有的预测类别与伪标签进行对比,通过保证二者的一致性来约束网络模型,实现了从大量无标签数据中学习有益的信息;
[0159]
5.本发明实施例将交叉熵损失与一致性损失联合起来对网络模型进行训练,通过交叉熵损失来实现有标签数据对网络模型的约束,通过一致性损失从无标签数据中提取有益的信息,实现了对训练数据的充分利用,在更全面的信息下也提高了网络的收敛速度和图像分类的准确性。
[0160]
实施例四
[0161]
图5为本发明实施例提供的一种计算机设备的结构示意图。如图5所示,该设备包括处理器510和存储器520。处理器510的数量可以是一个或多个,图5中以一个处理器510为例。
[0162]
存储器520作为一种计算机可读存储介质,可用于存储软件程序、计算机可执行程序以及模块,如本发明实施例一、三所述的图像分类方法的程序指令/模块,或实施例二所述的图像分类模型的训练方法的程序指令/模块。
[0163]
相应地,处理器510通过运行存储在存储器520中的软件程序、指令以及模块,实现本发明实施例一、三所述的图像分类方法,或实施例二所述的图像分类模型的训练方法。
[0164]
存储器520可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序;存储数据区可存储根据终端的使用所创建的数据等。此外,存储器520可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他非易失性固态存储器件。在一些实例中,存储器520可进一步包括相对于处理器510远程设置的存储器,这些远程存储器可以通过网络连接至设备/终端/服务器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
[0165]
本领域内的技术人员应明白,本发明的实施例可提供为方法、系统、或计算机程序产品。因此,本发明可采用硬件实施例、软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器和光学存储器等)上实施的计算机程序产品的形式。
[0166]
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1