一种基于增量学习的目标检测方法和装置

文档序号:26141458发布日期:2021-08-03 14:26阅读:118来源:国知局
一种基于增量学习的目标检测方法和装置

本发明属于目标识别领域,具体涉及一种基于增量学习的目标检测方法和装置。



背景技术:

传统的目标检测任务中,都是预先定义好要检测的物体类别并收集好相应数据后,对模型进行训练的。利用预先定义好的数据集训练好模型,并且部署到摄像头、卫星、无人机等终端上时,一旦遇到之前数据集中不存在的物体类别时,模型对新类别的检测效果就会很差。收集到原有数据集类别中的新样本时,模型也无法及时利用新样本进行更新。这些问题导致检测模型鲁棒性不高,无法成为自动化更高的系统。

针对上述问题,传统的解决方案是每当收集到新类别的数据时,就与旧类别的数据集合并为一个更大的数据集,对模型进行重新训练与部署。但是这样会导致不断重新对模型架构进行设计,训练时系统的存储要求更高,训练更加耗时,模型的部署周期更长,难度加大。

基于增量学习的目标检测系统可以不断利用新获取的样本对自身进行更新,而不需要重新包括旧的数据集,也不需要重新设计新的架构,这样可以有效减轻系统的存储负担与训练时间,更快捷地进行部署。利用增量学习方式获得的目标检测模型,在新获取到的类别上具有良好的检测效果,同时还能在原有类别的数据集上保留检测能力。因此,目标检测的增量学习研究已经成为新的研究热点。

但是,应用增量学习方法到传统的目标检测模型上时,会遇到灾难遗忘的问题,即在旧类别的数据上训练好的模型,利用新类别的样本对模型的参数进行微调时,模型在旧类别的检测效果就会急剧下降。

目前,针对灾难遗忘的研究多集中于物体分类,针对目标检测问题却鲜有研究。常用的具有较好检测效果的目标检测模型通常包括两个阶段,第一阶段生成目标候选区域(rpn网络),第二阶段对目标候选区域进行进一步地修正。

因此,如何针对具有生成目标候选区域和对目标候选区域进行进一步地修正这两阶段的目标检测模型引入增量学习方法,使目标检测模型不借助于旧类别的数据,仅利用新类别数据在新类别上获得好的检测效果,同时保留针对旧类别的检测能力,是当前亟待解决的问题。



技术实现要素:

鉴于上述,本发明的目的是提供一种基于增量学习的目标检测方法和装置,在不借助于旧类别样本的情况下,仅利用新类别样本进行目标检测模型的训练,以获得在新类别上具有良好检测效果且同时保留旧类别检测能力的目标检测模型。

第一方面,本发明实施例提供了一种基于增量学习的目标检测方法,包括以下步骤:

利用旧类别样本图像对目标检测网络进行训练得到原始模型;

在原始模型的输出层增加新类别样本图像的新类别检测分支,并初始化新类别检测分支参数,得到增量学习模型;

利用新类别样本图像训练增量学习模型,训练时,以新类别样本图像在增量学习模型的旧类别检测分支输出与在原始模型的预测输出的逼近误差、新类别样本图像在增量学习模型的新类别检测分支的检测误差构建损失函数,来优化训练增量学习模型参数,得到参数确定的目标检测模型;

利用目标检测模型对测试样本图像进行目标检测。

一个实施例中,所述目标检测网络采用fpn的网络,包括特征提取模块、rpn模块、cls模块,其中,特征提取模块用于提取输入样本图像的特征图,fpn模块用于根据输入的特征图生成感兴趣区域并进行分类输出和回归输出,cls模块用于对输入的感兴趣区域进一步修正,并进修正分类输出和回归输出;

在构建增量学习模型时,分别在fpn模块和cls模块的输出层增加新类别检测分支,同时保留旧类别检测分支。

一个实施例中,在训练增量学习模型时,首先进行增量学习模型的预训练阶段,具体包括:固定特征特征提取模块、rpn模块和cls模块的的旧类别检测分支不变,利用新类别样本图像优化rpn模块和cls模块的新类别检测分支直至收敛。

