一种基于元模块融合增量学习的图像分类方法

文档序号:31959419发布日期:2022-10-28 23:04阅读:107来源:国知局
一种基于元模块融合增量学习的图像分类方法

1.本发明涉及图像识别技术领域,通过有限网络模型扩张与分类器重训练,实现少量样本数据驱动的增量学习的图像识别。


背景技术:

2.近年来,神经网络模型已经在很多机器学习领域取得了巨大成功,如图像识别、目标检测、自然语言处理、姿态估计等。但目前神经网络模型依然有很多不足,灾难性遗忘即是一个亟待解决的重要问题。增量学习能够解决网络学习中灾难性遗忘问题
3.不遗忘学习方法(lwf)是首次把知识蒸馏的思路应用到增量学习中,仅利用现有新样本就可以在学习新任务的同时,对旧任务保持记忆。但由于完全没有使用旧的类别样本,随着类的不断增加,整体准确率也急剧下降。增量分类器和特征重表达学习(icarl)是最经典的基于样本回放的增量学习模型,它在算法层面借鉴保留了前例中的蒸馏技术。同时采用特征提取器和分类器分离方法,并在固定内存规模的情况下,通过筛选出具有代表性的旧样本和新样本组成新的训练集,实现增量学习,因此较前者在准确率上有所提升,代价便是增大了内存容量。基于空间蒸馏损失的方法(podnet),改进了特征的蒸馏方法,并将分类器与代理向量相结合,改进了分类器的损失形式,取得了不错的效果。小样本增量学习方法(fscil)沿用了特征提取器和分类器分开的思路。用拓扑关系来模拟特征空间上的关系,将特征提取后的特征空间上的位置做为神经气体网络的输入,以此输入分类器来分类。该方法在解决小样本增量学习问题上取得巨大成功。
4.最新的动态扩展重表达方法(der)通过模型结构扩展的方式为每一个新任务训练一个特征提取器。在每个增量任务时候对特征进行扩展,都将上一个阶段提取出的特征进行固定,并且运用新的特征提取器再对特征进行提取。这就使得模型在保持旧任务知识的同时可以获得适用于新增量任务的新知识。但由于该方法在训练网络时把每次的增量类别都整合到同一个网络中去,这就造成了当网络长期处于增量阶段时,类别间分类精度的相互干扰,网络增量越多,准确率下降越快。


技术实现要素:

5.为了克服现有技术的不足,本发明提供一种基于元模块融合增量学习的图像分类方法,能够使网络在长期增量阶段时,延缓精度下降,能有效的改善动态扩展重表达方法中存在的弊端。大量实验表明,本发明显著改进了识别精度。基于vgg网络和resnet,在cub、cifar-100和miniimagenet,
6.本发明解决其技术问题所采用的技术方案包括以下步骤:
7.步骤一、获取待分类图片,形成待分类图片集,设定每次增量学习的新添类别数量为k,增量学习的增长步数为t,最大增长步数为l
max

