一种分类模型训练方法、装置、设备及介质与流程

文档序号:26264940发布日期:2021-08-13 19:16阅读:174来源:国知局
一种分类模型训练方法、装置、设备及介质与流程

本申请涉及分类器技术领域,特别涉及一种分类模型训练方法、装置、设备及介质。



背景技术:

随着云计算、物联网、移动通信和智能终端等信息技术的快速发展,以社交网络、社区和博客为代表的新型应用得到广泛使用。这些应用不断产生大量数据,方便用图来建模分析。其中,顶点表示个人或团体,连接边表示他们之间的联系;顶点上通常附有标签信息,用以表示所建模对象的年龄、性别、位置、兴趣爱好和宗教信仰,以及其他许多可能的特征。这些特征从各个方面反映了个人的行为偏好,理想情况下,每个社交网络用户都附有所有与自己特征相关的标签。但现实情况却并非如此。这是因为,用户出于保护个人隐私的目的,越来越多的社交网络用户在分享个人信息时,显得更加谨慎,导致社交网络媒体仅能搜集用户的部分信息。因此,如何根据已知用户的标签信息,推测剩余用户的标签,显得尤为重要和迫切。该问题即顶点分类问题。

目前,通过图神经网络解决顶点分类问题已成为研究热点。图神经网络通常由输入层、一个或多个隐藏层,以及输出层组成。例如,参见图1所示,图1为现有技术中的一种图神经网络结构图,图1展示了一种典型的图卷积神经网络的结构,它由一个输入层(inputlayer)、两个图卷积层(gconvlayer),和一个输出层(outputlayer)组成。其中,输入层读取n*d维的顶点特征矩阵,图卷积层对顶点特征矩阵进行特征提取,经由非线性激活函数如relu变换后传递给下一个图卷积层,最后,输出层即任务层,完成特定的任务如顶点分类、聚类等,图1中展示的是一个顶点分类任务层,输出每个顶点的类别标签。当前,如何提高分类准确度是需要解决的问题。



技术实现要素:

有鉴于此,本申请的目的在于提供一种分类模型训练方法、装置、设备及介质,能够提升分类模型的分类准确度。其具体方案如下:

第一方面,本申请公开了一种分类模型训练方法,包括:

基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;

将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;

将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;

基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;

当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。

可选的,所述在训练过程中确定出相应的有监督训练损失,包括:

在训练过程中,基于teacher图小波神经网络的第一顶点标签预测结果与所述顶点标签矩阵确定出相应的有监督训练损失;

相应的,所述在训练过程中确定出相应的无监督训练损失,包括:

在训练过程中,基于student图小波神经网络的第二顶点标签预测结果与所述第一顶点标签预测结果确定出相应的无监督训练损失。

可选的,还包括:

在训练过程中,利用所述第一顶点标签预测结果更新所述顶点标签矩阵;

当所述目标训练损失收敛,则输出当前的顶点标签矩阵,得到每个无类别标签的顶点的类别预测结果。

可选的,所述方法还包括:

利用切比雪夫多项式计算所述图数据集的图小波变换基,以及图小波逆变换基;

相应的,teacher图小波神经网络以及student图小波神经网络在训练过程中基于所述图小波变换基和图小波逆变换基进行图卷积操作。

可选的,所述方法还包括:

获取所述图小波变换基的计算公式;

其中,所述计算公式为基于谱理论定义的公式。

可选的,teacher图小波神经网络以及student图小波神经网络均包括输入层,若干图卷积层,以及输出层;

其中,所述图卷积层用于在训练过程中对该层的输入数据依次进行特征变换以及图卷积操作处理。

可选的,所述方法还包括:

在训练过程中,基于注意力机制利用所述teacher图小波神经网络训练得到的图卷积层的卷积核确定所述student图小波神经网络中对应的图卷积层的卷积核。

第二方面,本申请公开了一种分类模型训练装置,包括:

训练数据构建模块,用于基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;

分类模型训练模块,用于将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。

第三方面,本申请公开了一种电子设备,包括:

存储器,用于保存计算机程序;

处理器,用于执行所述计算机程序,以实现前述的分类模型训练方法。

第四方面,本申请公开了一种计算机可读存储介质,用于保存计算机程序,所述计算机程序被处理器执行时实现前述的分类模型训练方法。

可见,本申请先基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息,之后将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。这样,将图数据集的顶点特征矩阵、邻接矩阵输入图神经网络进行训练,利用了图拓扑结构和顶点特征,在训练的时候,利用了有监督训练和无监督训练,充分发挥有监督训练和无监督训练各自的优势,能够提升分类模型的分类准确度。

附图说明

为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。

图1为现有技术中的一种图神经网络结构图;

图2为本申请公开的一种分类模型训练方法流程图;

图3为本申请公开的一种具体的分类模型训练方法流程图;

图4为本申请公开的一种分类模型结构图;

图5为本申请公开的一种具体的分类模型结构图;

图6为本申请公开的一种具体的分类模型训练方法流程图;

图7为本申请公开的一种分类模型训练装置结构示意图;

图8为本申请公开的一种电子设备结构图。

具体实施方式

下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。

参见图2所示,本申请实施例公开了一种分类模型训练方法,包括:

步骤s11:基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;

其中,所述标签信息表示对应的类别标签或无类别标签。

在具体的实施方式中,假设有图数据集为v表示顶点集合,v分为少量具有类别标签的顶点集合和大部分无类别标签的顶点集合两部分,并满足e表示连接边集合。除标签外,g的每个顶点v均拥有d个特征,所有顶点的特征构成了维的顶点特征矩阵,记为xg的邻接矩阵记为a,元素表示顶点ij之间的连接边的权重。根据已有标签的顶点集合,构建维的顶点标签矩阵y,其中,表示图中所有顶点个数,c表示所有顶点的标签类别数,矩阵元素表示顶点i的类别标签是否为,当顶点i已有类别标签时,置对应的第j列元素为1,其余列元素为0。即有:

当顶点i为无类别标签时,将该行对应的每一列元素都置为0。

步骤s12:将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失。

步骤s13:将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失。

在具体的实施方式中,在训练过程中,基于teacher图小波神经网络的第一顶点标签预测结果与所述顶点标签矩阵确定出相应的有监督训练损失;基于student图小波神经网络的第二顶点标签预测结果与所述第一顶点标签预测结果确定出相应的无监督训练损失。

具体的,第一顶点标签预测结果和顶点标签矩阵进行比较,计算有监督训练损失,第二顶点标签预测结果与所述第一顶点标签预测结果比较,计算无监督学习损失。

步骤s14:基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失。

在具体的实施方式中,目标训练损失的计算公式如下:

;

其中,表示有监督训练损失,表示无监督训练损失,为一个常数,用于调节无监督训练损失在目标损失中所占的比例。表示第一顶点标签预测结果,表示第二顶点标签预测结果。

其中,均为维的矩阵,并且,中每个列向量表示所有顶点属于类别j的概率,即它的第个元素表示顶点i属于类别的概率。

需要指出的是,本申请实施例可以将teacher图小波神经网络以及student图小波神经网络的输出层定义为

其中,

为图小波变换基,为图小波逆变换基,表示第l层图卷积层的卷积核矩阵,表示第l层顶点特征变换结果,teacher图小波神经网络以及student图小波神经网络均包括l层图卷积层。

并且,有监督训练损失函数基于交叉熵原理,计算顶点实际标签概率分布和预测标签概率分布的差异程度;无监督训练损失函数计算相同坐标元素之间差值的平方和。

这样,当整个网络训练结束时,两个网络的输出结果一致或差别可忽略不计。可以teacher图小波神经网络的输出为整个网络模型的输出。

本实施例在训练过程中,利用所述第一顶点标签预测结果更新所述顶点标签矩阵,具体的,对于无类别标签的顶点,即对于,将第一顶点标签预测结果中概率最大的类别作为该顶点的最新类别,更新顶点标签矩阵。

步骤s15:当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。

并且,当所述目标训练损失收敛,则输出当前的顶点标签矩阵,得到每个无类别标签的顶点的类别预测结果。

在具体的实施方式中,当目标训练损失达到预设阈值或者迭代次数达到指定迭代最大值,则目标训练损失收敛,训练结束。其中,预设阈值通常为一个较小的值,此时,对于无类别标签的顶点,根据当前的顶点标签矩阵,得到其应归属的类别。

也即,本申请将无标签顶点的预测融合进训练过程:在训练过程中,根据每次的训练结果更新顶点标签矩阵,训练结束后即可获得任意一个无标签顶点的类别标签。

其中,在具体的实施方式中,可以先根据按照特定策略如正态分布随机初始化、xavier初始化或he初始化,对图小波神经网络各层网络参数进行初始化。在训练的过程中,可以根据特定策略如sgd(即stochasticgradientdescent,随机梯度下降)、mgd(即momentumgradientdescent,动量梯度下降)、nesterovmomentum(牛顿动量)、adagrad(即adaptivegradientalgorithm,自适应梯度算法)、rmsprop(即rootmeansquareprop,前向均方根梯度下降算法)和adam(即adaptivemomentestimation,自适应矩估计)或bgd(即batchgradientdescent,批量梯度下降)等,对图小波神经网络各层网络参数进行修正和更新,以优化损失函数值。

可见,本申请实施例先基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息,之后将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。这样,将图数据集的顶点特征矩阵、邻接矩阵输入图神经网络进行训练,利用了图拓扑结构和顶点特征,在训练的时候,利用了有监督训练和无监督训练,充分发挥有监督训练和无监督训练各自的优势,能够提升分类模型的分类准确度。

参见图3所示,本申请实施例公开了一种具体的分类模型训练方法,包括:

步骤s21:获取图小波变换基的计算公式。

其中,所述计算公式为基于谱理论定义的公式。

需要指出的是,通过傅里叶变换定义的图卷积操作在顶点域局部性差,利用谱理论定义图小波变换的基底,保证了图卷积计算的局部性。

步骤s22:利用切比雪夫多项式计算所述图数据集的图小波变换基,以及图小波逆变换基。

在具体的实施方式中,图小波变换基的计算公式为,其中,表示从图数据集g中提取的图小波变换基,u表示由对图数据集g的拉普拉斯矩阵进行特征分解得到的特征向量所组成的矩阵;d是一个对角阵,其主对角线上的n个元素分别表示n个顶点的度数,其余元素均为零。是缩放尺度为r的缩放矩阵,并设是对图g的拉普拉斯矩阵进行特征分解得到的特征值;图小波逆变换基可以通过将中的替换为求得。由于矩阵的特征分解计算开销较大,为避免此开销,利用切比雪夫多项式,且,,来近似计算图小波变换基,以及图小波逆变换基。

相应的,teacher图小波神经网络以及student图小波神经网络在训练过程中基于所述图小波变换基和图小波逆变换基进行图卷积操作。

需要指出的是,现有技术中在图卷积操作的过程中图傅里叶变换是低效的,因为拉普拉斯矩阵的特征向量矩阵是稠密的,而本实施例基于所述图小波变换基和图小波逆变换基进行图卷积操作,图小波变换基和图小波逆变换基是稀疏的,所以能够提升图卷积操作的运算效率。

步骤s23:基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息,所述标签信息表示对应的类别标签或无类别标签。

步骤s24:将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失。

步骤s25:将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失。

在具体的实施方式中,teacher图小波神经网络以及student图小波神经网络均包括输入层,若干图卷积层,以及输出层;其中,所述图卷积层用于在训练过程中对该层的输入数据依次进行特征变换以及图卷积操作处理。具体的,可以包括1个输入层,个图卷积层,以及输出层。

也即,本申请实施例中,图卷积层对该层的输入数据先进行特征变换,然后图卷积操作处理,这样将图卷积层分为特征变换和图卷积操作先后两个处理阶段,能够减少网络参数,从而减低模型运算量,提升模型训练效率。

其中,在层图卷积层中:

特征变换:

图卷积:

其中,分别为第l层图隐藏层的输入和输出数据,且;为第l层待训练的特征变换矩阵,为第l层特征变换结果,t表示矩阵的转置操作。

需要指出的是,现有技术中图卷积层定义通常未区分特征变换和卷积操作,结合本申请实施例中的图小波变换基,如果不将图卷积层分为特征变换和图卷积操作先后两个处理阶段。以如下公式定义图卷积层:

其中,x表示顶点特征矩阵,m表示图卷积层的序数,f是图卷积核矩阵,h是激活函数。在采用上述方式定义的图卷积层中,包含的参数个数是,其中n表示图中顶点的个数,p表示该层输入的顶点特征维度,q表示该层输出的顶点特征维度。而本申请实施例将特征变换从图卷积操作剥离出来,每一个图卷积层的参数个数就变成了

另外,在具体的实施方式中,在训练过程中,基于注意力机制利用所述teacher图小波神经网络训练得到的图卷积层的卷积核确定所述student图小波神经网络中对应的图卷积层的卷积核。

具体的,分类模型可以包括teacher图小波神经网络、student图小波神经网络,以及连接teacher图小波神经网络与student图小波神经网络每一对图卷积层的注意力网络。

需要说明的是,设是第l层的图卷积核矩阵,为一个对角阵。从信号处理角度看,对角线上元素可视为图的频率,表示该频率对应的特征向量的重要性。记teacher图小波神经网络和student图小波神经网络第l层的卷积核矩阵分别为tlsl,分别由teacher图小波神经网络该层的卷积核和student图小波神经网络该层的卷积核对角化得到,两者均是n维的列向量。

本实施例中,可以基于注意力机制进行注意力转移(attentiontransfer):teacher图小波神经网络将每一层将学习到的卷积核转移给student图小波神经网络的相应层,也即,student图小波神经网络向teacher图小波神经网络学习,促使提高整个网络的性能。具体地,可设计一个单层的前馈神经网络,其输入层负责读取teacher图小波神经网络和student图小波神经网络第l层卷积核;其隐藏层用于实现注意力函数,以便得到两个向量之间的注意力权重

进一步的,通过softmax函数对注意力权重进行归一化得到归一化的注意力权重为

其中,表示的第i个分量,表示的第i个分量。则有:

其中,表示student图小波神经网络向teacher图小波神经网络,学习到的第l层卷积核。

需要指出的是,注意力机制的加入,促进student图小波神经网络快速利用teacher图小波神经网络掌握的知识,提高训练速度。

步骤s26:基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失。

步骤s27:当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。

例如,参见图4所示,图4为本申请实施例公开的一种分类模型结构图。teacher图小波神经网络gwnt,student图小波神经网络gwns,进一步的,参见图5所示,图5为本申请实施例公开的一种具体的分类模型结构图。分类模型由teacher图小波神经网络gwnt和,student图小波神经网络gwns,以及连接两个网络每一对图卷积层的注意力网络组成。gwnt根据有标签的图顶点进行有监督学习,预测准确度较高;gwns在gwnt的指导下(利用其预测结果)利用无标签的图顶点进行无监督学习,以期提高预测准确度,获得更好的顶点分类模型。注意力网络用于gwnt将每一层学习到的“知识”即卷积核转移给gwns对应层,也即gwns向gwnt学习。gwnt和gwns均包含1个输入层、l个图卷积层以及1个输出层。输入层主要用于读取待分类图数据,包括表示图拓扑结构的邻接矩阵a和顶点特征矩阵x。图卷积层中将图卷积操作分解为特征变换和图卷积先后两个阶段。输出层用于输出预测结果。

并且,整个分类模型中,每一层的网络参数均包含特征变换矩阵(包括teacher图小波神经网络的和student图小波神经网络的),卷积核(卷积核和卷积核),进而利用卷积核更新卷积核矩阵,以及注意力网络参数。在初始化阶段,初始前述网络参数,在训练过程中,更新前述网络参数。

例如,参见图6所示,本申请实施例公开了一种具体的分类模型训练方法流程图,对于一个给定的图数据集g,以其邻接矩阵a、顶点特征矩阵x以及顶点标签矩阵y作为输入,送入网络进行前向传播,计算所有顶点属于每一类别的预测结果,更新预测结果矩阵的同时,计算有监督学习部分的损失和无监督学习部分的损失,从而得到总的网络损失函数值,按照一定策略更新各层网络参数,直至网络误差达到一个指定的较小值或迭代次数达到指定的最大值时,训练结束。

例如,基于本申请实施例的方法利用科技论文集训练分类模型并预测无标签的科技论文的类别标签。

(1)下载引文网络数据集citeseer,包含共分为六个类别的3312篇科技论文以及4732条论文间的引用关系;利用bag-of-words(词袋模型)为每篇论文构建其特征向量x,所有文档的特征向量组成特征矩阵x。根据论文间的引用关系,构建其邻接矩阵a。目标是将每个文档归类,每个类别随机抽取20个实例作为标记数据,将1000个实例作为测试数据,其余用作未标记的数据;构建顶点标签矩阵y

(2)定义网络结构:基于前述公开内容定义图卷积层、输出层以及网络损失函数。

(3)利用切比雪夫多项式近似计算图小波变换基底和图小波逆变换的基底。

(4)按照正则化初始化方法,对网络参数进行初始化。

(5)以axy作为网络输入,送入网络进行前向传播。其中,teacher图小波神经网络gwnt以axy作为输入,student图小波神经网络gwns以ax作为输入。每个网络根据图卷积层的定义,结合该层的输入特征矩阵,计算每一层的输出特征矩阵;按照输出层的定义,计算所有顶点属于每一类别的预测结果ztzs,并根据前述定义的网络损失函数计算有监督学习损失函数值、无监督学习函数损失值,进而得到整个网络的损失函数值;对于无标签顶点,取概率最大的那一类别作为该顶点的最新类别,并更新顶点标签矩阵y

(6)按照优化方法,计算损失函数关于网络参数的梯度,并后向传播,以便对网络参数进行优化,直至网络预测误差达到一个指定的较小值或迭代次数达到指定迭最大值时,训练结束。此时,对于无类别标签的顶点,可根据顶点标签矩阵y得到其应归属得类别。

