一种图像分类残差神经网络训练实现方法与流程

文档序号:34361398发布日期:2023-06-04 17:16阅读:47来源:国知局
一种图像分类残差神经网络训练实现方法与流程

本发明提出的是一种图像分类残差神经网络训练实现方法,属于图像分类及神经网络。


背景技术:

1、随着人工智能的发展与进步,深度神经网络算法获得了广泛的应用,在各种领域取得显著的成果,并且被广泛应用于人工智能、自动控制、机器人、统计学等各个领域的信息处理中;以图像分类为例,如vgg网络、resnet网络、resnext网络等经典图像分类网络,对特定的图像分类任务使用特定的数据集对图像分类网络进行训练,在普通的训练方式下,分类精度的高低,取决于图像分类网络自身结构。

2、为追求更好地分类性能,通过增加网络深度可以较好提高网络性能;虽然深层网络比浅层网络性能更优,但当网络深度过度增加时,会引发神经网络退化以及梯度消失等问题;2015年,由kaiming he等人提出的resnet网络,通过堆叠残差结构,加深网络深度的同时,梯度消失的难题也得到了缓解,进一步提升了神经网络在图像分类任务上的表现;目前,已有更大的残差神经网络(resnet1000)被提出,提高的性能极其有限,通常以海量的计算和存储为代价;对于实时性要求较高且资源受限的移动端设备或嵌入式设备,想实现大型神经网络的应用异常困难;因此,如何能够在不增加网络深度条件下,提升resnet网络精度就显得非常重要。


技术实现思路

1、本发明提出的是一种图像分类残差神经网络训练实现方法,其目的旨在不增加网络深度条件下提升resnet网络的整体准确率。

2、本发明的技术解决方案:一种图像分类残差神经网络训练实现方法,该方法包括:

3、1、在resnet网络内部引入若干个分类器;所述resnet网络包括第一卷积层、最大池化层、第一模块、第二模块、第三模块、第四模块、第一自适应平均池化层、第四全连接层、第四概率转换函数;

4、2、使用图像数据集对若干个分类器进行联合训练;

5、3、通过真实标签(label)对resnet网络和若干个分类器进行监督以及通过resnet网络预测结果对resnet网络进行监督,得到联合训练的总损失;

6、4、在总损失作用下,更新resnet网络权重。

7、进一步地,所述在resnet网络内部引入若干个分类器,具体包括:在第一模块的基础上依次增加第一浅层模块、第一全连接层、第一概率转换函数作为第一分类器;在第二模块的基础上依次增加第二浅层模块、第二全连接层、第二概率转换函数作为第二分类器;在第三模块的基础上依次增加第三浅层模块、第三全连接层、第三概率转换函数作为第三分类器。

8、进一步地,所述使用图像数据集对若干个分类器进行联合训练,具体包括:

9、2-1、将图像数据集的训练集均分为a部分训练集和b部分训练集;

10、2-2、b部分训练集中的图像首先经过第一卷积层、最大池化层、第一模块获得第一模块的输出特征图;

11、2-3、将第一模块的输出特征图作为第二模块的输入特征图和第一浅层模块的输入特征图;

12、2-4、第一浅层模块的输入特征图依次经过第一浅层模块、第一全连接层、第一概率转换函数后获得第一分类器的输出结果,第一分类器的输出结果即为第一分类结果;

13、2-5、第二模块的输入特征图经过第二模块后获得第二模块的输出特征图;

14、2-6、将第二模块的输出特征图作为第三模块的输入特征图和第二浅层模块的输入特征图;

15、2-7、第二浅层模块的输入特征图依次经过第二浅层模块、第二全连接层、第二概率转换函数后获得第二分类器的输出结果,第二分类器的输出结果即为第二分类结果;

16、2-8、第三模块的输入特征图经过第三模块后获得第三模块的输出特征图;

17、2-9、将第三模块的输出特征图作为第四模块的输入特征图和第三浅层模块的输入特征图;

18、2-10、第三浅层模块的输入特征图依次经过第三浅层模块、第三全连接层、第三概率转换函数后获得第三分类器的输出结果,第三分类器的输出结果即为第三分类结果;

19、2-11、第四模块的输入特征图经过第四模块后获得第四模块的输出特征图;

20、2-12、第四模块的输出特征图再依次经过第一自适应平均池化层、第四全连接层、第四概率转换函数后获得resnet网络的输出结果,resnet网络的输出结果即为第四分类结果。

21、进一步地,所述使用图像数据集对若干个分类器进行联合训练,具体还包括:

