一种基于图注意力机制的多标记分类方法及系统

文档序号:37770734发布日期:2024-04-25 10:57阅读:9来源:国知局
一种基于图注意力机制的多标记分类方法及系统

本发明属于机器学习和深度学习,特别涉及一种基于图注意力机制的多标记分类方法及系统。


背景技术:

1、近些年来,多标记学习(multi-label learning)在生物医学、图像分类、机器学习等人工智能邻域引起了众多学者高度的关注。传统的单标记学习(single-labellearning)是指一个对象和一个标记相关联,而多标记学习是指一个对象可以和多个标记进行关联。对于多标记任务来说,由于多标记的对象具有多语义的特点,传统的单标记学习并不适合多标记学习任务,因此多标记学习算法引起了众多学者的广泛关注。

2、传统多标记算法分为两大类:首先是"问题转换"方法,将多标记问题转为单标记问题。例如,binary relevance(br)转化为二分类,calibrated label ranking转化为标记排序,random k-labelsets转化为多类别分类。其次是"算法适应"方法,改进单标记算法以适应多标记学习,如ml-knn改进knn,rank-svm利用支持向量机核学习,leda改造贝叶斯网络。然而,这些方法未充分考虑不同标记之间的差异和判别特征的不同,缺乏标记特定特征学习(label-specific feature learning)。标记特定特征学习致力于捕捉每个类标记的最相关和最具判别性的特征,因此备受研究者关注。这一方法有望更好地建模不同标记之间的关系,提高多标记学习的性能。

3、深度学习在传统分类任务中取得了巨大成功,能够通过构建深层潜在空间,来更好的捕捉特征与标记依赖关系。近年来,深度学习广泛应用于多标记分类,研究者利用其捕获标记和特征的关联,还研究如何捕获标记相关性。例如,使用图神经网络(gnn)嵌入标记,将标记语义和相关性融入潜在空间,指导多标记分类。但传统gnn通常只考虑0或1的连接关系,未充分利用标记相关性强弱。实际中,标记关联更复杂,如物理与数学成绩相关性高于物理与化学成绩。此外,标记在嵌入网络结构可能有语义损失,但多数网络忽略此问题,邻接特征信息的损失可能会对多标记分类网络做出错误的指导,从而降低网络的正确率。

4、另外,在以往多标记分类研究中通常假设标记权重是均匀分布的,但实际生活中每个对象对应的标记权重可能是不同的。例如,在图像标注中,不同标记在整幅图像中的占据比重可能不均衡。


技术实现思路

1、基于此,本发明实施例当中提供了一种基于图注意力机制的多标记分类方法及系统,以解决现有技术中,标记嵌入网络结构中的语义信息损失导致的网络正确率低的问题。

2、本发明实施例的第一方面提供了一种基于图注意力机制的多标记分类方法,所述方法包括:

3、将训练集中归一化处理后的多标记分类数据的特征以及标记图数据作为神经网络模型的输入,多标记分类数据的标记评分作为输出,以训练神经网络模型,得到目标神经网络模型,具体包括:获取多标记分类数据,并按照预设比例,将所述多标记分类数据划分为训练集和测试集,去除训练集和测试集中的异常数据,提取训练集和测试集中剩余的多标记分类数据的特征,并进行归一化处理;

4、提取训练集和测试集中剩余的多标记分类数据的标记,并进行编码,生成每个图数据中的节点特征,并确定两两节点特征的标记相关性,以构建由节点特征和标记相关性组成的标记图数据;

5、利用多层图注意力层及标记图数据来构建标记嵌入模块,并在所述标记嵌入模块的每一层均聚合所述节点特征的邻接特征信息;

6、通过计算每个标记和其它标记之间的平均条件概率,得到每个标记的重要度评分,并根据每个标记的重要度评分,生成每个标记的标记权重,其中,每个标记的重要度评分的计算公式为:

7、;

8、li和lj表示为两个不同标记的标记向量,表示为出现lj条件下出现li的概率,m表示为标签个数;

9、生成每个标记的标记权重的计算公式为:

10、;

11、lwi表示为标记权重;

12、提取训练集和测试集中剩余的多标记分类数据的多标记数据对象,并将多标记数据对象转化为多标记图数据对象,其中,利用k-means聚类将多标记数据对象进行划分,使得属于相同类别多标记数据对象之间产生一条边,并根据边将多标记数据对象转化为多标记图数据对象,边采用邻接矩阵表示;

13、利用多层图卷积层,将各多标记数据对象按照边的关系聚合邻接节点特征,得到一原始特征空间;

14、利用一个多层感知机将所述原始特征空间映射到潜在特征空间;

15、获取通过所述标记嵌入模块得来的邻接特征信息,并根据所述邻接特征信息指导所述潜在特征空间,以生成标记特定特征空间;

16、将所述标记特定特征空间内的标记特定特征生成对应标记的置信度评分,以得到各多标记分类数据的各标记的置信度评分;

