實踐torch.fx第二篇-fx量化實操

2022-09-22 06:04:21

好久不見各位,哈哈,又鴿了好久。

本文緊接上一篇《實踐torch.fx第一篇——基於Pytorch的模型優化量化神器》繼續說,主要講如何利用FX進行模型量化

為什麼這篇文章拖了這麼久,有部分原因是因為Pytorch的FX變動有點頻繁,我在使用過程中也嘗試補充些程式碼和官方對齊,而且官方的更新比較頻繁,很多瑣碎的API偶爾會變化。因為怕文章的實時性不夠,所以拖了一段時間,所幸比較好的觀察了一段時間,發現FX主要API不怎麼變,整體流程不會變化,還好還好。

目前基於6月24日的FX版本進行講解,藉助FX跑一遍量化的過程,FX推出一大亮點就是支援量化,比起之前Pytorch的Eager Mode Quantization 好用了不少,雖然還有很多需要補充的功能,但是已經可以實現一些常見模型的量化任務了。

下一篇文章打算講的fx2trt,可以將FX量化的模型部署到TensorRT。這個工具也在最近從Pytorch主倉庫移動到了這裡,合併到了Pytorch/TensorRT當中,後續我也會按照新的倉庫來講解,不過總體上變化不大。

回顧一下

因為距上一篇有一段時間了,首先簡單回顧下FX的功能:

  • A practical analysis of the features of program capture and transformation that are important for deep learning programs.
  • A Python-only program capture library that implements these features and can be customized to capture different levels of program detail
  • A simple 6 instruction IR for representing captured programs that focuses on ease of understanding and ease of doing static analysis
  • A code generation system for returning transformed code back to the host language’s ecosystem
  • Case studies in how torch.fx has been used in practice to develop features for performance optimization, program analysis, device lowering, and more

上述就是FX的功能元件介紹,簡單來說就是可以trace你的nn.module,然後可以做一些變換,然後還可以生成新的經過變換後的nn.module。上一篇中已經介紹了一些fx的使用場景:

  • 自動化修改網路
  • profile網路
  • debug網路
  • 客製化hook等

而這篇文章就是利用FX的transform和analysis以及codegen功能去生成已經量化完的模型。

可以做量化的框架

除了FX,目前可以做量化的框架有不少,我們經常使用的訓練框架Pytorch和TensorFlow目前都可以原生量化。而很多推理框架也可以進行量化,比如ONNXruntime和TVM。國內也有很多好用的量化工具,其中個人覺著比較好用的是PPQ,支援多種後端,主要是人家教學出的也不少,方便我們快速上手使用,這點好評。

這裡也列一下其他可以做量化的框架(或者說有自己的量化工具):

本文主要介紹Pytorch的FX量化工具,作為Pytorch原生支援的量化工具,在某些方面肯定是有些優勢的。不過需要注意的是,FX目前的開發仍然在積極推進中,最起碼每天都有一些pull request吧,我每隔一段時間就會重新同步下官方的程式碼,都快跟不上了。

Pytorch量化方式

Pytorch目前支援兩種量化方法:Eager Mode以及FX,FX沒出來之前大家都是用Eager Mode進行量化,後續FX出世後,Pytorch官方建議優先使用FX:

New users of quantization are encouraged to try out FX Graph Mode Quantization first, if it does not work, user may try to follow the guideline of using FX Graph Mode Quantization or fall back to eager mode quantization.

列下兩者的區別:

Eager Mode的缺點很明顯:

  • 需要手動設定哪些節點需要量化哪些節點不需要量化,哪些節點需要融合(比如CONV+BN+RELU)哪些不需要
  • 某些比較特殊的op,例如add和concat需要特殊對待
  • 對於沒有通過Class包裝的op,比如functional.conv2d或者functional.linear,無能為力

其實最重要的就是缺乏自動化,啥都要自己寫:

import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # 自己指定開始量化的層
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # 指定結束量化的層
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()
# model must be set to eval mode for static quantization logic to work
model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# 指定融合的層
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)
model_int8 = torch.quantization.convert(model_fp32_prepared)
res = model_int8(input_fp32)

小模型好說,大模型的話,除非模型簡單都可以直接量化,否則需要在torch.nn.Module中新增很多torch.quantization.QuantStub()的標記精細化整個模型的量化策略,這個其實和之前在量化番外篇——TensorRT-8的量化細節介紹的QDQ挺像,這篇中的TensorRT處理的QDQ模型就是通過FX匯出來的,只不過QDQ是FX自動生成插入的,不像Eager Mode需要自個兒寫...可以省去很多工作量。

官方總結的FX量化的優點,可以把FX理解為一個編譯器:

  • Simple quantization flow, minimal manual steps
  • Unlocks the possibility of doing higher level optimizations like automatic precision selection

不管是eager還是fx,Pytorch都支援三種量化型別:

  • dynamic quantization(weights quantized with activations read/stored in floating point and quantized for compute.)
  • static quantization (weights quantized, activations quantized, calibration required post training)
  • static quantization aware training (weights quantized, activations quantized, quantization numerics modeled during training)

上述詳細介紹可以看官方檔案,這裡就不贅述了。其實static quantizationstatic quantization aware training基本上就是我們常說的PTQ(訓練後量化)和QAT(訓練中量化):

  • Post Training Quantization (apply quantization after training, quantization parameters are calculated based on sample calibration data)
  • Quantization Aware Training (simulate quantization during training so that the quantization parameters can be learned together with the model using training data)

