【煉丹Trick】EMA的原理與實現

2022-07-10 12:00:32

在進行深度學習訓練時,同一模型往往可以訓練出不同的效果,這就是煉丹這件事的玄學所在。使用一些trick能夠讓你更容易追上目前SOTA的效果,一些流行的開原始碼中已經整合了不少trick,值得學習一番。本節介紹EMA這一方法。

1.原理:

EMA也就是指數移動平均(Exponential moving average)。其公式非常簡單,如下所示:

\(\theta_{\text{EMA}, t+1} = (1 - \lambda) \cdot \theta_{\text{EMA}, t} + \lambda \cdot \theta_{t}\)

\(\theta_{t}\)是t時刻的網路引數,\(\theta_{\text{EMA}, t}\)是t時刻滑動平均後的網路引數,那麼t+1時刻的滑動平均結果就是這兩者的加權融合。這裡 \(\lambda\)通常會取接近於1的數,比如0.9995,數位越大平均的效果就比較強。

值得注意的是,這裡可以看成有兩個模型,基礎模型其引數按照常規的前後向傳播來更新,另外一個模型則是基礎模型的滑動平均版本,它並不直接參與前後向傳播,僅僅是利用基礎模型的引數結果來更新自己。

EMA為什麼會有效呢?大概是因為在訓練的時候,會使用驗證集來衡量模型精度,但其實驗證集精度並不和測試集一致,在訓練後期階段,模型可能已經在測試集最佳精度附近波動,所以使用滑動平均的結果會比使用單一結果更加可靠。感興趣的話可以看看這幾篇論文,論文1,論文2,論文3

2.實現:

Pytorch其實已經為我們實現了這一功能,為了避免自己造輪子可能引入的錯誤,這裡直接學習一下官方的程式碼。這個類的名稱就叫做AveragedModel。程式碼如下所示。
我們需要做的是提供avg_fn這個函數,avg_fn用來指定以何種方式進行平均。

class AveragedModel(Module):
    """
    You can also use custom averaging functions with `avg_fn` parameter.
    If no averaging function is provided, the default is to compute
    equally-weighted average of the weights.
    """
    def __init__(self, model, device=None, avg_fn=None, use_buffers=False):
        super(AveragedModel, self).__init__()
        self.module = deepcopy(model)
        if device is not None:
            self.module = self.module.to(device)
        self.register_buffer('n_averaged',
                             torch.tensor(0, dtype=torch.long, device=device))
        if avg_fn is None:
            def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
                return averaged_model_parameter + \
                    (model_parameter - averaged_model_parameter) / (num_averaged + 1)
        self.avg_fn = avg_fn
        self.use_buffers = use_buffers

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def update_parameters(self, model):
        self_param = (
            itertools.chain(self.module.parameters(), self.module.buffers())
            if self.use_buffers else self.parameters()
        )
        model_param = (
            itertools.chain(model.parameters(), model.buffers())
            if self.use_buffers else model.parameters()
        )
        for p_swa, p_model in zip(self_param, model_param):
            device = p_swa.device
            p_model_ = p_model.detach().to(device)
            if self.n_averaged == 0:
                p_swa.detach().copy_(p_model_)
            else:
                p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
                                                 self.n_averaged.to(device)))
        self.n_averaged += 1


@torch.no_grad()
def update_bn(loader, model, device=None):
    r"""Updates BatchNorm running_mean, running_var buffers in the model.

    It performs one pass over data in `loader` to estimate the activation
    statistics for BatchNorm layers in the model.
    Args:
        loader (torch.utils.data.DataLoader): dataset loader to compute the
            activation statistics on. Each data batch should be either a
            tensor, or a list/tuple whose first element is a tensor
            containing data.
        model (torch.nn.Module): model for which we seek to update BatchNorm
            statistics.
        device (torch.device, optional): If set, data will be transferred to
            :attr:`device` before being passed into :attr:`model`.

    Example:
        >>> loader, model = ...
        >>> torch.optim.swa_utils.update_bn(loader, model)

    .. note::
        The `update_bn` utility assumes that each data batch in :attr:`loader`
        is either a tensor or a list or tuple of tensors; in the latter case it
        is assumed that :meth:`model.forward()` should be called on the first
        element of the list or tuple corresponding to the data batch.
    """
    momenta = {}
    for module in model.modules():
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            module.running_mean = torch.zeros_like(module.running_mean)
            module.running_var = torch.ones_like(module.running_var)
            momenta[module] = module.momentum

    if not momenta:
        return

    was_training = model.training
    model.train()
    for module in momenta.keys():
        module.momentum = None
        module.num_batches_tracked *= 0

    for input in loader:
        if isinstance(input, (list, tuple)):
            input = input[0]
        if device is not None:
            input = input.to(device)

        model(input)

    for bn_module in momenta.keys():
        bn_module.momentum = momenta[bn_module]
    model.train(was_training)

這裡同樣參考官方的範例程式碼,給出滑動平均的實現。ExponentialMovingAverage繼承了AveragedModel,並且複寫了init方法,其實更直接的方法是將ema_avg函數作為引數傳遞給AveragedModel,這裡可能是為了可讀性,避免出現一個孤零零的ema_avg函數。

class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
    """Maintains moving averages of model parameters using an exponential decay.
    ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
    `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
    is used to compute the EMA.
    """

    def __init__(self, model, decay, device="cpu"):
        def ema_avg(avg_model_param, model_param, num_averaged):
            return decay * avg_model_param + (1 - decay) * model_param

        super().__init__(model, device, ema_avg, use_buffers=True)

如何使用呢?方式是比較簡單的,首先是利用當前模型建立出一個滑動平均模型。

model_ema = utils.ExponentialMovingAverage(model, device=device, decay=ema_decay)

然後是進行基礎模型的前後向傳播,更新結束後再對滑動平均版的模型進行引數更新。

output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model_ema.update_parameters(model)