17、将训练集中归一化处理后的多标记分类数据的特征以及标记图数据作为神经网络模型的输入,各多标记分类数据的各标记的置信度评分作为输出,同时,设置优化器、训练轮数及神经网络模型参数,以训练神经网络模型,得到目标神经网络模型,所述目标神经网络模型的分类损失根据标记权重确定;

18、将待分类的数据输入目标神经网络模型中,根据输出的标记评分来划分相关标记和无关标记。

19、进一步的,所述提取训练集和测试集中剩余的多标记分类数据的标记,并进行编码,生成每个图数据中的节点特征,并确定两两节点特征的标记相关性,以构建由节点特征和标记相关性组成的标记图数据的步骤中,确定两两节点特征的标记相关性的表达式为:

20、;

21、其中,和表示为两个不同标记的标记向量,表示和的相关性。

22、进一步的,所述邻接矩阵表示为:

23、;

24、其中,kmeans表示为k-means聚类函数,xi表示为多标记数据对象i,xj表示为多标记数据对象j。

25、进一步的,所述利用多层图注意力层及标记图数据来构建标记嵌入模块的步骤中,每层图注意力层的表达式为:

26、;

27、其中,表示为sigmoid函数,ni表示为标记i的邻居标记,表示为标记i和标记j之间的注意力评分,w表示为对应权重矩阵,表示为第n层标记i对应节点的特征向量;

28、的表达式为:

29、;

30、其中,表示为待学习的注意力权重的向量,||表示为拼接操作,表示为标记i对应节点的特征向量,表示为标记j对应节点的特征向量,表示为标记g对应节点的特征向量。

31、进一步的,所述利用多层图注意力层及标记图数据来构建标记嵌入模块,并在所述标记嵌入模块的每一层均聚合所述节点特征的邻接特征信息的步骤中,在标记嵌入模块中引入多头注意力机制,其中,在标记嵌入模块的第一图注意力层中,采用拼接策略,第一图注意力层的表达式为:

32、;

33、其中,k表示为注意力头的个数,表示为标记i和标记j之间的注意力评分的k次方,表示为对应注意力的权重,在标记嵌入模块的第二图注意力层中,将所述拼接策略改为平均策略来更新最后一层的节点特征,第二图注意力层的表达式为:

34、。

35、进一步的,所述利用一个多层感知机将所述原始特征空间映射到潜在特征空间的过程表示为:

36、;

37、;

38、其中,z1表示为多层感知机,表示为多层感知机对应的线性层,和分别表示为可供学习的第一权重和第一偏置值,多层感知机的每一层采用leakyrelu来激活,x表示为特征向量。

39、进一步的,所述获取通过所述标记嵌入模块得来的邻接特征信息,并根据所述邻接特征信息指导所述潜在特征空间,以生成标记特定特征空间的步骤中,首先通过一层线性层来将标记嵌入转化为每个标记对每个潜在特征重要度的评分,并利用sigmoid函数来激活,该过程表示为:

40、;

41、其中,lend表示为标记嵌入模块的最后一层,和分别表示为可供学习的第二权重和第二偏置值,随后通过计算每个标记对每个潜在特征的重要度评分与特征嵌入的哈德玛积,完成了标记语义的指导过程,并利用一层全连接层通过leakyrelu来激活以获取最终的标记特定特征,该过程表示为:

42、;

43、其中,表示为哈德玛积,和分别表示为可供学习的第三权重和第三偏置值。

44、本发明实施例的第二方面提供了一种基于图注意力机制的多标记分类系统,用于实现第一方面所述的基于图注意力机制的多标记分类方法,所述系统包括:

45、训练模块,用于将训练集中归一化处理后的多标记分类数据的特征以及标记图数据作为神经网络模型的输入,多标记分类数据的标记评分作为输出,以训练神经网络模型,得到目标神经网络模型;

46、划分模块,用于将待分类的数据输入目标神经网络模型中,根据输出的标记评分来划分相关标记和无关标记。

47、本发明实施例的第三方面提供了一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现第一方面提供的基于图注意力机制的多标记分类方法。

48、本发明实施例的第四方面提供了一种电子设备,包括存储器、处理器以及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现第一方面提供的基于图注意力机制的多标记分类方法。

49、本发明实施例当中提供的一种基于图注意力机制的多标记分类方法及系统,该方法利用了神经网络强大学习能力来捕获标记和特征之间的依赖关系,能够更为高效准确的学习出标记特定特征。同时,为了利用标记之间的相关性来指导多标记的分类过程,利用多层图注意力层来进行标记嵌入,在聚合领域特征的过程中,利用注意力评分来优化每个节点的连接,并利用多头注意力机制来获得丰富而稳定的节点表达。另外,为了保证标记节点特征在标记嵌入前后,其相关性依然能够保持一致性,利用一种一致性标记损失来解决这一问题,能够正确指导标记的分类过程,从而最终有效解决标记嵌入网络结构中的语义信息损失导致的网络正确率低的问题,另外,还考虑了标记之间可能存在的不均衡权重,以进一步提高准确性。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1