一个实施例中,在进行增量学习模型的初始化阶段,获得新类别样本图像在原始模型的预测输出,以构建损失函数,其中,预测输出包括分类输出和回归输出。

一个实施例中,训练增量学习模型时,依据预训练阶段确定的增量学习模型,获得新类别样本图像分别在rpn模块和cls模块的旧类别检测分支的预测输出,获得新类别样本图像分别在rpn模块和cls模块的新类别检测分支的预测输出,其中,包括分类输出和回归输出;

依据新类别样本图像在原始模型的预测输出、在增量学习模型的rpn模块和cls模块的旧类别检测分支的预测输出的逼近误差构建旧类别损失;

依据新类别样本图像在增量学习模型的rpn模块和cls模块的旧类别检测分支的预测输出与标签的检测误差构建新类别损失;

综合旧类别损失和新类别损失构建的总损失函数来优化增量学习模型的网络参数。

一个实施例中,构建的总损失函数loss为:

loss=λolossold+lossnew

其中,λo为平衡超参数,lossold为旧类别损失,具体为:

lossnew为新类别损失,具体为:

lossnew=sigmoid(y′n_rpn,yn)+smoothl1(b′n_rpn,bn)+softmax(y′n_cls,yn)+smoothl1(b′n_cls,bn)

其中,y′o_rpn、b′o_rpn分别表示原始模型中rpn模块的分类输出和回归输出,分别表示增量学习模型中rpn模块的旧类别检测分支的分类输出和回归输出,y′o_cls、b′o_cls分别表示原始模型中cls模块的分类输出和回归输出,分别表示增量学习模型中cls模块的旧类别检测分支的分类输出和回归输出;y′n_rpn、b′n_rpn分别表示增量学习模型中rpn模块的新类别检测分支的分类输出和回归输出,y′n_cls、b′n_cls分别表示增量学习模型中cls模块的新类别检测分支的分类输出和回归输出,yn、bn分别表示新类别样本图像的分类标签和回归标签;在lossold中,rpn模块的回归输出和分类输出均采用smoothl1误差函数;cls模块的回归输出采用smoothl1误差函数,分类输出采用基于知识蒸馏的交叉熵损失函数;在lossnew中,rpn模块和cls模块的回归输出均采用smoothl1损失函数,rpn模块的分类输出采用基于sigmoid的交叉熵损失函数,cls模块的分类输出采用基于softmax的交叉熵损失函数。

一个实施例中,在训练增量学习模型时,依据损失函数采用随机梯度下降算法,对增量学习模型的网络参数进行更新。

第二方面,本发明实施例提供了一种基于增量学习的目标检测装置,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上执行的计算机程序,所述处理器执行所述计算机程序时实现第一方面所述的基于增量学习的目标检测方法的步骤。

实施例提供的上述技术方案具有的有益效果至少包括:对目标检测网络进行增量学习,获得对新类别的检测能力,同时保留旧类别的检测能力,有效解决了灾难遗忘问题,进而提高目标检测模型的鲁棒性,提高检测精度。另一方面,基于fpn的网络结构在多个特征尺度上对目标进行检测,能够有效检测尺度变化较大的目标,因此获得了相比于其他增量目标检测模型更好地检测效果。

附图说明

为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图做简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动前提下,还可以根据这些附图获得其他附图。

图1是一实施例中增量学习模型的结构示意及训练过程图;

图2是一实施例中增量学习模型中rpn模块增加新类别检测分支示意图;

图3是一实施例中增量学习模型中cls模块增加新类别检测分支示意图。

具体实施方式

为使本发明的目的、技术方案及优点更加清楚明白,以下结合附图及实施例对本发明进行进一步的详细说明。应当理解,此处所描述的具体实施方式仅仅用以解释本发明,并不限定本发明的保护范围。

为了解决因为灾难遗忘问题导致的目标检测模型的鲁棒性低,进而影响检测精度的问题。实施例提供了一种基于增量学习的目标检测方法和装置,在不借助于旧类别样本的情况下,仅利用新类别样本进行目标检测模型的训练,以获得在新类别上具有良好检测效果且同时保留旧类别检测能力的目标检测模型。

