PGL圖學習之圖神經網路GNN模型GCN、GAT[系列六]

2022-11-18 12:01:39

PGL圖學習之圖神經網路GNN模型GCN、GAT[系列六]

專案連結:一鍵fork直接跑程式 https://aistudio.baidu.com/aistudio/projectdetail/5054122?contributionType=1

0.前言-學術界業界論文發表情況

ICLR2023評審情況:

ICLR2023的評審結果已經正式釋出!今年的ICLR2023共計提交6300份初始摘要和4922份經過審查的提交,其中經過審查提交相比上一年增加了32.2%。在4922份提交內容中,99%的內容至少有3個評論,總共有超過18500個評論。按照Open Review評審制度,目前ICLR已經進入討論階段。

官網連結:https://openreview.net/group?id=ICLR.cc/2022/Conference

在4922份提交內容中,主要涉及13個研究方向,具體有:

1、AI應用應用,例如:語音處理、計算機視覺、自然語言處理等
2、深度學習和表示學習
3、通用機器學習
4、生成模型
5、基礎設施,例如:資料集、競賽、實現、庫等
6、科學領域的機器學習,例如:生物學、物理學、健康科學、社會科學、氣候/可持續性等
7、神經科學與認知科學,例如:神經編碼、腦機介面等
8、優化,例如:凸優化、非凸優化等
9、概率方法,例如:變分推理、因果推理、高斯過程等
10、強化學習,例如:決策和控制,計劃,分層RL,機器人等
11、機器學習的社會方面,例如:人工智慧安全、公平、隱私、可解釋性、人與人工智慧互動、倫理等
12、理論,例如:如控制理論、學習理論、演演算法博弈論。
13、無監督學習和自監督學習

ICLR詳細介紹:

ICLR,全稱為「International Conference on Learning Representations」(國際學習表徵會議),2013 年5月2日至5月4日在美國亞利桑那州斯科茨代爾順利舉辦了第一屆ICLR會議。該會議是一年一度的會議,截止到2022年它已經舉辦了10屆,而今年的(2023年)5月1日至5日,將在基加利會議中心完成ICLR的第十一屆會議。

該會議被學術研究者們廣泛認可,被認為是「深度學習的頂級會議」。為什麼ICLR為什麼會成為深度學習領域的頂會呢? 首先該會議由深度學習三大巨頭之二的Yoshua Bengio和Yann LeCun 牽頭創辦。其中Yoshua Bengio 是蒙特利爾大學教授,深度學習三巨頭之一,他領導蒙特利爾大學的人工智慧實驗室MILA進行 AI 技術的學術研究。MILA 是世界上最大的人工智慧研究中心之一,與谷歌也有著密切的合作。 Yann LeCun同為深度學習三巨頭之一的他現任 Facebook 人工智慧研究院FAIR院長、紐約大學教授。作為折積神經網路之父,他為深度學習的發展和創新作出了重要貢獻。

Keywords Frequency 排名前 50 的常用關鍵字(不區分大小寫)及其出現頻率:

可有看到圖、圖神經網路分別排在2、4。

1.Graph Convolutional Networks(GCN,圖折積神經網路)

GCN的概念首次提出於ICLR2017(成文於2016年):

Semi-Supervised Classification with Graph Convolutional Networks:https://arxiv.org/abs/1609.02907

圖資料中的空間特徵具有以下特點:
1) 節點特徵:每個節點有自己的特徵;(體現在點上)
2) 結構特徵:圖資料中的每個節點具有結構特徵,即節點與節點存在一定的聯絡。(體現在邊上)
總地來說,圖資料既要考慮節點資訊,也要考慮結構資訊,圖折積神經網路就可以自動化地既學習節點特徵,又能學習節點與節點之間的關聯資訊。

綜上所述,GCN是要為除CV、NLP之外的任務提供一種處理、研究的模型。
圖折積的核心思想是利用『邊的資訊』對『節點資訊』進行『聚合』從而生成新的『節點表示』。

1.1原理簡介

  • 假如我們希望做節點相關的任務,就可以通過 Graph Encoder,在圖上學習到節點特徵,再利用學習到的節點特徵做一些相關的任務,比如節點分類、關係預測等等;

  • 而同時,我們也可以在得到的節點特徵的基礎上,做 Graph Pooling 的操作,比如加權求和、取平均等等操作,從而得到整張圖的特徵,再利用得到的圖特徵做圖相關的任務,比如圖匹配、圖分類等。

