一种基于图自注意力机制的异质图网络节点分类方法

文档序号:35628376发布日期:2023-10-06 00:44阅读:30来源:国知局
一种基于图自注意力机制的异质图网络节点分类方法

本发明涉及交通流预测,特别指一种基于图自注意力机制的异质图网络节点分类方法。


背景技术:

1、图神经网络(graph neural networks,简称gnns)是一类用于图数据挖掘的深度学习方法,被广泛应用于众多领域并且取得了很好的成果。在异质图网络上进行节点分类是gnns的一项重要任务,异质图网络指由不同类型的节点和边(关系)组成的图网络,存在于许多现实世界的场景中,如社交网络中的用户和用户之间的多种关系,化合物分子中不同类型的原子和化学键等。异质图网络节点分类的目标是将所有节点分类到对应的类别中,从而更好地理解和学习异质图网络的结构和特征。

2、异质图网络的节点分类可应用在不同领域,如金融风险评估、推荐系统、医疗诊断等。在金融风险评估领域,可以使用异质图网络表示用户、资产和交易等信息,并通过节点分类来评价客户的信用等级和风险水平;在推荐系统领域,可以使用异质图网络表示用户、商品和用户商品交互信息,并通过节点分类来得到用户的兴趣和购买行为;在医疗诊断领域,可以使用异质图网络表示疾病、症状、药物等信息,并通过节点分类来预测疾病的类型以及严重程度。异质图网络的节点分类具有现实意义,可以更好地帮助我们理解及分析复杂的图结构数据,从而在多个领域实现更准确的预测。

3、由于异质图网络中的节点和边具有不同的类型,因此在进行节点分类时,不仅需要考虑节点的特征,同时需要考虑节点间复杂的异构信息。例如,在社交网络中,用户节点可能具有不同的类型以及节点间存在不同的联系等异构信息,而节点特征可能具有如年龄、性别、职业、爱好等属性,这些属性和异构信息可以作为图的特征输入到gnns中,以帮助提高分类的准确性。

4、异质图网络节点分类的一大难点是如何使用其丰富的异构信息提高分类的准确率,传统方法通常是使用异质图网络上的元路径来定义不同类型节点之间的关系,并利用元路径推导出节点之间的相似性,然后使用gnns对节点进行编码和分类,但传统方法无法捕捉到异质图网络中节点的高阶语义信息,无法学习到元路径以外的一些节点特征表示信息,导致使用传统方法进行交通流预测时,预测(节点分类)的准确率不尽如人意。

5、因此,如何提供一种基于图自注意力机制的异质图网络节点分类方法,实现提升交通流预测的准确率,成为一个亟待解决的技术问题。


技术实现思路

1、本发明要解决的技术问题,在于提供一种基于图自注意力机制的异质图网络节点分类方法,实现提升交通流预测的准确率。

2、本发明是这样实现的:一种基于图自注意力机制的异质图网络节点分类方法,包括如下步骤:

3、步骤s1、获取大量的交通异质图网络的数据集,按预设比例将所述数据集划分为训练集、验证集和测试集,从所述训练集、验证集和测试集中分别提取交通异质图网络的节点特征矩阵和邻接矩阵集合;

4、步骤s2、基于全局自注意力模块、图自注意力模块以及输出模块创建一异质图自注意力网络模型;

5、步骤s3、利用所述训练集对异质图自注意力网络模型进行训练,利用所述验证集对训练后的异质图自注意力网络模型进行验证;

6、步骤s4、利用所述测试集对验证后的异质图自注意力网络模型进行测试,并不断优化所述异质图自注意力网络模型的超参数;

7、步骤s5、利用测试后的所述异质图自注意力网络模型进行交通异质图网络的节点分类,进而进行交通流预测。

8、进一步的,所述步骤s1中,所述预设比例为2:1:7。

9、进一步的,所述步骤s1中,所述节点特征矩阵为:

10、x∈rn×d;

11、所述邻接矩阵集合为不同类型边的邻接矩阵集合,公式为:

12、;

13、其中,x表示节点特征;r表示实数;n表示节点数量;d表示节点特征的输入维度;a表示邻接矩阵;k表示异质图的边的类型数;k表示邻接矩阵编号。

14、进一步的,所述步骤s2中,所述全局自注意力模块用于学习交通异质图网络中各节点在全局的节点特征依赖和节点特征表示;

15、所述全局自注意力模块的学习过程为:

16、s211、将所述节点特征矩阵x分别通过三个可学习的矩阵wq、wk、wv投影为q、k、v:

17、q=xwq,k=xwk,v=xwv;

18、其中,wq∈rd×dk;wk∈rd×dk;wv∈rd×dv;dk=dv=d;

19、s212、对所述q、k、v应用归一化的点乘注意力机制计算自注意力矩阵sattn:

20、;

21、其中,softmax()表示归一化指数函数;t表示矩阵转置操作;

22、s213、并行执行多次归一化的点乘注意力机制,把计算得到的各所述自注意力矩阵sattn相加取均值,得到节点嵌入xmhead:

23、;

24、其中,xmhead∈rn×d,表示经过多头注意力机制学习得到的节点嵌入;head表示多头注意力机制的头数;w0∈rd×dv;

25、s214、对所述节点嵌入xmhead与q做残差连接后进行归一化,得到节点嵌入xn1:

26、xn1=norm(q+xmhead(q,k,v));

27、其中,xn1∈rn×d,表示经过第一次归一化后得到的节点嵌入;norm()表示归一化函数;

28、s215、将所述节点嵌入xn1输送到由两层线性连接层组成的前馈网络,并在两个所述线性连接层之间使用激活函数relu来增加全局自注意力模块的非线性,得到节点嵌入xffn:

29、xffn=linear(relu(linear(xn1)));

30、其中,xffn∈rn×d,表示经过前馈网络后得到的节点嵌入;linear()表示线性连接层;

31、s216、对所述节点嵌入xffn与xn1做残差连接后进行归一化,得到节点嵌入xn2:

32、xn2=norm(xn1+xffn);

33、s217、对所述节点特征矩阵x和节点嵌入xn2进行拼接,得到节点特征表示xg:

34、xg=x‖xn2;

35、其中,xg∈rn×2d;‖表示拼接操作。

36、进一步的,所述步骤s2中,所述图自注意力模块用于学习交通异质图网络中不同类型边和节点特征的表示;

37、所述图自注意力模块的学习过程为:

38、s221、把不同类型边所构成的邻接矩阵a聚合在一起,得到新的邻接矩阵ac:

39、ac=conv(a;wc)=awc;

40、其中,ac∈rn×n;conv()表示卷积函数;wc∈rk×1×1,表示可学习的参数矩阵;

41、s222、在所述邻接矩阵ac、节点特征表示xg的基础上,利用图卷积层学习交通异质图网络的节点以及其一阶邻居的特征信息,得到节点嵌入xc:

42、xc=relu(graphconv(xg;ac))=relu(acxgw);

43、其中,xc∈rn×dout,表示经过图卷积层学习得到的节点嵌入;dout表示输出的嵌入维度;graphconv()表示图卷积操作;w∈r2d×dout,表示图卷积的权重矩阵;

44、s223、给定节点嵌入xc=[x1,x2…xn]t∈rn×dout,xn∈rdout,表示节点n的特征表示;对于存在连接边的节点i和节点j,使用可学习参数wq、wk、bq、bk,将节点i的特征xi和节点j的特征xj分别转化为qi和kj:

45、qi=wqxi+bq;

46、kj=wkxj+bk;

47、其中,qi∈rdout,kj∈rdout,均为向量;

48、s224、将所述邻接矩阵ac通过可学习参数we、be转换为边缘特征eij,将所述边缘特征eij加入向量kj,得到向量kj’:

49、eij=weaij+be;

50、kj’=kj+eij;

51、其中,aij为邻接矩阵ac中的元素值,表示节点i和节点j之间存在相连的边;

52、s225、计算从节点j到节点i的每一条边的归一化点乘注意力αij:

53、;

54、;

55、其中,exp()表示以自然常数e为底的指数函数;n(i)表示节点i基于邻接矩阵ac的一阶邻居节点;

56、s226、通过可学习参数wv、bv将节点j的特征xj转换为vj:

57、vj=wvxj+bv;

58、其中,vj∈rdout;

59、s227、基于所述vj、αij、eij计算多头注意力,得到节点嵌入zi:

60、;

61、s228、对所述节点嵌入zi引入门控单元gate以及残差连接,得到节点嵌入:

62、ri=wrxi+br;

63、di=zi‖ri‖(zi-ri);

64、;

65、;

66、其中,wr、br、wg均为可学习参数,且wg∈r3dout;i表示节点编号;t表示转置操作;‖表示拼接操作;d表示拼接操作后得到的矢量;

67、s229、对所述节点嵌入进行归一化,得到节点嵌入zi:

68、;

69、其中,zi∈rdout;

70、s230、重复两次s221-s229的学习过程,在经过所述图自注意力模块的学习后,获得所有节点最终的节点嵌入z,z∈rn×dout。

71、进一步的,所述步骤s2中,所述输出模块用于预测节点类别;

72、所述输出模块的计算过程为:

73、将所述节点嵌入z输入两个全连接层和softmax函数得到预测的节点类别p:

74、p=softmax(linear(linear(z)));

75、其中,p∈r1×n,n表示节点类别数。

76、进一步的,所述步骤s4中,所述超参数至少包括随机失活率、权值衰减率以及学习率。

77、本发明的优点在于:

78、通过获取大量的交通异质图网络的数据集并划分为训练集、验证集和测试集,从训练集、验证集和测试集中分别提取异质图网络的节点特征矩阵和邻接矩阵集合;基于全局自注意力模块、图自注意力模块以及输出模块创建一异质图自注意力网络模型,利用训练集对异质图自注意力网络模型进行训练,利用验证集对训练后的异质图自注意力网络模型进行验证,利用测试集对验证后的异质图自注意力网络模型进行测试,并不断优化异质图自注意力网络模型的超参数,最后利用测试后的异质图自注意力网络模型进行交通流预测;由于全局自注意力模块用于学习交通异质图网络中各节点在全局的节点特征依赖和节点特征表示,图自注意力模块用于学习交通异质图网络中不同类型边和节点特征的表示,在整个学习过程中不需要使用元路径,并能够更好学习交通异质图网络丰富的特征信息和高阶语义信息,具有更强大的异质图网络的节点特征学习能力,进而极大的提升了交通流预测的准确率。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1