实施例提供的基于增量学习的目标检测方法,包括以下步骤:

步骤1,利用旧类别样本图像对目标检测网络进行训练得到原始模型。

实施例中,提供的目标检测网络采用featurepyramidnetwork(fpn)特征金字塔网络,如图1所示包括特征提取模块、regionproposalnetwork(rpn)区域生成网络模块、classification(cls)分类回归模块,其中,特征提取模块作为共享模块,其参数表示为θs,用于提取输入样本图像的特征图。在特征提取模块中,如图1所示,包括4层卷积层,每层卷积层进一步通过卷积操作和上采样操作与低层级的卷积层相加和,之后再通过卷积操作送入各个层级共享的rpn模块中;rpn模块如图2所示,包括由全连接层构成的rpn隐藏层,以及针对旧类别的分类回归输出分支和针对新类别的分类回归输出分支,分别产生新老类别的候选框,经过对感兴趣区域roi池化操作后送入后续cls模块;cls模块用于对输入的感兴趣区域进一步修正,如图3所示,包括两层全连接层和新旧类别的分类回归输出,其中分类输出预测感兴趣区域所属类别的概率、回归输出预测感兴趣区域具体的位置坐标。

利用旧类别样本图像对目标检测网络进行训练,直到网络收敛,保存训练好的模型参数得到原始模型,其中,模型参数包括特征提取模块的参数θs,rpn模块中旧类别检测分支的参数θo_prn,cls模块中旧类别检测分支的参数θo_cls。

步骤2,在原始模型的输出层增加新类别样本图像的新类别检测分支,并初始化新类别检测分支参数,得到增量学习模型。

实施例中,当扩展网络用于检测新类别时,在原始模型的rpn模块和cls模块中添加用于检测新类别的新类别检测分支。如图2所示,在rpn模块的输出层中,保留针对旧类别的旧类别检测分支,其参数表示为θo_prn,同时添加针对新类别的新类别检测分支,其参数表示为θn_rpn。如图3所示,即在cls模块的输出层中,保留针对旧类别的检旧类别测分支,其参数表示为θo_cls,同时添加针对新类别的新类别检测分支,其参数表示为θn_cls。

在原始模型的输出层增加新类别样本图像的新类别检测分支后,需要初始化新类别检测分支参数,实施例中随机初始化θn_rpn,θn_cls。

步骤3,利用新类别样本图像训练增量学习模型。

实施例中,新类别样本图像表示为xn,yn,bn,其中,yn,bn标注的新类别样本图像xn的分类标签和位置坐标。在训练增量学习模型之前,利用原始模型,获得新类别样本图像在原始模型的预测输出,其中预测输出包括分类输出和回归输出。

具体地,利用原始模型,获得新类别样本图像在原始模型的rpn模块的预测输出,即y′o_rpn,b′o_rpn,roisprevious,其中,y′o_rpn,b′o_rpn分别是rpn模块输出层中的分类输出与回归输出,roisprevious=rpn模块(xn,yn,bn,θs,θo_rpn)为原始模型在新类别样本上获得的感兴趣区域。

利用原始模型,获得新类别样本图像在原始模型的cls模块的预测输出,即y′o_cls,b′o_cls=cls模块(roisprevious,yn,bn,θs,θo_cls),其中y′o_cls,b′o_cls分别是cls模块输出层中的分类输出与回归输出。

实施例中,训练增量学习模型时,首先进行增量学习模型的预训练阶段,具体包括:固定特征特征提取模块、rpn模块和cls模块的旧类别检测分支的参数θs,θo_rpn,θo_cls不变,利用新类别样本图像优化rpn模块和cls模块的新类别检测分支的参数θn_rpn,θn_cls直至收敛。

训练增量学习模型时,依据预训练阶段确定的增量学习模型,获得新类别样本图像分别在rpn模块和cls模块的旧类别检测分支的预测输出,获得新类别样本图像分别在rpn模块和cls模块的新类别检测分支的预测输出,其中,包括分类输出和回归输出。