22、2-13、a部分训练集中的图像依次经过第一卷积层、最大池化层、第一模块、第二模块、第三模块、第四模块、第一自适应平均池化层、第四全连接层、第四概率转换函数获得resnet网络的预测结果;resnet网络的预测结果即resnet网络预测的分类结果。

23、进一步地,所述图像数据集包括训练集和测试集。

24、进一步地,所述第一模块、第二模块、第三模块、第四模块各自均包括若干卷积层(conv);第一模块、第二模块、第三模块、第四模块中每个模块的输入特征图进入相应模块后均通过相应模块内部的卷积层进行处理获得相应模块的输出特征图。

25、进一步地,所述第一浅层模块、第二浅层模块、第三浅层模块各自均包括若干卷积层和一个自适应平均池化层;所述第一浅层模块、第二浅层模块、第三浅层模块中每个浅层模块通过各自内部的卷积层和自适应平均池化层对经过的图像进行处理。

26、进一步地,所述第一概率转换函数、第二概率转换函数、第三概率转换函数均为softmax函数;所述第四概率转换函数在softmax函数中额外引入参数temp,如式(1)所示:

27、

28、式(1)中,zj指第j类别的输出结果,若分类类别总数为m,则j的取值范围为[1,m];zn为图像经过resnet网络之后第n类别的输出结果,pron为resnet网络输出第n类别的输出概率,temp是额外引入的参数,temp的取值要大于0;所述类别指图像数据集的m个类别。

29、进一步地,所述通过真实标签对resnet网络和若干个分类器进行监督以及通过resnet网络预测结果对resnet网络进行监督,得到联合训练的总损失,具体包括:

30、3-1、将第一分类结果与真实标签进行比较得到第一误差损失值,将第二分类结果与真实标签进行比较得到第二误差损失值,将第三分类结果与真实标签进行比较得到第三误差损失值,将第四分类结果与真实标签进行比较得到第四误差损失值;

31、3-2、将第四分类结果与resnet网络的预测结果进行比较得到第五误差损失值;

32、3-3、利用公式(2)得到联合训练的总损失;所述公式(2)具体如下:

33、

34、公式(2)中loss为总损失,loss1为第一误差损失值,loss2为第二误差损失值,loss3为第三误差损失值,loss4为第四误差损失值,lossn为第五误差损失值;qtn表示b部分训练集中的图像训练resnet网络时第四概率转换函数的输出,此处第四概率转换函数中temp=2;公式(2)中n表示对应于不同系列的resnet网络;tn表示a部分训练集中的图像训练resnet网络时第四概率转换函数的输出,此处第四概率转换函数中temp=2;cr表示交叉熵损失函数;kl表示kl散度;qi中i=1,2,3;i=1时qi表示b部分训练集中的图像训练resnet网络时第一分类器中第一概率转换函数的输出,i=2时qi表示b部分训练集中的图像训练resnet网络时第二分类器中第二概率转换函数的输出,i=3时qi表示b部分训练集中的图像训练resnet网络时第三分类器中第三概率转换函数的输出,q表示b部分训练集中的图像训练resnet网络时第四概率转换函数中temp=1时的输出;y表示b部分训练集中的图像的真实标签。

35、进一步地,所述在总损失作用下,更新resnet网络权重,具体包括:

36、4-1、利用图像数据集对resnet网络进行若干批次训练;每个批次训练均包括m轮训练,每个批次训练中具体训练的轮数m=训练集中图像的个数/每轮训练中处理的图像个数,每轮都会训练batchsize张图像,batchsize表示每轮训练中处理的图像个数;

37、4-2、每个批次训练中每轮训练时都利用总损失更新一次resnet网络权重,每个批次训练结束之后都会用图像数据集中的测试集进行测试,得到每个批次相应的测试准确率,直到若干批次训练结束,取所有批次中最高的测试准确率作为resnet网络最后的准确率。

38、本发明的有益效果:

39、1)本发明通过在resnet网络框架内部额外引入模块结构,构建多个分类器联合训练,相较于原训练方式,能够辅助提高resnet网络的整体准确率;

40、2)本发明通过仿真表明,在cifar100数据集上,整体准确率得到一定的提升;resnet18网络的准确率由77.6%提高到79.0%,提升幅度为1.4%;resnet50网络的准确率由77.1%提高到79.5%,提升幅度为2.4%;

41、3)本发明具有普适性,可广泛应用于基于resnet网络的图像分类、图像去噪、图像分割及图像超分辨率等任务。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1