一种基于多助教模型知识蒸馏训练的文本分类方法

文档序号:30581960发布日期:2022-06-29 12:46阅读:180来源:国知局
一种基于多助教模型知识蒸馏训练的文本分类方法

1.本发明涉及一种基于多助教模型知识蒸馏训练的文本分类方法,属于自然语言文本处理技术领域。


背景技术:

2.随着自然语言处理领域预训练语言模型的发展,自然语言处理领域众多模型的性能有着巨大的飞跃,目前较为流行的预训练语言模型在众多自然语言处理领域任务,例如文本分类、命名体识别、信息抽取、词性标注,乃至问答领域、推荐领域、机器翻译领域等中都有应用,并且也已经渐渐开始在工业界中落地,例如gpt-3模型api的开放、文心模型的上线等,这些无不暗示着预训练语言模型的广泛应用前景。
3.但与此同时,伴随模型性能巨大提高的同时,所带来的是模型参数量以几何倍数的暴涨,短短几年,预训练语言模型参数量已经达到千亿级别,并已向着万亿门槛迈进。模型参数量的骤增,使得模型对相应的软硬件配置要求变得越来越高,这对工业界模型落地、以及众多软硬件资源不充分的科研工作者是极其不友好的,并且过大的模型使得模型的推理速度也无法得到有效提高。
4.现有的知识蒸馏方法往往只是单一的将教师模型直接蒸馏到学生模型上,这导致随着教师模型性能的提高、参数量的增加,教师模型与学生模型之间的差异性将成倍放大,从而妨碍学生模型学习教室模型性能。又或者如cn110826344b-神经网络模型压缩方法、语料翻译方法及其装置中提到的方法,使用多级中间教师模型进行过渡来减少模型差异性造成的蒸馏性能损失,但这种方法并未考虑到多个中间模型在层层递进的蒸馏过程中,层数越多,深层级的中间教师模型相比于原始教师模型性能差距也越大,多层知识蒸馏传递过程造成的性能损失也会越多。


技术实现要素:

