用于自然语言处理的模型蒸馏方法及装置、介质、设备与流程

文档序号:37377930发布日期:2024-03-22 10:30阅读:8来源:国知局
用于自然语言处理的模型蒸馏方法及装置、介质、设备与流程

所属的技术人员能够理解,本公开的各个方面可以实现为系统、方法或程序产品。因此,本公开的各个方面可以具体实现为以下形式,即:完全的硬件实施方式、完全的软件实施方式(包括固件、微代码等),或硬件和软件方面结合的实施方式,这里可以统称为“电路”、“模块”或“系统”。下面参照图7来描述根据本公开的这种实施方式的电子设备700。图7显示的电子设备700仅仅是一个示例,不应对本公开实施例的功能和使用范围带来任何限制。如图7所示,电子设备700以通用计算设备的形式表现。电子设备700的组件可以包括但不限于:上述至少一个处理单元710、上述至少一个存储单元720、连接不同系统组件(包括存储单元720和处理单元710)的总线730以及显示单元740。其中,所述存储单元存储有程序代码,所述程序代码可以被所述处理单元710执行,使得所述处理单元710执行本说明书上述“示例性方法”部分中描述的根据本公开各种示例性实施方式的步骤。例如,所述处理单元710可以执行如图1中所示的:步骤s110,获取训练样本集,利用所述训练样本集对待训练的多任务教师模型进行预训练,获得预训练的多任务教师模型;所述多任务教师模型用于执行m种自然语言处理任务,所述多任务教师模型包括与所述m种自然语言处理任务分别对应的m个子模型;所述训练样本集包含m个子集,不同的所述子集用于训练不同的所述子模型,m为大于1的整数;步骤s120,利用所述训练样本集对所述预训练的多任务教师模型和待训练的学生模型进行整体蒸馏训练,获得训练好的多任务教师模型和初次蒸馏后的学生模型;所述待训练的学生模型对应的自然语言处理任务为所述m种自然语言处理任务中任一种;步骤s130,利用与所述初次蒸馏后的学生模型对应的自然语言处理任务相匹配的目标子集和所述训练好的多任务教师模型,对所述初次蒸馏后的学生模型进行再次蒸馏训练,获得蒸馏好的学生模型。存储单元720可以包括易失性存储单元形式的可读介质,例如随机存取存储单元(ram)7201和/或高速缓存存储单元7202,还可以进一步包括只读存储单元(rom)7203。存储单元720还可以包括具有一组(至少一个)程序模块7205的程序/实用工具7204,这样的程序模块7205包括但不限于:操作系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。总线730可以为表示几类总线结构中的一种或多种,包括存储单元总线或者存储单元控制器、外围总线、图形加速端口、处理单元或者使用多种总线结构中的任意总线结构的局域总线。电子设备700也可以与一个或多个外部设备800(例如键盘、指向设备、蓝牙设备等)通信,还可与一个或者多个使得用户能与该电子设备700交互的设备通信,和/或与使得该电子设备700能与一个或多个其它计算设备进行通信的任何设备(例如路由器、调制解调器等等)通信。这种通信可以通过输入/输出(i/o)接口750进行。并且,电子设备700还可以通过网络适配器760与一个或者多个网络(例如局域网(lan),广域网(wan)和/或公共网络,例如因特网)通信。如图所示,网络适配器760通过总线730与电子设备700的其它模块通信。应当明白,尽管图中未示出,可以结合电子设备700使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理单元、外部磁盘驱动阵列、raid系统、磁带驱动器以及数据备份存储系统等。本领域技术人员在考虑说明书及实践这里公开的发明后,将容易想到本公开的其他实施例。本公开旨在涵盖本公开的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本公开的一般性原理并包括本公开未公开的本中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本公开的真正范围和精神由权利要求指出。


背景技术:

1、模型蒸馏是一种将复杂的神经网络模型简化的技术,在这个过程中,一个复杂的模型(即多任务教师模型)被用来指导一个简化的模型(即学生模型)学习,模型蒸馏的基本原理是将多任务教师模型的知识转移给学生模型。

2、目前,在自然语言处理领域中,一般都是在多任务教师模型训练结束、功能完备之后才进行蒸馏,该方案割裂了多任务教师模型和蒸馏得到的学生模型之间的关系,难以保证学生模型的蒸馏效果。

3、鉴于此,本领域亟需开发一种新的用于自然语言处理的模型蒸馏方法及装置。

4、需要说明的是,上述背景技术部分公开的信息仅用于加强对本公开的背景的理解。


技术实现思路

1、本公开的目的在于提供一种用于自然语言处理的模型蒸馏方法、用于自然语言处理的模型蒸馏装置、计算机存储介质及电子设备,进而至少在一定程度上克服由于相关技术的限制而导致的蒸馏效果差的技术问题。

2、本公开的其他特性和优点将通过下面的详细描述变得显然,或部分地通过本公开的实践而习得。

