前向操作符過載自動微分實現

2022-05-26 21:00:45

前向操作符過載自動微分實現

在這篇文章裡,ZOMI會介紹是怎麼實現自動微分的,因為程式碼量非常小,也許你也可以寫一個玩玩。前面的文章當中,已經把自動微分的原理深入淺出的講了一下,也參照了非常多的論文。有興趣的可以順著綜述A survey這篇深扒一下。

前向自動微分原理

瞭解自動微分的不同實現方式非常有用。在這裡呢,我們將介紹主要的前向自動微分,通過Python這個高階語言來實現操作符過載。在正反向模式中的這篇的文章中,我們介紹了前向自動微分的基本數學原理。

前向模式(Forward Automatic Differentiation,也叫做 tangent mode AD)或者前向累積梯度(前向模式)

前向自動微分中,從計算圖的起點開始,沿著計算圖邊的方向依次向前計算,最終到達計算圖的終點。它根據自變數的值計算出計算圖中每個節點的值 以及其導數值,並保留中間結果。一直得到整個函數的值和其導數值。整個過程對應於一元複合函數求導時從最內層逐步向外層求導。

 

簡單確實簡單,可以總結前向自動微分關鍵步驟為:

  • 分解程式為一系列已知微分規則的基礎表示式的組合
  • 根據已知微分規則給出各基礎表示式的微分結果
  • 根據基礎表示式間的資料依賴關係,使用鏈式法則將微分結果組合完成程式的微分結果

而通過Python高階語言,進行操作符過載後的關鍵步驟其實也相類似:

  • 分解程式為一系列已知微分規則的基礎表示式組合,並使用高階語言的過載操作
  • 在過載運算操作的過程中,根據已知微分規則給出各基礎表示式的微分結果
  • 根據基礎表示式間的資料依賴關係,使用鏈式法則將微分結果組合完成程式的微分結果

具體實現

首先呢,我們需要載入通用的numpy庫,用於實際運算的,如果不用numpy,在python中也可以使用math來代替。

import numpy as np

前向自動微分又叫做tangent mode AD,所以我們準備一個叫做ADTangent的類,這類初始化的時候有兩個引數,一個是 x,表示輸入具體的數值;另外一個是 dx,表示經過對自變數 x 求導後的值。

需要注意的是,操作符過載自動微分不像原始碼轉換可以給出求導的公式,一般而言並不會給出求導公式,而是直接給出最後的求導值,所以就會有 dx 的出現。

class ADTangent:

    # 自變數 x,對自變數進行求導得到的 dx
    def __init__(self, x, dx):
        self.x = x
        self.dx = dx

    # 過載 str 是為了方便列印的時候,看到輸入的值和求導後的值
    def __str__(self):
        context = f'value:{self.x:.4f}, grad:{self.dx}'
        return context

下面是核心程式碼,也就是操作符過載的內容,在 ADTangent 類中通過 Python 私有函數過載加號,首先檢查輸入的變數 other 是否屬於 ADTangent,如果是那麼則把兩者的自變數 x 進行相加。

其中值得注意的就是 dx 的計算,因為是正向自動微分,因此每一個前向的計算都會有對應的反向求導計算。求導的過程是這個程式的核心,不過不用擔心的是這都是最基礎的求導法則。最後返回自身的物件 ADTangent(x, dx)。

    def __add__(self, other):
        if isinstance(other, ADTangent):
            x = self.x + other.x
            dx = self.dx + other.dx
        elif isinstance(other, float):
            x = self.x + other
            dx = self.dx
        else:
            return NotImplementedError
        return ADTangent(x, dx)

下面則是對減號、乘法、log、sin幾個操作進行操作符過載,正向的過載的過程比較簡單,基本都是按照上面的 add 的程式碼討論來實現。

    def __sub__(self, other):
        if isinstance(other, ADTangent):
            x = self.x - other.x
            dx = self.dx - other.dx
        elif isinstance(other, float):
            x = self.x - other
            ex = self.dx
        else:
            return NotImplementedError
        return ADTangent(x, dx)

    def __mul__(self, other):
        if isinstance(other, ADTangent):
            x = self.x * other.x
            dx = self.x * other.dx + self.dx * other.x
        elif isinstance(other, float):
            x = self.x * other
            dx = self.dx * other
        else:
            return NotImplementedError
        return ADTangent(x, dx)

    def log(self):
        x = np.log(self.x)
        dx = 1 / self.x * self.dx
        return ADTangent(x, dx)

    def sin(self):
        x = np.sin(self.x)
        dx = self.dx * np.cos(self.x)
        return ADTangent(x, dx)

以公式5為例:

因為是基於 ADTangent 類進行了操作符過載,因此在初始化自變數 x 和 y 的值需要使用 ADTangent 來初始化,然後通過程式碼 f = ADTangent.log(x) + x * y - ADTangent.sin(y) 來實現。

由於這裡是求 f 關於自變數 x 的導數,因此初始化資料的時候呢,自變數 x 的 dx 設定為1,而自變數 y 的 dx 設定為0。

x = ADTangent(x=2., dx=1)
y = ADTangent(x=5., dx=0)
f = ADTangent.log(x) + x * y - ADTangent.sin(y)
print(f)

value:11.6521, grad:5.5

從輸出結果來看,正向計算的輸出結果是跟上面圖相同,而反向的導數求導結果也與上圖相同。下面一個是 Pytroch 的實現結果對比,最後是MindSpore的實現結果對比。

可以看到呢,上面的簡單實現的自動微分結果和 Pytroch 、MindSpore是相同的。還是很有意思的。

Pytroch 對公式1的自動微分結果:

import torch
from torch.autograd import Variable

x = Variable(torch.Tensor([2.]), requires_grad=True)
y = Variable(torch.Tensor([5.]), requires_grad=True)
f = torch.log(x) + x * y - torch.sin(y)
f.backward()

print(f)
print(x.grad)
print(y.grad)

輸出結果:

tensor([11.6521], grad_fn=<SubBackward0>)
tensor([5.5000])
tensor([1.7163])

MindSpore 對公式1的自動微分結果:

import numpy as np
import mindspore.nn as nn
from mindspore import Parameter, Tensor

class Fun(nn.Cell):
    def __init__(self):
        super(Fun, self).__init__()

    def construct(self, x, y):
        f = ops.log(x) + x * y - ops.sin(y)
        return f

x = Tensor(np.array([2.], np.float32))
y = Tensor(np.array([5.], np.float32))
f = Fun()(x, y)

grad_all = ops.GradOperation()
grad = grad_all(Fun())(x, y)

print(f)
print(grad[0])

輸出結果:

[11.65207]
5.5