模型训练方法、装置、电子设备、存储介质及程序产品与流程

文档序号:37803268发布日期:2024-04-30 17:13阅读:7来源:国知局
模型训练方法、装置、电子设备、存储介质及程序产品与流程

本技术涉及计算机,尤其涉及一种基于人工智能的模型训练方法、装置、电子设备、计算机可读存储介质及计算机程序产品。


背景技术:

1、在对业务数据进行分类处理时,常用分类模型实现分类处理,即通过分类模型的特征处理网络提取业务数据的数据特征,通过分类模型的分类网络对数据特征进行分类。在使用源域数据对分类模型进行训练后,分类模型能够在源域样本上进行精确的分类处理。随着时间的推移,业务数据不断增加,其中会出现新增的不同类别的目标域数据,此时需要充分利用这部分无标签的目标域数据对分类模型进行不断更新,使分类模型能够不断适应新增的目标域数据,以保证分类模型在目标域数据上的分类结果依然准确。

2、然而,由于目标域数据和源域数据的数据集分布存在显著差异,导致使用源域数据训练的分类模型在目标域数据的泛化能力不足,主要是因为分类模型的特征处理网络无法对目标域数据进行精确的特征提取处理。已有技术中使用最小化最大平均差异准则或对抗性训练,来缩小目标域数据和源域数据之间的差异,但已有技术都是从数据样本角度和目标业务和原业务之间的差异角度进行设计,未考虑样本在由源域与目标域组成的整体数据集中的特征信息,导致分类模型的特征处理网络在目标域数据上的泛化能力弱。


技术实现思路

1、本技术实施例提供一种基于人工智能的模型训练方法、装置、电子设备、计算机可读存储介质及计算机程序产品,能够提高分类模型的特征处理网络在目标域数据上的泛化能力,并提高对应分类服务的计算资源利用率。

2、本技术实施例的技术方案是这样实现的:

3、本技术实施例提供一种基于人工智能的模型训练方法,所述方法包括:

4、获取样本集合,其中,所述样本集合包括源域样本以及目标域样本,所述源域样本与所述目标域样本来源于不同业务领域;

5、通过特征处理网络对所述样本集合中每个样本进行样本特征提取处理,得到每个所述样本的样本特征,并通过所述特征处理网络对所述样本集合中每个所述样本的样本特征进行图特征提取处理,得到每个所述样本的图结构特征;

6、基于每个所述样本的图结构特征,确定每个样本对的样本间转移概率,并基于每个所述目标域样本的图结构特征对每个所述目标域样本进行分类处理,得到每个所述目标域样本的预测标签;

7、基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定每个所述样本对的标签损失;

8、基于每个所述样本对的样本间转移概率以及每个所述样本对的标签损失,确定样本差异损失,并基于所述样本差异损失对所述特征处理网络进行更新,得到经过更新的特征处理网络。

9、本技术实施例提供一种基于人工智能的模型训练方法,所述方法包括:

10、获取待分类数据;

11、通过特征处理网络对所述待分类数据进行特征提取处理,得到所述待分类数据的样本特征,并通过所述特征处理网络对所述待分类数据的样本特征进行图特征提取处理,得到所述待分类数据的图结构特征;

12、其中,所述特征处理网络是通过本技术实施例提供的基于人工智能的模型训练方法得到的;

13、基于所述待分类数据的图结构特征,对所述待分类数据进行分类处理,得到所述待分类数据的预测标签。

14、本技术实施例提供一种基于人工智能的模型训练装置,包括:

15、获取模块,用于获取样本集合,其中,所述样本集合包括源域样本以及目标域样本,所述源域样本与所述目标域样本来源于不同业务领域;

16、特征处理模块,用于通过特征处理网络对所述样本集合中每个样本进行样本特征提取处理,得到每个所述样本的样本特征,并通过所述特征处理网络对所述样本集合中每个所述样本的样本特征进行图特征提取处理,得到每个所述样本的图结构特征;

17、分类处理模块,用于基于每个所述样本的图结构特征,确定每个样本对的样本间转移概率,并基于每个所述目标域样本的图结构特征对每个所述目标域样本进行分类处理,得到每个所述目标域样本的预测标签;

18、损失处理模块,用于基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定每个所述样本对的标签损失;