5.本发明所要解决的技术问题是提供一种基于多助教模型知识蒸馏训练的文本分类方法,采用全新设计逻辑,能够有效地加快模型推理速度,提高模型文本分类准确度。
6.本发明为了解决上述技术问题采用以下技术方案:本发明设计了一种基于多助教模型知识蒸馏训练的文本分类方法,按如下步骤a至步骤d,训练获得预设各目标类型文本分类模型;然后应用预设各目标类型文本分类模型,针对待分类文本,实现获取待分类文本对应各目标类型文本分类模型下的类别;
7.步骤a.基于预设数量的各样本文本,以及各样本文本分别对应预设分类下的相应真实类别,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,针对包含第一预设层数transformer层的第一bert模型进行训练,获得第一bert模型所对应训练后的主教师模型,然后进入步骤b;
8.步骤b.基于图神经网络模型、包含第二预设层数transformer层的第二bert模型,以及各样本文本、各样本文本分别对应预设分类下的相应真实类别,针对图神经网络模型,
以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合输入图神经网络模型样本文本真实类别与图神经网络模型输出预测类别之间的交叉熵损失、图神经网络模型训练过程与第二bert模型训练过程间相互学习损失所构成的副教师目标损失函数,针对图神经网络模型进行训练,获得图神经网络模型所对应训练后的副教师模型;
9.同时,针对第二bert模型,以样本文本为输入,样本文本所对应预设分类下相应真实类别为输出,结合主教师模型中多层渐进蒸馏损失、输入第二bert模型样本文本真实类别与第二bert模型输出的预测类别之间的交叉熵损失、第二bert模型与主教师模型预测层输出logits之间的蒸馏损失以及第二bert模型训练过程与图神经网络模型训练过程间相互学习损失所构成的助教蒸馏目标损失函数,针对第二bert模型进行多层渐进蒸馏训练,获得第二bert模型所对应训练后的助教模型,其中,第二预设层数小于第一预设层数;
10.然后进入步骤c;
11.步骤c.基于各样本文本,以及各样本文本分别对应预设分类下的相应真实类别,针对包含第三预设层数transformer层的第三bert模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合助教模型中多层渐进蒸馏损失、输入第三bert模型样本文本真实类别与第三bert模型输出的预测类别之间的交叉熵损失、第三bert模型与主教师助教模型预测层输出logits之间的蒸馏损失以及第三bert模型与副教师模型预测层输出logits之间蒸馏损失所构成的第一蒸馏目标损失函数,针对第三bert模型进行多层渐进蒸馏训练,获得第三bert模型所对应训练后的第一学生模型,然后进入步骤d,其中,第三预设层数小于第二预设层数;
12.步骤d.基于预设第一分类模型、第二分类模型,以及各样本文本、各样本文本分别对应预设分类下的相应真实类别,针对第一分类模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合输入第一分类模型样本文本真实类别与第一分类模型输出的预测类别之间的交叉熵损失、第一分类模型预测层输出的logits与第一学生模型中预测层输出的logits之间的蒸馏损失、第一分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第一分类模型训练过程中与第二分类模型训练过程间的相互学习损失所构成的第二蒸馏目标损失函数,针对第一分类模型进行蒸馏训练,获得第一分类模型所对应训练后的第二学生模型;
13.同时针对第二分类模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合输入第二分类模型样本文本真实类别与第二分类模型输出的预测类别之间的交叉熵损失、第二分类模型输出的预测层logits与第一学生模型中输出的预测层logits之间蒸馏损失、第二分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第二分类模型训练过程与第一分类模型训练过程间相互学习损失所构成的第三蒸馏目标损失函数,针对第二分类模型进行蒸馏训练,获得第二分类模型所对应训练后的第三学生模型;
14.所获第一学生模型、第二学生模型、第三学生模型即为所获预设各目标类型文本分类模型。
15.作为本发明的一种优选技术方案:所述步骤a至步骤d中预设数量的各样本文本、以及各样本文本分别对应预设分类下的相应真实类别,首先删除各样本文本中预设无意义类型词、以及空字符,更新各样本文本,然后去除重复的样本文本,更新获得剩余的各样本
文本、以及各样本文本分别对应预设分类下的相应真实类别,最后用于各步骤中的训练操作。
16.作为本发明的一种优选技术方案:所述步骤b中,关于针对图神经网络模型进行训练,获得副教师模型的过程中,由输入图神经网络模型样本文本真实类别与图神经网络模型输出预测类别之间的交叉熵损失、与图神经网络模型训练过程和第二bert模型训练过程间相互学习损失所构成的副教师目标损失函数l
ass_t
如下:
[0017][0018]
其中,λ表示预设控制交叉熵损失与相互学习损失权重大小的超参数,表示图神经网络模型训练过程与第二bert模型训练过程间相互学习的损失函数;与分别表示第二bert模型训练过程中输出的预测层logits与图神经网络模型训练过程中输出的预测层logits;表示图神经网络模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数;label
ass_t
表示输入图神经网络样本文本所对应的真实类别。
[0019]
作为本发明的一种优选技术方案:所述步骤b中,关于针对第二bert模型进行训练,获得助教模型的过程中,由主教师模型中多层渐进蒸馏损失、第二bert模型与主教师模型预测层输出logits之间的蒸馏损失、输入第二bert模型样本文本真实类别与第二bert模型输出预测类别之间的交叉熵损失、以及第二bert模型训练过程与图神经网络模型训练过程间相互学习损失所构成的助教蒸馏目标损失函数l
ass_2
如下:
[0020][0021]
其中,表示第二bert模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
ass_2
表示输入第二bert模型样本文本所对应的真实类别;表示度量第二bert模型与主教师模型预测层之间logits的损失函数,表示主教师模型输出的预测层logits;表示图神经网络模型训练过程与第二bert模型训练过程间相互学习的损失函数,与分别表示第二bert模型训练过程中输出的预测层logits与图神经网络模型训练过程中输出的预测层logits;表示主教师模型与第二bert模型训练过程中中间隐藏层之间的损失函数,与分别表示主教师模型中间隐藏层输出的根据样本文本训练得到的logits、以及第二bert模型训练过程中中间隐藏层输出的根据样本文本训练得到的logits;α1、β1、γ1分别表示预设控制助教蒸馏目标损失函数l
ass_2
中交叉熵损失权重大小的超参数、多层渐进蒸馏损失权重大小的超参数、以及预测层损失与相互学习损失两者之间权重大小的超参数。
[0022]
作为本发明的一种优选技术方案:所述步骤c中,关于针对第三bert模型进行训练,获得第一学生模型的过程中,由助教模型中多层渐进蒸馏损失、输入第三bert模型样本文本真实类别与第三bert模型输出的预测类别之间的交叉熵损失、第三bert模型与助教模型预测层输出logits之间的蒸馏损失以及第三bert模型与副教师模型预测层输出logits
之间蒸馏损失所构成的第一蒸馏目标损失函数l
s1
如下:
[0023][0024]
其中,表示第三bert模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
s1
表示输入第三bert模型样本文本所对应的真实类别;表示度量第三bert模型与助教模型输出预测层之间logits的损失函数,表示助教模型输出的预测层logits;表示度量副教师模型与第三bert模型训练过程中输出预测层之间logits的损失函数,与分别表示第三bert模型训练过程中输出的预测层logits与副教师模型输出的预测层logits;表示助教模型与第三bert模型训练过程中隐藏层输出logits之间的损失函数,与分别表示助教模型中中间隐藏层输出的logits、以及第二bert模型训练过程中中间隐藏层输出的logits;α2、β2、γ2分别表示预设控制第一蒸馏目标损失函数l
s1
中交叉熵损失权重大小的超参数、多层渐进蒸馏损失权重大小的超参数、第三bert模型与主教师模型输出的预测层logits之间的蒸馏损失以及第三bert模型与副教师模型输出的预测层logits之间蒸馏损失两者之间权重大小的超参数。
[0025]
作为本发明的一种优选技术方案:所述步骤d中,关于针对第一分类模型进行训练,获得第二学生模型的过程中,由输入第一分类模型样本文本真实类别与第一分类模型输出的预测类别之间的交叉熵损失、第一分类模型预测层输出的logits与第一学生模型中预测层输出的logits之间的蒸馏损失、第一分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第一分类模型训练过程中与第二分类模型训练过程间的相互学习损失所构成的第二蒸馏目标损失函数l
s2
如下:
[0026][0027]
其中,表示第一分类模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
s2
表示输入第一分类模型样本文本所对应的真实类别;表示度量第一分类模型与第一学生模型预测层之间输出logits的损失函数;表示度量副教师模型与第一分类模型训练过程中输出预测层logits之间的损失函数,与分别表示第一分类模型训练过程中输出的预测层logits、以及副教师模型输出的预测层logits;表示第一分类模型与第二分类模型训练过程间相互学习的损失函数,与分别表示第一学生模型输出的预测层logits、以及第二分类模型训练过程输出的预测层logits;α3、β3、γ3分别表示预设控制第二蒸馏目标损失函数l
s2
中交叉熵损失权重大小的超参数、第一分类模型与第一学生模型输出的预测层logits之间的蒸馏损失以及第一分类模型与副教师模型输出的预测层logits之间蒸馏损失两者之间权重大小的超
参数、相互学习损失权重大小的超参数。
[0028]
作为本发明的一种优选技术方案:所述步骤d中,关于针对第二分类模型进行训练,获得第三学生模型的过程中,由输入第二分类模型样本文本真实类别与第二分类模型输出的预测类别之间的交叉熵损失、第二分类模型输出的预测层logits与第一学生模型中输出的预测层logits之间蒸馏损失、第二分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第二分类模型训练过程与第一分类模型训练过程间相互学习损失所构成的第三蒸馏目标损失函数l
s3
如下:
[0029][0030]
其中,表示第二分类模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
s3
表示输入第二分类模型样本文本所对应的真实类别;表示度量第二分类模型与第一学生模型预测层之间logits的损失函数;表示度量副教师模型与第二分类模型训练过程中输出预测层logits之间的损失函数,与分别表示第二分类模型训练过程中输出的预测层logits、以及副教师模型输出的预测层logits;表示第二分类模型与第一分类模型训练过程间相互学习的损失函数,与分别表示第一分类模型输出的预测层logits、以及第一学生模型输出的预测层logits;α4、β4、γ4分别表示预设控制第三蒸馏目标损失函数l
s3
中交叉熵损失权重大小的超参数、第二分类模型与第一学生模型输出的预测层logits之间的蒸馏损失以及第二分类模型与副教师模型输出的预测层logits之间蒸馏损失两者之间权重大小的超参数、相互学习损失权重大小的超参数。
[0031]
作为本发明的一种优选技术方案:所述第一bert模型包含12层transformer层,所述第二bert模型包含8层transformer层,所述第三bert模型包含5层transformer层;
[0032]
其中,第一bert模型中的第2层transformer层、第4层transformer层、第6层transformer层、第8层transformer层、第9层transformer层、第10层transformer层、第11层transformer层依次一一对应第二bert模型中的第1层transformer层、第2层transformer层、第3层transformer层、第4层transformer层、第5层transformer层、第6层transformer层、第7层transformer层;
[0033]
第二bert模型中的第2层transformer层、第4层transformer层、第6层transformer层、第7层transformer层依次一一对应第三bert模型中的第1层transformer层、第2层transformer层、第3层transformer层、第4层transformer层。
[0034]
作为本发明的一种优选技术方案:所述助教模型的输出、以及各个学生模型的输出分别均是对抗扰动叠加在其模型embedding层后模型的最终迭代输出,其叠加对抗扰动计算公式为:
[0035][0036]
其中,δx指代迭代叠加在embedding层输出x后的对抗扰动项;∈代表权重参数,|
|g||2表示对梯度g求解2范数;l(
·
)代表损失函数,

