一种模型蒸馏方法、系统以及文本检索方法与流程

文档序号:29632946发布日期:2022-04-13 16:26阅读:128来源:国知局
一种模型蒸馏方法、系统以及文本检索方法与流程

1.本发明涉及自然语言处理技术领域,具体而言,涉及一种模型蒸馏方法、系统以及文本检索方法。


背景技术:

2.文本检索需求常常是线上文本检索,需要进行线上部署;研究发现,采用对比学习的方式对文本检索模型进行训练,能够得到更好的文本检索效果。但当前对比学习的模型均采用预训练模型,预训练模型体量很大,计算过于复杂,导致响应速度慢,服务器负载过高;为了解决这些问题,本技术结合对比训练方式提出一种模型蒸馏方式,将经过对比学习后的预训练模型蒸馏到一个小模型上,使得小模型能够学习到预训练模型的文本检索能力,且解决线上部署的问题。


技术实现要素:

3.本发明的目的在于提供模型蒸馏方法、系统以及文本检索方法,其通过将教师模型经过对比学习所学习到的知识迁移到学生模型上,且通过对比学习训练,得到教师模型相似度矩阵以及学生模型相似度矩阵,以此设计集成蒸馏损失函数,能够更好拟合教师模型的文本检索能力。
4.本发明的实施例通过以下技术方案实现:
5.第一方面,提供一种模型蒸馏方法,包括如下步骤:
6.s1.构建一个学生模型以及多个教师预训练模型,采用对比学习方法对多个教师预训练模型进行训练,训练完成得到多个教师模型;采用对比学习方法对学生模型进行训练,训练得到预蒸馏学生模型,在对比学习训练中,获取多个教师预训练模型的平均教师模型相似度矩阵,获取学生模型的学生模型相似度矩阵;
7.s2.获取预蒸馏学生模型的第一输出以及多个教师模型的第二输出;对多个教师模型的第二输出求取平均,得到平均输出;
8.s3.根据平均教师模型相似度矩阵、学生模型相似度矩阵、学生模型的第一输出、多个教师模型的平均输出以及样本数据的真实标签获取集成蒸馏损失函数,并根据集成蒸馏损失函数训练所述预蒸馏学生模型,并得到蒸馏学生模型。
9.优选地,所述教师预训练模型为bert模型。
10.优选地,s1具体为:获取教师预训练模型以及学生模型的训练样本数据集,多个教师预训练的训练样本数据集的数据量相同;分别对学生模型以及多个教师预训练模型进行对比学习训练,其训练方式如下:
11.a.在一个batch内,选定输入样本,分别两次向教师预训练模型或学生模型输入输入样本,通过dropout方法得到输入样本的正样本对;计算正样本对的余弦相似度;
12.b.随机采样输入样本所在的同一batch内的其他输入样本作为负样本;
13.c.根据正样本对的余弦相似度以及负样本的余弦相似度设计对比学习损失函数,
根据对比学习损失函数训练教师预训练模型以及学生模型。
14.优选地,所述对比学习损失函数具体为:
[0015][0016]
其中,sim()表示余弦相似度函数,z是dropout mask,
[0017]z′
是dropout mask;以及构成正样本对,为负样本,τ为温度超参数。
[0018]
优选地,所述集成蒸馏损失函数具体为:
[0019]
loss=β*((1-α)*loss1+α*loss2)+(1-β)*batch_loss
[0020]
其中,loss1=ce(q,y);
[0021]
loss2=ce(p,q);
[0022]
batch_loss=mse(m_t,m_s),
[0023]
上式中,q为学生模型的输出logits,p为多个教师模型输出的logits的平均logits,y为样本数据的真正标签,m_t为平均教师模型相似度矩阵,m_s为学生模型相似度矩阵,α以及β分别为超参数。
[0024]
优选地,其中获取平均教师相似度矩阵包括如下步骤:
[0025]
a.获取同一个batch两次输入一个教师预训练模型的输出,分别为第一向量矩阵以及第二向量矩阵,向量矩阵用(batch_size,dimension)表示;
[0026]
b.计算第一向量矩阵以及第二向量矩阵的相似度,得到教师模型相似度矩阵,表示为(batch_size,batch_size);
[0027]
c.计算多个教师模型余弦相似度矩阵的平均,得到平均教师模型相似度矩阵;
[0028]
获取学生相似度矩阵包括如下步骤:
[0029]
a.获取同一个batch两次输入学生的输出,分别为第三向量矩阵以及第四向量矩阵,向量矩阵用(batch_size,dimension)表示;
[0030]
b.计算第三向量矩阵以及第四向量矩阵的相似度,得到学生模型相似度矩阵,表示为(batch_size,batch_size);
[0031]
其中batch_size为batch的大小,dimension为教师预训练模型输出的向量维度。
[0032]
优选地,所述学生模型包括多层全连接层,多层所述全连接层用于对齐教师模型的输出维度。
[0033]
第二方面,提供一种文本检索方法,包括如下步骤:
[0034]
s1.获取查询文本以及候选文本;
[0035]
s2.将所述查询文本以及候选文本输入蒸馏学生模型,获得蒸馏学生模型输出的所述查询文本与所述候选文本的匹配率。
[0036]
第三方面,提供一种模型蒸馏系统,包括:
[0037]
获取模块,所述获取模块用于获取教师预训练模型以及学生模型的训练样本数据
集;
[0038]
对比学习模块,所述对比学习模块用于构建教师预训练模型或学生模型的正样本对以及负样本,设计对比学习损失函数,训练教师预训练模型或学生模型;
[0039]
蒸馏模块,用于获取教师模型的输出、预蒸馏学生模型的输出、学生模型相似度矩阵以及平均教师模型相似度矩阵,设计集成蒸馏损失函数,根据集成蒸馏损失函数训练学生模型。
[0040]
本发明实施例的技术方案至少具有如下优点和有益效果:
[0041]
1.通过将教师模型经过对比学习所学习到的知识迁移到学生模型上,其特点是参数少,运算简单,响应速度快;
[0042]
2.本技术基于对比学习得到教师模型相似度矩阵以及学生模型相似度矩阵,以此来设计集成蒸馏损失函数,使得学生模型利用对比学习的优势,在一个batch内进一步拟合教师模型的检索能力;
[0043]
3.本技术采用不同的训练样本数据集进行训练,使得学生模型学习到更全方面的知识;
[0044]
本发明设计合理、结构简单。
附图说明
[0045]
图1为本发明实施例1提供的模型蒸馏方法的流程示意图;
[0046]
图2为本发明实施例1提供的学生模型的框架图;
具体实施方式
[0047]
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。通常在此处附图中描述和示出的本发明实施例的组件可以以各种不同的配置来布置和设计。
[0048]
实施例1
[0049]
第一方面,如图1所述,一种模型蒸馏方法,包括如下步骤:
[0050]
s1.构建一个学生模型以及多个教师预训练模型,采用对比学习方法对多个教师预训练模型进行训练,训练完成得到多个教师模型;采用对比学习方法对学生模型进行训练,训练得到预蒸馏学生模型,在对比学习训练中,获取多个教师预训练模型的平均教师模型相似度矩阵,获取学生模型的学生模型相似度矩阵;
[0051]
在本实施例中,建立10个教师预训练模型,教师预训练模型采用bert模型,用10个教师预训练模型训练学生模型。采用对比学习的方法分别对教师预训练模型以及学生模型进行训练;
[0052]
首先需要获取教师预训练模型以及学生模型的训练样本数据集,10个教师预训练的训练样本数据集的数据不同但数据量相同;分别对学生模型以及多个教师预训练模型进行对比学习训练。
[0053]
利用不同的训练样本数据集对教师预训练模型进行对比训练,使得学生模型能够学习到更全方面的知识。
[0054]
在本实施例中,以训练其中一个教师预训练模型为例,介绍其训练方式如下:
[0055]
a.在一个batch中,选定输入样本,分别两次向教师预训练模型或学生模型输入输入样本,在本实施例中,输入样本为一个句子,(即同一个样本句子分两次输入教师预训练模型)。通过dropout方法得到输入样本的正样本对;计算正样本对的余弦相似度。
[0056]
具体的,同一个样本句子输入bert模型两次,采用两个不同的dropout mask,得到两个不同的句向量,输出的两个句向量作为正样本,两个句向量构成正样本对,采用余弦相似度公式计算正样本对的余弦相似度。其余弦相似度公式如下:
[0057][0058]
其中,h1,h2分别输出的两个正样本;
[0059]
在本实施例中,教师预训练模型采用的是bert模型,bert模型本身就具有随机dropout功能,所以一个输入样本只需要走两次bert模型就能得到两个不同但相似的表示向量。
[0060]
在本技术的其他实施例中,需向教师预训练模型添加dropout功能,进行对比学习训练。
[0061]
b.随机采样输入样本所在的同一batch内的其他输入样本作为负样本,计算负样本与正样本的余弦相似度。
[0062]
c.根据正样本对的余弦相似度以及负样本的余弦相似度设计对比学习损失函数,根据对比学习损失函数训练教师预训练模型以及学生模型。
[0063]
所述对比学习损失函数具体为:
[0064][0065]
其中,sim()表示余弦相似度函数,z是dropout mask,
[0066]z′
是dropout mask;以及构成正样本对,为负样本,τ为温度超参数。
[0067]
在对比学习训练中,获取平均教师模型相似度矩阵以及学生模型相似度矩阵,其中获取平均教师相似度矩阵包括如下步骤:
[0068]
a.获取同一个batch两次输入一个教师预训练模型的输出,分别为第一向量矩阵以及第二向量矩阵,向量矩阵用(batch_size,dimension)表示;
[0069]
b.计算第一向量矩阵以及第二向量矩阵的相似度,得到教师模型相似度矩阵,表示为(batch_size,batch_size);
[0070]
c.计算多个教师模型相似度矩阵的平均,得到平均教师模型相似度矩阵。
[0071]
获取学生相似度矩阵包括如下步骤:
[0072]
a.获取同一个batch两次输入学生模型的输出,分别为第三向量矩阵以及第四向量矩阵,向量矩阵用(batch_size,dimension)表示;
[0073]
b.计算第三向量矩阵以及第四向量矩阵的相似度,得到学生模型相似度矩阵,表示为(batch_size,batch_size);
[0074]
具体的,一个batch表示一批输入样本,batch_size为batch的大小,dimension为模型输出的向量维度,即针对一个batch,模型最后都会生成维度为(batch_size,batch_size)的相似度矩阵。
[0075]
在本实施例中,计算了第一向量矩阵以及第二向量矩阵的余弦相似度,在本技术的其他实施例中,还可以计算第一向量矩阵以及第二向量矩阵的距离相似度等。
[0076]
在本实施例中,第一向量矩阵以及第二向量矩阵相乘,得到余弦相似度矩阵。
[0077]
本技术基于对比学习得到教师模型相似度矩阵以及学生模型相似度矩阵,以此来设计集成蒸馏损失函数,使得学生模型利用对比学习的优势,在一个batch内进一步拟合教师模型的检索能力。
[0078]
s2.获取预蒸馏学生模型的第一输出以及多个教师模型的第二输出;对多个教师模型的第二输出求取平均,得到平均;。
[0079]
s3.根据平均教师模型相似度矩阵、学生模型相似度矩阵、学生模型的第一输出、多个教师模型的平均输出以及样本数据的真实标签获取集成蒸馏损失函数,并根据集成蒸馏损失函数训练所述预蒸馏学生模型,并得到蒸馏学生模型;在本实施例中,训练样本数据集包括查询文本以及至少一个候选文本,其真正标签即为至少一个候选文本中哪些是与查询文本相匹配的文本,哪些是不匹配的;
[0080]
具体的,集成蒸馏损失函数为:
[0081]
loss=β*((1-α)*loss1+α*loss2)+(1-β)*batch_loss
[0082]
其中,loss1=ce(q,y);
[0083]
loss2=ce(p,q);
[0084]
batch_loss=mse(m_t,m_s),
[0085]
上式中,q为学生模型的输出logits,p为多个教师模型输出的logits的平均logits,y为样本数据的真正标签,m_t为平均教师模型相似度矩阵,m_s为学生模型相似度矩阵,α以及β分别为超参数,ce(,)为交叉熵损失函数,mse(,)为均方误差损失函数。超参数的范围取(0,1)。
[0086]
p=(log its1+log its2+...+log its10)/10。
[0087]
在本实施例中取α=0.1,β=0.5。
[0088]
如图2所示,在本实施例中,所述学生模型包括多层全连接层,多层所述全连接层用于对齐教师模型的输出维度。学生模型还包括嵌入层embedding,卷积层cnn,更优的全连接层为三层。特别的,学生模型也加入了dropout方式进行对比学习训练。在本实施例中,学生模型在embedding层以及后面的三层全连接层mlp加入dropout。
[0089]
第二方面,提供一种文本检索方法,包括如下步骤:
[0090]
s1.获取查询文本以及候选文本;
[0091]
s2.将所述查询文本以及候选文本输入蒸馏学生模型,获得蒸馏学生模型输出的所述查询文本与所述候选文本的匹配率。
[0092]
第三方面,提供一种模型蒸馏系统,包括:
[0093]
获取模块,所述获取模块用于获取教师预训练模型以及学生模型的训练样本数据集;
[0094]
对比学习模块,所述对比学习模块用于构建教师预训练模型或学生模型的正样本对以及负样本,设计对比学习损失函数,训练教师预训练模型或学生模型;
[0095]
蒸馏模块,用于获取教师模型的输出、预蒸馏学生模型的输出、学生模型相似度矩阵以及平均教师模型相似度矩阵,设计集成蒸馏损失函数,根据集成蒸馏损失函数训练学生模型。
[0096]
综上,本技术具有如下优点,
[0097]
1.通过将教师模型经过对比学习所学习到的知识迁移到学生模型上,其特点就是参数少,运算简单,响应速度快;
[0098]
2.本技术基于对比学习得到教师模型相似度矩阵以及学生模型相似度矩阵,以此来设计集成蒸馏损失函数,使得学生模型利用对比学习的优势,在一个batch内进一步拟合教师模型的检索能力;
[0099]
3.本技术采用不同的训练样本数据集进行训练,使得学生模型学习到更全方面的知识。
[0100]
以上仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1