19、更新处理模块,用于基于每个所述样本对的样本间转移概率以及每个所述样本对的标签损失,确定样本差异损失,并基于所述样本差异损失对所述特征处理网络进行更新,得到经过更新的特征处理网络。

20、在上述方案中,更新处理模块,用于基于初始化的样本特征网络获取预训练三联体损失,并基于所述预训练三联体损失对所述初始化的样本特征网络进行更新,得到经过一次更新的样本特征网络;基于初始化的源域分类网络获取预训练分类器损失,并基于所述预训练分类器损失对所述经过一次更新的样本特征网络以及初始化的图特征网络进行更新,得到经过二次更新的样本特征网络以及经过一次更新的图特征网络;基于初始化的判别网络获取预训练判别器损失,并基于所述预训练判别器损失对所述经过二次更新的样本特征网络以及所述经过一次更新的图特征网络进行更新,得到用于构成所述特征处理网络的样本特征网络以及用于构成所述特征处理网络的图特征网络;基于所述预训练分类器损失对所述初始化的源域分类网络进行更新,得到源域分类网络,其中,所述源域分类网络用于基于每个所述目标域样本的图结构特征对每个所述目标域样本进行分类处理;基于所述预训练判别器损失对所述初始化的判别网络进行更新,得到判别网络,其中,所述判别网络用于对所述目标域样本以及所述源域样本进行域判别处理。

21、在上述方案中,更新处理模块,用于通过初始化的样本特征网络对每个所述源域样本进行样本特征提取处理,得到每个所述源域样本的预训练样本特征;针对每个所述源域样本执行以下处理:获取与所述源域样本具有相同真实标签的正面源域样本以及与所述源域样本具有不同真实标签的负面源域样本;获取所述源域样本的预训练样本特征与所述负面源域样本的预训练样本特征之间的第一特征距离,以及所述源域样本的预训练样本特征与所述正面源域样本的预训练样本特征之间的第二特征距离;获取与所述第一特征距离负相关,且与所述第二特征距离正相关的预训练三联体损失。

22、在上述方案中,更新处理模块,用于通过初始化的图特征网络对所述源域样本的预训练样本特征进行图特征提取处理,得到所述源域样本的预训练图结构特征;通过初始化的源域分类网络对所述源域样本执行基于所述预训练图结构特征的分类处理,得到所述源域样本的预测标签;基于所述源域样本的预测标签与所述源域样本的真实标签之间的差异,确定所述预训练分类器损失。

23、在上述方案中,更新处理模块,用于通过所述初始化的判别网络对所述源域样本进行基于所述源域样本的预训练图结构特征的域判别处理,得到所述源域样本的预训练域判别结果;通过所述初始化的判别网络对所述目标域样本进行基于所述目标域样本的预训练图结构特征的域判别处理,得到所述目标域样本的预训练域判别结果;对所述源域样本的预训练域判别结果以及所述目标域样本的预训练域判别结果进行融合处理,得到所述预训练判别器损失。

24、在上述方案中,特征处理模块,用于针对每个所述源域样本,结合所述样本集合中每个所述源域样本的样本特征,对所述源域样本的样本特征进行图特征提取处理,得到所述源域样本的图结构特征;针对每个所述目标域样本,结合所述样本集合中每个所述目标域样本的样本特征,对所述目标域样本的样本特征进行图特征提取处理,得到所述目标域样本的图结构特征。

25、在上述方案中,特征处理模块,用于获取所述源域样本的样本特征与所述样本集合中每个所述源域样本的样本特征之间的第三特征距离;以所述样本集合中每个所述源域样本的样本特征为权重,将对应所述样本集合中每个所述源域样本的第三特征距离进行融合处理,得到对应所述源域样本的图结构特征。

26、在上述方案中,分类处理模块,用于对所述样本集合进行两两遍历处理,得到多个样本对;获取每个所述样本对对应的两个图结构特征之间的第四特征距离;将多个所述样本对的第四特征距离进行融合处理,得到第一融合结果;获取与每个所述样本对的第四特征距离正相关,且与所述第一融合结果负相关的数值作为每个所述样本对的样本间转移概率。