x表示对损失函数求解偏导,y是文本样本所属类别真实标签,θ即为模型参数。
[0037]
作为本发明的一种优选技术方案:所述第一分类模型为bilstm模型,所述第二分类模型为fasttext模型。
[0038]
本发明所述一种基于多助教模型知识蒸馏训练的文本分类方法,采用以上技术方案与现有技术相比,具有以下技术效果:
[0039]
(1)本发明所设计一种基于多助教模型知识蒸馏训练的文本分类方法,采用多训练范式,通过引入助教模型的方式,解决传统知识蒸馏在教师模型与学生模型之间由于结构和模型参数大小差异过大而引起的知识鸿沟问题,同时,为了避免助教模型过渡过程中会损失过多性能,引入额外的异构副教师模型对助教模型的训练进行指导,最大程度减小助教模型与学生模型参数量差异的同时,保留助教模型的文本分类任务性能;
[0040]
(2)本发明所设计一种基于多助教模型知识蒸馏训练的文本分类方法中,在transformer架构的模型进行知识蒸馏时,针对模型中间层知识的蒸馏,提出了一种渐进式蒸馏策略,在较浅层级的模型蒸馏时,此时知识表达较浅,学习难度较大,因此进行跳层蒸馏;在较深层级的模型蒸馏时,此时知识表达更为充分,这种深层transformer中蕴含更多知识,因此进行逐层蒸馏。通过这种策略使得学生模型文本分类精度更高;
[0041]
(3)本发明所设计一种基于多助教模型知识蒸馏训练的文本分类方法中,考虑到主教师模型以及助教模型和后续第一学生模型transformer架构存在模型参数量大,训练时文本长度有限制,并且占用gpu显存高的问题,构建了对gpu显存需求小的副教师模型,相同硬件条件下,副教师模型可以使用更大的批量数据训练,从而使得模型训练更充分,可以获得性能更优的模型。本发明为了缓解硬件配置对主教师模型以及助教模型训练的压力,让模型可以在硬件配置更低的环境下进行优化训练,构建对硬件配置要求低的副教师模型训练输出的较优的预测层logits对助教模型以及后续学生模型进行训练指导,从而使得transformer架构的助教模型以及学生模型在硬件条件不变的情况下,文本分类性能得到进一步提升;
[0042]
(4)本发明所设计一种基于多助教模型知识蒸馏训练的文本分类方法中,设计在模型蒸馏损失函数中不学习embedding词嵌入层的特征表示,而是通过在embedding词嵌入层构建对抗扰动叠加在embedding词嵌入层上,使得最终模型的输出鲁棒性更强,不仅减少了整个训练框架的参数量,加快了训练速度,而且使得最终训练得到的文本分类学生模型鲁棒性更好,泛化性能更强;
[0043]
(5)本发明所设计一种基于多助教模型知识蒸馏训练的文本分类方法中,在整个框架下的模型训练过程中,在损失函数中引入相互学习的思想,在对助教模型进行蒸馏训练时,相互学习使得两种异构的网络结构可以充分学习彼此的优点,进而让助教模型在减小参数量的同时,尽可能降低模型的损失;在对学生模型训练时,相互学习使得两个参数量较小的模型之间进行互相学习,由于三个学生模型架构都不同,相互学习可以进一步让学生模型充分学习不同架构的优势,从而提升模型文本分类性能,最终实现,一次知识蒸馏训练得到多个具有不同优点的学生模型,得以应用到不同需求的文本分类任务场景下,降低了重复训练的成本。
附图说明
[0044]
图1是本本发明设计一种基于多助教模型知识蒸馏训练的文本分类方法的流程示意图;
[0045]
图2是本发明设计中文本分类流程图;
[0046]
图3是本发明设计中的渐进蒸馏策略模型层级映射关系图。
具体实施方式
[0047]
下面结合说明书附图对本发明的具体实施方式作进一步详细的说明。
[0048]
本发明所设计一种基于多助教模型知识蒸馏训练的文本分类方法,实际应用当中,如图1、图2所示,按如下步骤a至步骤d,训练获得预设各目标类型文本分类模型;然后应用预设各目标类型文本分类模型,针对待分类文本,实现获取待分类文本对应各目标类型文本分类模型下的类别。
[0049]
步骤a.针对预设数量的各样本文本,以及各样本文本分别对应预设分类下的相应真实类别,首先删除各样本文本中预设无意义类型词、以及空字符,更新各样本文本,然后去除重复的样本文本,更新获得剩余的各样本文本、以及各样本文本分别对应预设分类下的相应真实类别。
[0050]
然后基于该各样本文本,以及各样本文本分别对应预设分类下的相应真实类别,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,针对包含12层transformer层的第一bert模型进行训练,获得第一bert模型所对应训练后的主教师模型,然后进入步骤b。
[0051]
步骤b.基于图神经网络模型、包含8层transformer层的第二bert模型,以及各样本文本、各样本文本分别对应预设分类下的相应真实类别,针对图神经网络模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合输入图神经网络模型样本文本真实类别与图神经网络模型输出预测类别之间的交叉熵损失、图神经网络模型训练过程与第二bert模型训练过程间相互学习损失所构成的副教师目标损失函数,针对图神经网络模型进行训练,获得图神经网络模型所对应训练后的副教师模型。
[0052]
实际应用当中,关于针对图神经网络模型进行训练,获得副教师模型的过程中,由输入图神经网络模型样本文本真实类别与图神经网络模型输出预测类别之间的交叉熵损失、与图神经网络模型训练过程和第二bert模型训练过程间相互学习损失所构成的副教师目标损失函数l
ass_t
如下:
[0053][0054]
其中,λ表示预设控制交叉熵损失与相互学习损失权重大小的超参数,表示图神经网络模型训练过程与第二bert模型训练过程间相互学习的损失函数;与分别表示第二bert模型训练过程中输出的预测层logits与图神经网络模型训练过程中输出的预测层logits;表示图神经网络模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数;label
ass_t
表示输入图神经网络样本文本所对应的真实类别。
[0055]
并且应用中,如图3所示,第一bert模型中的第2层transformer层、第4层transformer层、第6层transformer层、第8层transformer层、第9层transformer层、第10层transformer层、第11层transformer层依次一一对应第二bert模型中的第1层transformer层、第2层transformer层、第3层transformer层、第4层transformer层、第5层transformer层、第6层transformer层、第7层transformer层。
[0056]
同时,针对第二bert模型,以样本文本为输入,样本文本所对应预设分类下相应真实类别为输出,结合主教师模型中多层渐进蒸馏损失、输入第二bert模型样本文本真实类别与第二bert模型输出的预测类别之间的交叉熵损失、第二bert模型与主教师模型预测层输出logits之间的蒸馏损失以及第二bert模型训练过程与图神经网络模型训练过程间相互学习损失所构成的助教蒸馏目标损失函数,针对第二bert模型进行多层渐进蒸馏训练,获得第二bert模型所对应训练后的助教模型,其中,第二预设层数小于第一预设层数;然后进入步骤c。
[0057]
实际应用当中,关于针对第二bert模型进行训练,获得助教模型的过程中,由主教师模型中多层渐进蒸馏损失、第二bert模型与主教师模型预测层输出logits之间的蒸馏损失、输入第二bert模型样本文本真实类别与第二bert模型输出预测类别之间的交叉熵损失、以及第二bert模型训练过程与图神经网络模型训练过程间相互学习损失所构成的助教蒸馏目标损失函数l
ass_2
如下:
[0058][0059]
其中,表示第二bert模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
ass_2
表示输入第二bert模型样本文本所对应的真实类别;表示度量第二bert模型与主教师模型预测层之间logits的损失函数,表示主教师模型输出的预测层logits;表示图神经网络模型训练过程与第二bert模型训练过程间相互学习的损失函数,与分别表示第二bert模型训练过程中输出的预测层logits与图神经网络模型训练过程中输出的预测层logits;表示主教师模型与第二bert模型训练过程中中间隐藏层之间的损失函数,与分别表示主教师模型中间隐藏层输出的根据样本文本训练得到的logits、以及第二bert模型训练过程中中间隐藏层输出的根据样本文本训练得到的logits;α1、β1、γ1分别表示预设控制助教蒸馏目标损失函数l
ass_2
中交叉熵损失权重大小的超参数、多层渐进蒸馏损失权重大小的超参数、以及预测层损失与相互学习损失两者之间权重大小的超参数。
[0060]
步骤c.基于各样本文本,以及各样本文本分别对应预设分类下的相应真实类别,针对包含5层transformer层的第三bert模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合助教模型中多层渐进蒸馏损失、输入第三bert模型样本文本真实类别与第三bert模型输出的预测类别之间的交叉熵损失、第三bert模型与主教师助教模型预测层输出logits之间的蒸馏损失以及第三bert模型与副教师模型预测层输出logits之间蒸馏损失所构成的第一蒸馏目标损失函数,针对第三bert模型进行多层渐进蒸馏训
练,获得第三bert模型所对应训练后的第一学生模型,然后进入步骤d,其中,第三预设层数小于第二预设层数。
[0061]
实际应用当中,关于针对第三bert模型进行训练,获得第一学生模型的过程中,由助教模型中多层渐进蒸馏损失、输入第三bert模型样本文本真实类别与第三bert模型输出的预测类别之间的交叉熵损失、第三bert模型与助教模型预测层输出logits之间的蒸馏损失以及第三bert模型与副教师模型预测层输出logits之间蒸馏损失所构成的第一蒸馏目标损失函数l
s1
如下:
[0062][0063]
其中,表示第三bert模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
s1
表示输入第三bert模型样本文本所对应的真实类别;表示度量第三bert模型与助教模型输出预测层之间logits的损失函数,表示助教模型输出的预测层logits;表示度量副教师模型与第三bert模型训练过程中输出预测层之间logits的损失函数,与分别表示第三bert模型训练过程中输出的预测层logits与副教师模型输出的预测层logits;表示助教模型与第三bert模型训练过程中隐藏层输出logits之间的损失函数,与分别表示助教模型中中间隐藏层输出的logits、以及第二bert模型训练过程中中间隐藏层输出的logits;α2、β2、γ2分别表示预设控制第一蒸馏目标损失函数l
s1
中交叉熵损失权重大小的超参数、多层渐进蒸馏损失权重大小的超参数、第三bert模型与主教师模型输出的预测层logits之间的蒸馏损失以及第三bert模型与副教师模型输出的预测层logits之间蒸馏损失两者之间权重大小的超参数。
[0064]
并且应用中,如图3所示,第二bert模型中的第2层transformer层、第4层transformer层、第6层transformer层、第7层transformer层依次一一对应第三bert模型中的第1层transformer层、第2层transformer层、第3层transformer层、第4层transformer层。
[0065]
步骤d.基于构成第一分类模型的bilstm模型、构成第二分类模型的fasttext模型,以及各样本文本、各样本文本分别对应预设分类下的相应真实类别,针对第一分类模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合输入第一分类模型样本文本真实类别与第一分类模型输出的预测类别之间的交叉熵损失、第一分类模型预测层输出的logits与第一学生模型中预测层输出的logits之间的蒸馏损失、第一分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第一分类模型训练过程中与第二分类模型训练过程间的相互学习损失所构成的第二蒸馏目标损失函数,针对第一分类模型进行蒸馏训练,获得第一分类模型所对应训练后的第二学生模型。
[0066]
实际应用当中,关于针对第一分类模型进行训练,获得第二学生模型的过程中,由输入第一分类模型样本文本真实类别与第一分类模型输出的预测类别之间的交叉熵损失、第一分类模型预测层输出的logits与第一学生模型中预测层输出的logits之间的蒸馏损
失、第一分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第一分类模型训练过程中与第二分类模型训练过程间的相互学习损失所构成的第二蒸馏目标损失函数l
s2
如下:
[0067][0068]
其中,表示第一分类模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
s2
表示输入第一分类模型样本文本所对应的真实类别;表示度量第一分类模型与第一学生模型预测层之间输出logits的损失函数;表示度量副教师模型与第一分类模型训练过程中输出预测层logits之间的损失函数,与分别表示第一分类模型训练过程中输出的预测层logits、以及副教师模型输出的预测层logits;表示第一分类模型与第二分类模型训练过程间相互学习的损失函数,与分别表示第一学生模型输出的预测层logits、以及第二分类模型训练过程输出的预测层logits;α3、β3、γ3分别表示预设控制第二蒸馏目标损失函数l
s2
中交叉熵损失权重大小的超参数、第一分类模型与第一学生模型输出的预测层logits之间的蒸馏损失以及第一分类模型与副教师模型输出的预测层logits之间蒸馏损失两者之间权重大小的超参数、相互学习损失权重大小的超参数。
[0069]
同时针对第二分类模型,以样本文本为输入,样本文本所对应预设分类下相应类别为输出,结合输入第二分类模型样本文本真实类别与第二分类模型输出的预测类别之间的交叉熵损失、第二分类模型输出的预测层logits与第一学生模型中输出的预测层logits之间蒸馏损失、第二分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第二分类模型训练过程与第一分类模型训练过程间相互学习损失所构成的第三蒸馏目标损失函数,针对第二分类模型进行蒸馏训练,获得第二分类模型所对应训练后的第三学生模型。
[0070]
实际应用当中,关于针对第二分类模型进行训练,获得第三学生模型的过程中,由输入第二分类模型样本文本真实类别与第二分类模型输出的预测类别之间的交叉熵损失、第二分类模型输出的预测层logits与第一学生模型中输出的预测层logits之间蒸馏损失、第二分类模型预测层输出的logits与副教师模型预测层输出logits之间蒸馏损失、以及第二分类模型训练过程与第一分类模型训练过程间相互学习损失所构成的第三蒸馏目标损失函数l
s3
如下:
[0071][0072]
其中,表示第二分类模型训练过程中根据样本文本训练输出的预测类别与真实类别之间的交叉熵损失函数,label
s3
表示输入第二分类模型样本文本所对应的真实类别;表示度量第二分类模型与第一学生模型预测层之间logits的损失函数;表示度量副教师模型与第二分类模型训练过程中输出预测层logits之间的损失函数,与
分别表示第二分类模型训练过程中输出的预测层logits、以及副教师模型输出的预测层logits;表示第二分类模型与第一分类模型训练过程间相互学习的损失函数,与分别表示第一分类模型输出的预测层logits、以及第一学生模型输出的预测层logits;α4、β4、γ4分别表示预设控制第三蒸馏目标损失函数l
s3
中交叉熵损失权重大小的超参数、第二分类模型与第一学生模型输出的预测层logits之间的蒸馏损失以及第二分类模型与副教师模型输出的预测层logits之间蒸馏损失两者之间权重大小的超参数、相互学习损失权重大小的超参数。
[0073]
上述方案设计中,助教模型的输出、以及各个学生模型的输出分别均是对抗扰动叠加在其模型embedding层后模型的最终迭代输出,其叠加对抗扰动计算公式为:
[0074][0075]
其中,δx指代迭代叠加在embedding层输出x后的对抗扰动项;∈代表权重参数,||g||2表示对梯度g求解2范数;l(
·
)代表损失函数,