圖遊走類演演算法主要的目的是在訓練得到節點 embedding 之後,再利用其做下游任務,也就是說區分為了兩個階段。

對於圖折積網路而言,則可以進行一個端到端的訓練,不需要對這個過程進行區分,那麼這樣其實可以更加針對性地根據具體任務來進行圖上的學習和訓練。

回顧折積神經網路在影象及文字上的發展

在影象上的二維折積,其實質就是折積核在二維影象上平移,將折積核的每個元素與被折積影象對應位置相乘,再求和,得到一個新的結果。其實它有點類似於將當前畫素點和周圍的畫素點進行某種程度的轉換,再得到當前畫素點更新後的一個值。

它的本質是利用了一維折積,因為文字是一維資料,在我們已知文字的詞表示之後,就在詞級別上做一維的折積。其本質其實和影象上的折積沒有什麼差別。
(注:折積核維度和紅框維度相同,2 * 6就是2 * 6)

影象折積的本質其實非常簡單,就是將一個畫素點周圍的畫素,按照不同的權重疊加起來,當然這個權重就是我們通常說的折積核。

其實可以把當前畫素點類比做圖的節點,而這個節點周圍的畫素則類比為節點的鄰居,從而可以得到圖結構折積的簡單的概念:

將一個節點周圍的鄰居按照不同的權重疊加起來

而影象上的折積操作,與圖結構上的折積操作,最大的一個區別就在於:

  • 對於影象的畫素點來說,它的周圍畫素點數量其實是固定的;
  • 但是對於圖而言,節點的鄰居數量是不固定的。

1.2圖折積網路的兩種理解方式

GCN的本質目的就是用來提取拓撲圖的空間特徵。 而圖折積神經網路主要有兩類,一類是基於空間域或頂點域vertex domain(spatial domain)的,另一類則是基於頻域或譜域spectral domain的。通俗點解釋,空域可以類比到直接在圖片的畫素點上進行折積,而頻域可以類比到對圖片進行傅立葉變換後,再進行折積。

所謂的兩類其實就是從兩個不同的角度理解,關於從空間角度的理解可以看本文的從空間角度理解GCN

vertex domain(spatial domain):頂點域(空間域)

基於空域折積的方法直接將折積操作定義在每個結點的連線關係上,它跟傳統的折積神經網路中的折積更相似一些。在這個類別中比較有代表性的方法有 Message Passing Neural Networks(MPNN)[1], GraphSage[2], Diffusion Convolution Neural Networks(DCNN)[3], PATCHY-SAN[4]等。