27、在上述方案中,损失处理模块,用于基于每个所述源域样本的真实标签以及每个所述目标域样本的预测标签,确定对应每个类别的期望;对所述样本集合进行两两遍历处理,得到多个样本对,其中,所述样本对包括第一样本以及第二样本;针对每个所述样本对执行以下处理:从所述第一样本的真实标签或者预测标签中提取所述第一样本属于每个类别的第一概率,并从所述第二样本的真实标签或者预测标签中提取所述第二样本属于每个类别的第二概率;针对所述类别,获取与对应所述类别的第一概率以及第二概率正相关,且与所述类别的期望负相关的类别损失;对多个所述类别分别对应的类别损失进行融合处理,得到所述样本对的标签损失。

28、在上述方案中,损失处理模块,用于针对每个所述类别执行以下处理:从每个所述源域样本的真实标签中提取每个所述源域样本属于所述类别的第三概率,并从每个所述目标域样本的预测标签中提取每个所述目标域样本属于所述类别的第四概率;对多个所述第三概率以及多个所述第四概率进行融合处理,得到融合概率,并将所述融合概率与所述类别对应的标签值进行相乘,得到对应所述类别的期望。

29、在上述方案中,更新处理模块,用于针对每个所述样本对执行以下处理:对所述样本对的样本间转移概率以及所述样本对的标签损失进行相乘处理,得到所述样本对的子差异损失;对多个所述样本对的子差异损失进行融合处理,得到第二融合结果;获取与所述第二融合结果负相关的样本差异损失。

30、在上述方案中,更新处理模块,用于获取训练三联体损失;基于源域分类网络获取训练分类器损失;基于判别网络获取训练判别器损失;对所述训练三联体损失、所述训练分类器损失、所述训练判别器损失以及所述样本差异损失进行融合处理,得到综合损失;基于所述综合损失对所述特征处理网络进行更新,得到经过更新的特征处理网络;基于所述综合损失对所述判别网络、所述源域分类网络进行更新,得到更新后的判别网络以及更新后的源域分类网络。

31、本技术实施例提供一种基于人工智能的数据处理装置,包括:

32、获取模块,用于获取待分类数据;

33、特征处理模块,用于通过特征处理网络对所述待分类数据进行特征提取处理,得到所述待分类数据的样本特征,并通过所述特征处理网络对所述待分类数据的样本特征进行图特征提取处理,得到所述待分类数据的图结构特征;

34、其中,所述特征处理网络是通过本技术实施例提供的基于人工智能的模型训练方法得到的;

35、分类处理模块,用于基于所述待分类数据的图结构特征,对所述待分类数据进行分类处理,得到所述待分类数据的预测标签。

36、本技术实施例提供一种电子设备,所述电子设备包括:

37、存储器,用于存储计算机可执行指令;

38、处理器,用于执行所述存储器中存储的计算机可执行指令时,实现本技术实施例提供的基于人工智能的模型训练方法。

39、本技术实施例提供一种计算机可读存储介质,存储有计算机可执行指令,用于被处理器执行时实现本技术实施例提供的基于人工智能的模型训练方法。

40、本技术实施例提供一种计算机程序产品,包括计算机可执行指令,所述计算机可执行指令被处理器执行时,实现本技术实施例提供的基于人工智能的模型训练方法。

41、本技术实施例具有以下有益效果:

42、获取来源于不同业务领域的源域样本以及目标域样本,源域样本与目标域样本,通过特征处理网络对每个样本进行样本特征提取处理,得到每个样本的样本特征,以对每个样本的个体特征信息进行表征,并通过特征处理网络对样本集合中每个样本的样本特征进行图特征提取处理,得到每个样本的图结构特征,以对每个样本与相同域中的其他样本之间的关联特征信息进行表征,基于每个样本的图结构特征,确定每个样本对的样本间转移概率,以确定样本对中的两个样本之间的相似程度,并基于每个目标域样本的图结构特征对每个目标域样本进行分类处理,得到每个目标域样本的预测标签,以确定目标域样本的分类结果,基于每个源域样本的真实标签以及每个目标域样本的预测标签,确定每个样本对的标签损失,以表征样本对中两个样本之间的分类结果差异,基于每个样本对的样本间转移概率以及每个样本对的标签损失,确定样本差异损失,将样本之间的特征差异融合为样本差异损失,并基于样本差异损失对特征处理网络进行更新,得到经过更新的特征处理网络,使特征处理网络学习到相同域的样本之间以及不同域的样本之间的特征差异与特征共性,使特征处理网络能够对目标域样本进行精确的特征提取处理,从而提高分类模型的特征处理网络在目标域上的泛化能力,并提高对应分类服务的计算资源利用率。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1