8.其中,数据集d代表总的图像数据集,n代表图像类别,dn代表第n类图像的数据集,样本总数为s,(xs,ys)代表样本输入以及对应
标签,k代表每次增量学习的新添类别数量,t代表增量学习的增长步数;
9.步骤二、依次将步骤一中分类图片集输入至增量分类神经网络vgg网络或resnet,训练元模型;
10.步骤三、在元模型的分类器后添加修正器,利用保留数据训练修正器。
11.在每次增量学习之后扩展修正器的输出,并重训练修正器;
12.步骤3.1、训练修正器;对步骤二中所有训练过的数据集进行抽样得到保留数据集dr,在网络的分类层后添加一层全连接层fc做为修正器c,利用保留数据训练该fc层参数,训练方式采用交叉熵损失函数训练w个epoch,学习率从λ开始;
13.步骤3.2、如果训练步数t小于最大增长步数l
max
,则返回步骤二进行增量学习的元模块训练,即步骤二中的步骤2.2,如果训练步数大于等于最大增长步数l
max
,则完成元模型mi的训练;
14.步骤四、训练门控选择层,实现对元模型的融合;
15.步骤4.1、重复步骤二,直到训练完所有数据得到多个元模块每个元模块mi中包含l
max
次增量学习,总的增量学习次数是t,得到的元模块数量是
16.步骤4.2、维持已训练网络模型特征提取层参数不变,在θu对应的特征提取层后添加门控分类层g;
17.步骤4.3、在总的数据集中抽取部分样本组成新的保留数据集dr训练新添的门控分类层g,训练损失函数为pi=η-mi,其中η表示输出向量,mi表示所有输出向量的平均,n表示训练元门控分类层g时输入样本数量;
18.步骤4.4、在测试阶段,输入图像依据门控分类层g的输出结果,选择对应的元模块,经过元模块的分类层,确定对应的具体图像类别。
19.所述步骤二中,增量分类神经网络训练元模型的具体步骤如下:
20.步骤2.1、训练初始网络;选择步骤一的数据输入初始神经网络vgg或resnet中,神经网络采用随机初始化,采用交叉熵损失函数训练w个epoch,学习率从λ开始,得到神经网络特征提取层的参数θf=[θu,θs]和分类层参数θc;
[0021]
步骤2.2、训练增量网络;保持初始网络特征提取层的后1/2层结构不变,θs是神经网络特征提取层的后1/2层的参数,选择新的增量类别扩展初始网络结构的前部分,即θu对应的特征提取层,利用新增数据训练新扩展层的参数,训练方式采用交叉熵损失函数训练w个epoch,学习率从λ开始。
[0022]
所述epoch的w取值为小于等于100。
[0023]
所述学习率λ取值为0.01。
[0024]
本发明的有益效果在于通过提供一种基于元模块融合增量学习的图像分类方法,解决了现有的增量学习中图像分类精度下降过快的算法问题。通过将多个元模型融合的方式实现增量学习能有效的减少参数增长速度,延缓灾难性遗忘问题,保持分类精度在可靠范围内。相比与现有的方法能够在内存规模,网络模型规模,分类精度上达到一个较合适的平衡点。
[0025]
相比于精度相当的算法例如,动态扩展重表达方法,能在内存规模和网络模型上形成优势,相比与内存规模相当的算法例如,增量分类器和特征重表达学习,能在精度和计算速断上形成优势。总的来说,本发明能在保持精度较高的情况下,实现在内存规模,网络增长规模和计算速度等多个上面的优势。
附图说明
[0026]
图1为本发明总体的算法实现步骤图。
[0027]
图2为元模型网络融合训练过程示意图。
具体实施方式
[0028]
下面结合附图和实施例对本发明进一步说明。
[0029]
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合附图对本发明的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
[0030]
目前已经成熟的增量学习分类算法无法兼顾分类准确率,存储成本和计算成本这三个要求,通常分类准确率高的算法的存储成本和计算成本往往高,而分类准确率低的算法成本往往也会随之降低;但在实际需求中,我们追求的往往是分类准确率高、并且存储成本和计算成本还低,对此本技术提出如下技术方案。为便于对本实施例进行理解,首先对本发明实施例所公开的基于元模块融合增量学习的图像分类算法进行详细介绍。
[0031]
图1所展示的总的增量学习图像分类算法的流程示意图,该算法包括如下步骤:
[0032]
第一步是元模型训练:将待学习的图像依次输入网络,再网络规模不断增长中学习新的特征,保留旧的特征;
[0033]
1)cifar-100数据集包含100个不同图像类别,即n=100;将其分成5个大类,即i=5;每大类中包含了20类小类。每小类中含有该类数据的数量是相同的,均为500张32
×
32的彩色图像,即s=500。
[0034]
2)选择其中的一个大类作为每个元模型训练的数据,将大类中包含的20类小类分为4组,即l
max
=4;每组5类小类,即k=5。通过依次增量输入的方式在vgg网络和resnet上训练元模型;
[0035]
具体参数选择如下:
[0036]
梯度下降算法选择:sgd
[0037]
批量大小:128
[0038]
学习率:0.01。并且学习率在30、60、和90个epoch后的以0.1的速率开始衰减正则系数λ:0.75
[0039]
3)第一组的训练过程如上,依次第二,第三,第四,第五组都是直接迁移使用第一组训练的特征提取层的前段参数。在本例中使用了vgg-16特征提取的后6层的卷积层作为共享参数;
[0040]
4)将剩下的四个大类依次按照步骤2)和步骤3)中的描述训练生成元模型;
[0041]
第二步训练修正器:在已训练的元模块的分类器后添加一层全连接层(fc),随机初始化全连接层的网络参数并保持特征提取器和分类器中参数不变。对每一类数据保留部分样本,这里选择每类数据集训练样本数的1/10。利用交叉熵损失函数作为损失函数,并保持特征提取器和分类器中参数不变;
[0042]
具体参数如下:梯度下降算法选择sgd,批量大小为256,学习率为0.01;
[0043]
第三步是元模型融合:将vgg-16网络训练的五个元模型融合,整合成一个完全的增量学习网络,图2为元模型融合训练过程,其中m
old
表示已有的元模块,m
new
表示新训练得到的元模块,g表示门控选择层,该层能将新旧元模块关联起来,在测试阶段以便选择合适的元模块。
[0044]
3.1)每个小类中利用特征分布最近中心原则选择50张图片作为保留图像;
[0045]
3.2)设计一个具有三个特征提取层和一个分类层的旁支网络,并连接到每个元模型前7层的特征提取层后,利用空间特征分布距离作为损失函数,训练门控分类层;
[0046]
3.3)门控分类层的作用在于融合各个独立的元模块,在测试阶段,输入图像能根据其分类结果选择对应的元模块,再根据元模块的分类结果确定其类别;
[0047]
以上所述实施例,仅为本发明的具体实施方式,用以说明本发明的技术方案,而非对其限制,本发明的保护范围并不局限于此,尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,其依然可以对前述实施例所记载的技术方案进行修改或可轻易想到变化,或者对其中部分技术特征进行等同替换;而这些修改、变化或者替换,并不使相应技术方案的本质脱离本发明实施例技术方案的精神和范围,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应所述以权利要求的保护范围为准。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1