MindSpore Graph Learning

2022-11-14 12:10:14

技術背景

MindSpore Graph Learning是一個基於MindSpore的高效易用的圖學習框架。得益於MindSpore的圖算融合能力,MindSpore Graph Learning能夠針對圖模型特有的執行模式進行編譯優化,幫助開發者縮短訓練時間。 MindSpore Graph Learning 還創新提出了以點為中心程式設計正規化,提供更原生的圖神經網路表達方式,並內建覆蓋了大部分應用場景的模型,使開發者能夠輕鬆搭建圖神經網路。

這是一個關於mindspore-gl的官方介紹,其定位非常接近於dgl,而且從文章(參考連結3)中的資料來看,mindspore-gl的運算效率還要高於dgl。

在傳統的機器學習中,我們可以對各種Tensor進行高效的運算、折積等。但是如果是一個圖結構的網路,除了把圖結構轉換成Tensor資料,再對Tensor進行處理之外,有沒有可能用一種更加便捷的運算方式,能夠直接在圖的基礎上去計算呢?在這裡mindSpore-gl也給出了自己的答案。我們可以一起來看一下mindspore-gl是如何安裝和使用的。

mindspore-gl的安裝

雖然官方有提供pip的安裝方法,但是在庫中能夠提供的軟體版本是非常有限的,這裡我們推薦使用原始碼編譯安裝,這樣也可以跟自己原生的MindSpore的版本更好的對應上。首先把倉庫clone下來,並進入到graphlearning目錄下:

$ git clone https://gitee.com/mindspore/graphlearning.git
正克隆到 'graphlearning'...
remote: Enumerating objects: 1275, done.
remote: Counting objects: 100% (221/221), done.
remote: Compressing objects: 100% (152/152), done.
remote: Total 1275 (delta 116), reused 127 (delta 68), pack-reused 1054
接收物件中: 100% (1275/1275), 1.41 MiB | 316.00 KiB/s, 完成.
處理 delta 中: 100% (715/715), 完成.
$ cd graphlearning/
$ ll
總用量 112
drwxrwxr-x 12 dechin dechin  4096 11月  9 17:19 ./
drwxrwxr-x 10 dechin dechin  4096 11月  9 17:19 ../
-rwxrwxr-x  1 dechin dechin  1429 11月  9 17:19 build.sh*
drwxrwxr-x  2 dechin dechin  4096 11月  9 17:19 examples/
-rwxrwxr-x  1 dechin dechin  3148 11月  9 17:19 FAQ_CN.md*
-rwxrwxr-x  1 dechin dechin  4148 11月  9 17:19 faq.md*
drwxrwxr-x  8 dechin dechin  4096 11月  9 17:19 .git/
-rwxrwxr-x  1 dechin dechin  1844 11月  9 17:19 .gitignore*
drwxrwxr-x  2 dechin dechin  4096 11月  9 17:19 images/
drwxrwxr-x  3 dechin dechin  4096 11月  9 17:19 .jenkins/
-rw-rw-r--  1 dechin dechin 11357 11月  9 17:19 LICENSE
drwxrwxr-x 11 dechin dechin  4096 11月  9 17:19 mindspore_gl/
drwxrwxr-x 11 dechin dechin  4096 11月  9 17:19 model_zoo/
-rwxrwxr-x  1 dechin dechin    52 11月  9 17:19 OWNERS*
-rwxrwxr-x  1 dechin dechin  3648 11月  9 17:19 README_CN.md*
-rwxrwxr-x  1 dechin dechin  4570 11月  9 17:19 README.md*
drwxrwxr-x  4 dechin dechin  4096 11月  9 17:19 recommendation/
-rwxrwxr-x  1 dechin dechin   922 11月  9 17:19 RELEASE.md*
-rwxrwxr-x  1 dechin dechin   108 11月  9 17:19 requirements.txt*
drwxrwxr-x  2 dechin dechin  4096 11月  9 17:19 scripts/
-rwxrwxr-x  1 dechin dechin  4164 11月  9 17:19 setup.py*
drwxrwxr-x  5 dechin dechin  4096 11月  9 17:19 tests/
drwxrwxr-x  5 dechin dechin  4096 11月  9 17:19 tools/

然後執行官方提供的編譯構建的指令碼:

$ bash build.sh 
mkdir: 已建立目錄 '/home/dechin/projects/mindspore/graphlearning/output'
Collecting Cython>=0.29.24
  Downloading Cython-0.29.32-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (2.0 MB)
     |████████████████████████████████| 2.0 MB 823 kB/s 
Collecting ast-decompiler>=0.6.0
  Downloading ast_decompiler-0.7.0-py3-none-any.whl (13 kB)
Collecting astpretty>=2.1.0
  Downloading astpretty-3.0.0-py2.py3-none-any.whl (4.9 kB)
Collecting scikit-learn>=0.24.2
  Downloading scikit_learn-1.1.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30.8 MB)
     |████████████████████████████████| 30.8 MB 2.6 MB/s 
Requirement already satisfied: numpy>=1.21.2 in /home/dechin/anaconda3/envs/mindspore16/lib/python3.9/site-packages (from -r /home/dechin/projects/mindspore/graphlearning/requirements.txt (line 5)) (1.23.2)
Collecting networkx>=2.6.3
  Downloading networkx-2.8.8-py3-none-any.whl (2.0 MB)
     |████████████████████████████████| 2.0 MB 4.6 MB/s 
Requirement already satisfied: scipy>=1.3.2 in /home/dechin/anaconda3/envs/mindspore16/lib/python3.9/site-packages (from scikit-learn>=0.24.2->-r /home/dechin/projects/mindspore/graphlearning/requirements.txt (line 4)) (1.5.3)
Collecting threadpoolctl>=2.0.0
  Downloading threadpoolctl-3.1.0-py3-none-any.whl (14 kB)
Collecting joblib>=1.0.0
  Downloading joblib-1.2.0-py3-none-any.whl (297 kB)
     |████████████████████████████████| 297 kB 2.2 MB/s 
Installing collected packages: threadpoolctl, joblib, scikit-learn, networkx, Cython, astpretty, ast-decompiler
Successfully installed Cython-0.29.32 ast-decompiler-0.7.0 astpretty-3.0.0 joblib-1.2.0 networkx-2.8.8 scikit-learn-1.1.3 threadpoolctl-3.1.0
running bdist_wheel
running build
running build_py
creating build
creating build/lib.linux-x86_64-3.9
...
removing build/bdist.linux-x86_64/wheel
mindspore_gl_gpu-0.1-cp39-cp39-linux_x86_64.whl
------Successfully created mindspore_gl package------

如果看到以上的訊息,那就表示編譯構建成功了,接下來只要把生成的whl包使用pip進行安裝即可:

$ python3 -m pip install ./output/mindspore_gl_gpu-0.1-cp39-cp39-linux_x86_64.whl
Processing ./output/mindspore_gl_gpu-0.1-cp39-cp39-linux_x86_64.whl
Requirement already satisfied: Cython in /home/dechin/.local/lib/python3.9/site-packages (from mindspore-gl-gpu==0.1) (0.29.32)
Requirement already satisfied: astpretty in /home/dechin/.local/lib/python3.9/site-packages (from mindspore-gl-gpu==0.1) (3.0.0)
Requirement already satisfied: ast-decompiler>=0.3.2 in /home/dechin/.local/lib/python3.9/site-packages (from mindspore-gl-gpu==0.1) (0.7.0)
Requirement already satisfied: scikit-learn>=0.24.2 in /home/dechin/.local/lib/python3.9/site-packages (from mindspore-gl-gpu==0.1) (1.1.3)
Requirement already satisfied: threadpoolctl>=2.0.0 in /home/dechin/.local/lib/python3.9/site-packages (from scikit-learn>=0.24.2->mindspore-gl-gpu==0.1) (3.1.0)
Requirement already satisfied: joblib>=1.0.0 in /home/dechin/.local/lib/python3.9/site-packages (from scikit-learn>=0.24.2->mindspore-gl-gpu==0.1) (1.2.0)
Requirement already satisfied: scipy>=1.3.2 in /home/dechin/anaconda3/envs/mindspore16/lib/python3.9/site-packages (from scikit-learn>=0.24.2->mindspore-gl-gpu==0.1) (1.5.3)
Requirement already satisfied: numpy>=1.17.3 in /home/dechin/anaconda3/envs/mindspore16/lib/python3.9/site-packages (from scikit-learn>=0.24.2->mindspore-gl-gpu==0.1) (1.23.2)
Installing collected packages: mindspore-gl-gpu
Successfully installed mindspore-gl-gpu-0.1

我們可以用如下指令驗證一下mindspore-gl是否安裝成功(後面的告警資訊是MindSpore產生的,不是mindspore-gl產生的,一般情況下,我們可以忽視掉):

$ python3 -c 'import mindspore_gl'
[WARNING] ME(3662914:140594637309120,MainProcess):2022-11-09-17:22:29.348.03 [mindspore/run_check/_check_version.py:189] Cuda ['10.1', '11.1'] version(need by mindspore-gpu) is not found, please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the installation guidelines: https://www.mindspore.cn/install
[WARNING] ME(3662914:140594637309120,MainProcess):2022-11-09-17:22:29.350.73 [mindspore/run_check/_check_version.py:189] Cuda ['10.1', '11.1'] version(need by mindspore-gpu) is not found, please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the installation guidelines: https://www.mindspore.cn/install
[WARNING] ME(3662914:140594637309120,MainProcess):2022-11-09-17:22:29.351.54 [mindspore/run_check/_check_version.py:189] Cuda ['10.1', '11.1'] version(need by mindspore-gpu) is not found, please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the installation guidelines: https://www.mindspore.cn/install
[WARNING] ME(3662914:140594637309120,MainProcess):2022-11-09-17:22:29.352.40 [mindspore/run_check/_check_version.py:189] Cuda ['10.1', '11.1'] version(need by mindspore-gpu) is not found, please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the installation guidelines: https://www.mindspore.cn/install
[WARNING] ME(3662914:140594637309120,MainProcess):2022-11-09-17:22:29.352.94 [mindspore/run_check/_check_version.py:189] Cuda ['10.1', '11.1'] version(need by mindspore-gpu) is not found, please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the installation guidelines: https://www.mindspore.cn/install
[WARNING] ME(3662914:140594637309120,MainProcess):2022-11-09-17:22:29.353.43 [mindspore/run_check/_check_version.py:189] Cuda ['10.1', '11.1'] version(need by mindspore-gpu) is not found, please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the installation guidelines: https://www.mindspore.cn/install
[WARNING] ME(3662914:140594637309120,MainProcess):2022-11-09-17:22:29.353.91 [mindspore/run_check/_check_version.py:189] Cuda ['10.1', '11.1'] version(need by mindspore-gpu) is not found, please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the installation guidelines: https://www.mindspore.cn/install

mindspore-gl的簡單案例

我們先考慮這樣一個比較基礎的案例,就是最簡單的一個全連線圖,一個三角形。其頂點編號分別為0、1、2,節點值分別為1、2、3,但是這裡需要注意的一點是:mindspore-gl所構建的圖是有向圖,如果我們需要構建一個無向圖,那麼就需要手動copy+concat一份反方向的引數。mindspore-gl的一種典型的使用方法,是使用稀疏形式的近鄰表COO去定義一個圖結構GraphField,再把圖作為GNNCell的一個入參傳進去。

在計算的過程中,mindspore-gl會先執行一步編譯。mindspore-gl支援使用者使用一個非常簡單的for迴圈去對圖的所有節點或者鄰近節點進行遍歷,然後在後臺對該操作進行優化和編譯。為了展示編譯成效和語法的簡潔,mindspore-gl會在編譯過程中把沒有mindspore-gl支援下的語法都展示出來,從對比中可以看出,mindspore-gl極大程度上提高了程式設計的便利性。

In [1]: import mindspore as ms

In [2]: from mindspore_gl import Graph, GraphField

In [3]: from mindspore_gl.nn import GNNCell

In [4]: n_nodes = 3

In [5]: n_edges = 3

In [6]: src_idx = ms.Tensor([0, 1, 2], ms.int32)

In [7]: dst_idx = ms.Tensor([1, 2, 0], ms.int32)

In [8]: graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges)

In [9]: node_feat = ms.Tensor([[1], [2], [3]], ms.float32)

In [10]: class TestSetVertexAttr(GNNCell):
    ...:     def construct(self, x, y, g: Graph):
    ...:         g.set_src_attr({"hs": x})
    ...:         g.set_dst_attr({"hd": y})
    ...:         return [v.hd for v in g.dst_vertex] * [u.hs for u in g.src_vertex]
    ...: 

In [11]: ret = TestSetVertexAttr()(node_feat[src_idx], node_feat[dst_idx], *graph_field.get_graph()).asnumpy().tolist()
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|    def construct(self, x, y, g: Graph):                                                      1   ||  1      def construct(                                                                           |
|                                                                                                  ||             self,                                                                                |
|                                                                                                  ||             x,                                                                                   |
|                                                                                                  ||             y,                                                                                   |
|                                                                                                  ||             src_idx,                                                                             |
|                                                                                                  ||             dst_idx,                                                                             |
|                                                                                                  ||             n_nodes,                                                                             |
|                                                                                                  ||             n_edges,                                                                             |
|                                                                                                  ||             UNUSED_0=None,                                                                       |
|                                                                                                  ||             UNUSED_1=None,                                                                       |
|                                                                                                  ||             UNUSED_2=None                                                                        |
|                                                                                                  ||         ):                                                                                       |
|                                                                                                  ||  2          SCATTER_ADD = ms.ops.TensorScatterAdd()                                              |
|                                                                                                  ||  3          SCATTER_MAX = ms.ops.TensorScatterMax()                                              |
|                                                                                                  ||  4          SCATTER_MIN = ms.ops.TensorScatterMin()                                              |
|                                                                                                  ||  5          GATHER = ms.ops.Gather()                                                             |
|                                                                                                  ||  6          ZEROS = ms.ops.Zeros()                                                               |
|                                                                                                  ||  7          FILL = ms.ops.Fill()                                                                 |
|                                                                                                  ||  8          MASKED_FILL = ms.ops.MaskedFill()                                                    |
|                                                                                                  ||  9          IS_INF = ms.ops.IsInf()                                                              |
|                                                                                                  ||  10         SHAPE = ms.ops.Shape()                                                               |
|                                                                                                  ||  11         RESHAPE = ms.ops.Reshape()                                                           |
|                                                                                                  ||  12         scatter_src_idx = RESHAPE(src_idx, (SHAPE(src_idx)[0], 1))                           |
|                                                                                                  ||  13         scatter_dst_idx = RESHAPE(dst_idx, (SHAPE(dst_idx)[0], 1))                           |
|        g.set_src_attr({'hs': x})                                                             2   ||  14         hs, = [x]                                                                            |
|        g.set_dst_attr({'hd': y})                                                             3   ||  15         hd, = [y]                                                                            |
|        return [v.hd for v in g.dst_vertex] * [u.hs for u in g.src_vertex]                    4   ||  16         return hd * hs                                                                       |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [12]: print (ret)
[[2.0], [6.0], [3.0]]

從這個結果中,我們獲得的是三條邊兩頭的節點值的積。除了節點id和節點值之外,mindspore-gl還支援了一些如近鄰節點、節點的度等引數的獲取,可以參考如下圖片所展示的內容(圖片來自於參考連結2):

除了基本的API介面之外,還可以學習下mindspore-gl的使用中有可能出現的報錯資訊:

在mindspore-gl這一個框架中,還有一個對於大型資料來說非常有用的功能,當然,在文章這裡只是放一下大概用法,因為暫時沒有遇到這種使用的場景。那就是把一個大型的圖網路根據近鄰的數量去拆分成不同大小的資料塊進行儲存和運算。這樣做一方面可以避免動態的shape出現,因為網路可能隨時都在改變。另一方面本身圖的近鄰數大部分就不是均勻分佈的,有少部分特別的密集,而更多的情況是一些比較稀疏的圖,那麼這個時候如果要固定shape的話,就只能padding到較大數量的那一個維度,這樣一來就無形之中浪費了巨大的儲存空間。這種分塊模式的儲存,能夠最大限度上減小視訊記憶體的佔用,同時還能夠提高運算的速度。


那麼最後我們再展示一個聚合的簡單案例,其實就是獲取節點的近鄰節點值的加和:

import mindspore as ms
from mindspore import ops
from mindspore_gl import Graph, GraphField
from mindspore_gl.nn import GNNCell

n_nodes = 3
n_edges = 3

src_idx = ms.Tensor([0, 1, 2, 3, 4], ms.int32)
dst_idx = ms.Tensor([1, 2, 0, 1, 2], ms.int32)

graph_field = GraphField(src_idx, dst_idx, n_nodes, n_edges)
node_feat = ms.Tensor([[1], [2], [3], [4], [5]], ms.float32)

class GraphConvCell(GNNCell):
    def construct(self, x, y, g: Graph):
        g.set_src_attr({"hs": x})
        g.set_dst_attr({"hd": y})
        return [g.sum([u.hs for u in v.innbs]) for v in g.dst_vertex]

ret = GraphConvCell()(node_feat[src_idx], node_feat[dst_idx], *graph_field.get_graph()).asnumpy().tolist()
print (ret)

那麼這裡只要使用一個graph.sum這樣的介面就可以實現,非常的易寫方便,程式碼可讀性很高。