FX支援這兩種常見量化型別。

TORCH-FX量化

本篇主要介紹FX中的PTQ方法,也就是我們一般常用的後訓練量化方法,PTQ方法的優點就是不需要資料進行訓練,量化框架只要把所有網路節點搭好,不需要反向傳播,正向推理收集量化資訊即可。QAT(訓練中量化)則麻煩點,後續文章中會介紹。

使用FX做PTQ量化的基本程式碼結構如下,整體比較簡單:

import torch
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx
float_model.eval()  # 因為是PTQ,所以就推理模式就夠了
qconfig = get_default_qconfig("fbgemm")  # 指定量化細節設定
qconfig_dict = {"": qconfig}             # 指定量化選項
def calibrate(model, data_loader):       # 校準功能函數
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
prepared_model = prepare_fx(float_model, qconfig_dict)  # 準備量化模型,比如融合CONV+BN+RELU,然後插入量化觀察節點
calibrate(prepared_model, data_loader_test)  # 校準資料集進行標準
quantized_model = convert_fx(prepared_model)  # 把校準後的模型轉化為量化版本模型

程式碼很簡單,設定好config之後,呼叫prepare_fx函數準備模型到量化狀態(插入了量化觀察節點),然後輸入資料集進行校準,之後將校準後的帶有scalezero-point的模型變換為真正的量化模型。

上述程式碼prepare_fx(float_model, qconfig_dict) 沒有指定is_reference引數,那麼convert後的pytorch模型就是實打實的量化模型,所有的運算元的精度都是INT8然後執行在CPU上,Pytorch支援以下的INT8後端:

  • x86 CPUs with AVX2 support or higher (without AVX2 some operations have inefficient implementations), via fbgemm
  • ARM CPUs (typically found in mobile/embedded devices), via qnnpack
  • (early prototype) support for NVidia GPU via TensorRT through fx2trt (to be open sourced)

如果加上is_reference引數,量化後的模型則會僅僅儲存量化資訊,但實際跑的還是FP32精度的op(通過quantize->dequantize->fp32-op->quantize->dequantize)模型,一般稱之為simulator quantize,也就是說模型可以通過quantize->dequantize這種fake quantize來模擬量化的過程和量化誤差,計算的時候使用的FP32運算元,但是計算的輸入的input和weight都是經過量化反量化操作得來的。

如上圖所示,加上is_reference引數後,convert後的模型就是帶有fake量化節點的模型,根據相應的convertor,可以將fake量化節點QDQ按照TensorRT中的IQuantizeLayerIDequantizeLayer搭建,即通過fx2trt轉化為TensorRT-engine,這個之後會說。關於TensorRT的量化細節也可以參考這篇文章量化番外篇——TensorRT-8的量化細節

因為下一章要轉TensorRT,所以這一步選擇與TensorRT相同的量化策略:

設定整體的量化規則:

  • 整體模型量化方式:activation為per-tensor,weight為per-channel
  • int8對稱量化 -128-127
  • 量化的模型是Centernet-resnet50,包含折積、反折積、add、concat,bn

設定好FX的量化config:

qconfig = ao.quantization.qconfig.QConfig(
    activation=ao.quantization.observer.HistogramObserver.with_args(
        qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
    ),
    weight=ao.quantization.observer.default_per_channel_weight_observer
)

然後單獨對模型中的某一型別運算元操作torch.nn.ConvTranspose2d進行設定,這個qconfig會優先匹配,優先順序比整體qconfig高,具體細節可以參考_propagate_qconfig_helper這個函數。

為啥要單獨設定torch.nn.ConvTranspose2d,因為torch.fx中預設對torch.nn.ConvTranspose2dper-tensor的量化,精度會受影響,我這裡修改為per-channel量化,同時指定量化維度ch_axis=1

完整的config如下:

prepared = prepare_fx(fx_model, {"": qconfig,
                                "object_type":[  # 這裡設定反折積的量化規則,注意看維度的per-channel量化ch_axis=1
                                (torch.nn.ConvTranspose2d,
                                    ao.quantization.qconfig.QConfig(
                                            activation=ao.quantization.observer.HistogramObserver.with_args(
                                                qscheme=torch.per_tensor_symmetric, dtype=torch.qint8, 
                                            ),
                                            weight=ao.quantization.observer.PerChannelMinMaxObserver.with_args(
                                                ch_axis=1, dtype=torch.qint8, qscheme=torch.per_channel_symmetric)) )
                                ]
                                },
                                example_inputs=(torch.randn(1, 3, 512, 512),),
                                backend_config_dict=get_tensorrt_backend_config_dict()
                                )

設定好之後就可以開始量化了。

整體量化流程

整體一共這幾個步驟:

  • fuse模型,也就是通常的優化,比如conv+bn啥的,利用fx對模型進行transform
  • 插入量化觀察運算元,即observer
  • 輸入資料進行校準,收集weights和activation的max和min資訊
  • 把經過資料推理得到的量化資料整理合併到每一層中

首先看一下最開始的模型,Centernet-res50典型的backbone+neck+head,其中neck是upsample,主要由反折積組成,head就是普通的head(最常見的結構,折積加點啟用層,然後最後conv輸出需要的特徵維度),就不畫圖了,看結構比較直觀:

