本发明涉及人工智能,尤其涉及一种基于知识增强的连续学习软标签构建方法。
背景技术:
1、尽管深度学习在分类和检测领域取得了显著成就,但大多数算法都是基于封闭环境中的固定类别。真实的场景是一个开放和动态的环境,总是有新的类别出现。当神经网络模型应用于实际任务时,需要在新的数据集上进行更新。如果我们直接对模型进行微调,那么前一项任务的准确性就会降低,这种情况被称为灾难性遗忘。直接联合训练将造成巨大的训练成本。持续学习就是为了解决这个问题。连续学习的目标是在不忘记旧任务的情况下学习新任务,它已被应用于许多领域。
2、目前,连续学习的方法已经取得了一些进展。但他们中的大多数人都专注于学习策略。在图像分类任务中,他们遵循多分类问题的默认配置,并使用一个基于softmax损失的热编码器。这些方法将神经网络模型输出与groundtruth的一次性编码相匹配,称为硬标签。但对于连续学习任务,多个任务按顺序出现,并且类别是逐步学习的。由于缺乏完整的先前数据,无法通过前一类和当前类之间的关联,而导致遗忘的问题。
技术实现思路
1、有鉴于此,本发明提供了一种基于知识增强的连续学习软标签构建方法,以解决现有技术中由于缺乏先前数据的完整性,而无法考虑前一类和当前类之间的关联,从而导致遗忘的技术问题。
2、本发明提供了一种基于知识增强的连续学习软标签构建方法,包括:
3、s1.随机初始化语义软标签,计算语义gram矩阵,通过所述语义gram矩阵、词向量gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数;
4、以及随机初始化知识蒸馏软标签,计算知识蒸馏gram矩阵,通过所述知识蒸馏gram矩阵、嵌入gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数;
5、s2.将所述优化后的语义软标签损失函数与所述优化后的知识蒸馏软标签损失函数结合,获得总损失函数;
6、s3.采用所述总损失函数进行新任务的训练。
7、进一步地,所述随机初始化语义软标签,通过所述语义gram矩阵、词向量gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数,包括:
8、a1.随机初始化所述语义软标签,确定不同类别的语义软标签之间的相关性,得到所述语义gram矩阵;
9、a2.采用外部词向量模型获得对应类别的词向量,确定不同类别的词向量之间的相关性,得到所述词向量gram矩阵,其中,所述外部词向量模型为clip或bert;
10、a3.计算语义gram矩阵和词向量gram矩阵之间的欧几里得距离,获得中间过程语义软标签;
11、a4.采用softmax函数对所述中间过程语义软标签归一化,获得优化后的中间过程语义软标签,对于每个类别,采用该类别优化后的中间过程语义软标签平滑相应的原始硬标签,获得该类别平滑后的语义软标签;a5.基于所有类别平滑后的语义软标签,获得所述优化后的语义软标签损失函数。
12、进一步地,语义gram矩阵和词向量gram矩阵之间的欧几里得距离的表达式如下:
13、
14、其中,表示中间过程的语义损失函数,表示语义gram矩阵,表示词向量gram矩阵。
15、进一步地,优化后的中间过程语义软标签的表达式如下:
16、
17、其中,表示所述优化后的中间过程语义软标签,k表示中间过程语义软标签,表示将中间过程语义软标签除以温度系数t,进行数学上的操作,q(x)表示硬标签。
18、进一步地,所述优化后的语义软标签损失函数的表达式如下:
19、
20、其中,表示超参数,q(x)表示硬标签,p(x)表示相应类别神经网络模型的输出,表示硬标签和相应类别神经网络模型的输出的kl散度,表示优化后的中间过程语义软标签和相应类别神经网络模型输出的kl散度。
21、进一步地,所述随机初始化知识蒸馏软标签,通过所述知识蒸馏gram矩阵、嵌入gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数,包括:
22、b1.随机初始化知识蒸馏软标签,确定不同类别的知识蒸馏软标签之间的相关性,得到所述知识蒸馏gram矩阵;
23、b2.将旧任务和新任务不同类别的聚类中心输入至旧神经网络模型,获取每个类别的嵌入特征,确定不同类别嵌入特征之间的相关性,得到所述嵌入gram矩阵;
24、b3.计算所述知识蒸馏gram矩阵和嵌入gram矩阵之间的欧几里得距离,获得中间过程知识蒸馏软标签;
25、b4.采用softmax函数对所述中间过程知识蒸馏软标签归一化,获得优化后的中间过程知识蒸馏软标签,对于每个类别,采用该类别优化后的中间过程知识蒸馏软标签平滑相应原始的硬标签,获得该类别平滑后的知识蒸馏软标签;
26、b5.基于所有类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数。
27、进一步地,所述知识蒸馏gram矩阵和嵌入gram矩阵之间的欧几里得距离的表达式如下:
28、
29、其中,表示中间过程的知识蒸馏损失函数,表示知识蒸馏gram矩阵,表示嵌入gram矩阵。
30、进一步地,所述优化后的中间过程知识蒸馏软标签的表达式如下:
31、
32、其中,表示优化后的中间过程知识蒸馏软标签,表示中间过程知识蒸馏软标签,温度t被加到softmax中以缩放整体分布。
33、进一步地,优化后的知识蒸馏软标签损失函数的表达式如下:
34、
35、其中,表示超参数,表示优化后的中间过程知识蒸馏标签和相应类别神经网络模型的输出的kl散度,表示优化后的中间过程知识蒸馏软标签。
36、进一步地,所述总损失函数的表达式如下:
37、
38、其中,表示总损失函数,表示优化后的中间过程知识蒸馏标签和相应类别神经网络模型的输出的kl散度。
39、本发明与现有技术相比存在的有益效果是:
40、1、本发明的方法通过对标签进行平滑有助于提高神经网络模型对新任务中样本的泛化能力;
41、2、本发明采用知识嵌入的方法来反映类别相关性,有助于新任务的学习和对旧任务学习的关联信息;
42、3、本发明的方法获得的总损失函数,通过新任务的学习和对旧任务的学习之间的关联性,解决了由于缺乏先前数据的完整性,而无法考虑前一类和当前类之间的关系的问题,避免了灾难性遗忘。
1.一种基于知识增强的连续学习软标签构建方法,其特征在于,包括:
2.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述随机初始化语义软标签,通过所述语义gram矩阵、词向量gram矩阵和相应类别平滑后的语义软标签,获得优化后的语义软标签损失函数包括:
3.根据权利要求2所述的基于知识增强的连续学习软标签构建方法,其特征在于,
4.根据权利要求2所述的基于知识增强的连续学习软标签构建方法,其特征在于,优化后的中间过程语义软标签的表达式如下:
5.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述优化后的语义软标签损失函数的表达式如下:
6.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述随机初始化知识蒸馏软标签,通过所述知识蒸馏gram矩阵、嵌入gram矩阵和相应类别平滑后的知识蒸馏软标签,获得优化后的知识蒸馏软标签损失函数,包括:
7.根据权利要求6所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述知识蒸馏gram矩阵和嵌入gram矩阵之间的欧几里得距离的表达式如下:
8.根据权利要求5所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述优化后的中间过程知识蒸馏软标签的表达式如下:
9.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,优化后的知识蒸馏软标签损失函数的表达式如下:
10.根据权利要求1所述的基于知识增强的连续学习软标签构建方法,其特征在于,所述总损失函数的表达式如下: