取出預訓練模型中間層的輸出(pytorch)

2023-03-12 12:00:31

1 遍歷子模組直接提取

對於簡單的模型,可以採用直接遍歷子模組的方法,取出相應name模組的輸出,不對模型做任何改動。該方法的缺點在於,只能得到其子模組的輸出,而對於使用nn.Sequensial()中包含很多層的模型,無法獲得其指定層的輸出

範例 resnet18取出layer1的輸出

from torchvision.models import resnet18
import torch

model = resnet18(pretrained=True)
print("model:", model)
out = []
x = torch.randn(1, 3, 224, 224)
return_layer = "layer1"
for name, module in model.named_children():
    x = module(x)
    if name == return_layer:
        out.append(x.data)
        break
print(out[0].shape)  # torch.Size([1, 64, 56, 56])

IntermediateLayerGetter類

torchvison中提供了IntermediateLayerGetter類,該方法同樣只能得到其子模組的輸出,而對於使用nn.Sequensial()中包含很多層的模型,無法獲得其指定層的輸出

from torchvision.models._utils import IntermediateLayerGetter

IntermediateLayerGetter類的pytorch原始碼

class IntermediateLayerGetter(nn.ModuleDict):
    """
    Module wrapper that returns intermediate layers from a model

    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
    twice in the forward if you want this to work.

    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.

    Args:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
    """
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
        orig_return_layers = return_layers
        return_layers = {str(k): str(v) for k, v in return_layers.items()}

        # 重新構建backbone,將沒有使用到的模組全部刪掉
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

範例 使用IntermediateLayerGetter類 改 resnet34+unet 完整程式碼見gitee

import torch
from torchvision.models import resnet18, vgg16_bn, resnet34
from torchvision.models._utils import IntermediateLayerGetter

model = resnet34()
stage_indices = ['relu', 'layer1', 'layer2', 'layer3', 'layer4']
return_layers = dict([(str(j), f"stage{i}") for i, j in enumerate(stage_indices)])
model= IntermediateLayerGetter(model, return_layers=return_layers)
input = torch.randn(1, 3, 224, 224)
output = model(input)
print([(k, v.shape) for k, v in output.items()])

3 create_feature_extractor函數

使用create_feature_extractor方法,建立一個新的模組,該模組將給定模型中的中間節點作為字典返回,使用者指定的鍵作為字串,請求的輸出作為值。該方法比 IntermediateLayerGetter方法更通用, 不侷限於獲得模型第一層子模組的輸出。比如下面的vgg,池化層都在子模組feature中,上面的方法無法取出,因此推薦使用create_feature_extractor方法。

範例 FCN論文中以vgg為backbone,分別取出三個池化層的輸出

import torch
from torchvision.models import vgg16_bn
from torchvision.models.feature_extraction import create_feature_extractor

model = vgg16_bn()
model = create_feature_extractor(model, {"features.43": "pool5", "features.33": "pool4", "features.23": "pool3"})
input = torch.randn(1, 3, 224, 224)
output = model(input)
print([(k, v.shape) for k, v in output.items()])

4 hook函數

  hook函數是程式中預定義好的函數,這個函數處於原有程式流程當中(暴露一個勾點出來)。我們需要再在有流程中勾點定義的函數塊中實現某個具體的細節,需要把我們的實現,掛接或者註冊(register)到勾點裡,使得hook函數對目標可用。hook 是一種程式設計機制,和具體的語言沒有直接的關係。

  Pytorch的hook程式設計可以在不改變網路結構的基礎上有效獲取、改變模型中間變數以及梯度等資訊。在pytorch中,Module物件有register_forward_hook(hook) 和 register_backward_hook(hook) 兩種方法,兩個的操作物件都是nn.Module類,如神經網路中的折積層(nn.Conv2d),全連線層(nn.Linear),池化層(nn.MaxPool2d, nn.AvgPool2d),啟用層(nn.ReLU)或者nn.Sequential定義的小模組等。register_forward_hook是獲取前向傳播的輸出的,即特徵圖或啟用值register_backward_hook是獲取反向傳播的輸出的,即梯度值。(這邊只講register_forward_hook,其餘見連結

範例 獲取resnet18的avgpool層的輸入輸出

import torch
from torchvision.models import resnet18

model = resnet18()
fmap_block = dict()  # 裝feature map
def forward_hook(module, input, output):
    fmap_block['input'] = input
    fmap_block['output'] = output

layer_name = 'avgpool'
for (name, module) in model.named_modules():
    if name == layer_name:
        module.register_forward_hook(hook=forward_hook)

input = torch.randn(64, 3, 224, 224)
output = model(input)
print(fmap_block['input'][0].shape)
print(fmap_block['output'].shape)

  

參考

1. Pytorch提取預訓練模型特定中間層的輸出

2. Pytorch的hook技術——獲取預訓練/已訓練好模型的特定中間層輸出