一种基于多助教模型知识蒸馏训练的文本分类方法

文档序号:30581960发布日期:2022-06-29 12:46阅读:来源:国知局

技术特征:
1.一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:按如下步骤a至步骤d,训练获得预设各目标类型文本分类模型;然后应用预设各目标类型文本分类模型,针对待分类文本,实现获取待分类文本对应各目标类型文本分类模型下的类别;步骤a.基于预设数量的各样本文本,以及各样本文本分别对应预设分类下的相应真实类别,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,针对包含第一预设层数transformer层的第一bert模型进行训练,获得第一bert模型所对应训练后的主教师模型,然后进入步骤b;步骤b.基于图神经网络模型、包含第二预设层数transformer层的第二bert模型,以及各样本文本、各样本文本分别对应预设分类下的相应真实类别,针对图神经网络模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合输入图神经网络模型样本文本真实类别与图神经网络模型输出预测类别之间的交叉熵损失、图神经网络模型训练过程与第二bert模型训练过程间相互学习损失所构成的副教师目标损失函数,针对图神经网络模型进行训练,获得图神经网络模型所对应训练后的副教师模型;同时,针对第二bert模型,以样本文本为输入,样本文本所对应预设分类下相应真实类别为输出,结合主教师模型中多层渐进蒸馏损失、输入第二bert模型样本文本真实类别与第二bert模型输出的预测类别之间的交叉熵损失、第二bert模型与主教师模型预测层输出logits之间的蒸馏损失以及第二bert模型训练过程与图神经网络模型训练过程间相互学习损失所构成的助教蒸馏目标损失函数,针对第二bert模型进行多层渐进蒸馏训练,获得第二bert模型所对应训练后的助教模型,其中,第二预设层数小于第一预设层数;然后进入步骤c;步骤c.基于各样本文本,以及各样本文本分别对应预设分类下的相应真实类别,针对包含第三预设层数transformer层的第三bert模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合助教模型中多层渐进蒸馏损失、输入第三bert模型样本文本真实类别与第三bert模型输出的预测类别之间的交叉熵损失、第三bert模型与主教师助教模型预测层输出logits之间的蒸馏损失以及第三bert模型与副教师模型预测层输出logits之间蒸馏损失所构成的第一蒸馏目标损失函数,针对第三bert模型进行多层渐进蒸馏训练,获得第三bert模型所对应训练后的第一学生模型,然后进入步骤d,其中,第三预设层数小于第二预设层数;步骤d.基于预设第一分类模型、第二分类模型,以及各样本文本、各样本文本分别对应预设分类下的相应真实类别,针对第一分类模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合输入第一分类模型样本文本真实类别与第一分类模型输出的预测类别之间的交叉熵损失、第一分类模型预测层输出的logits与第一学生模型中预测层输出的logits之间的蒸馏损失、第一分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第一分类模型训练过程中与第二分类模型训练过程间的相互学习损失所构成的第二蒸馏目标损失函数,针对第一分类模型进行蒸馏训练,获得第一分类模型所对应训练后的第二学生模型;同时针对第二分类模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合输入第二分类模型样本文本真实类别与第二分类模型输出的预测类别之间的交叉熵损失、第二分类模型输出的预测层logits与第一学生模型中输出的预测层logits之间
蒸馏损失、第二分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第二分类模型训练过程与第一分类模型训练过程间相互学习损失所构成的第三蒸馏目标损失函数,针对第二分类模型进行蒸馏训练,获得第二分类模型所对应训练后的第三学生模型;所获第一学生模型、第二学生模型、第三学生模型即为所获预设各目标类型文本分类模型。2.根据权利要求1所述一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:所述步骤a至步骤d中预设数量的各样本文本、以及各样本文本分别对应预设分类下的相应真实类别,首先删除各样本文本中预设无意义类型词、以及空字符,更新各样本文本,然后去除重复的样本文本,更新获得剩余的各样本文本、以及各样本文本分别对应预设分类下的相应真实类别,最后用于各步骤中的训练操作。3.根据权利要求1所述一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:所述步骤b中,关于针对图神经网络模型进行训练,获得副教师模型的过程中,由输入图神经网络模型样本文本真实类别与图神经网络模型输出预测类别之间的交叉熵损失、与图神经网络模型训练过程和第二bert模型训练过程间相互学习损失所构成的副教师目标损失函数l
ass_t
如下:其中,λ表示预设控制交叉熵损失与相互学习损失权重大小的超参数,表示图神经网络模型训练过程与第二bert模型训练过程间相互学习的损失函数;与分别表示第二bert模型训练过程中输出的预测层logits与图神经网络模型训练过程中输出的预测层logits;表示图神经网络模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数;label
ass_t
表示输入图神经网络样本文本所对应的真实类别。4.根据权利要求1所述一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:所述步骤b中,关于针对第二bert模型进行训练,获得助教模型的过程中,由主教师模型中多层渐进蒸馏损失、第二bert模型与主教师模型预测层输出logits之间的蒸馏损失、输入第二bert模型样本文本真实类别与第二bert模型输出预测类别之间的交叉熵损失、以及第二bert模型训练过程与图神经网络模型训练过程间相互学习损失所构成的助教蒸馏目标损失函数l
ass_2
如下:其中,表示第二bert模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
ass_2
表示输入第二bert模型样本文本所对应的真实类别;表示度量第二bert模型与主教师模型预测层之间logits的损失函数,表示主教师模型输出的预测层logits;表示图神经网络模型训练过程与第二bert模型训练过程间相互学习的损失函数,与分别表示第二bert模型训练过程中输出的预测
层logits与图神经网络模型训练过程中输出的预测层logits;表示主教师模型与第二bert模型训练过程中中间隐藏层之间的损失函数,与分别表示主教师模型中间隐藏层输出的根据样本文本训练得到的logits、以及第二bert模型训练过程中中间隐藏层输出的根据样本文本训练得到的logits;α1、β1、γ1分别表示预设控制助教蒸馏目标损失函数l
ass_2
中交叉熵损失权重大小的超参数、多层渐进蒸馏损失权重大小的超参数、以及预测层损失与相互学习损失两者之间权重大小的超参数。5.根据权利要求1所述一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:所述步骤c中,关于针对第三bert模型进行训练,获得第一学生模型的过程中,由助教模型中多层渐进蒸馏损失、输入第三bert模型样本文本真实类别与第三bert模型输出的预测类别之间的交叉熵损失、第三bert模型与助教模型预测层输出logits之间的蒸馏损失以及第三bert模型与副教师模型预测层输出logits之间蒸馏损失所构成的第一蒸馏目标损失函数l
s1
如下:其中,表示第三bert模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
s1
表示输入第三bert模型样本文本所对应的真实类别;表示度量第三bert模型与助教模型输出预测层之间logits的损失函数,表示助教模型输出的预测层logits;表示度量副教师模型与第三bert模型训练过程中输出预测层之间logits的损失函数,与分别表示第三bert模型训练过程中输出的预测层logits与副教师模型输出的预测层logits;表示助教模型与第三bert模型训练过程中隐藏层输出logits之间的损失函数,与分别表示助教模型中中间隐藏层输出的logits、以及第二bert模型训练过程中中间隐藏层输出的logits;α2、β2、γ2分别表示预设控制第一蒸馏目标损失函数l
s1
中交叉熵损失权重大小的超参数、多层渐进蒸馏损失权重大小的超参数、第三bert模型与主教师模型输出的预测层logits之间的蒸馏损失以及第三bert模型与副教师模型输出的预测层logits之间蒸馏损失两者之间权重大小的超参数。6.根据权利要求1所述一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:所述步骤d中,关于针对第一分类模型进行训练,获得第二学生模型的过程中,由输入第一分类模型样本文本真实类别与第一分类模型输出的预测类别之间的交叉熵损失、第一分类模型预测层输出的logits与第一学生模型中预测层输出的logits之间的蒸馏损失、第一分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第一分类模型训练过程中与第二分类模型训练过程间的相互学习损失所构成的第二蒸馏目标损失函数l
s2
如下:
其中,表示第一分类模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
s2
表示输入第一分类模型样本文本所对应的真实类别;表示度量第一分类模型与第一学生模型预测层之间输出logits的损失函数;表示度量副教师模型与第一分类模型训练过程中输出预测层logits之间的损失函数,与分别表示第一分类模型训练过程中输出的预测层logits、以及副教师模型输出的预测层logits;表示第一分类模型与第二分类模型训练过程间相互学习的损失函数,与分别表示第一学生模型输出的预测层logits、以及第二分类模型训练过程输出的预测层logits;α3、β3、γ3分别表示预设控制第二蒸馏目标损失函数l
s2
中交叉熵损失权重大小的超参数、第一分类模型与第一学生模型输出的预测层logits之间的蒸馏损失以及第一分类模型与副教师模型输出的预测层logits之间蒸馏损失两者之间权重大小的超参数、相互学习损失权重大小的超参数。7.根据权利要求1所述一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:所述步骤d中,关于针对第二分类模型进行训练,获得第三学生模型的过程中,由输入第二分类模型样本文本真实类别与第二分类模型输出的预测类别之间的交叉熵损失、第二分类模型输出的预测层logits与第一学生模型中输出的预测层logits之间蒸馏损失、第二分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第二分类模型训练过程与第一分类模型训练过程间相互学习损失所构成的第三蒸馏目标损失函数l
s3
如下:其中,表示第二分类模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
s3
表示输入第二分类模型样本文本所对应的真实类别;表示度量第二分类模型与第一学生模型预测层之间logits的损失函数;表示度量副教师模型与第二分类模型训练过程中输出预测层logits之间的损失函数,与分别表示第二分类模型训练过程中输出的预测层logits、以及副教师模型输出的预测层logits;表示第二分类模型与第一分类模型训练过程间相互学习的损失函数,与分别表示第一分类模型输出的预测层logits、以及第一学生模型输出的预测层logits;α4、β4、γ4分别表示预设控制第三蒸馏目标损失函数l
s3
中交叉熵损失权重大小的超参数、第二分类模型与第一学生模型输出的预测层logits之间的蒸馏损失以及第二分类模型与副教师模型输出的预测层logits之间蒸馏损失两者之间权重大小的超参数、相互学习损失权重大小的超参数。8.根据权利要求1所述一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:所述第一bert模型包含12层transformer层,所述第二bert模型包含8层transformer层,所述第三bert模型包含5层transformer层;
其中,第一bert模型中的第2层transformer层、第4层transformer层、第6层transformer层、第8层transformer层、第9层transformer层、第10层transformer层、第11层transformer层依次一一对应第二bert模型中的第1层transformer层、第2层transformer层、第3层transformer层、第4层transformer层、第5层transformer层、第6层transformer层、第7层transformer层;第二bert模型中的第2层transformer层、第4层transformer层、第6层transformer层、第7层transformer层依次一一对应第三bert模型中的第1层transformer层、第2层transformer层、第3层transformer层、第4层transformer层。9.根据权利要求1所述一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:所述助教模型的输出、以及各个学生模型的输出分别均是对抗扰动叠加在其模型embedding层后模型的最终迭代输出,其叠加对抗扰动计算公式为:其中,δx指代迭代叠加在embedding层输出x后的对抗扰动项;∈代表权重参数,||g||2表示对梯度g求解2范数;l(
·
)代表损失函数,表示对损失函数求解偏导,y是文本样本所属类别真实标签,θ即为模型参数。10.根据权利要求1所述一种基于多助教模型知识蒸馏训练的文本分类方法,其特征在于:所述第一分类模型为bilstm模型,所述第二分类模型为fasttext模型。

技术总结
本发明专利涉及一种基于多助教模型知识蒸馏训练的文本分类方法,首先根据样本数据,分别针对主教师模型、副教师模型进行训练,接着根据样本数据,结合副教师模型与主教师模型对助教模型的联合渐进蒸馏,同时副教师模型与助教模型之间进行相互学习;再通过对副教师模型与助教模型联合渐进蒸馏得到第一学生模型,并继续对第一学生模型与副教师模型进行联合蒸馏,得到第二学生模型与第三学生模型,并相互学习;最后得到文本分类精度高的第一学生模型、推理速度快的第三学生模型、以及处于两者之间水平的第二学生模型;在实际应用中,将文本输入相应学生模型,得到相应类型下的文本分类结果,有效加快了模型推理速度,提高学生模型文本分类准确度。型文本分类准确度。型文本分类准确度。


技术研发人员:高尚兵 张骏强 苏睿 王媛媛 张海艳 马甲林 张正伟 朱全银
受保护的技术使用者:淮阴工学院
技术研发日:2022.03.30
技术公布日:2022/6/28
当前第2页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1