本申请涉及但不限于人工智能,尤其涉及一种模型训练方法、图像分类方法、装置、设备及存储介质。
背景技术:
1、深度学习广泛应用于工业视觉,其在复杂场景的表现明显优于传统图像处理算法。但是随着训练数据的不断增加,模型在不同数据集上迁移不可避免地会存在遗忘性问题,即在新数据集上训练深度学习分类模型,训练得到的新深度学习分类模型虽能够精确识别新数据特征,遗忘了在旧数据上学习到的知识的问题。
2、目前为解决模型遗忘性问题,主要采用蒸馏方法和模型组合方法。蒸馏往往会牺牲模型的准确率,其最主要的原因是:当旧模型的输出和新模型的输出偏差非常大的时候,通过蒸馏方法强行让他们一致,往往会得到负面的结果。模型组合的方式则会增加推理成本,推理时间变长。
技术实现思路
1、有鉴于此,本申请实施例至少提供一种模型训练方法、图像分类方法、装置、设备及存储介质。
2、本申请实施例的技术方案是这样实现的:
3、第一方面,本申请实施例提供一种模型训练方法,所述方法包括:
4、获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
5、第二方面,本申请实施例提供一种图像分类方法,所述方法包括:
6、获取待分类的图像数据集;通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果;其中,所述图像分类模型是基于上述第一方面所述的模型训练方法进行训练得到的。
7、第三方面,本申请实施例提供一种模型训练装置,所述装置包括:
8、样本获取模块,用于获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;
9、模型训练模块,用于基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
10、第四方面,本申请实施例提供一种图像分类装置,所述装置包括:
11、数据获取模块,用于获取待分类的图像数据集;
12、图像分类模块,用于通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果;其中,所述图像分类模型是基于上述第一方面所述的模型训练方法进行训练得到的。
13、第五方面,本申请实施例提供一种计算机设备,包括存储器和处理器,所述存储器存储有可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述第一方面或第二方面方法中的部分或全部步骤。
14、第六方面,本申请实施例提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现上述第一方面或第二方面方法中的部分或全部步骤。
15、本申请实施例中,在利用原始数据集训练得到第一分类模型的基础上,获取包括原始数据集中的至少一个原始样本的样本数据集对第二分类模型进行训练,在训练过程中通过计算第一分类模型和第二分类模型针对同一样本输出的类别得分的差异得到差异抑制损失,通过在损失函数中增加差异抑制损失,惩罚新旧模型对于同一样本输出的类别得分变化,从而使得第二分类模型在精确识别新数据特征的同时,保持在旧数据上的识别精度。
16、应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,而非限制本公开的技术方案。
1.一种模型训练方法,其特征在于,所述方法包括:
2.根据权利要求1所述的方法,其特征在于,所述第二分类模型至少包括全连接层,所述基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型,包括:
3.根据权利要求2所述的方法,其特征在于,所述基于所述第二类别得分,基于所述目标损失函数确定所述第二分类模型的学习损失值,包括:
4.根据权利要求2所述的方法,其特征在于,所述目标损失函数还包括拟合损失,所述拟合损失用于表征所述第二分类模型的预测类别与样本标签之间的差异;所述第二分类模型还包括所述全连接层之后的归一化层;
5.根据权利要求4所述的方法,其特征在于,所述基于所述第二类别得分和所述第二预测类别,基于所述目标损失函数确定所述第二分类模型的学习损失值,包括:
6.根据权利要求5所述的方法,其特征在于,所述基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失,包括:
7.根据权利要求6所述的方法,其特征在于,所述基于所述变化距离和预设的焦点函数确定所述差异抑制损失,包括:
8.根据权利要求6或7所述的方法,其特征在于,所述预设的焦点函数为第一焦点参数、第二焦点参数和二值函数的线性组合,所述方法还包括:
9.根据权利要求8所述的方法,其特征在于,所述方法还包括:
10.一种图像分类方法,其特征在于,所述方法包括:
11.一种模型训练装置,其特征在于,所述装置包括:
12.一种图像分类装置,其特征在于,所述装置包括:
13.一种计算机设备,包括存储器和处理器,所述存储器存储有可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现权利要求1至9任一项所述方法中的步骤,或执行权利要求10所述方法中的步骤。
14.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行时实现权利要求1至9任一项所述方法中的步骤,或执行权利要求10所述方法中的步骤。