基于知识蒸馏的模型训练方法、装置、设备及介质与流程

文档序号:36650489发布日期:2024-01-06 23:34阅读:29来源:国知局
基于知识蒸馏的模型训练方法、装置、设备及介质与流程

本技术涉及人工智能,特别涉及一种基于知识蒸馏的模型训练方法、装置、设备及介质。


背景技术:

1、知识蒸馏(knowledge distillation,kd)作为一种模型压缩技术,其核心思想是将大型神经网络(也被称为通用模型或教师模型)学习到的知识迁移到小型神经网络(也被称为垂直模型或学生模型)。换言之,知识蒸馏的目标是将通用模型学习到的知识迁移到垂直模型上,以使垂直模型在保持较高性能的同时具有较低的计算复杂性。

2、目前知识蒸馏技术已应用于众多领域,比如自然语言处理领域等。而无论应用于何种领域,如何实现更高效的知识迁移和更优的垂直模型性能一直是本领域的一个研究热点。即,如何通过一种新的模型训练方法来提高垂直模型的训练效率和性能是本领域的一个关注焦点。


技术实现思路

1、本技术实施例提供了一种基于知识蒸馏的模型训练方法、装置、设备及介质,提高了学生模型的训练效率和性能。所述技术方案如下所示。

2、一方面,提供了一种基于知识蒸馏的模型训练方法,所述方法包括如下步骤。

3、获取第一训练数据集和第二训练数据集;

4、基于所述第一训练数据集训练第一深度学习模型,得到第一教师模型;

5、在学生模型的训练过程中,根据第一类指标获取知识蒸馏强度;其中,所述知识蒸馏强度用于反映知识蒸馏过程中传递知识的程度;所述第一类指标包括所述学生模型的训练状态和模型性能、知识蒸馏过程中的温度参数以及模型训练参数中的至少一种;

6、在所述知识蒸馏强度的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型。

7、在一种可能的实现方式中,所述第一教师模型用于执行多种深度学习任务;所述第二训练数据集中的训练数据来源于所述第一训练数据集;所述第一训练数据集中的训练数据未被标注;所述第二训练数据集中的训练数据已被标注;或,

8、所述第一训练数据集和所述第二训练数据集为与所述目标深度学习任务匹配的同一数据集。

9、在一种可能的实现方式中,所述方法还包括:

10、在所述学生模型的训练过程中,根据第二类指标获取知识蒸馏率;其中,所述知识蒸馏率用于控制知识蒸馏的速度;所述第二类指标包括所述学生模型的训练进度、模型性能和模型训练参数中的至少一种;

11、所述在所述知识蒸馏强度的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

12、在所述知识蒸馏强度和所述知识蒸馏率的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练所述第二深度学习模型,得到所述学生模型。

13、在一种可能的实现方式中,所述基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

14、基于所述第一教师模型的输出概率分布和所述第二深度学习模型预测的输出概率,构建第一损失函数;

15、基于所述第二训练数据集和所述第二深度学习模型预测的输出概率,构建第二损失函数;

16、获取所述第一损失函数的第一权重和所述第二损失函数的第二权重;

17、基于所述第一权重和所述第二权重,对所述第一损失函数和所述第二损失函数进行加权,得到目标损失函数;

18、通过所述目标损失函数迭代获取损失值,直至满足训练停止条件,得到所述学生模型。

19、在一种可能的实现方式中,所述获取所述第一损失函数的第一权重和所述第二损失函数的第二权重,包括:

20、周期性获取所述学生模型在指定数据集上的性能变化;根据所述性能变化,确定所述第一权重和所述第二权重;其中,所述指定数据集为验证数据集或测试数据集;或,

21、基于所述模型训练参数,确定所述第一权重和所述第二权重;或,

22、在模型训练过程中配置权重参数;基于所述权重参数,在模型训练过程中确定所述第一权重和所述第二权重。

23、在一种可能的实现方式中,所述训练状态包括所述学生模型的训练进度和损失变化情况;所述根据第一类指标获取知识蒸馏过程中的知识蒸馏强度,包括:

24、根据所述第一类指标中的每一项分别获取知识蒸馏强度,得到与所述第一类指标中包含的指标项数匹配的多个知识蒸馏强度;

25、获取所述第一类指标中的每一项对知识蒸馏强度的影响权重;

26、根据所述第一类指标中每一项对应的知识蒸馏强度和影响权重,确定当前的知识蒸馏强度。

27、在一种可能的实现方式中,所述在所述知识蒸馏强度的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

28、获取所述第一教师模型的中间层特征表示和所述第二深度学习模型的中间层特征表示;