具体地,利用增量学习模型,获得新类别样本图像在rpn模块中旧类别分支上的输出,即roiscurrent_old,,其中代表旧类别分支的分类输出与回归输出,roiscurrent_old=rpn模块(xn,yn,bn,)代表在旧类别分支上产生的感兴趣区域;

利用增量学习模型,获得新类别样本图像在cls模块中旧类别分支上的输出,即模块(roisprevious,yn,bn,),其中分别代表旧类别目标的分类输出与回归输出。

利用增量学习模型,获得新类别样本图像在rpn模块中,新类别分支上的输出,即y′n_rpn,b′n_rpn,roiscurrent_new,其中y′n_rpn,b′n_rpn代表新类别分支的分类输出与回归输出,roiscurrent_new=rpn模块(xn,yn,bn,θn_cls)代表在新类别分支上产生的感兴趣区域。

利用增量学习模型,获得新类别样本图像在cls模块中,新类别分支上的输出,即y′n_cls,b′n_cls,其中y′n_cls,b′n_cls=cls模块(roiscurrent_new,yn,bn,θn_cls)分别代表新类别目标的分类输出与回归输出。

基于以上的分类输出和回归输出构建新类别损失和旧类别损失以得到总损失,利用随机梯度下降算法对模型参数进行更新。即:其中λo为用于平衡新旧任务的超参。

实施例中,采用的总损失函数loss为:

loss=λolossold+lossnew

其中,λo为平衡超参数,lossold为旧类别损失,具体为:

lossnew为新类别损失,具体为:

lossnew=sigmoid(y′n_rpn,yn)+smoothl1(b′n_rpn,bn)+softmax(y′n_cls,yn)+smoothl1(b′n_cls,bn)

其中,在利用新类别样本训练,使训练模型在旧类别分支上的输出逼近保存模型输出的损失函数lossold中,rpn模块的回归输出和分类输出均采用smoothl1误差函数;cls模块的回归输出采用smoothl1误差函数,分类输出采用基于知识蒸馏的交叉熵损失函数;采用基于知识蒸馏的损失函数,能够更好地捕捉训练模型和保存模型输出值中较小的部分,使得训练模型的输出更好地逼近保存模型的输出。利用新类别样本在训练模型的新类别分支上训练模型针对新类别目标的检测能力的损失函数lossnew中,rpn模块和cls模块的回归输出均采用smoothl1损失函数,rpn模块的分类输出采用基于sigmoid的交叉熵损失函数,cls模块的分类输出采用基于softmax的交叉熵损失函数。

针对每一个新类别样本图像都采用上述训练步骤进行训练,直到增量学习模型达到收敛,获得最后的目标检测模型。

步骤4,利用目标检测模型对测试样本图像进行目标检测。

当训练获得目标检测模型之后,将测试样本图像输入至目标检测模型,经计算获得目标检测结果。

实施例还提供了一种基于增量学习的目标检测装置,包括存储器、处理器以及存储在存储器中并可在所述处理器上执行的计算机程序,处理器执行所述计算机程序时实现基于增量学习的目标检测方法步骤,具体包括:

步骤1,利用旧类别样本图像对目标检测网络进行训练得到原始模型;

步骤2,在原始模型的输出层增加新类别样本图像的新类别检测分支,并初始化新类别检测分支参数,得到增量学习模型;

步骤3,利用新类别样本图像训练增量学习模型;

步骤4,利用目标检测模型对测试样本图像进行目标检测。

实际应用中,计算机存储器可以为在近端的易失性存储器,如ram,还可以是非易失性存储器,如rom,flash,软盘,机械硬盘等,还可以是远端的存储云。计算机处理器可以为中央处理器(cpu)、微处理器(mpu)、数字信号处理器(dsp)、或现场可编程门阵列(fpga),即可以通过这些处理器实现基于增量学习的目标检测方法的步骤。

以上所述的具体实施方式对本发明的技术方案和有益效果进行了详细说明,应理解的是以上所述仅为本发明的最优选实施例,并不用于限制本发明,凡在本发明的原则范围内所做的任何修改、补充和等同替换等,均应包含在本发明的保护范围之内。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1