当然,本申请不局限应用于实施例中列举的科学引文分类问题,还可应用于任意方便用图来建模表示的数据的分类问题,如蛋白质、图形图像等,以及用于研究传染性疾病和思想观点等在社交网络中随着时间传播扩散的规律、研究社交网络中的群体如何围绕特定利益或隶属关系形成社团,以及社团连接的强度;社交网络根据“人以群分”的规律,发现具有相似兴趣的人,向他们建议或推荐新的链接或联系;问答系统将问题引导给最有相关经验的人;广告系统向最有兴趣并愿意接受特定主题广告的个人显示广告等。

参见图7所示,本申请实施例公开了一种分类模型训练装置,包括:

训练数据构建模块11,用于基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;

分类模型训练模块12,用于将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。

可见,本申请实施例先基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息,之后将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。这样,将图数据集的顶点特征矩阵、邻接矩阵输入图神经网络进行训练,利用了图拓扑结构和顶点特征,在训练的时候,利用了有监督训练和无监督训练,充分发挥有监督训练和无监督训练各自的优势,能够提升分类模型的分类准确度。

其中,分类模型训练模块12,具体用于在训练过程中,基于teacher图小波神经网络的第一顶点标签预测结果与所述顶点标签矩阵确定出相应的有监督训练损失;基于student图小波神经网络的第二顶点标签预测结果与所述第一顶点标签预测结果确定出相应的无监督训练损失。

分类模型训练模块12,还用于:在训练过程中,利用所述第一顶点标签预测结果更新所述顶点标签矩阵;当所述目标训练损失收敛,则输出当前的顶点标签矩阵,得到每个无类别标签的顶点的类别预测结果。

所述装置还包括图小波变换基计算模块,用于利用切比雪夫多项式计算所述图数据集的图小波变换基,以及图小波逆变换基;相应的,teacher图小波神经网络以及student图小波神经网络在训练过程中基于所述图小波变换基和图小波逆变换基进行图卷积操作。

所述装置还包括,图小波变换基公式获取模块,用于获取所述图小波变换基的计算公式;其中,所述计算公式为基于谱理论定义的公式。

在具体的实施方式中,teacher图小波神经网络以及student图小波神经网络均包括输入层,若干图卷积层,以及输出层;

其中,所述图卷积层用于在训练过程中对该层的输入数据依次进行特征变换以及图卷积操作处理。

分类模型训练模块12,还用于在训练过程中,基于注意力机制利用所述teacher图小波神经网络训练得到的图卷积层的卷积核确定所述student图小波神经网络中对应的图卷积层的卷积核。

参见图8所示,本申请实施例公开了一种电子设备20,包括处理器21和存储器22;其中,所述存储器22,用于保存计算机程序;所述处理器21,用于执行所述计算机程序,前述实施例公开的分类模型训练方法。

关于上述分类模型训练方法的具体过程可以参考前述实施例中公开的相应内容,在此不再进行赘述。

并且,所述存储器22作为资源存储的载体,可以是只读存储器、随机存储器、磁盘或者光盘等,存储方式可以是短暂存储或者永久存储。

另外,所述电子设备20还包括电源23、通信接口24、输入输出接口25和通信总线26;其中,所述电源23用于为所述服务器20上的各硬件设备提供工作电压;所述通信接口24能够为所述电子设备20创建与外界设备之间的数据传输通道,其所遵循的通信协议是能够适用于本申请技术方案的任意通信协议,在此不对其进行具体限定;所述输入输出接口25,用于获取外界输入数据或向外界输出数据,其具体的接口类型可以根据具体应用需要进行选取,在此不进行具体限定。

进一步的,本申请实施例还公开了一种计算机可读存储介质,用于保存计算机程序,其中,所述计算机程序被处理器执行时实现前述实施例公开的分类模型训练方法。

关于上述分类模型训练方法的具体过程可以参考前述实施例中公开的相应内容,在此不再进行赘述。

本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其它实施例的不同之处,各个实施例之间相同或相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。

结合本文中所公开的实施例描述的方法或算法的步骤可以直接用硬件、处理器执行的软件模块,或者二者的结合来实施。软件模块可以置于随机存储器(ram)、内存、只读存储器(rom)、电可编程rom、电可擦除可编程rom、寄存器、硬盘、可移动磁盘、cd-rom、或技术领域内所公知的任意其它形式的存储介质中。

以上对本申请所提供的一种分类模型训练方法、装置、设备及介质进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的一般技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1