29、在所述知识蒸馏强度的约束下,基于所述第二训练数据集、所述第一教师模型的输出概率分布和中间层特征表示、所述第二深度学习模型的中间层特征表示,训练所述第二深度学习模型,得到所述学生模型。

30、在一种可能的实现方式中,所述方法还包括:

31、构建元学习任务,所述元学习任务包括支持集和查询集;其中,所述支持集包括多个子任务,每个子任务配置有不同的第一类指标;所述查询集包括用于训练所述子任务的任务样本;

32、基于所述支持集训练所述元学习模型以及基于所述查询集进行模型性能评估,得到训练好的元学习模型;

33、在所述学生模型的训练过程中,调用训练好的元学习模型,基于所述目标深度学习任务的任务特点或所述第二训练数据集的数据分布,获取知识蒸馏过程中的知识蒸馏强度。

34、在一种可能的实现方式中,所述方法还包括:

35、获取多个训练数据集;

36、基于所述多个训练数据集训练多个深度学习模型,得到多个第二教师模型;

37、所述基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

38、根据所述学生模型的训练进度和模型性能,获取所述第一教师模型和所述多个第二教师模型中每个模型的权重;

39、基于获取到的每个模型的权重,对所述第一教师模型和所述多个第二教师模型的输出概率进行加权,得到融合后的输出概率分布;

40、基于所述第二训练数据集和所述融合后的输出概率分布,训练所述第二深度学习模型,得到所述学生模型。

41、在一种可能的实现方式中,所述方法还包括:

42、在新增训练数据的情况下,基于新增的训练数据对所述第一教师模型进行模型微调,得到更新后的第一教师模型;

43、所述基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型,包括:

44、基于所述第二训练数据集和所述更新后的第一教师模型的输出概率分布,训练所述第二深度学习模型,得到所述学生模型。

45、在一种可能的实现方式中,所述方法还包括:

46、基于新增的训练数据,对训练好的学生模型进行模型微调,得到更新后的学生模型。

47、在一种可能的实现方式中,所述方法还包括:

48、接收与所述目标深度学习任务匹配的输入数据;其中,所述输入数据为文本、图像、音频和视频中的至少一种;

49、调用所述学生模型根据所述输入数据执行所述目标深度学习任务。

50、另一方面,提供了一种基于知识蒸馏的模型训练装置,所述装置包括如下模块。

51、第一获取模块,被配置为获取第一训练数据集和第二训练数据集;

52、第一训练模块,被配置为基于所述第一训练数据集训练第一深度学习模型,得到第一教师模型;

53、第二获取模块,被配置为在学生模型的训练过程中,根据第一类指标获取知识蒸馏强度;其中,所述知识蒸馏强度用于反映知识蒸馏过程中传递知识的程度;所述第一类指标包括所述学生模型的训练状态和模型性能、知识蒸馏过程中的温度参数以及模型训练参数中的至少一种;

54、第二训练模块,被配置为在所述知识蒸馏强度的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练用于执行目标深度学习任务的第二深度学习模型,得到所述学生模型。

55、在一种可能的实现方式中,所述第一教师模型用于执行多种深度学习任务;所述第二训练数据集中的训练数据来源于所述第一训练数据集;所述第一训练数据集中的训练数据未被标注;所述第二训练数据集中的训练数据已被标注;或,所述第一训练数据集和所述第二训练数据集为与所述目标深度学习任务匹配的同一数据集。

56、在一种可能的实现方式中,所述第一获取模块,还被配置为在所述学生模型的训练过程中,根据第二类指标获取知识蒸馏率;其中,所述知识蒸馏率用于控制知识蒸馏的速度;所述第二类指标包括所述学生模型的训练进度、模型性能和模型训练参数中的至少一种;

57、所述第二训练模块,被配置为在所述知识蒸馏强度和所述知识蒸馏率的约束下,基于所述第二训练数据集和所述第一教师模型的输出概率分布,训练所述第二深度学习模型,得到所述学生模型。

58、在一种可能的实现方式中,所述第二训练模块,被配置为:

59、基于所述第一教师模型的输出概率分布和所述第二深度学习模型预测的输出概率,构建第一损失函数;

60、基于所述第二训练数据集和所述第二深度学习模型预测的输出概率,构建第二损失函数;

61、获取所述第一损失函数的第一权重和所述第二损失函数的第二权重;

62、基于所述第一权重和所述第二权重,对所述第一损失函数和所述第二损失函数进行加权,得到目标损失函数;

63、通过所述目标损失函数迭代获取损失值,直至满足训练停止条件,得到所述学生模型。

