本发明涉及语音分类,具体涉及语音分类模型训练方法、语音分类方法、装置及设备。
背景技术:
1、端到端深度神经网络已成为语音分类领域中的一种流行框架,与传统分类识别框架相比,它可以简化模型的构建和训练流程。在实际应用中,许多场合需要现有的语音分类模型既可以分类新场景下的语音数据,又能够保持原有场景的分类准确率。
2、相关技术中,由于新数据集的分布不一致等问题往往导致灾难性遗忘,即旧数据集的错误率急剧增大,因而多通过增量学习方法解决灾难性遗忘问题。目前的增量学习方法一般需要接触较多旧数据集或者与旧数据集一起联合训练,但在实际应用中往往受限且消耗的时间和计算成本很大。
技术实现思路
1、有鉴于此,本发明提供了语音分类模型训练方法、语音分类方法、装置及设备,以解决目前增量学习消耗时间和计算成本很大的问题。
2、第一方面,本发明提供了一种语音分类模型训练方法,方法包括:
3、获取训练数据,训练数据为新场景数据与分类错误数据;
4、获取教师模型与学生模型,教师模型为已训练好的语音分类模型,学生模型的结构与教师模型的结构相同;
5、将训练数据输入教师模型与学生模型,得到教师模型的预测结果、学生模型的预测结果、教师模型的中间层参数矩阵与学生模型的中间层参数矩阵;
6、基于教师模型的预测结果、学生模型的预测结果、教师模型的中间层参数矩阵与学生模型的中间层参数矩阵,计算得到学生模型与教师模型之间的交叉熵损失、第一蒸馏损失、梯度类激活损失与第二蒸馏损失;
7、基于交叉熵损失、第一蒸馏损失、梯度类激活损失与第二蒸馏损失,计算得到最终损失;
8、基于最终损失,对学生模型进行梯度回传,得到目标语音分类模型。
9、在本发明中,通过仅采用新场景数据及分类错误数据,对已有的语音分类模型进行微调,避免了使用较多旧数据集或者与旧数据集一起联合训练,进而降低了时间与计算成本。通过采用新场景数据及分类错误数据对语音分类模型进行训练,解决了该模型在实际应用中出现的分类错误情况,同时提升了在新场景下的分类效果。通过仅使用新场景数据及分类错误数据进行训练,克服了混合新旧场景数据集重新训练方法中遇到的训练数据不平衡的问题与数据量大导致的训练的速度较慢问题。通过联合交叉熵损失和蒸馏损失,在提升对新数据的分类效果的同时,保证模型在原场景下的效果,解决迁移学习及fine-tuning微调训练导致的灾难性遗忘问题。
10、在一种可选的实施方式中,基于教师模型的预测结果与学生模型的预测结果,计算得到学生模型与教师模型之间的第一蒸馏损失,包括:
11、对学生模型的预测结果进行激活,得到第一结果矩阵;
12、对教师模型的预测结果进行归一化处理,得到第二结果矩阵;
13、基于第一结果矩阵与第二结果矩阵,计算得到学生模型与教师模型之间的第一蒸馏损失。
14、在该方式中,通过计算得到学生模型与教师模型之间预测结果的蒸馏损失,对学生模型的预测结果进行限制,限制学生模型预测结果的变化。
15、在一种可选的实施方式中,基于教师模型的中间层参数矩阵与学生模型的中间层参数矩阵,计算得到学生模型与教师模型之间的梯度类激活损失,包括:
16、基于教师模型的中间层参数矩阵与学生模型的中间层参数矩阵,获取教师模型的编码器注意力矩阵与学生模型的编码器注意力矩阵;
17、基于教师模型的编码器注意力矩阵与学生模型的编码器注意力矩阵,计算得到教师模型的梯度热力图与学生模型的梯度热力图;
18、计算教师模型的梯度热力图与学生模型的梯度热力图之间的差值,得到学生模型与教师模型之间的梯度类激活损失。
19、在该方式中,通过使用编码器的注意力矩阵作为权重,计算加权后的梯度作为梯度热力图,计算得到教师模型与学生模型之间的热力图损失,通过教师模型与学生模型之间的热力图损失对学生模型进行训练,从而限制学生模型梯度的变化,减小学生模型参数的变化。
20、在一种可选的实施方式中,基于教师模型的中间层参数矩阵与学生模型的中间层参数矩阵,计算得到学生模型与教师模型之间的第二蒸馏损失,包括:
21、对教师模型的中间层参数矩阵与学生模型的中间层参数矩阵进行主成分提取,得到教师模型各中间层的重要特征与学生模型各中间层的重要特征;
22、计算教师模型各中间层的重要特征与学生模型各中间层的重要特征之间的最短变化路径;
23、计算最短变化路径均值与教师模型各中间层的重要特征之商,得到学生模型与教师模型之间的第二蒸馏损失。
24、在该方式中,通过计算得到学生模型与教师模型之间重要特征变化路径的蒸馏损失,对学生模型的参数变化进行限制。
25、在一种可选的实施方式中,基于交叉熵损失、第一蒸馏损失、梯度类激活损失与第二蒸馏损失,计算得到最终损失,包括:
26、通过公式
27、loss=ce_weight*ce_loss+(alpha*loss1+beta*loss2+gamma*loss3)*(1-ce_weight)计算得到最终损失;
28、其中,ce_weight为交叉熵损失对应权重,ce_loss为交叉熵损失,alpha为第一蒸馏损失对应权重,loss1为第一蒸馏损失,beta为梯度类激活损失对应权重,loss2为梯度类激活损失,gamma为第二蒸馏损失对应权重,loss3为第二蒸馏损失。
29、在该方式中,通过利用交叉熵损失优化语音对新数据的分类效果,同时利用蒸馏损失来降低模型在增量学习中的遗忘。其中蒸馏损失包括:同时使用基于预测结果的第一蒸馏损失、梯度加权类激活映射方法得到的梯度类激活损失和基于重要特征变化路径的第二蒸馏损失,分别限制模型的最终预测结果、模型训练时的梯度改变和模型参数的变化。
30、第二方面,本发明提供了一种语音分类方法,方法包括:
31、获取待分类语音数据;
32、将待分类语音数据输入语音分类模型中,得到待分类语音数据的语音分类结果,其中语音分类模型是利用第一方面任意一项的语音分类模型训练方法训练得到的。
33、在本发明中,通过利用新场景数据及分类错误数据训练得到的目标语音分类模型,可以得到更为准确的语音分类结果,更适配在新场景和分类错误情况下进行语音分类。
34、第三方面,本发明提供了一种语音分类模型训练装置,装置包括:
35、数据获取模块,用于获取训练数据,训练数据为新场景数据与分类错误数据;
36、模型获取模块,用于获取教师模型与学生模型,教师模型为已训练好的语音分类模型,学生模型的结构与教师模型的结构相同;
37、数据输入模块,用于将训练数据输入教师模型与学生模型,得到教师模型的预测结果、学生模型的预测结果、教师模型的中间层参数矩阵与学生模型的中间层参数矩阵;
38、损失计算模块,用于基于教师模型的预测结果、学生模型的预测结果、教师模型的中间层参数矩阵与学生模型的中间层参数矩阵,计算得到学生模型与教师模型之间的交叉熵损失、第一蒸馏损失、梯度类激活损失与第二蒸馏损失;
39、最终损失计算模块,用于基于交叉熵损失、第一蒸馏损失、梯度类激活损失与第二蒸馏损失,计算得到最终损失;
40、梯度回传模块,用于基于最终损失,对学生模型进行梯度回传,得到目标语音分类模型。
41、第四方面,本发明提供了一种语音分类装置,装置包括:
42、语音数据获取模块,用于获取待分类语音数据;
43、语音分类模块,用于将待分类语音数据输入语音分类模型中,分类得到待分类语音数据的语音分类结果,其中语音分类模型是利用第三方面任意一项的语音分类模型训练装置训练得到的。
44、第五方面,本发明提供了一种计算机设备,包括:存储器和处理器,存储器和处理器之间互相通信连接,存储器中存储有计算机指令,处理器通过执行计算机指令,从而执行上述第一方面及其对应的任一实施方式的语音分类模型训练方法或者执行第二方面中的语音分类方法。
45、第六方面,本发明提供了一种计算机可读存储介质,该计算机可读存储介质上存储有计算机指令,计算机指令用于使计算机执行上述第一方面及其对应的任一实施方式的语音分类模型训练方法或者执行第二方面中的语音分类方法。