一种图节点分类模型训练方法、图节点分类方法
未命名
10-08
阅读:89
评论:0

1.本发明属于图机器学习、图表示学习技术领域,具体涉及一种基于图注意力和改进transformer图节点分类模型训练方法、图节点分类方法。
背景技术:
2.近年来,以卷积神经网络为代表的深度学习方法在机器学习领域取得了显著成效,特别是在计算机视觉、自然语言处理等方向均取得了突破性进展。但是,上述成果主要面向欧几里得空间数据例如视频、图像等,无法处理现实世界中更多存在的非欧几里得空间数据例如社交网络、分子结构、交通流网络等。此时,若直接将cnn等传统深度学习模型迁移到图数据任务中,其表现并不尽如人意。鉴于此,研究人员提出图神经网络作为图深度学习解决方案,并将其广泛应用于图节点分类任务中。
3.目前,在各类图神经网络中,如何构建节点邻居间高效的消息传递机制,一直是提高模型性能的主要关注点。例如,kipf等人将卷积操作应用于图数据提出了图卷积网络gcn,为神经网络迁移至图数据提供了标准范式。随后velickovic等人提出图注意力网络gat,通过学习节点邻居的不同聚合权重,自适应传递消息及更新中心节点特征。除此之外,现有技术中还公开了sgc模型,通过连续移除激活层及预计算邻接矩阵次幂来降低模型复杂度。然而,上述模型无法学习节点间的长距离依赖关系。而且,由于模型不断重复局部聚合操作,还易出现过平滑和过挤压问题。
4.在此背景下,transformer因其能捕获长距离依赖,减轻过平滑、过挤压现象正受到越来越多的关注。很多学者尝试将其与图表示学习进行深度融合,对图transformer进行了广泛研究。例如现有技术中公开了将gnn作为辅助模块引入transformer提出了graphtrans模型。然而,hussain等人则采用邻接矩阵的svd向量将其压缩成稠密位置嵌入,以此提高transformer对图拓扑结构的捕捉能力。此外,dwivedi等人还提出了graph transformer (gt),通过引入注意力屏蔽机制令中心节点仅关注一阶邻居,从而将全局自注意力机制改造为类似于gnn的结构。但是,上述方法仅在transformer中附加一些辅助模块实现图数据建模,从本质上说这并未改变其原有体系结构,数据建模精度还有待进一步提高。
技术实现要素:
5.针对上述技术问题,本发明提出一种基于图注意力和改进transformer的节点分类方法,能够同时兼顾小规模和中规模数据集的节点分类任务,实现分类性能的全面提升。
6.在文献引用、分子发现等不同规模的图节点分类任务场景中,分类网络需利用节点自身及邻居信息对该节点所属类别进行预测。本发明提出了一种图节点分类方法、一种基于图注意力和改进transformer的节点分类方法,以解决传统transformer模型对图结构数据建模精度不高的问题,并令其可同时兼顾小规模和中规模数据集的节点分类任务,实现分类性能的全面提升。
7.技术方案:
8.本发明的第一方面,提供了一种图节点分类模型训练方法,即一种基于图注意力和改进transformer的节点分类模型训练方法,如图1所示,包括以下步骤:
9.步骤s1,获取图结构源数据,所述图结构源数据的获取方式包括但不仅限于利用各类感知设备如传感器、或来自公开的图结构源数据库;所述图结构源数据具体类型包括但不仅限于论文引用、分子结构类图结构源数据,每类图结构源数据中的数据包括实体和实体间联系两类数据,将实体作为节点、实体间联系作为边进行数据预处理以构造图结构数据集g = (v, e),其中v = {v1, v2,
ꢀ…
, vn}、e = {e1, e2,
ꢀ…
, em}分别代表节点集和边集;
10.步骤s2,分割图结构数据集,将图结构数据集划分为训练集、验证集及测试集;
11.步骤s3,采用所述训练集训练获得初始节点分类模型,模型训练的超参数包括:多头注意力头数、初始学习率、衰减因子、丢弃率、层数;优选的,所述超参数的初始设置具体为:多头注意力头数为head_num = 8,初始学习率为learning_rate = 0.0005,衰减因子weight_decay = 0.5,丢弃率从0开始递增、层数从2开始递增,所述模型训练过程包括步骤:采用基于拓扑强化的节点嵌入步骤将图邻接矩阵a和度矩阵d所表示的结构信息压缩入节点位置编码;采用基于二级掩码的多头注意力机制处理,输入目标节点集和相应邻居节点集,针对性筛选邻居节点进行聚合及更新,输出各卷积层的目标节点输出嵌入表示;采用改进的transformer框架结构以跃连接方式进行节点特征融合,所述改进的transformer框架结构是由归一前置的ffn及层间残差构成的;
12.步骤s4,利用验证集对所述初始节点分类模型超参数进行优化设置获得优化的节点分类模型;令其可同时兼顾小规模和中规模数据集的节点分类任务,实现分类性能的全面提升;
13.步骤s5,利用测试集对所述优化的节点分类模型进行性能测试验证。
14.进一步的,所述基于拓扑强化的节点嵌入步骤如图2所示,包含以下步骤:
15.步骤s201:将训练集的图g的邻接矩阵a、度矩阵d作为变量计算图拉普拉斯矩阵δ,并对其进行特征值分解获取特征向量矩阵u,进而以u中k个最小非平凡特征向量ε1,
ꢀ…
, εk的升序构造拉普拉斯位置编码集λ,将图结构信息压缩入各节点位置编码;
16.步骤s202:利用隐空间映射矩阵w
λ
对拉普拉斯位置编码集λ中的各节点位置编码λ
v,k
, (v = 1,
ꢀ…
, n)进行维度对齐,获得对齐后的节点位置编码λv;
17.步骤s203:通过sum pooling在每层均将节点v的输入特征z
v(l)
与位置编码λv进行融合,进而获得该节点的拓扑强化特征h
v(l)
,实现图拓扑学习强化;其中,当l = 0时,z
v(0)
为节点初始输入特征,其值可由节点v对应的初始属性xv直接赋值获取,也可通过隐空间映射获得。
18.进一步的,所述基于二级掩码的多头注意力机制处理如图3所示,包含以下步骤:
19.步骤s301:在步骤s203获取基于拓扑强化的节点嵌入基础上,将目标节点集和相应邻居节点集作为该模块输入,针对当前层目标节点v及其邻居节点集进行组归一化操作,获取归一化输入gn(h
v(l)
)、gn(h
u(l)
),并利用线性映射构造qkv模型,其中q = gn(h
v(l)
)w
vq
、k = gn(h
u(l)
)w
vk
、v = gn(h
u(l)
)w
vv
;
20.步骤s302:根据点积操作计算各节点注意力分数矩阵(gn(h
v(l)
)w
vq
)
·
( gn(h
u(l)
)wvk
)
t
,并根据分值构造二级掩码指示矩阵m
(l)
进行聚合邻居筛选,进而采取维度缩放及softmax操作获得二级掩码矩阵;优选的,具体计算方法如下式所示:
[0021][0022]
步骤s303:根据所述二级掩码矩阵对目标节点v进行多头注意力计算,获得多头输出特征h
(l)v,cat
,并采用门控线性单元对h
(l)v,cat
进行激活,获取激活后的节点特征h
(l)v,glu
;
[0023]
步骤s304:利用感知机对激活后的节点特征h
(l)v,glu
进行线性映射,使其维度与隐层输出维度对齐,获取映射后该层目标节点的隐层输出特征h
(l)v,linear
。
[0024]
优选的,步骤s302中,所述构造二级掩码指示矩阵m
(l)
进行聚合邻居筛选具体步骤如图4所示:
[0025]
步骤s401:构造元素值均为0的二级掩码指示矩阵m
(l)
;
[0026]
步骤s402:根据步骤s302计算所得注意力分数矩阵,获取各邻居节点注意力分数su,(u = 1,
ꢀ…
, m),其中m为一阶邻居节点个数;
[0027]
步骤s403:比较各节点注意力分数,当su大于0时,对应位置的二级掩码指示矩阵元素值设置为1,表示该邻居节点参与后续特征传播及聚合过程;
[0028]
步骤s404:当su小于等于0时,对应位置的二级掩码指示矩阵元素值不变,表示舍弃该邻居节点。
[0029]
进一步的,采用改进的transformer框架结构以跃连接方式进行节点特征融合,如图5所示,包含以下步骤:
[0030]
步骤s501:以步骤s304获得的该层目标节点的隐层输出特征h
(l)v,linear
与该层目标节点输入特征h
(l)v
进行sum pooling操作的结果作为ffn模块的输入特征h
(l)v,ffn
,采用pre-norm结构,在位置前馈网络中增加组归一化前置操作gn(h
(l)v,ffn
),其中h
(l)v,ffn
为ffn模块的输入特征,并且,同步替换ffn模块中的层归一化为组归一化操作处理节点嵌入表示,然后输出目标节点输出特征z
v(l)
;
[0031]
步骤s502:利用跳跃残差模式,在相邻层构造层间残差网络,并且以本层输入特征h
v(l)
与输出特征z
v(l)
作为输入数据,采用sum pooling将h
v(l)
和z
v(l)
融合生成特征h
v,r(l+1)
作为残差输出特征;
[0032]
步骤s503:将步骤s202所得位置编码λv与各层残差输出特征h
v,r(l+1)
作为输入,进行sum pooling融合,构造基于拓扑学习强化的节点特征h
v(l+1)
,输出下层输入特征数据;
[0033]
步骤s504:判断是否为最后一层,若未到达最后一层,则跳转至步骤s301;若达到最后一层,则以步骤s503中最后一层的融合输出表示h
vl
作为mlp分类器输入,进行节点分类,获得模型输出结果。
[0034]
优选的,步骤s4中,若图结构数据为小规模数据集,本发明中所述小规模数据集指的是节点数小于10000的数据集,包含但不限于cora、citeseer、pubmed,所述超参数的优化设置包括:设置模型丢弃率dropout = 0.1,层数layers = 10;
[0035]
优选的,步骤s4中,若图结构数据为中规模数据集,本发明中所述中规模数据集指的是节点数大于10000的数据集,包含但不限于cluster、pattern,所述超参数的优化设置包括:设置模型丢弃率dropout = 0.1或0.01,层数layers = 12或13;
[0036]
本发明第二方面,提供了一种图节点分类方法,所述分类方法采用上述优化的节点分类模型如图6所示,所述分类方法包括步骤:
[0037]
获取分类任务数据,将任务数据中的实体数据预处理为节点、将实体间的联系数据预处理为边,构造用于节点分类的图结构数据;
[0038]
输入预处理后的任务数据,根据所述分类任务规模,选择上述优化后的图节点分类模型,输出分类结果。
[0039]
本发明的第三方面,提供了一种电子设备,包括:一个或多个处理器;存储装置,用于存储一个或多个程序;当一个或多个程序被一个或多个处理器执行,使得一个或多个处理器实现上述第一方面描述的图节点分类方法。
[0040]
有益效果
[0041]
本发明考虑的是在图节点分类任务中,通过一种基于图注意力和改进transformer的节点分类方法提高模型对图数据的建模精度,令其可同时兼顾小规模和中规模数据集的节点分类任务,实现分类性能的全面提升。
[0042]
对比传统的基于位置编码的图transformer节点嵌入,为避免传统方法容易产生节点特征趋同,进而导致模型对节点位置特征学习能力不足的问题,本发明申请在每层节点输入中均融入位置编码进行拓扑学习强化,以提高模型对图结构信息的学习能力,与此同时,基于拓扑学习强化,为了进一步捕捉节点间的长距离依赖关系,以解决传统图transformer仅通过一阶连通性构造注意力矩阵的缺陷,本发明将基于二级掩码的多头注意力机制并引入组归一化及门控线性单元等优化策略,以达到提高注意力模块学习能力的目的;为了进一步避免由上述改进可能导致的特征驱动问题,本发明还将层间残差结构引入框架,以跳跃连接方式进行隐层节点特征融合,从而提升多层卷积后节点特征的多样性,避免特征趋同所引起的过平滑和过挤压问题。
附图说明
[0043]
图1为本发明方法中图节点分类任务的处理流程图;
[0044]
图2为本发明方法中构建基于拓扑强化的节点嵌入流程图;
[0045]
图3为本发明方法中构建基于二级掩码的多头注意力机制流程图;
[0046]
图4为本发明方法中构造二级掩码指示矩阵流程图;
[0047]
图5为本发明方法中改进传统transformer框架结构流程图;
[0048]
图6为本发明方法中一种图节点分类方法流程图;
[0049]
图7为本发明方法中节点分类模型的网络结构示意图;
[0050]
图8为本发明方法中层间残差网络结构示意图;
[0051]
图9a为本发明方法与各对比模型在小规模数据集cora上的评估指标直方图;
[0052]
图9b为本发明方法与各对比模型在小规模数据集citeseer上的评估指标直方图;
[0053]
图9c为本发明方法与各对比模型在小规模数据集pubmed上的评估指标直方图;
[0054]
图10a为本发明方法与各对比模型在中规模数据集cluster上的评估指标直方图;
[0055]
图10b为本发明方法与各对比模型在中规模数据集pattern上的评估指标直方图;
[0056]
图11为数据集详细参数及训练集、验证集、测试集划分;
[0057]
图12为本发明方法与各对比模型在小规模数据集上的评估指标比较结果;
[0058]
图13为本发明方法与各对比模型在中规模数据集上的评估指标比较结果。
具体实施方式
[0059]
下面结合附图和表格对本发明的技术方案进行详细阐述。根据下列描述和权利要求书,本发明的优点和特征将表述地更清楚。需要说明的是,附图均采用简化的形式且使用非精准的比例,仅用以方便、明晰地辅助说明本发明实施的目的。
[0060]
图7是本发明节点分类模型的网络结构示意图,该模型由l层子层即l layers堆叠组合实现,每层子层即l layer均包含改进的多头自注意力层即h headers和前馈神经网络层即ffn,且输入特征均经过拓扑学习强化,子层层间则通过跳跃连接进行特征融合。最后,再通过mlp分类器得到分类结果cv。
[0061]
设通过各类传感器及爬虫等设备设施和技术采集信息所构造的图结构数据为g = (v, e),其中v = {v1, v2,
ꢀ…
, vn}、e = {e1, e2,
ꢀ…
, em}分别代表节点集和边集,a、d分别为其邻接矩阵和度矩阵。随后,采用拉普拉斯特征向量将图结构信息压缩入节点位置编码,构建基于拓扑强化的节点嵌入。图拉普拉斯矩阵计算公式如下:
[0062][0063]
其中,λ表示图拉普拉斯矩阵的特征值向量,u表示对应的特征向量矩阵。以u中k个最小非平凡特征向量ε1,
ꢀ…
, εk的升序构造拉普拉斯位置编码集λ = [λ
1,kt
,
ꢀ…
, λ
n,kt
] = [ε1,
ꢀ…
, εk],其中λ
v,k
(v = 1,
ꢀ…
, n)为节点v的拉普拉斯位置编码。
[0064]
利用隐空间映射对齐位置编码维度,计算表达式满足如下公式:
[0065][0066]
其中,w
λ
表示隐空间映射权值,a
λ
为映射偏置,λ
v,k
为节点v映射前的位置编码。
[0067]
通过sum pooling在每层均将节点输入特征与位置编码进行融合,则节点v在l层的拓扑强化特征h
v(l)
可表示为:
[0068][0069]
其中,z
v(l)
为l层的节点输入特征。当l = 0时,z
v(0)
为节点v的初始输入特征,其值可由节点v对应的初始属性xv直接赋值获取,也可通过隐空间映射z
v(0) = wvx
v + av获得。
[0070]
在确定各层节点输入嵌入后,进一步构造基于二级掩码的多头注意力机制。首先,针对各层输入嵌入进行组归一前置操作,并利用线性映射构造qkv模型,其中q = gn(h
v(l)
)w
vq
、k = gn(h
u(l)
)w
vk
、v = gn(h
u(l)
)w
vv
。此时,利用维度缩放即scale及softmax操作构造二级掩码矩阵对聚合邻居进行针对性筛选,具体构造表达式如下式:
[0071][0072]
其中,dk为w
vk
的列维度,m
(l)
为l层二级掩码的指示矩阵即mask,其构造策略为通过q
(l)g,v
(k
(l)g,v
)
t
所得注意力分数进行直接筛选判定,示性函数为:
[0073]
;
[0074]
其次,采用glu操作对多头注意力输出进行激活即glu&linear。此时,基于二级掩
码的多头注意力机制可表示如下:
[0075][0076][0077][0078]
其中,h
(l)v,cat
、h
(l)v,glu
、h
(l)v,linear
分别表示多头拼接、glu操作、线性映射后节点v的隐层特征,w
(l)1,glu
、w
(l)2,glu
、c
(l)
、d
(l)
分别为glu中的权值和偏置参数,w
(l)linear
为线性映射的权值参数。
[0079]
接下来,采用改进的transformer框架结构:由于多头注意力模块在输入时已经过gn处理,故本发明针对transformer框架稍作变化,将多头注意力输出目标节点输入的sum pooling结果直接作为ffn模块的输入特征。此时,采用pre-norm结构并替换层归一化的位置ffn可定义为:
[0080][0081]
其中,z
v(l)
表示ffn模块的输出,同时也为l层的隐层输出特征。w
(l)1,ffn
、w
(l)2,ffn
分别为ffn权值参数,b
(l)1,ffn
、b
(l)2,ffn
为相应的偏置。
[0082]
最后,以跳跃连接方式构造层间残差网络,并采用sum pooling融合节点特征作为下层输入,其计算公式如下所示:
[0083][0084]
鉴于特征趋同会令模型对图拓扑结构的捕捉能力减弱,故构造位置残差网络令每层均进行拓扑学习强化。如图8所示,不同于传统方式,本发明在分类模型的每个隐层中均累加接入位置编码连接,则层间残差网络计算公式为:
[0085][0086]
经过上述处理后,利用l层卷积迭代计算各节点最终输出嵌入。对于节点v,其分类预测结果可表示为:
[0087][0088]
其中,∈r1×c表示节点v的最终分类预测结果,c为分类标签数,wo、bo分别为mlp分类网络的权值参数和偏置。
[0089]
此时,根据已划分的训练集及验证集,利用交叉熵损失函数训练模型中超参数及学习参数,并通过测试集进行模型分类性能测试验证。训练过程的交叉熵损失函数计算公式如下:
[0090][0091]
其中,训练集为s,n = |s|为训练集样本个数。
[0092]
为验证本发明申请的节点分类方法能够同时兼顾小规模和中规模数据集的节点分类任务,实现分类性能的全面提升的有益效果,将本发明申请的方法在小规模数据集cora、citeseer、pubmed及中规模数据集pattern、cluster上进行实现,具体各类数据集详细参数及训练集、验证集、测试集划分如图11所示。并且,为验证本发明方法的性能,将本发明申请方法与gcn、gat、graph transformer (gt) 、graphsage、gatedgcn、gin共6类经典模型在accuracy,f1和recall三类评估指标上进行比较。
[0093]
在模型训练时,采用adam优化器进行模型优化,并且将多头个数设置为8,初始学习率和衰减因子分别设置为0.0005、0.5。同时,在小规模数据集cora、citeseer和pubmed上将dropout设置为0.1,layers设为10,而针对中规模数据集cluster、pattern,则将其dropout分别设为0.1、0.01,layers分别设为13、12。
[0094]
图12给出了本发明申请方法与各对比模型在三类小规模数据集cora、citeseer及pubmed上的评估指标比较结果,其中最佳性能指标以粗体表示。图9a、图9b和图9c则分别列出了本发明申请方法与各对比模型在三类小规模数据集cora、citeseer及pubmed上的更加详细的性能比较结果。可以看出,本发明申请方法在不同小规模数据集上均可获得最优性能,其与6类对比模型在三类评估指标上的最好结果相比均有大幅提升。
[0095]
此外,为验证本发明申请方法在中规模数据集上同样具有较高的有效性,图13还给出了针对cluster及pattern数据集的不同模型评估比较结果。根据图13数据,所提模型对于中规模数据集同样适用,仍可取得令人满意的节点分类效果。由图10a和图10b也可知,该方法在cluster及pattern数据集上均取得了各评估指标的最优性能,相较于各对比模型均可获得较大提升。
技术特征:
1.一种图节点分类模型训练方法,其特征在于,所述训练方法包括以下步骤:步骤s1,获取图结构源数据,所述图结构源数据的类型包括论文引用、分子结构类图结构源数据,每类图结构源数据中的数据包括实体和实体间联系两类数据,将实体作为节点、实体间联系作为边进行数据预处理以构造图结构数据集g = (v, e),其中v = {v1, v2,
ꢀ…
, v
n
}、e = {e1, e2,
ꢀ…
, e
m
}分别代表节点集和边集;步骤s2,分割图结构数据集,将图结构数据集划分为训练集、验证集及测试集;步骤s3,采用所述训练集训练获得初始节点分类模型,模型训练的超参数包括:多头注意力头数、初始学习率、衰减因子、丢弃率、层数;所述模型训练过程包括步骤:采用基于拓扑强化的节点嵌入步骤将图邻接矩阵a和度矩阵d所表示的结构信息压缩入节点位置编码;采用基于二级掩码的多头注意力机制处理,输入目标节点集和相应邻居节点集,针对性筛选邻居节点进行聚合及更新,输出各卷积层的目标节点输出嵌入表示;采用改进的transformer框架结构以跃连接方式进行节点特征融合,所述改进的transformer框架结构是由归一前置的ffn及层间残差构成的;步骤s4,利用验证集对所述初始节点分类模型的超参数进行优化设置获得优化的节点分类模型;令其可同时兼顾小规模和中规模数据集的节点分类任务,实现分类性能的全面提升;步骤s5,利用测试集对所述优化的节点分类模型进行性能测试验证。2.如权利要求书1所述的一种图节点分类模型训练方法,其特征在于,步骤s3中所述超参数的初始设置具体为:多头注意力头数为8,初始学习率为0.0005,衰减因子0.5,丢弃率从0开始递增、层数从2开始递增。3.如权利要求书1所述的一种图节点分类模型训练方法,其特征在于,所述基于拓扑强化的节点嵌入包含以下步骤:步骤s201:将训练集的图g的邻接矩阵a、度矩阵d作为变量计算图拉普拉斯矩阵δ,并对其进行特征值分解获取特征向量矩阵u,进而以u中k个最小非平凡特征向量ε1,
ꢀ…
, ε
k
的升序构造拉普拉斯位置编码集λ,将图结构信息压缩入各节点位置编码;步骤s202:利用隐空间映射矩阵w
λ
对拉普拉斯位置编码集λ中的各节点位置编码λ
v,k
,(v = 1,
ꢀ…
, n)进行维度对齐,获得对齐后的节点位置编码λ
v
;步骤s203:通过sum pooling在每层均将节点v的输入特征z
v(l)
与位置编码λ
v
进行融合,进而获得该节点的拓扑强化特征h
v(l)
,实现图拓扑学习强化;其中,当l= 0时,z
v(0)
为节点初始输入特征,其值可由节点v对应的初始属性x
v
直接赋值获取,也可通过隐空间映射获得。4.如权利要求书3所述的一种图节点分类模型训练方法,其特征在于,所述基于二级掩码的多头注意力机制处理包含以下步骤:步骤s301:在步骤s203获取基于拓扑强化的节点嵌入基础上,将目标节点集和相应邻居节点集作为输入,针对当前层目标节点v及其邻居节点集进行组归一化操作,获取归一化输入gn(h
v(l)
)、gn(h
u(l)
),并利用线性映射构造qkv模型,其中q = gn(h
v(l)
)w
vq
、k = gn(h
u(l)
)w
vk
、v = gn(h
u(l)
)w
vv
;步骤s302:根据点积操作计算各节点注意力分数矩阵(gn(h
v(l)
)w
vq
)
·
( gn(h
u(l)
)w
vk
)
t
,并根据分值构造二级掩码指示矩阵m
(l)
进行聚合邻居筛选,进而采取维度缩放及softmax操作获得二级掩码矩阵;
步骤s303:根据所述二级掩码矩阵对目标节点v进行多头注意力计算,获得多头输出特征h
(l)v,cat
,并采用门控线性单元对h
(l)v,cat
进行激活,获取激活后的节点特征h
(l)v,glu
;步骤s304:利用感知机对激活后的节点特征h
(l)v,glu
进行线性映射,使其维度与隐层输出维度对齐,获取映射后该层目标节点的隐层输出特征h
(l)v,linear
。5.如权利要求书4所述的一种图节点分类模型训练方法,其特征在于,步骤s302中,所述构造二级掩码指示矩阵m
(l)
进行聚合邻居筛选包含以下步骤:步骤s401:构造元素值均为0的二级掩码指示矩阵m
(l)
;步骤s402:根据步骤s302计算所得注意力分数矩阵,获取各邻居节点注意力分数s
u
,(u = 1,
ꢀ…
, m),其中m为一阶邻居节点个数;步骤s403:比较各节点注意力分数,当s
u
大于0时,对应位置的二级掩码指示矩阵元素值设置为1,表示该邻居节点参与后续特征传播及聚合过程;步骤s404:当s
u
小于等于0时,对应位置的二级掩码指示矩阵元素值不变,表示舍弃该邻居节点。6.如权利要求书4所述的一种图节点分类模型训练方法,其特征在于,采用改进的transformer框架结构以跃连接方式进行节点特征融合包含以下步骤:步骤s501:以步骤s304获得的该层目标节点的隐层输出特征h
(l)v,linear
与该层目标节点输入特征h
(l)v
进行sum pooling操作的结果作为ffn模块的输入特征h
(l)v,ffn
,采用pre-norm结构,在位置前馈网络中增加组归一化前置操作gn(h
(l)v,ffn
),其中h
(l)v,ffn
为ffn模块的输入特征,并且,同步替换ffn模块中的层归一化为组归一化操作处理节点嵌入表示,然后输出目标节点输出特征z
v(l)
;步骤s502:利用跳跃残差模式,在相邻层构造层间残差网络,并且以本层输入特征h
v(l)
与输出特征z
v(l)
作为输入数据,采用sum pooling将h
v(l)
和z
v(l)
融合生成特征h
v,r(l+1)
作为残差输出特征;步骤s503:将步骤s202所得位置编码λ
v
与各层残差输出特征h
v,r(l+1)
作为输入,进行sum pooling融合,构造基于拓扑学习强化的节点特征h
v(l+1)
,输出下层输入特征数据;步骤s504:判断是否为最后一层,若未到达最后一层,则跳转至步骤s301;若达到最后一层,则以步骤s503中最后一层的融合输出表示h
vl
作为mlp分类器输入,进行节点分类,获得模型输出结果。7.如权利要求书1所述的一种图节点分类模型训练方法,其特征在于,步骤s4中,图结构数据为小规模数据集,所述超参数的优化设置包括:设置模型丢弃率为0.1,层数为10。8.如权利要求书1所述的一种图节点分类模型训练方法,其特征在于,步骤s4中,图结构数据为中规模数据集,所述超参数的优化设置包括:设置模型丢弃率为0.1或0.01,层数为12或13。9.一种图节点分类方法,其特征在于,所述分类方法包括步骤:获取分类任务数据,将任务数据中的实体数据预处理为节点、将实体间的联系数据预处理为边,构造用于节点分类的图结构数据;输入预处理后的任务数据,根据所述分类任务规模,设置超参数后输入权利要求1中所述优化后的图节点分类模型,输出分类结果。10.一种电子设备,其特征在于,包括:一个或多个处理器;存储装置,用于存储一个或多个程序;当一个或多个程序被一个或多个处理器执行,使得一个或多个处理器实现权利
要求9所述的图节点分类方法。
技术总结
本发明属于图机器学习、图表示学习技术领域,提供了一种图节点分类模型训练方法、图节点分类方法,具体涉及一种基于图注意力和改进Transformer图节点分类模型训练方法、图节点分类方法,具体将基于二级掩码的图注意力机制及结构强化学习、层间残差等优化策略融入Transformer框架,构建一种改进的Transformer模型以提高其对图数据的建模精度,同时兼顾小规模和中规模图数据集的节点分类任务,实现分类性能的全面提升。类性能的全面提升。类性能的全面提升。
技术研发人员:李鑫 朱攀 陆伟 马召祎 赵晨廷 吕赛 李青松
受保护的技术使用者:南京邮电大学
技术研发日:2023.08.28
技术公布日:2023/10/5
版权声明
本文仅代表作者观点,不代表航家之家立场。
本文系作者授权航家号发表,未经原创作者书面授权,任何单位或个人不得引用、复制、转载、摘编、链接或以其他任何方式复制发表。任何单位或个人在获得书面授权使用航空之家内容时,须注明作者及来源 “航空之家”。如非法使用航空之家的部分或全部内容的,航空之家将依法追究其法律责任。(航空之家官方QQ:2926969996)
航空之家 https://www.aerohome.com.cn/
飞机超市 https://mall.aerohome.com.cn/
航空资讯 https://news.aerohome.com.cn/