64、在一种可能的实现方式中,所述第二训练模块,被配置为:

65、周期性获取所述学生模型在指定数据集上的性能变化;根据所述性能变化,确定所述第一权重和所述第二权重;其中,所述指定数据集为验证数据集或测试数据集;或,

66、基于所述模型训练参数,确定所述第一权重和所述第二权重;或,

67、在模型训练过程中配置权重参数;基于所述权重参数,在模型训练过程中确定所述第一权重和所述第二权重。

68、在一种可能的实现方式中,所述训练状态包括所述学生模型的训练进度和损失变化情况;所述第二获取模块,被配置为:

69、根据所述第一类指标中的每一项分别获取知识蒸馏强度,得到与所述第一类指标中包含的指标项数匹配的多个知识蒸馏强度;

70、获取所述第一类指标中的每一项对知识蒸馏强度的影响权重;

71、根据所述第一类指标中每一项对应的知识蒸馏强度和影响权重,确定当前的知识蒸馏强度。

72、在一种可能的实现方式中,所述第二训练模块,被配置为:

73、获取所述第一教师模型的中间层特征表示和所述第二深度学习模型的中间层特征表示;

74、在所述知识蒸馏强度的约束下,基于所述第二训练数据集、所述第一教师模型的输出概率分布和中间层特征表示、所述第二深度学习模型的中间层特征表示,训练所述第二深度学习模型,得到所述学生模型。

75、在一种可能的实现方式中,所述装置还包括:第三获取模块;所述第三获取模块,被配置为:

76、构建元学习任务,所述元学习任务包括支持集和查询集;其中,所述支持集包括多个子任务,每个子任务配置有不同的第一类指标;所述查询集包括用于训练所述子任务的任务样本;

77、基于所述支持集训练所述元学习模型以及基于所述查询集进行模型性能评估,得到训练好的元学习模型;

78、在所述学生模型的训练过程中,调用训练好的元学习模型,基于所述目标深度学习任务的任务特点或所述第二训练数据集的数据分布,获取知识蒸馏过程中的知识蒸馏强度。

79、在一种可能的实现方式中,所述第一获取模块,还被配置为获取多个训练数据集;

80、所述第一训练模块,还被配置为基于所述多个训练数据集训练多个深度学习模型,得到多个第二教师模型;

81、所述第二训练模块,被配置为:

82、根据所述学生模型的训练进度和模型性能,获取所述第一教师模型和所述多个第二教师模型中每个模型的权重;

83、基于获取到的每个模型的权重,对所述第一教师模型和所述多个第二教师模型的输出概率进行加权,得到融合后的输出概率分布;

84、基于所述第二训练数据集和所述融合后的输出概率分布,训练所述第二深度学习模型,得到所述学生模型。

85、在一种可能的实现方式中,所述第一训练模块,还被配置为在新增训练数据的情况下,基于新增的训练数据对所述第一教师模型进行模型微调,得到更新后的第一教师模型;

86、所述第二训练模块,还被配置为基于所述第二训练数据集和所述更新后的第一教师模型的输出概率分布,训练所述第二深度学习模型,得到所述学生模型。

87、在一种可能的实现方式中,所述第二训练模块,还被配置为基于新增的训练数据对所述学生模型进行模型微调,得到更新后的学生模型。

88、在一种可能的实现方式中,所述装置还包括:处理模块;

89、所述处理模块,被配置为接收与所述目标深度学习任务匹配的输入数据;其中,所述输入数据为文本、图像、音频和视频中的至少一种;调用所述学生模型根据所述输入数据执行所述目标深度学习任务。

90、另一方面,提供了一种计算机设备,所述设备包括处理器和存储器,所述存储器中存储有至少一条程序代码,所述至少一条程序代码由所述处理器加载并执行以实现上述的基于知识蒸馏的模型训练方法。

91、另一方面,提供了一种计算机可读存储介质,所述存储介质中存储有至少一条程序代码,所述至少一条程序代码由处理器加载并执行以实现上述的基于知识蒸馏的模型训练方法。

92、另一方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机程序代码,该计算机程序代码存储在计算机可读存储介质中,计算机设备的处理器从计算机可读存储介质读取该计算机程序代码,处理器执行该计算机程序代码,使得该计算机设备执行上述的基于知识蒸馏的模型训练方法。

93、本技术实施例提供的基于知识蒸馏的模型训练方案,在学生模型的训练过程中能够自动调整知识蒸馏强度,这使得学生模型能够在知识蒸馏过程的不同阶段获得适当的监督信号,从而能够提高学生模型的训练效率和模型性能。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1