x表示对损失函数求解偏导,y是文本样本所属类别真实标签,θ即为模型参数。
[0076]
即最终获得文本分类精度高的第一学生模型、推理速度快的第三学生模型、以及处于两者之间水平的第二学生模型,即第一学生模型侧重于分类精度,第三学生模型侧重于推理速度快,第二学生模型介于两者之间,实际应用当中,使用者可以根据实际情况,选择相应学生模型,针对待分类文本进行分类,获得相应分类结果。
[0077]
上述技术方案所设计基于多助教模型知识蒸馏训练的文本分类方法,在实际应用当中,执行包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,应用中,计算机程序被加载至处理器时,实现所设计基于多助教模型知识蒸馏训练的文本分类方法,获得侧重于不同方向的各个学生模型,并进行实际文本分类的应用。
[0078]
实际应用当中,为了更好的说明本方法的可行性与有效性,将本发明所设计基于多助教模型知识蒸馏训练的文本分类方法应用于实际当中,通过对128051条化工文本数据进行文本分类实验,结果表明使用本发明设计方法生成的学生模型,应用在文本分类任务上性能优于传统知识蒸馏方法获得的学生模型,准确率达到了86.63%。
[0079]
上面结合附图对本发明的实施方式作了详细说明,但是本发明并不限于上述实施方式,在本领域普通技术人员所具备的知识范围内,还可以在不脱离本发明宗旨的前提下做出各种变化。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1