一种模型蒸馏方法、系统以及文本检索方法与流程

文档序号:29632946发布日期:2022-04-13 16:26阅读:来源:国知局

技术特征:
1.一种模型蒸馏方法,其特征在于,包括如下步骤:s1.构建一个学生模型以及多个教师预训练模型,采用对比学习方法对多个教师预训练模型进行训练,训练完成得到多个教师模型;采用对比学习方法对学生模型进行训练,训练得到预蒸馏学生模型,在对比学习训练中,获取多个教师预训练模型的平均教师模型相似度矩阵,获取学生模型的学生模型相似度矩阵;s2.获取预蒸馏学生模型的第一输出以及多个教师模型的第二输出;对多个教师模型的第二输出求取平均,得到平均输出;s3.根据平均教师模型相似度矩阵、学生模型相似度矩阵、学生模型的第一输出、多个教师模型的平均输出以及样本数据的真实标签获取集成蒸馏损失函数,并根据集成蒸馏损失函数训练所述预蒸馏学生模型,并得到蒸馏学生模型。2.根据权利要求1所述的基于对比学习与集成蒸馏的文本检索方法,其特征在于,s1具体为:获取教师预训练模型以及学生模型的训练样本数据集,多个教师预训练的训练样本数据集的数据量相同;分别对学生模型以及多个教师预训练模型进行对比学习训练,其训练方式如下:a.在一个batch内,选定输入样本,分别两次向教师预训练模型或学生模型输入输入样本,通过dropout方法得到输入样本的正样本对;计算正样本对的余弦相似度;b.随机采样输入样本所在的同一batch内的其他输入样本作为负样本;c.根据正样本对的余弦相似度以及负样本的余弦相似度设计对比学习损失函数,根据对比学习损失函数训练教师预训练模型以及学生模型。3.根据权利要求2所述的模型蒸馏方法,其特征在于,所述对比学习损失函数具体为:其中,sim()表示余弦相似度函数,z是dropoutmask,z

是dropout mask;以及构成正样本对,为负样本,τ为温度超参数。4.根据权利要求1所述的基于对比学习与集成蒸馏的文本检索方法,其特征在于,所述集成蒸馏损失函数具体为:loss=β*((1-α)*loss1+α*loss2)+(1-β)*batch_loss其中,loss1=ce(q,y);loss2=ce(p,q);batch_loss=mse(m_t,m_s),上式中,q为学生模型的输出logits,p为多个教师模型输出的logits的平均logits,y为样本数据的真正标签,m_t为平均教师模型相似度矩阵,m_s为学生模型相似度矩阵,α以及β分别为超参数。5.根据权利要求1所述的模型蒸馏方法,其特征在于,其中获取平均教师相似度矩阵包
括如下步骤:a.获取同一个batch两次输入一个教师预训练模型的输出,分别为第一向量矩阵以及第二向量矩阵,向量矩阵用(batch_size,dimension)表示;b.计算第一向量矩阵以及第二向量矩阵的相似度,得到教师模型相似度矩阵,表示为(batch_size,batch_size);c.计算多个教师模型余弦相似度矩阵的平均,得到平均教师模型相似度矩阵;其中batch_size为batch的大小,dimension为教师预训练模型输出的向量维度。6.根据权利要求1所述的模型蒸馏方法,其特征在于,所述学生模型包括多层全连接层,多层所述全连接层用于对齐教师模型的输出维度。7.一种文本检索方法,其特征在于,应用如权利要求1-6任意一项所述模型蒸馏方法获得蒸馏学生模型,包括如下步骤:s1.获取查询文本以及候选文本;s2.将所述查询文本以及候选文本输入蒸馏学生模型,获得蒸馏学生模型输出的所述查询文本与所述候选文本的匹配率。8.一种模型蒸馏系统,其特征在于,应用如权利要求1-6任意一项所述的模型蒸馏方法,包括:获取模块,所述获取模块用于获取教师预训练模型以及学生模型的训练样本数据集;对比学习模块,所述对比学习模块用于构建教师预训练模型或学生模型的正样本对以及负样本,设计对比学习损失函数,训练教师预训练模型或学生模型;蒸馏模块,用于获取教师模型的输出、预蒸馏学生模型的输出、学生模型相似度矩阵以及平均教师模型相似度矩阵,设计集成蒸馏损失函数,根据集成蒸馏损失函数训练学生模型。

技术总结
本发明提供了一种模型蒸馏方法、系统以及文本检索方法,其模型蒸馏方法包括如下步骤:S1.构建一个学生模型以及多个教师预训练模型,采用对比学习方法对多个教师预训练模型以及学生模型进行训练,在对比学习训练中,获取平均教师模型相似度矩阵以及学生模型相似度矩阵;S2.获取预蒸馏学生模型的第一输出以及多个教师模型的平均输出;S3.获取集成蒸馏损失函数,并根据集成蒸馏损失函数训练所述预蒸馏学生模型,并得到蒸馏学生模型;本发明通过对比学习训练得到平均教师模型相似度矩阵与学生模型相似度矩阵,将其加入集成蒸馏损失函数中,通过集成蒸馏损失函数训练预蒸馏学生模型,使得学生模型能够更好的拟合教师模型的检索能力。索能力。索能力。


技术研发人员:郭湘 黄鹏 江岭
受保护的技术使用者:成都晓多科技有限公司
技术研发日:2021.12.29
技术公布日:2022/4/12
当前第2页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1