spectral domain:頻域方法(譜方法

這就是譜域圖折積網路的理論基礎了。這種思路就是希望藉助圖譜的理論來實現拓撲圖上的折積操作。從整個研究的時間程序來看:首先研究GSP(graph signal processing)的學者定義了graph上的Fourier Transformation,進而定義了graph上的convolution,最後與深度學習結合提出了Graph Convolutional Network。

基於頻域折積的方法則從圖訊號處理起家,包括 Spectral CNN[5], Cheybyshev Spectral CNN(ChebNet)[6], 和 First order of ChebNet(1stChebNet)[7] 等

論文Semi-Supervised Classification with Graph Convolutional Networks就是一階鄰居的ChebNet

認真讀到這裡,腦海中應該會浮現出一系列問題:

Q1 什麼是Spectral graph theory?

Spectral graph theory請參考維基百科的介紹,簡單的概括就是藉助於圖的拉普拉斯矩陣的特徵值和特徵向量來研究圖的性質

Q2 GCN為什麼要利用Spectral graph theory?

這是論文(Semi-Supervised Classification with Graph Convolutional Networks)中的重點和難點,要理解這個問題需要大量的數學定義及推導

過程:

  • (1)定義graph上的Fourier Transformation傅立葉變換(利用Spectral graph theory,藉助圖的拉普拉斯矩陣的特徵值和特徵向量研究圖的性質)
  • (2)定義graph上的convolution折積

1.3 圖折積網路的計算公式

  • H代表每一層的節點表示,第0層即為最開始的節點表示
  • A表示鄰接矩陣,如下圖所示,兩個節點存在鄰居關係就將值設為1,對角線預設為1
  • D表示度矩陣,該矩陣除對角線外均為0,對角線的值表示每個節點的度,等價於鄰接矩陣對行求和
  • W表示可學習的權重

鄰接矩陣的對角線上都為1,這是因為新增了自環邊,這也是這個公式中使用的定義,其他情況下鄰接矩陣是可以不包含自環的。(包含了自環邊的鄰接矩陣)

度矩陣就是將鄰接矩陣上的每一行進行求和,作為對角線上的值。而由於我們是要取其-1/2的度矩陣,因此還需要對對角線上求和後的值做一個求倒數和開根號的操作,因此最後可以得到右邊的一個矩陣運算結果。

為了方便理解,我們可以暫時性地把度矩陣在公式中去掉:

  • 為了得到 H^{(l+1)}的第0行,我們需要拿出A的第0行與 H^{(l)}相乘,這是矩陣乘法的概念。
  • 接下來就是把計算結果相乘再相加的過程。
  • 這個過程其實就是訊息傳遞的過程:對於0號節點來說,將鄰居節點的資訊傳給自身

將上式進行拆分,A*H可以理解成將上一層每個節點的節點表示進行聚合,如圖,0號節點就是對上一層與其相鄰的1號、2號節點和它本身進行聚合。而度矩陣D存在的意義是每個節點的鄰居的重要性不同,根據該節點的度來對這些相鄰節點的節點表示進行加權,d越大,說明資訊量越小。

實際情況中,每個節點傳送的資訊所帶的資訊量應該是不同的。

圖折積網路將鄰接矩陣的兩邊分別乘上了度矩陣,相當於給這個邊加了權重。其實就是利用節點度的不同來調整資訊量的大小。

這個公式其實體現了:
一個節點的度越大,那麼它所包含的資訊量就越小,從而對應的權值也就越小。

怎麼理解這樣的一句話呢,我們可以設想這樣的一個場景。假如說在一個社群網路里,一個人認識了幾乎所有的人,那麼這個人能夠給我們的資訊量是比較小的。

也就是說,每個節點通過邊對外傳送相同量的資訊, 邊越多的節點,每條邊傳送出去的資訊量就越小。

1.4用多層圖網路完成節點分類任務

GCN演演算法主要包括以下幾步:

  • 第一步是利用上面的核心公式進行節點間特徵傳遞
  • 第二步對每個節點過一層DNN
  • 重複以上兩步得到L層的GCN
  • 獲得的最終節點表示H送入分類器進行分類

更詳細的資料參考:圖折積網路 GCN Graph Convolutional Network(譜域GCN)的理解和詳細推導

1.5 GCN引數解釋

主要是幫助大家理解訊息傳遞機制的一些引數型別。

這裡給出一個簡化版本的 GCN 模型,幫助理解PGL框架實現訊息傳遞的流程。


def gcn_layer(gw, feature, hidden_size, activation, name, norm=None):
    """
    描述:通過GCN層計算新的節點表示
    輸入:gw - GraphWrapper物件
         feature - 節點表示 (num_nodes, feature_size)
         hidden_size - GCN層的隱藏層維度 int
         activation - 啟用函數 str
         name - GCN層名稱 str
         norm - 標準化tensor float32 (num_nodes,),None表示不標準化
    輸出:新的節點表示 (num_nodes, hidden_size)
    """

    # send函數
    def send_func(src_feat, dst_feat, edge_feat):
        """
        描述:用於send節點資訊。函數名可自定義,參數列固定
        輸入:src_feat - 源節點的表示字典 {name:(num_edges, feature_size)}
             dst_feat - 目標節點表示字典 {name:(num_edges, feature_size)}
             edge_feat - 與邊(src, dst)相關的特徵字典 {name:(num_edges, feature_size)}
        輸出:儲存傳送資訊的張量或字典 (num_edges, feature_size) or {name:(num_edges, feature_size)}
        """
        return src_feat["h"] # 直接返回源節點表示作為資訊

    # send和recv函數是搭配實現的,send的輸出就是recv函數的輸入
    # recv函數
    def recv_func(msg):
        """
        描述:對接收到的msg進行聚合。函數名可自定義,參數列固定
        輸出:新的節點表示張量 (num_nodes, feature_size)
        """
        return L.sequence_pool(msg, pool_type='sum') # 對接收到的訊息求和

    ### 訊息傳遞機制執行過程
    # gw.send函數
    msg = gw.send(send_func, nfeat_list=[("h", feature)]) 
    """ 
    描述:觸發message函數,傳送訊息並將訊息返回
    輸入:message_func - 自定義的訊息函數
         nfeat_list - list [name] or tuple (name, tensor)
         efeat_list - list [name] or tuple (name, tensor)
    輸出:訊息字典 {name:(num_edges, feature_size)}
    """

    # gw.recv函數
    output = gw.recv(msg, recv_func)
    """ 
    描述:觸發reduce函數,接收並處理訊息
    輸入:msg - gw.send輸出的訊息字典
         reduce_function - "sum"或自定義的reduce函數
    輸出:新的節點特徵 (num_nodes, feature_size)

    如果reduce函數是對訊息求和,可以直接用"sum"作為引數,使用內建函數加速訓練,上述語句等價於 \
    output = gw.recv(msg, "sum")
    """

    # 通過以activation為啟用函數的全連線輸出層
    output = L.fc(output, size=hidden_size, bias_attr=False, act=activation, name=name)
    return output

2.Graph Attention Networks(GAT,圖注意力機制網路)

Graph Attention Networks:https://arxiv.org/abs/1710.10903

GCN網路中的一個缺點是邊的權重與節點的度度相關而且不可學習,因此有了圖注意力演演算法。在GAT中,邊的權重變成節點間的可學習的函數並且與兩個節點之間的相關性有關。

2.1.計算方法

注意力機制的計算方法如下:

首先將目標節點和源節點的表示拼接到一起,通過網路計算相關性,之後通過LeakyReLu啟用函數和softmax歸一化得到注意力分數,最後用得到的α進行聚合,後續步驟和GCN一致。

以及多頭Attention公式

2.2 空間GNN

空間GNN(Spatial GNN):基於鄰居聚合的圖模型稱為空間GNN,例如GCN、GAT等等。大部分的空間GNN都可以用訊息傳遞實現,訊息傳遞包括訊息的傳送和訊息的接受。

基於訊息傳遞的圖神經網路的通用公式:

2.3 訊息傳遞demo例子

2.4 GAT引數解釋

其中:

  • 在 send 函數中完成 LeakyReLU部分的計算;
  • 在 recv 函數中,對接受到的 logits 資訊進行 softmax 操作,形成歸一化的分數(公式當中的 alpha),再與結果進行加權求和。
def single_head_gat(graph_wrapper, node_feature, hidden_size, name):
    # 實現單頭GAT

    def send_func(src_feat, dst_feat, edge_feat):
        ##################################
        # 按照提示一步步理解程式碼吧,你只需要填###的地方

        # 1. 將源節點特徵與目標節點特徵concat 起來,對應公式當中的 concat 符號,可能用到的 API: fluid.layers.concat
        Wh = fluid.layers.concat(input=[src_feat["Wh"], dst_feat["Wh"]], axis=1)
    
        # 2. 將上述 Wh 結果通過全連線層,也就對應公式中的a^T

        alpha = fluid.layers.fc(Wh, 
                            size=1, 
                            name=name + "_alpha", 
                            bias_attr=False)

        # 3. 將計算好的 alpha 利用 LeakyReLU 函數啟用,可能用到的 API: fluid.layers.leaky_relu
        alpha = fluid.layers.leaky_relu(alpha, 0.2)
        
        ##################################
        return {"alpha": alpha, "Wh": src_feat["Wh"]}
    
    def recv_func(msg):
        ##################################
        # 按照提示一步步理解程式碼吧,你只需要填###的地方

        # 1. 對接收到的 logits 資訊進行 softmax 操作,形成歸一化分數,可能用到的 API: paddle_helper.sequence_softmax
        alpha = msg["alpha"]
        norm_alpha = paddle_helper.sequence_softmax(alpha)

        # 2. 對 msg["Wh"],也就是節點特徵,用上述結果進行加權
        output = norm_alpha * msg["Wh"]

        # 3. 對加權後的結果進行相加的鄰居聚合,可能用到的API: fluid.layers.sequence_pool
        output = fluid.layers.sequence_pool(output, pool_type="sum")
        ##################################
        return output
    
    # 這一步,其實對應了求解公式當中的Whi, Whj,相當於對node feature加了一個全連線層

    Wh = fluid.layers.fc(node_feature, hidden_size, bias_attr=False, name=name + "_hidden")
    # 訊息傳遞機制執行過程
    message = graph_wrapper.send(send_func, nfeat_list=[("Wh", Wh)])
    output = graph_wrapper.recv(message, recv_func)
    output = fluid.layers.elu(output)
    return output


def gat(graph_wrapper, node_feature, hidden_size):
    # 完整多頭GAT

    # 這裡設定多個頭,每個頭的輸出concat在一起,構成多頭GAT
    heads_output = []
    # 可以調整頭數 (8 head x 8 hidden_size)的效果較好 
    n_heads = 8
    for head_no in range(n_heads):
        # 請完成單頭的GAT的程式碼
        single_output = single_head_gat(graph_wrapper, 
                            node_feature, 
                            hidden_size, 
                            name="head_%s" % (head_no) )
        heads_output.append(single_output)
    
    output = fluid.layers.concat(heads_output, -1)
    return output


3.資料集介紹

3個常用的圖學習資料集,CORA, PUBMED, CITESEER。可以在論文中找到資料集的相關介紹。

今天我們來了解一下這幾個資料集

3.1Cora資料集

Cora資料集由機器學習論文組成,是近年來圖深度學習很喜歡使用的資料集。
在整個語料庫中包含2708篇論文,並分為以下七類:

  • 基於案例
  • 遺傳演演算法
  • 神經網路
  • 概率方法
  • 強化學習
  • 規則學習
  • 理論

論文之間互相參照,在該資料集中,每篇論文都至少參照了一篇其他論文,或者被其他論文參照,也就是樣本點之間存在聯絡,沒有任何一個樣本點與其他樣本點完全沒聯絡。如果將樣本點看做圖中的點,則這是一個連通的圖,不存在孤立點。這樣形成的網路有5429條邊。
在消除停詞以及除去檔案頻率小於10的詞彙,最終詞彙表中有1433個詞彙。每篇論文都由一個1433維的詞向量表示,所以,每個樣本點具有1433個特徵。詞向量的每個元素都對應一個詞,且該元素只有0或1兩個取值。取0表示該元素對應的詞不在論文中,取1表示在論文中。

資料集有包含兩個檔案:

  1. .content檔案包含以下格式的論文描述:

<paper_id> <word_attributes>+ <class_label>

每行的第一個條目包含紙張的唯一字串標識,後跟二進位制值,指示詞彙中的每個單詞在文章中是存在(由1表示)還是不存在(由0表示)。

最後,該行的最後一個條目包含紙張的類別標籤。因此資料集的$feature$應該為$2709×14332709 \times 14332709×1433$維度。
第一行為$idx$,最後一行為$label$。

  1. 那個.cites檔案包含語料庫的參照’圖’。每行以以下格式描述一個連結:

<被引論文編號> <引論文編號>

每行包含兩個紙質id。第一個條目是被參照論文的標識,第二個標識代表包含參照的論文。連結的方向是從右向左。

如果一行由「論文1 論文2」表示,則連結是「論文2 - >論文1」。可以通過論文之間的索引關係建立鄰接矩陣adj

下載連結:

https://aistudio.baidu.com/aistudio/datasetdetail/177587

相關論文:


3.2PubMed資料集

PubMed 是一個提供生物醫學方面的論文搜尋以及摘要,並且免費搜尋的資料庫。它的資料庫來源為MEDLINE。其核心主題為醫學,但亦包括其他與醫學相關的領域,像是護理學或者其他健康學科。

PUBMED資料集是基於PubMed 文獻資料庫生成的。它包含了19717篇糖尿病相關的科學出版物,這些出版物被分成三個類別。
這些出版物的互相參照網路包含了44338條邊。在消除停詞以及除去低頻詞彙,最終詞彙表中有500個詞彙。這些出版物用一個TF/IDF加權的詞向量來描述是否包含詞彙表中的詞彙。

下載連結:
https://aistudio.baidu.com/aistudio/datasetdetail/177591

相關論文:


3.3CiteSeer資料集

CiteSeer(又名ResearchIndex),是NEC研究院在自動引文索引(Autonomous Citation Indexing, ACI)機制的基礎上建設的一個學術論文數點陣圖書館。這個引文索引系統提供了一種通過引文連結的檢索文獻的方式,目標是從多個方面促進學術文獻的傳播和反饋。

在整個語料庫中包含3312篇論文,並分為以下六類:

  • Agents
  • AI
  • DB
  • IR
  • ML
  • HCI

論文之間互相參照,在該資料集中,每篇論文都至少參照了一篇其他論文,或者被其他論文參照,也就是樣本點之間存在聯絡,沒有任何一個樣本點與其他樣本點完全沒聯絡。如果將樣本點看做圖中的點,則這是一個連通的圖,不存在孤立點。這樣形成的網路有4732條邊。
在消除停詞以及除去檔案頻率小於10的詞彙,最終詞彙表中有3703個詞彙。每篇論文都由一個3703維的詞向量表示,所以,每個樣本點具有3703個特徵。詞向量的每個元素都對應一個詞,且該元素只有0或1兩個取值。取0表示該元素對應的詞不在論文中,取1表示在論文中。

下載連結:

https://aistudio.baidu.com/aistudio/datasetdetail/177589

相關論文


3.4 小結

資料集 圖數 節點數 邊數 特徵維度 標籤數
Cora 1 2708 5429 1433 7
Citeseer 1 3327 4732 3703 6
Pubmed 1 19717 44338 500 3

更多圖資料集:
https://linqs.org/datasets/

GCN常用資料集
KarateClub:資料為無向圖,來源於論文An Information Flow Model for Conflict and Fission in Small Groups
TUDataset:包括58個基礎的分類資料集集合,資料都為無向圖,如」IMDB-BINARY」,」PROTEINS」等,來源於TU Dortmund University
Planetoid:參照網路資料集,包括「Cora」, 「CiteSeer」 and 「PubMed」,資料都為無向圖,來源於論文Revisiting Semi-Supervised Learning with Graph Embeddings。節點代表檔案,邊代表參照關係。
CoraFull:完整的」Cora」參照網路資料集,資料為無向圖,來源於論文Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking。節點代表檔案,邊代表參照關係。
Coauthor:共同作者網路資料集,包括」CS」和」Physics」,資料都為無向圖,來源於論文Pitfalls of Graph Neural Network Evaluation。節點代表作者,若是共同作者則被邊相連。學習任務是將作者對映到各自的研究領域中。
Amazon:亞馬遜網路資料集,包括」Computers」和」Photo」,資料都為無向圖,來源於論文Pitfalls of Graph Neural Network Evaluation。節點代表貨物i,邊代表兩種貨物經常被同時購買。學習任務是將貨物對映到各自的種類裡。
PPI:蛋白質-蛋白質反應網路,資料為無向圖,來源於論文Predicting multicellular function through multi-layer tissue networks
Entities:關係實體網路,包括「AIFB」, 「MUTAG」, 「BGS」 和「AM」,資料都為無向圖,來源於論文Modeling Relational Data with Graph Convolutional Networks
BitcoinOTC:資料為有向圖,包括138個」who-trusts-whom」網路,來源於論文EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs,資料連結為Bitcoin OTC trust weighted signed network

4.基於PGL的GNN演演算法實踐

4.1 GCN

圖折積網路 (GCN)是一種功能強大的神經網路,專為圖上的機器學習而設計。基於PGL復現了GCN演演算法,在引文網路基準測試中達到了與論文同等水平的指標。

搭建GCN的簡單例子:要構建一個 gcn 層,可以使用我們預定義的pgl.nn.GCNConv或者只編寫一個帶有訊息傳遞介面的 gcn 層。

!CUDA_VISIBLE_DEVICES=0 python train.py --dataset cora

模擬結果:

Dataset Accuracy
Cora 81.16%
Pubmed 79.34%
Citeseer 70.91%

4.2 GAT

圖注意力網路 (GAT)是一種對圖結構資料進行操作的新型架構,它利用掩蔽的自注意層來解決基於圖折積或其近似的先前方法的缺點。基於PGL,我們復現了GAT演演算法,在引文網路基準測試中達到了與論文同等水平的指標。

搭建單頭GAT的簡單例子:
要構建一個 gat 層,可以使用我們的預定義pgl.nn.GATConv或只編寫一個帶有訊息傳遞介面的 gat 層。

GAT模擬結果:

Dataset Accuracy
Cora 82.97%
Pubmed 77.99%
Citeseer 70.91%

專案連結:一鍵fork直接跑程式 https://aistudio.baidu.com/aistudio/projectdetail/5054122?contributionType=1

5.總結

本次專案講解了圖神經網路的原理並對GCN、GAT實現方式進行講解,最後基於PGL實現了兩個演演算法在資料集Cora、Pubmed、Citeseer的表現,在引文網路基準測試中達到了與論文同等水平的指標。

目前的資料集樣本節點和邊都不是很大,下個專案將會講解面對億級別圖應該如何去做。

參考連結:感興趣可以看看詳細的推到以及涉及有趣的問題