3、根据本公开的第一方面,提供一种用于自然语言处理的模型蒸馏方法,包括:获取训练样本集,利用所述训练样本集对待训练的多任务教师模型进行预训练,获得预训练的多任务教师模型;所述多任务教师模型用于执行m种自然语言处理任务,所述多任务教师模型包括与所述m种自然语言处理任务分别对应的m个子模型;所述训练样本集包含m个子集,不同的所述子集用于训练不同的所述子模型,m为大于1的整数;

4、利用所述训练样本集对所述预训练的多任务教师模型和待训练的学生模型进行整体蒸馏训练,获得训练好的多任务教师模型和初次蒸馏后的学生模型;所述待训练的学生模型对应的自然语言处理任务为所述m种自然语言处理任务中任一种;

5、利用与所述初次蒸馏后的学生模型对应的自然语言处理任务相匹配的目标子集和所述训练好的多任务教师模型,对所述初次蒸馏后的学生模型进行再次蒸馏训练,获得蒸馏好的学生模型。

6、在本公开的示例性实施例中,每个所述子集中包含多个训练样本,每个所述训练样本包含样本输入信息和所述样本输入信息对应的真实标注标签。

7、在本公开的示例性实施例中,当所述待训练的学生模型的数目为1时;

8、所述利用所述训练样本集对所述预训练的多任务教师模型和待训练的学生模型进行整体蒸馏训练,获得训练好的多任务教师模型和初次蒸馏后的学生模型,包括:

9、从每个所述子集中选取训练样本;

10、将所述训练样本中的样本输入信息输入至所述预训练的多任务教师模型中,根据所述预训练的多任务教师模型对应的输出结果与所述真实标注标签之间的差异程度,获得第一损失值;

11、判断所述训练样本所属的子集关联的目标自然语言处理任务与所述待训练的学生模型对应的自然语言处理任务是否一致;

12、若一致,则将所述训练样本输入至所述待训练的学生模型中,根据所述待训练的学生模型对应的输出结果与所述真实标注标签之间的差异程度,确定第二损失值;

13、根据所述第二损失值对所述待训练的学生模型进行迭代训练,获得初次蒸馏后的学生模型;

14、根据所述第一损失值和所述第二损失值,对所述预训练的多任务教师模型进行迭代训练,获得训练好的多任务教师模型。

15、在本公开的示例性实施例中,所述根据所述第二损失值对所述待训练的学生模型进行迭代训练,获得初次蒸馏后的学生模型,包括:

16、根据所述第二损失值对所述待训练的学生模型进行迭代训练,以更新所述待训练的学生模型对应的模型参数;

17、直至所述第二损失值满足预设的第一收敛条件,获得所述初次蒸馏后的学生模型。

18、在本公开的示例性实施例中,所述根据所述第一损失值和所述第二损失值,对所述预训练的多任务教师模型进行迭代训练,获得训练好的多任务教师模型,包括:

19、根据所述第一损失值和所述第二损失值,确定整体损失值;

20、根据所述整体损失值对所述预训练的多任务教师模型进行迭代训练,以更新所述预训练的多任务教师模型对应的模型参数;

21、直至所述整体损失值满足预设的第二收敛条件,获得所述训练好的多任务教师模型。

22、在本公开的示例性实施例中,当所述待训练的学生模型的数目为n时,n为大于1的整数;

23、所述利用所述训练样本集对所述预训练的多任务教师模型和待训练的学生模型进行整体蒸馏训练,获得训练好的多任务教师模型和初次蒸馏后的学生模型,包括:

24、从每个所述子集中选取训练样本;

25、将所述训练样本中的样本输入信息输入至所述预训练的多任务教师模型中,根据所述预训练的多任务教师模型对应的输出结果与所述真实标注标签之间的差异程度,获得第一损失值;

26、从n个所述待训练的学生模型中逐次筛选目标学习模型,每次筛选出来的所述目标学生模型对应的自然语言处理任务与所述训练样本所属的子集关联的目标自然语言处理任务一致;

27、将所述训练样本中的样本输入信息输入至所述目标学生模型中,根据所述目标学生模型对应的输出结果与所述真实标注标签之间的差异程度,确定所述目标学生模型对应的第三损失值;

28、根据所述目标学生模型对应的第三损失值对所述目标学生模型进行迭代训练,获得初次蒸馏后的学生模型;

29、根据所述第一损失值和所述目标学生模型对应的第三损失值,对所述预训练的多任务教师模型进行迭代训练,获得训练好的多任务教师模型。

30、在本公开的示例性实施例中,所述目标子集中包含多个目标训练样本,每个所述目标训练样本包含目标样本输入信息和所述目标样本输入信息对应的真实标注标签;

31、所述利用与所述初次蒸馏后的学生模型对应的自然语言处理任务相匹配的目标子集和所述训练好的多任务教师模型,对所述初次蒸馏后的学生模型进行再次蒸馏训练,获得蒸馏好的学生模型,包括:

32、从所述目标子集中选取目标训练样本;

33、将所述目标训练样本中的目标样本输入信息输入所述初次蒸馏后的学生模型中,获得所述初次蒸馏后的学生模型的第一输出结果;

34、将所述目标训练样本中的目标样本输入信息输入所述训练好的多任务教师模型中,获得所述训练好的多任务教师模型的第二输出结果;

35、根据所述第一输出结果、所述第二输出结果和所述真实标注标签之间的差异程度,确定蒸馏损失值;

36、根据所述蒸馏损失值对所述初次蒸馏后的学生模型进行迭代训练,获得所述蒸馏好的学生模型。

37、在本公开的示例性实施例中,所述根据所述第一输出结果、所述第二输出结果和所述真实标注标签之间的差异程度,确定蒸馏损失值,包括:

38、根据所述第一输出结果与所述真实标注标签之间的差异程度,确定第四损失值;

39、根据所述第一输出结果与所述第二输出结果之间的差异程度,确定第五损失值;

40、根据所述第四损失值和所述第五损失值,确定所述蒸馏损失值。

41、在本公开的示例性实施例中,所述根据所述第四损失值和所述第五损失值,确定所述蒸馏损失值,包括:

42、获取所述第四损失值与预设权重系数之间的乘积值;

43、根据所述乘积值与所述第五损失值之间的累加值,确定所述蒸馏损失值。

44、在本公开的示例性实施例中,所述根据所述蒸馏损失值对所述初次蒸馏后的学生模型进行迭代训练,获得所述蒸馏好的学生模型,包括:

45、根据所述蒸馏损失值对所述初次蒸馏后的学生模型进行迭代训练,以更新所述初次蒸馏后的学生模型对应的模型参数;

46、直至所述蒸馏损失值满足预设的第三收敛条件时,获得所述蒸馏好的学生模型。

47、根据本公开的第二方面,提供一种用于自然语言处理的模型蒸馏装置,包括:

48、预训练模块,用于获取训练样本集,利用所述训练样本集对待训练的多任务教师模型进行预训练,获得预训练的多任务教师模型;所述多任务教师模型用于执行m种自然语言处理任务,所述多任务教师模型包括与所述m种自然语言处理任务分别对应的m个子模型;所述训练样本集包含m个子集,不同的所述子集用于训练不同的所述子模型,m为大于1的整数;

49、整体蒸馏训练模块,用于利用所述训练样本集对所述预训练的多任务教师模型和待训练的学生模型进行整体蒸馏训练,获得训练好的多任务教师模型和初次蒸馏后的学生模型;所述待训练的学生模型对应的自然语言处理任务为所述m种自然语言处理任务中任一种;

50、再次蒸馏训练模块,用于利用与所述初次蒸馏后的学生模型对应的自然语言处理任务相匹配的目标子集和所述训练好的多任务教师模型,对所述初次蒸馏后的学生模型进行再次蒸馏训练,获得蒸馏好的学生模型。

51、根据本公开的第三方面,提供一种计算机存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现上述第一方面所述的用于自然语言处理的模型蒸馏方法。

52、根据本公开的第四方面,提供一种电子设备,包括:处理器;以及存储器,用于存储所述处理器的可执行指令;其中,所述处理器配置为经由执行所述可执行指令来执行上述第一方面所述的用于自然语言处理的模型蒸馏方法。

53、由上述技术方案可知,本公开示例性实施例中的用于自然语言处理的模型蒸馏方法、用于自然语言处理的模型蒸馏装置、计算机存储介质及电子设备至少具备以下优点和积极效果:

54、在本公开的一些实施例所提供的技术方案中,一方面,通过获取训练样本集,利用训练样本集对待训练的用于执行m种自然语言处理任务的多任务教师模型进行预训练,获得预训练的多任务教师模型,之后,利用训练样本集对预训练的多任务教师模型和待训练的学生模型进行整体蒸馏训练,获得训练好的多任务教师模型和初次蒸馏后的学生模型,能够在多任务教师模型还未训练结束之前便加入了针对学生模型的蒸馏过程,使得学生模型在该阶段便已经学习到了通用知识,提升了学生模型与多任务教师模型之间的关联度,便于后续快速蒸馏得到学生模型。另一方面,利用与初次蒸馏后的学生模型对应的自然语言处理任务相匹配的目标子集和训练好的多任务教师模型,对初次蒸馏后的学生模型进行再次蒸馏训练,获得蒸馏好的学生模型,能够在整体蒸馏训练的基础上叠加再次蒸馏训练的蒸馏效果,使得学生模型的学习效果更好,模型准确率更高,并且,在该阶段无需像现有技术那样花费大量的时间去蒸馏原始的学生网络,从而,也能够提升模型的蒸馏效率。

55、本公开应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1