$ python3 test_msgl_01.py
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|    def construct(self, x, y, g: Graph):                                                      1   ||  1      def construct(                                                                           |
|                                                                                                  ||             self,                                                                                |
|                                                                                                  ||             x,                                                                                   |
|                                                                                                  ||             y,                                                                                   |
|                                                                                                  ||             src_idx,                                                                             |
|                                                                                                  ||             dst_idx,                                                                             |
|                                                                                                  ||             n_nodes,                                                                             |
|                                                                                                  ||             n_edges,                                                                             |
|                                                                                                  ||             UNUSED_0=None,                                                                       |
|                                                                                                  ||             UNUSED_1=None,                                                                       |
|                                                                                                  ||             UNUSED_2=None                                                                        |
|                                                                                                  ||         ):                                                                                       |
|                                                                                                  ||  2          SCATTER_ADD = ms.ops.TensorScatterAdd()                                              |
|                                                                                                  ||  3          SCATTER_MAX = ms.ops.TensorScatterMax()                                              |
|                                                                                                  ||  4          SCATTER_MIN = ms.ops.TensorScatterMin()                                              |
|                                                                                                  ||  5          GATHER = ms.ops.Gather()                                                             |
|                                                                                                  ||  6          ZEROS = ms.ops.Zeros()                                                               |
|                                                                                                  ||  7          FILL = ms.ops.Fill()                                                                 |
|                                                                                                  ||  8          MASKED_FILL = ms.ops.MaskedFill()                                                    |
|                                                                                                  ||  9          IS_INF = ms.ops.IsInf()                                                              |
|                                                                                                  ||  10         SHAPE = ms.ops.Shape()                                                               |
|                                                                                                  ||  11         RESHAPE = ms.ops.Reshape()                                                           |
|                                                                                                  ||  12         scatter_src_idx = RESHAPE(src_idx, (SHAPE(src_idx)[0], 1))                           |
|                                                                                                  ||  13         scatter_dst_idx = RESHAPE(dst_idx, (SHAPE(dst_idx)[0], 1))                           |
|        g.set_src_attr({'hs': x})                                                             2   ||  14         hs, = [x]                                                                            |
|        g.set_dst_attr({'hd': y})                                                             3   ||  15         hd, = [y]                                                                            |
|        return [g.sum([u.hs for u in v.innbs]) for v in g.dst_vertex]                         4   ||  16         SCATTER_INPUT_SNAPSHOT1 = GATHER(hs, src_idx, 0)                                     |
|                                                                                                  ||  17         return SCATTER_ADD(                                                                  |
|                                                                                                  ||                 ZEROS(                                                                           |
|                                                                                                  ||                     (n_nodes,) + SHAPE(SCATTER_INPUT_SNAPSHOT1)[1:],                             |
|                                                                                                  ||                     SCATTER_INPUT_SNAPSHOT1.dtype                                                |
|                                                                                                  ||                 ),                                                                               |
|                                                                                                  ||                 scatter_dst_idx,                                                                 |
|                                                                                                  ||                 SCATTER_INPUT_SNAPSHOT1                                                          |
|                                                                                                  ||             )                                                                                    |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
[[3.0], [5.0], [7.0]]

下圖是上面這個案例所對應的拓撲圖:

總結概要

對於從元素運算到矩陣運算再到張量運算,最後抽象到圖運算,這個預算模式的發展歷程,在每個階段都需要有配套的工具來進行支援。比如矩陣時代的numpy,張量時代的mindspore,還有圖時代的mindspore-gl。我們未必說哪種運算模式就一定更加先進,但是對於coder來說,「公式即程式碼」這是一個永恆的話題,而mindspore-gl在這一個工作上確實做的很好。不僅僅是圖模式的程式設計可讀性更高,在GPU運算的效能上也有非常大的優化。

版權宣告

本文首發連結為:https://www.cnblogs.com/dechinphy/p/mindspore_gl.html

作者ID:DechinPhy

更多原著文章請參考:https://www.cnblogs.com/dechinphy/

打賞專用連結:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

騰訊雲專欄同步:https://cloud.tencent.com/developer/column/91958

CSDN同步連結:https://blog.csdn.net/baidu_37157624?spm=1008.2028.3001.5343

51CTO同步連結:https://blog.51cto.com/u_15561675

參考連結

  1. https://gitee.com/mindspore/graphlearning
  2. https://www.bilibili.com/video/BV14a411976w/
  3. Seastar: Vertex-Centric Progamming for Graph Neural Networks. Yidi Wu and other co-authors.