CenterNet(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
     ...
  (upsampler): UpsampleLayer(
    (deconv_layers): Sequential(
      (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
      (4): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
    )
  )
  (head): Head(
    (hm): Sequential(
      (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (reg): Sequential(
      (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (wh): Sequential(
      (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
))

fuse 圖優化

這一步是一般圖優化,和量化無直接關係,不量化的模型也可以這樣搞,這樣搞之後對量化也有好處。

從上節的模型結構可以看到一些通用、可以應用的圖優化策略:

  • conv+bn+relu
  • convtranspose+bn
  • bn+relu

當然還有更激進的優化策略,不過因為FX可能並不代表最終量化模型的執行框架(因為有可能我們經過FX量化後的模型會遷移到其他可以框架中,比如TensorRT),所以其他一些其他平臺相關的優化策略就無法實施了。

FX目前的融合策略有,基本的CONV+BN+RELU、CONV+BN、CONV+RELU等等。也包含了常見的融合方法,比如吸bn等操作:

# pytorch/torch/ao/quantization/fuser_method_mappings.py
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
    (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
    (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,
    (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d),
    (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
    (nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d),
    (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
    (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
    (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
    (nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d),
    (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn,
    (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn,
    (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn,
}

FX的融合匹配策略程式碼如下,將graph的有向無環圖reverse後,從output開始,倒著開始匹配,匹配也很簡單粗暴,for迴圈遍歷就行:

  for node in reversed(graph.nodes):
      if node.name not in match_map:
          for pattern, value in patterns.items():
              matched_node_pattern: List[Node] = []
              if is_match(modules, node, pattern):
                  apply_match(pattern, node, (node, pattern, value(node)), matched_node_pattern, node_to_subpattern)
                  break

融合之後將新的node拷貝到新的graph,即fused_graph,構建新的融合後的graphmodule

# pytorch/torch/ao/quantization/fx/fuse.py
# 尋找匹配的 pairs
fusion_pairs = _find_matches(
    input_root, input_graph, fusion_pattern_to_fuse_handler_cls)
fused_graph = Graph()  # 這裡新建一個graph
env: Dict[Any, Any] = {}  # env記錄已經融合後複製到新graph的node

for node in input_graph.nodes:
    maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \
        fusion_pairs.get(node.name, (None, None, None, None, None))
    # get the corresponding subpattern for the current node
    if node_to_subpattern is not None:
        node_subpattern = node_to_subpattern.get(node, None)
    else:
        node_subpattern = None
    if maybe_last_node is node:
        assert obj is not None
        root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter)
        root_node = root_node_getter(matched_node_pattern)  # type: ignore[index]
        extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None)
        extra_inputs = []
        if extra_inputs_getter is not None:
            extra_inputs = extra_inputs_getter(matched_node_pattern)
        # TODO: add validation that root_node is a module and has the same type
        # as the root_module in the configuration
        env[node.name] = obj.fuse(
            load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern,  # type: ignore[arg-type]
            fuse_custom_config, fuser_method_mapping, is_qat)
    elif maybe_last_node is None or node_subpattern is MatchAllNode:
      # 這裡進行融合後的node構建
        env[node.name] = fused_graph.node_copy(node, load_arg)
    # node matched in patterns and is not root is removed here

看一下圖優化後的模型:

GraphModule(
  (backbone): Module(
    (conv1): ConvReLU2d(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): ReLU(inplace=True)
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Module(
      (0): Module(
        (conv1): ConvReLU2d(
          (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (downsample): Module(
          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (relu): ReLU(inplace=True)
      )
      ...
      (2): Module(
        (conv1): ConvReLU2d(
          (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
          (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))
        (relu): ReLU(inplace=True)
      )
    )
  )
  (upsampler): Module(
    (deconv_layers): Module(
      (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): ReLU(inplace=True)
    )
  )
  (head): Module(
    (hm): Module(
      (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (wh): Module(
      (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (reg): Module(
      (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

可以看到ConvReLU2dconv+relu或者conv+bn+relu的產物,而ConvTranspose2d後面的BN也被吸進ConvTranspose2d裡頭了。

插入量化觀察運算元

模型融合後,就可以進行模型量化了,首先我們需要在模型中插入量化觀察運算元,具體對應程式碼中的insert_observers_for_model操作。

不過在開始執行量化的時候,FX會檢測之前傳入的qconfig是否合法,也就是我們之前傳遞的反折積的qconfig是否正確(activation是per-tensor量化,weight是per-channel量化)。因為我們的模型有反折積操作,因此這裡修改了官方的程式碼,註釋掉了torch.ao.quantization.PerChannelMinMaxObserver,就可以使用了(看到pr有更好的解法 https://github.com/pytorch/pytorch/pull/79233):

def assert_valid_qconfig(qconfig: Optional[QConfig],
                         mod: torch.nn.Module) -> None:
    """
    Verifies that this `qconfig` is valid.
    """
    if qconfig is None:
        return
    is_conv_transpose_mod = (
        isinstance(mod, torch.nn.ConvTranspose1d) or
        isinstance(mod, torch.nn.ConvTranspose2d) or
        isinstance(mod, torch.nn.ConvTranspose3d))
    if is_conv_transpose_mod:
        if qconfig.weight is None:
            # for now, we assume that any qconfig for ConvTranspose without a weight is valid
            return
        example_observer = qconfig.weight()
        is_per_channel = (
            # isinstance(example_observer, torch.ao.quantization.PerChannelMinMaxObserver) or  把這句去掉
            isinstance(example_observer, torch.ao.quantization.MovingAveragePerChannelMinMaxObserver)
        )
        assert not is_per_channel, \
            'Per channel weight observer is not supported yet for ConvTranspose{n}d.'  # 實測可以支援

把這個解決後,我們重點看insert_observers_for_model這個函數,負責插入量化觀察節點。因為權重不需要推理資料觀察,所以只需要插入啟用值的observer節點即可。

此時模型的具體op實現還是原先FP32的實現,但是在合適的位置已經插入了觀察節點,我們可以執行推理來進行PTQ收集activations和weights的量化資訊。

# 插入觀察節點後的模型forward範例
def forward(self, input):
    input_1 = input
    activation_post_process_0 = self.activation_post_process_0(input_1);  input_1 = None
    backbone_conv1 = self.backbone.conv1(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(backbone_conv1);  backbone_conv1 = None
    ...
    head_angle_2 = getattr(self.head.angle, "2")(activation_post_process_83);  activation_post_process_83 = None
    activation_post_process_84 = self.activation_post_process_84(head_angle_2);  head_angle_2 = None
    return (activation_post_process_78, activation_post_process_80, activation_post_process_82, activation_post_process_84)

看下模型的部分結構如下,可以發現多出了HistogramObserver,都是activation_post_process_xx,用於觀察啟用值的分佈資訊。

(upsampler): Module(
    (deconv_layers): Module(
    (0): Conv2d(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): ReLU(inplace=True)
    )
)
(activation_post_process_71): HistogramObserver()
(activation_post_process_72): HistogramObserver()
(activation_post_process_73): HistogramObserver()
(activation_post_process_74): HistogramObserver()
(activation_post_process_75): HistogramObserver()
(activation_post_process_76): HistogramObserver()
(activation_post_process_77): HistogramObserver()
(head): Module(
    (hm): Module(
    (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
    )
    (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (wh): Module(
    (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
    )
    (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
    (reg): Module(
    (0): ConvReLU2d(
        (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
    )
    (2): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
)
(activation_post_process_78): HistogramObserver()
(activation_post_process_79): HistogramObserver()
(activation_post_process_80): HistogramObserver()
(activation_post_process_81): HistogramObserver()
(activation_post_process_82): HistogramObserver()
(activation_post_process_83): HistogramObserver()
)

收集過程中:

  • 啟用層使用的是HistogramObserver
  • 權重層使用的是PerChannelMinMaxObserver

接下來就可以喂入資料進行推理校準了,和我們平常的方式一樣,準備好影象資料然後可以組batch輸入進去,此時的input會輸入到我們模型的forward當中:

def forward(self, input):
    input_1 = input
    activation_post_process_0 = self.activation_post_process_0(input_1);  input_1 = None
    backbone_conv1 = self.backbone.conv1(activation_post_process_0);  activation_post_process_0 = None
    activation_post_process_1 = self.activation_post_process_1(backbone_conv1);  backbone_conv1 = None
    backbone_maxpool = self.backbone.maxpool(activation_post_process_1);  activation_post_process_1 = None
    ... 

第一行中,activation_post_process_0 = self.activation_post_process_0(input_1);,實際進入的是HistogramObserver這個觀察者物件,其中的forward函數主要是收集min和max資訊,最終返回的還是原始輸入:

# HistogramObserver::forward
def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
    if x_orig.numel() == 0:
        return x_orig
    x = x_orig.detach()
    min_val = self.min_val
    max_val = self.max_val
    same_values = min_val.item() == max_val.item()
    is_uninitialized = min_val == float("inf") and max_val == float("-inf")
    if is_uninitialized or same_values:
        min_val, max_val = torch.aminmax(x)
        self.min_val.resize_(min_val.shape)
        self.min_val.copy_(min_val)
        self.max_val.resize_(max_val.shape)
        self.max_val.copy_(max_val)
        assert (
            min_val.numel() == 1 and max_val.numel() == 1
        ), "histogram min/max values must be scalar."
        torch.histc(
            x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram
        )
    else:
        new_min, new_max = torch.aminmax(x)
        combined_min = torch.min(new_min, min_val)
        combined_max = torch.max(new_max, max_val)
        # combine the existing histogram and new histogram into 1 histogram
        # We do this by first upsampling the histogram to a dense grid
        # and then downsampling the histogram efficiently
        (
            combined_min,
            combined_max,
            downsample_rate,
            start_idx,
        ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
        assert (
            combined_min.numel() == 1 and combined_max.numel() == 1
        ), "histogram min/max values must be scalar."
        combined_histogram = torch.histc(
            x, self.bins, min=int(combined_min), max=int(combined_max)
        )
        if combined_min == min_val and combined_max == max_val:
            combined_histogram += self.histogram
        else:
            combined_histogram = self._combine_histograms(
                combined_histogram,
                self.histogram,
                self.upsample_rate,
                downsample_rate,
                start_idx,
                self.bins,
            )

        self.histogram.detach_().resize_(combined_histogram.shape)
        self.histogram.copy_(combined_histogram)
        self.min_val.detach_().resize_(combined_min.shape)
        self.min_val.copy_(combined_min)
        self.max_val.detach_().resize_(combined_max.shape)
        self.max_val.copy_(combined_max)
    return x_orig

推理過程中,僅僅涉及到啟用層資訊的收集,因為PTQ就是前向推理收集啟用層資訊,不涉及到權重的更新。但是QAT中模型權重會更新,不過這個後話了。

轉化量化模型 convert

收集好資訊後,我們需要將收集好的min-max轉化為實際可用的scaleoffset

轉換程式碼也很簡單,呼叫FX提供的convert_fx,需要加入is_reference=True引數,這裡表明我們轉換後的量化模型僅僅是包含量化引數,但實際上執行的還是FP32的精度,這種模型是為了之後轉換trt做準備。

quantized_fx = convert_fx(model, 
                is_reference=True,  # 選擇reference模式
                )   
"""
細節看這裡
We will convert an observed model (a module with observer calls) to a reference
quantized model, the rule is simple:
1. for each observer module call in the graph, we'll convert it to calls to
    quantize and dequantize functions based on the observer instance
2. for weighted operations like linear/conv, we need to convert them to reference
    quantized module, this requires us to know whether the dtype configured for the
    weight is supported in the backend, this is done in prepare step and the result
    is stored in observed_node_names, we can decide whether we need to swap the
    module based on this set
"""

那怎麼處理呢?我們有很多activation_post_process_xx層,這些層是可以轉化為quantize and dequantize層的,具體的函數呼叫看下面這段程式碼,其中呼叫了with graph.inserting_before(node)node.replace_all_uses_with等Graph Manipulation方法去對模型進行修改:

    def replace_observer_with_quantize_dequantize_node(
            model: torch.nn.Module,
            graph: Graph,
            node: Node,
            modules: Dict[str, torch.nn.Module],
            node_name_to_scope: Dict[str, Tuple[str, type]],
            qconfig_map: Dict[str, QConfigAny]) -> None:
        """ Replace activation_post_process module call node with quantize and
        dequantize node

        Before:
        ... -> observer_0(x) -> ...
        After:
        ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
        """
        assert modules is not None
        assert isinstance(node.target, str)
        module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map)
        observer_module = modules[node.target]
        maybe_quantize_node_info = get_quantize_node_info(observer_module)
        # Skip replacing observers to quant/dequant nodes if the qconfigs of all
        # consumers and producers of this observer are None
        skip_replacement = all([
            has_none_qconfig(n, qconfig_map) for n in
            list(node.args) + list(node.users.keys())])
        ...
        else:
            # otherwise, we can convert the observer moduel call to quantize/dequantize node
            node_type, quantize_op, qparams = maybe_quantize_node_info
            # replace observer node with quant - dequant node
            with graph.inserting_before(node):
                input_node = node.args[0]
                inputs = [input_node]
                for key, value in qparams.items():
                    # TODO: we can add the information of whether a value needs to
                    # be registered as an attribute in qparams dict itself
                    if key in ['_scale_', '_zero_point_']:
                        # For scale and zero_point values we register them as buffers in the root module.
                        # TODO: maybe need more complex attr name here
                        qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
                        inputs.append(qparam_node)
                    else:
                        # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
                        inputs.append(value)
                # 構建quantized_node和dequantized_node
                quantized_node = graph.create_node(node_type, quantize_op, tuple(inputs), {})
                dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
                node.replace_all_uses_with(dequantized_node)
                graph.erase_node(node)

那權重層怎麼處理?converter_fx在遍歷整個模型的時候會區分該層的node是什麼型別,如果是is_activation_post_process就會進入上面的replace_observer_with_quantize_dequantize_node函數,如果判斷為權重則會進入convert_weighted_module函數:

    for node in list(model.graph.nodes):
        if node.op == 'placeholder':
            ...
        elif node.op == "output":
            ...
        elif node.op == "call_module":
            if is_activation_post_process(modules[node.target]):
                observed_node = node.args[0]
                if observed_node in statically_quantized_custom_module_nodes:
                    replace_observer_with_dequantize_node(node, model.graph)
                else:
                    replace_observer_with_quantize_dequantize_node(
                        model, model.graph, node, modules, node_name_to_scope,
                        qconfig_map)
            elif is_observed_standalone_module(modules[node.target]):
                convert_standalone_module(
                    node, modules, model, is_reference, backend_config_dict)
            elif type(modules[node.target]) in set(
                    root_module_classes).union(qat_module_classes).union(fused_module_classes):
                # extra check for fused module classes to make sure they are fused module classes
                # of target modules
                if type(modules[node.target]) in fused_module_classes and \
                   type(modules[node.target][0]) not in root_module_classes:
                    continue
                convert_weighted_module(
                    node, modules, observed_node_names, qconfig_map, backend_config_dict)
            elif type(modules[node.target]) in custom_module_classes:
                convert_custom_module(
                    node, model.graph, modules, custom_module_class_mapping,
                    statically_quantized_custom_module_nodes)

convert_weighted_module函數中主要就是處理weight的量化資訊,首先根據設定好的config來進行處理,比如下面程式碼中的weight_post_process其實就是PerChannelMinMaxObserver物件,執行的時候會收集該層權重的min-max資訊,收集好之後通過get_qparam_dict計算出scale和offset並存入wq_or_wq_dict中:

# pytorch/torch/ao/quantization/fx/convert.py
    ...
    else:
        # weight_post_process is None means the original module is not a QAT module
        # we need to get weight_post_process from qconfig in this case
        if weight_post_process is None:
            weight_post_process = qconfig.weight()  # type: ignore[union-attr, operator]
        # run weight observer
        # TODO: This is currently a hack for QAT to get the right shapes for scale and zero point.
        # In the future, we should require the user to calibrate the model after calling prepare
        # Issue: https://github.com/pytorch/pytorch/issues/73941
        weight_post_process(float_module.weight)  # type: ignore[operator]
        wq_or_wq_dict = get_qparam_dict(weight_post_process)

    # We use the same reference module for all modes of quantization: static, dynamic, weight_only
    # root_module_to_quantized_reference_module: module mapping from root (floating point) module class
    # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
    root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config_dict)
    ref_qmodule_cls = root_module_to_quantized_reference_module.get(type(float_module), None)
    assert ref_qmodule_cls is not None, f"No reference quantized module class configured for {type(float_module)}"
    ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict)  # type: ignore[attr-defined]
    if fused_module is not None:
        fused_module[0] = ref_qmodule  # type: ignore[operator]
    else:
        parent_name, name = _parent_name(node.target)
        setattr(modules[parent_name], name, ref_qmodule)

得到權重層的wq_or_wq_dict資訊後,通過get_root_module_to_quantized_reference_module獲取FP32-OP對應的量化版本OP,對應關係如下:

<class 'torch.nn.modules.conv.Conv1d'>: <class 'torch.nn.quantized._reference.modules.conv.Conv1d'>
<class 'torch.nn.modules.conv.ConvTranspose1d'>: <class 'torch.nn.quantized._reference.modules.conv.ConvTranspose1d'>
<class 'torch.nn.modules.conv.Conv2d'>: <class 'torch.nn.quantized._reference.modules.conv.Conv2d'>
<class 'torch.nn.modules.conv.ConvTranspose2d'>: <class 'torch.nn.quantized._reference.modules.conv.ConvTranspose2d'>
...

如果是conv2d,則ref_qmodule_clstorch.nn.quantized._reference.modules.conv.Conv2d,通過from_float(float_module, wq_or_wq_dict)傳入FP32版本的conv2d-op,通過fp32版本的引數和之前收集好的wq_or_wq_dict構建量化版本的conv2d,直接替換模型中的FP32版本的op,此時模型中conv2d -> quantized-reference-conv2d,折積和反折積都變成了reference版本。

最終的reference量化模型

經過以上步驟,經過convert_fx後的模型,怎麼說,其實就是simulator quantization,也就是模擬量化,我們校準得到的scale和offset用於模擬模型的量化誤差,實際模型執行的時候是這樣:

def forward(self, input):
    input_1 = input
    # 首先得到量化引數scale和zero-point
    backbone_conv1_input_scale_0 = self.backbone_conv1_input_scale_0
    backbone_conv1_input_zero_point_0 = self.backbone_conv1_input_zero_point_0
    # 然後量化輸入
    quantize_per_tensor = torch.quantize_per_tensor(input_1, backbone_conv1_input_scale_0, backbone_conv1_input_zero_point_0, torch.qint8);  
    input_1 = backbone_conv1_input_scale_0 = backbone_conv1_input_zero_point_0 = None
    # 然後反量化輸入
    dequantize = quantize_per_tensor.dequantize();  quantize_per_tensor = None
    backbone_conv1 = self.backbone.conv1(dequantize);  dequantize = None
    ...
    dequantize_80 = quantize_per_tensor_83.dequantize();  quantize_per_tensor_83 = None
    head_angle_2 = getattr(self.head.angle, "2")(dequantize_80);  dequantize_80 = None
    head_angle_2_output_scale_0 = self.head_angle_2_output_scale_0
    head_angle_2_output_zero_point_0 = self.head_angle_2_output_zero_point_0
    quantize_per_tensor_84 = torch.quantize_per_tensor(head_angle_2, head_angle_2_output_scale_0, head_angle_2_output_zero_point_0, torch.qint8);  head_angle_2 = head_angle_2_output_scale_0 = head_angle_2_output_zero_point_0 = None
    dequantize_81 = quantize_per_tensor_78.dequantize();  quantize_per_tensor_78 = None
    dequantize_82 = quantize_per_tensor_80.dequantize();  quantize_per_tensor_80 = None
    dequantize_83 = quantize_per_tensor_82.dequantize();  quantize_per_tensor_82 = None
    dequantize_84 = quantize_per_tensor_84.dequantize();  quantize_per_tensor_84 = None
    return {'hm': dequantize_81, 'wh': dequantize_82, 'reg': dequantize_83, 'angle': dequantize_84}

看一下converter後reference模型結構,可以看到該融合的都融合了,所有conv帶有引數的計算層都替換了為Quantizedxxxx(Reference)版本,其他比如maxpooling和add、concat的不需要變動,到時候在轉trt的時候,在trt內部會進行處理:

GraphModule(
  (backbone): Module(
    (conv1): ConvReLU2d(
      (0): QuantizedConv2d(Reference)(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
      (1): ReLU(inplace=True)
    )
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Module(
      (0): Module(
        (conv1): ConvReLU2d(
          (0): QuantizedConv2d(Reference)(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv2): ConvReLU2d(
          (0): QuantizedConv2d(Reference)(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
        (conv3): QuantizedConv2d(Reference)(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (downsample): Module(
          (0): QuantizedConv2d(Reference)(64, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (relu): ReLU(inplace=True)
      )
    )
      ...
  (upsampler): Module(
    (deconv_layers): Module(
      (0): QuantizedConv2d(Reference)(2048, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): QuantizedConvTranspose2d(Reference)(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU(inplace=True)
      ...
  )
  (head): Module(
    (hm): Module(
      (0): ConvReLU2d(
        (0): QuantizedConv2d(Reference)(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (2): QuantizedConv2d(Reference)(64, 2, kernel_size=(1, 1), stride=(1, 1))
    )
  )
    ...
))

同時也可以看此時模型的IR資訊:

opcode         name                                                        target                                                                  args                                                                                                                                                                      kwargs
-------------  ----------------------------------------------------------  ----------------------------------------------------------------------  ------------------------------------------------------------------------------------------------------------------------------------------------------------------------  --------
placeholder    input_1                                                     input                                                                   ()                                                                                                                                                                        {}
get_attr       backbone_base_base_layer_0_input_scale_0                    backbone_base_base_layer_0_input_scale_0                                ()                                                                                                                                                                        {}
get_attr       backbone_base_base_layer_0_input_zero_point_0               backbone_base_base_layer_0_input_zero_point_0                           ()                                                                                                                                                                        {}
call_function  quantize_per_tensor                                         <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (input_1, backbone_base_base_layer_0_input_scale_0, backbone_base_base_layer_0_input_zero_point_0, torch.qint8)                                                           {}
call_method    dequantize                                                  dequantize                                                              (quantize_per_tensor,)                                                                                                                                                    {}
call_module    backbone_base_base_layer_0                                  backbone.base.base_layer.0                                              (dequantize,)                                                                                                                                                             {}
get_attr       backbone_base_base_layer_0_output_scale_0                   backbone_base_base_layer_0_output_scale_0                               ()                                                                                                                                                                        {}
get_attr       backbone_base_base_layer_0_output_zero_point_0              backbone_base_base_layer_0_output_zero_point_0                          ()                                                                                                                                                                        {}
call_function  quantize_per_tensor_1                                       <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (backbone_base_base_layer_0, backbone_base_base_layer_0_output_scale_0, backbone_base_base_layer_0_output_zero_point_0, torch.qint8)                                      {}
call_method    dequantize_1                                                dequantize                                                              (quantize_per_tensor_1,)                                                                                                                                                  {}
call_module    backbone_base_level0_0                                      backbone.base.level0.0                                                  (dequantize_1,)                                                                                                                                                           {}
get_attr       backbone_base_level0_0_output_scale_0                       backbone_base_level0_0_output_scale_0                                   ()                                                                                                                                                                        {}
get_attr       backbone_base_level0_0_output_zero_point_0                  backbone_base_level0_0_output_zero_point_0                              ()                                                                                                                                                                        {}
call_function  quantize_per_tensor_2                                       <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (backbone_base_level0_0, backbone_base_level0_0_output_scale_0, backbone_base_level0_0_output_zero_point_0, torch.qint8)                                                  {}
call_method    dequantize_2                                                dequantize                                                              (quantize_per_tensor_2,)                                                                                                                                                  {}
call_module    backbone_base_level1_0                                      backbone.base.level1.0                                                  (dequantize_2,)                                                                                                                                                           {}
get_attr       backbone_base_level1_0_output_scale_0                       backbone_base_level1_0_output_scale_0                                   ()                                                                                                                                                                        {}
get_attr       backbone_base_level1_0_output_zero_point_0                  backbone_base_level1_0_output_zero_point_0                              ()                                                                                                                                                                        {}
call_function  quantize_per_tensor_3                                       <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (backbone_base_level1_0, backbone_base_level1_0_output_scale_0, backbone_base_level1_0_output_zero_point_0, torch.qint8)                                                  {}
call_method    dequantize_3                                                dequantize                                                              (quantize_per_tensor_3,)                                                                                                                                                  {}
call_module    backbone_base_level2_downsample                             backbone.base.level2.downsample                                         (dequantize_3,)                                                                                                                                                           {}
get_attr       backbone_base_level2_downsample_output_scale_0              backbone_base_level2_downsample_output_scale_0                          ()                                                                                                                                                                        {}
get_attr       backbone_base_level2_downsample_output_zero_point_0         backbone_base_level2_downsample_output_zero_point_0                     ()                                                                                                                                                                        {}
call_function  quantize_per_tensor_4                                       <built-in method quantize_per_tensor of type object at 0x7f4f0d8491a0>  (backbone_base_level2_downsample, backbone_base_level2_downsample_output_scale_0, backbone_base_level2_downsample_output_zero_point_0, torch.qint8)                       {}                                                                                                                                            

至此,我們就得到了量化後的模型,這個模型的型別是GraphModule,和nn.Module類似,有對應的forward函數。我們可以直接在Pytorch中執行這個模型測試精度,不過需要注意,這裡僅僅是測試模擬的量化模型精度,也是測試校準後得到的scale和offset有沒有問題,在轉化為TensorRT後精度可能會略有差異,畢竟實際推理框架內部實現的一些運算元細節我們是不知道的。

type(quantized_fx)
<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>

再提一句,其實目前FX的量化對於TensorRT的轉換是比較友好的,FX只需要把量化後模型的QDQ層相應地轉化為trt的QDQ層即可,這時TensorRT的network中會包含TensorRT定義的QDQ層,TensorRT內部會對QDQ層進行自動優化,最終生成的engine中QDQ中的引數已經被吸進其它層中,也算是圖優化過程的一部分。

執行模擬量化模型

我這邊簡單在COCO資料集上測試了下量化前後的Centernet模型精度,直接測試的mAP,精度誤差相差在1%以內,一般來說檢測模型在1%以內都算正常。

再強調下,我這裡的模型在量化後預設是reference模式,也就是模擬量化的方式(因為之後要轉為TensorRT),此時的量化模型執行的精度還是FP32,只不過模型中的運算元會在計算時進行quantizedequantize的操作。

為啥要這樣搞,這樣搞可以方便地不需要硬體(也就是可以實際執行INT8指令集的硬體)便可以模擬量化誤差,方便定位問題,如果模擬量化過程中就已經有問題了,那麼在硬體上執行肯定也有問題。但反之則不然,如果在硬體上執行發現精度不夠,但是模擬量化的精度夠,那就是INT8運算元實現的bug問題了。

conv2d舉例子,Pytorch模擬量化的運算元在pytorch/torch/nn/quantized/_reference/modules/目錄下:

class Conv2d(_ConvNd, nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True,
                 padding_mode='zeros',
                 device=None,
                 dtype=None,
                 weight_qparams: Optional[Dict[str, Any]] = None):
        nn.Conv2d.__init__(
            self, in_channels, out_channels, kernel_size, stride, padding, dilation,
            groups, bias, padding_mode, device, dtype)
        self._init_weight_qparams(weight_qparams, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        we have:
        w(float) -- quant - dequant \
        x(float) ------------- F.conv2d ---

        In the full model, we will see
        w(float) -- quant - *dequant \
        x -- quant --- *dequant --  *F.conv2d --- *quant - dequant
        and the backend should be able to fuse the ops with `*` into a quantized conv2d
        """
        weight_quant_dequant = self.get_weight()  # 對權重進行量化以及反量化操作
        result = F.conv2d(
            x, weight_quant_dequant, self.bias, self.stride,
            self.padding, self.dilation, self.groups)
        return result

    def _get_name(self):
        return "QuantizedConv2d(Reference)"

    @classmethod
    def from_float(cls, float_conv, weight_qparams):
        return _ConvNd.from_float(cls, float_conv, weight_qparams)

forward輸入的input是上一層quantize + dequantize後的input,權重也是quantize + dequantize的權重,而執行的conv2d是FP32實現的,體現了一個模擬的過程。我們也可以補充一個forward_fp32成員方法,使用原始的FP32權重就可以,來實現非量化的操作,用於作對比。

DEBUG 精度

利用reference模型,我們可以自己寫個簡單的小工具,來跑一下模擬量化模型的每一層精度怎麼樣如何。參照官方教學中的ShapeProp類,我們可以模仿著寫一個:

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        # 主要修改以下部分
        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

將for迴圈中的推理部分修改為,其中forward_fp32是上節提到補充的FP32實現方法,用於作對比:

# op_sim即當前的node-op
result_fp32_layer = op_sim.forward_fp32(*load_arg(node.args), **load_arg(node.kwargs))
result_int8_layer = op_sim(*load_arg(node.args), **load_arg(node.kwargs))
result_fp32_model = op_sim.forward_fp32(*load_arg_fp32(node.args), **load_arg_fp32(node.kwargs))
activation_dif_accmulated = torch_cosine_similarity(result_int8_layer, result_fp32_model)
activation_dif_layer = torch_cosine_similarity(result_int8_layer, result_fp32_layer)
weight_dif = torch_cosine_similarity(op_sim.weight, op_sim.get_weight())

對比三個地方:

  • 當前啟用層FP32-INT8誤差
  • 當前啟用層FP32-INT8累計誤差
  • 當前層權重誤差

以下是COCO資料集在Centernet下的精度對比資訊,一般來說,餘弦相似度大於0.99就問題不大:

Quantize similarity : 
dequantize [activation_dif_layer:0.9945, activation_dif_accmulated:0.9945]
backbone_conv1 [activation_dif_layer:1.0000, activation_dif_accmulated:0.9975, weight_dif:1.0000]
dequantize_1 [activation_dif_layer:0.9999, activation_dif_accmulated:0.9978]
backbone_maxpool [activation_dif_layer:0.9999, activation_dif_accmulated:0.9978, weight_dif:1.0000]
dequantize_2 [activation_dif_layer:1.0000, activation_dif_accmulated:0.9989]
backbone_layer1_0_conv1 [activation_dif_layer:0.9999, activation_dif_accmulated:0.9983, weight_dif:1.0000]
dequantize_3 [activation_dif_layer:0.9999, activation_dif_accmulated:0.9987]
backbone_layer1_0_conv2 [activation_dif_layer:1.0000, activation_dif_accmulated:0.9991, weight_dif:0.9999]
dequantize_4 [activation_dif_layer:0.9999, activation_dif_accmulated:0.9987]
...

模型視覺化

TORCH-FX提供了使用graphviz畫FX模型的視覺化工具——FxGraphDrawer,直接呼叫以下介面就可以畫當前的FX模型了:

g = FxGraphDrawer(quantized_fx, "centernet_fx_quantize")
g.get_main_dot_graph().write_svg("centernet_fx_quantize.svg")

我們來展示下插入量化觀察節點的模型:

這個是經過converter融合後,帶有QDQ的模型:

後記

Pytorch.fx是個有潛力的工具,量化功能做的也不錯,但是實際使用中仍然有很多侷限性,很多功能不完善,有一些bug需要自己去趟。

我自己使用今年2月份的FX可以成功量化模型以及部署TensorRT,但是隔了幾個月再更新就發現變了很多,需要自己花點精力再去同步下。個人感覺FX目前適合嚐鮮或者動手能力強一點的人去用,適合折騰。

下一篇會繼續寫如何將FX量化後的模型轉化為TensorRT,當然還是會有坑,不過,下篇文章見吧~

參考資料