一种学生模型训练方法、装置、设备及介质与流程

文档序号:36318341发布日期:2023-12-08 11:49阅读:62来源:国知局
一种学生模型训练方法与流程

本发明涉及模型训练,尤其涉及一种学生模型训练方法、装置、设备及介质。


背景技术:

1、在深度学习发展过程中,越来越趋向于构建神经网络更深的大型模型,以求达到很好的非线性表达。但由于模型落地应用时,大模型的实时性不佳,资源消耗较大。因而产生了模型蒸馏方法,该方法旨在用一个体量较小的模型去学习一个体量较大的模型的输出,学习其泛化能力。在模型蒸馏领域,原来的较大的效果非常好的模型叫做教师模型,体量较小,用来蒸馏教师模型中的信息的称作学生模型。

2、一般情况下,在对学生模型进行训练时,会通过训练数据所属的原始样本标签类别、以及根据教师模型的输出结果确定的该输出结果对应的类别,对学生模型进行训练,虽然一般情况教师模型效果已经达到理想要求(比如输出结果的准确率大于90%),即使这样,仍存在部分错判的情况,与实际的原始样本标签类别存在偏差,导致学生模型在训练时,也会学习到教师模型错判的部分,从而导致训练完成的学生模型的输出结果不够准确。


技术实现思路

1、本申请实施例提供了一种学生模型训练方法、装置、设备及介质,用以解决现有技术中训练完成的学生模型的输出结果不够准确的问题。

2、第一方面,本申请实施例提供了一种学生模型训练方法,所述方法包括:

3、获取训练集中任一训练数据以及训练数据所属的原始标签;

4、将训练数据输入到教师模型中,获取教师模型的第一预测标签;将训练数据输入到学生模型中,获取学生模型的第二预测标签;

5、若原始标签和第一预测标签不同,则采用根据原始标签和第二预测标签确定的第一目标损失值,对学生模型进行训练。

6、第二方面,本申请实施例还提供了一种学生模型训练装置,所述装置包括:

7、第一获取模块,用于获取训练集中任一训练数据以及训练数据所属的原始标签;

8、第二获取模块,用于将训练数据输入到教师模型中,获取教师模型的第一预测标签;将训练数据输入到学生模型中,获取学生模型的第二预测标签;

9、训练模块,用于若原始标签和第一预测标签不同,则采用根据原始标签和第二预测标签确定的第一目标损失值,对学生模型进行训练。

10、第三方面,本申请实施例还提供了一种电子设备,所述电子设备至少包括处理器和存储器,所述处理器用于执行存储器中存储的计算机程序时实现如上述任一项所述学生模型训练方法的步骤。

11、第四方面,本申请实施例还提供了一种计算机可读存储介质,其存储有计算机程序,所述计算机程序被处理器执行时实现如上述任一项所述学生模型训练方法的步骤。

12、由于在本申请实施例中,获取训练集中任一训练数据以及训练数据所属的原始标签;将训练数据输入到教师模型中,获取教师模型的第一预测标签;将训练数据输入到学生模型中,获取学生模型的第二预测标签;若原始标签和第一预测标签不同,则采用根据原始标签和第二预测标签确定的目标损失值,对学生模型进行训练。由于当训练数据的原始标签和教师模型的第一预测标签不同时,即当教师模型出现错判时,根据原始标签和第二预测标签对学生模型进行训练,可以减小教师模型错误的输出结果对学生模型带来的影响,提高了学生模型的准确性。



技术特征:

1.一种学生模型训练方法,其特征在于,所述方法包括:

2.根据权利要求1所述的方法,其特征在于,所述方法还包括:

3.根据权利要求1所述的方法,其特征在于,所述采用根据所述原始标签和所述第二预测标签确定的第一目标损失值,对所述学生模型进行训练,包括:

4.根据权利要求2所述的方法,其特征在于,所述采用根据所述原始标签、所述第一预测标签和所述第二预测标签确定的第二目标损失值,对所述学生模型进行训练,包括:

5.根据权利要求3或4所述的方法,其特征在于,所述损失函数满足如下公式:total_loss=α×loss1+β×multiply(f,loss2),其中total_loss表示损失函数,α和β表示权重,multiply()表示标量积,f表示标识矩阵,loss1表示与所述原始标签和所述第二预测标签有关的损失值,loss2表示与所述第一预测标签和所述第二预测标签有关的损失值。

6.根据权利要求5所述的方法,其特征在于,所述标识矩阵f在所述原始标签和所述第一预测标签不同时取值为0,在所述原始标签和所述第一预测标签相同时取值为1。

7.根据权利要求5所述的方法,其特征在于,所述loss1满足如下公式:loss1=ce(s′,y),其中ce()表示交叉熵函数,s′表示所述第二预测标签,y表示所述原始标签;

8.一种学生模型训练装置,其特征在于,所述装置包括:

9.一种电子设备,其特征在于,所述电子设备至少包括处理器和存储器,所述处理器用于执行存储器中存储的计算机程序时实现如权利要求1-7任一项所述的学生模型训练方法的步骤。

10.一种计算机存储介质,其特征在于,其存储有可由电子设备执行的计算机程序,当所述程序在所述电子设备上运行时,使得所述电子设备执行权利要求1-7任一项所述的学生模型训练方法的步骤。


技术总结
本申请实施例提供了一种学生模型训练方法、装置、设备及介质,用以解决现有技术中训练完成的学生模型的输出结果不够准确的问题。获取训练集中任一训练数据以及训练数据所属的原始标签;将训练数据输入到教师模型中,获取教师模型的第一预测标签;将训练数据输入到学生模型中,获取学生模型的第二预测标签;若原始标签和第一预测标签不同,则采用根据原始标签和第二预测标签确定的目标损失值,对学生模型进行训练。由于当训练数据的原始标签和教师模型的第一预测标签不同时,即当教师模型出现错判时,根据原始标签和第二预测标签对学生模型进行训练,可以减小教师模型错误的输出结果对学生模型带来的影响,提高了学生模型的准确性。

技术研发人员:施丽佳,蔡锋,王兴科,王有元
受保护的技术使用者:中国电信股份有限公司技术创新中心
技术研发日:
技术公布日:2024/1/15
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1