使用梯度上升欺騙神經網路,讓網路進行錯誤的分類

2020-09-22 11:00:09

在本教學中,我將將展示如何使用梯度上升來解決如何對輸入進行錯誤分類。

出如何使用梯度上升改變一個輸入分類

神經網路是一個黑盒。理解他們的決策需要創造力,但他們並不是那麼不透明。

在本教學中,我將向您展示如何使用反向傳播來更改輸入,使其按照想要的方式進行分類。

人類的黑盒

首先讓我們以人類為例。如果我向你展示以下輸入:

很有可能你不知道這是5還是6。事實上,我相信我可以讓你們相信這也可能是8。

現在,如果你問一個人,他們需要做什麼才能把一個東西變成5,你可能會在視覺上做這樣的事情:

如果我想讓你把這個變成8,你可以這樣做:

現在,用幾個if語句或檢視幾個係數不容易解釋這個問題的答案。 並且對於某些型別的輸入(影象,聲音,視訊等),可解釋性無疑會變得更加困難,但並非不可能。

神經網路怎麼處理

一個神經網路如何回答我上面提出的同樣的問題?要回答這個問題,我們可以用梯度上升來做。

這是神經網路認為我們需要修改輸入使其更接近其他分類的方式。

由此產生了兩個有趣的結果。首先,黑色區域是我們需要去除畫素密度的網路物體。第二,黃色區域是它認為我們需要增加畫素密度的地方。

我們可以在這個梯度方向上採取一步,新增梯度到原始影象。當然,我們可以一遍又一遍地重複這個過程,最終將輸入變為我們所希望的預測。

你可以看到圖片左下角的黑斑和人類的想法非常相似。

讓輸入看起來更像8怎麼樣?這是網路認為你必須改變輸入的方式。

值得注意的是,在左下角有一團黑色的物質在中間有一團明亮的物質。如果我們把這個和輸入相加,我們得到如下結果:

在這種情況下,我並不特別相信我們已經將這個5變成了8。但是,我們減少了5的概率,說服你這個是8的論點肯定會更容易使用 右側的圖片,而不是左側的圖片。

梯度

在迴歸分析中,我們通過係數來了解我們所學到的知識。在隨機森林中,我們可以觀察決策節點。

在神經網路中,它歸結為我們如何創造性地使用梯度。為了對這個數位進行分類,我們根據可能的預測生成了一個分佈。

這就是我們說的前向傳播

在前進過程中,我們計算輸出的概率分佈

程式碼類似這樣:

現在假設我們想要欺騙網路,讓它預測輸入x的值為「5」,實現這一點的方法是給它一個影象(x),計算對影象的預測,然後最大化預測標籤「5」的概率。

為此,我們可以使用梯度上升來計算第6個索引處(即label = 5) §相對於輸入x的預測的梯度。

為了在程式碼中做到這一點,我們將輸入x作為引數輸入到神經網路,選擇第6個預測(因為我們有標籤:0,1,2,3,4,5,…),第6個索引意味著標籤「5」。

視覺上這看起來像:

程式碼如下:

當我們呼叫.backward()時,所發生的過程可以通過前面的動畫視覺化。

現在我們計算了梯度,我們可以視覺化並繪製它們:

由於網路還沒有經過訓練,所以上面的梯度看起來像隨機噪聲……但是,一旦我們對網路進行訓練,梯度的資訊會更豐富:

通過回撥實現自動化

這是一個非常有用的工具,幫助闡明在你的網路訓練中發生了什麼。在這種情況下,我們想要自動化這個過程,這樣它就會在訓練中自動發生。

為此,我們將使用PyTorch Lightning來實現我們的神經網路:

import torch
import torch.nn.functional as F
import pytorch_lightning as pl

class LitClassifier(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.TrainResult(loss)

        # enable the auto confused logit callback
        self.last_batch = batch
        self.last_logits = y_hat.detach()

        result.log('train_loss', loss, on_epoch=True)
        return result
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss)
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.005)

可以將自動繪製出此處描述內容的複雜程式碼,抽象為Lightning中的Callback。 Callback回撥是一個小程式,您可能會在訓練的各個部分呼叫它。

在本例中,當處理訓練批次處理時,我們希望生成這些影象,以防某些輸入出現混亂。。

import torch
from pytorch_lightning import Callback
from torch import nn


class ConfusedLogitCallback(Callback):

    def __init__(
            self,
            top_k,
            projection_factor=3,
            min_logit_value=5.0,
            logging_batch_interval=20,
            max_logit_difference=0.1
    ):
        super().__init__()
        self.top_k = top_k
        self.projection_factor = projection_factor
        self.max_logit_difference = max_logit_difference
        self.logging_batch_interval = logging_batch_interval
        self.min_logit_value = min_logit_value

    def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        # show images only every 20 batches
        if (trainer.batch_idx + 1) % self.logging_batch_interval != 0:
            return

        # pick the last batch and logits
        x, y = batch
        try:
            logits = pl_module.last_logits
        except AttributeError as e:
            m = """please track the last_logits in the training_step like so:
                def training_step(...):
                    self.last_logits = your_logits
            """
            raise AttributeError(m)

        # only check when it has opinions (ie: the logit > 5)
        if logits.max() > self.min_logit_value:
            # pick the top two confused probs
            (values, idxs) = torch.topk(logits, k=2, dim=1)

            # care about only the ones that are at most eps close to each other
            eps = self.max_logit_difference
            mask = (values[:, 0] - values[:, 1]).abs() < eps

            if mask.sum() > 0:
                # pull out the ones we care about
                confusing_x = x[mask, ...]
                confusing_y = y[mask]

                mask_idxs = idxs[mask]

                pl_module.eval()
                self._plot(confusing_x, confusing_y, trainer, pl_module, mask_idxs)
                pl_module.train()

    def _plot(self, confusing_x, confusing_y, trainer, model, mask_idxs):
        from matplotlib import pyplot as plt

        confusing_x = confusing_x[:self.top_k]
        confusing_y = confusing_y[:self.top_k]

        x_param_a = nn.Parameter(confusing_x)
        x_param_b = nn.Parameter(confusing_x)

        batch_size, c, w, h = confusing_x.size()
        for logit_i, x_param in enumerate((x_param_a, x_param_b)):
            x_param = x_param.to(model.device)
            logits = model(x_param.view(batch_size, -1))
            logits[:, mask_idxs[:, logit_i]].sum().backward()

        # reshape grads
        grad_a = x_param_a.grad.view(batch_size, w, h)
        grad_b = x_param_b.grad.view(batch_size, w, h)

        for img_i in range(len(confusing_x)):
            x = confusing_x[img_i].squeeze(0).cpu()
            y = confusing_y[img_i].cpu()
            ga = grad_a[img_i].cpu()
            gb = grad_b[img_i].cpu()

            mask_idx = mask_idxs[img_i].cpu()

            fig, axarr = plt.subplots(nrows=2, ncols=3, figsize=(15, 10))
            self.__draw_sample(fig, axarr, 0, 0, x, f'True: {y}')
            self.__draw_sample(fig, axarr, 0, 1, ga, f'd{mask_idx[0]}-logit/dx')
            self.__draw_sample(fig, axarr, 0, 2, gb, f'd{mask_idx[1]}-logit/dx')
            self.__draw_sample(fig, axarr, 1, 1, ga * 2 + x, f'd{mask_idx[0]}-logit/dx')
            self.__draw_sample(fig, axarr, 1, 2, gb * 2 + x, f'd{mask_idx[1]}-logit/dx')

            trainer.logger.experiment.add_figure('confusing_imgs', fig, global_step=trainer.global_step)

    @staticmethod
    def __draw_sample(fig, axarr, row_idx, col_idx, img, title):
        im = axarr[row_idx, col_idx].imshow(img)
        fig.colorbar(im, ax=axarr[row_idx, col_idx])
        axarr[row_idx, col_idx].set_title(title, fontsize=20)

但是,通過安裝pytorch-lightning-bolts,我們讓它變得更容易了

!pip install pytorch-lightning-bolts
from pl_bolts.callbacks.vision import ConfusedLogitCallback

trainer = Trainer(callbacks=[ConfusedLogitCallback(1)])

把它們放在一起

最後,我們可以訓練我們的模型,並在判斷邏輯產生混亂時自動生成影象。

# data
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

# model
model = LitClassifier()

# attach callback
trainer = Trainer(callbacks=[ConfusedLogitCallback(1)])

# train!
trainer.fit(model, DataLoader(train, batch_size=64), DataLoader(val, batch_size=64))

tensorboard會自動生成如下圖片:

看看這個是不是變得不一樣了

作者:William Falcon

完整程式碼:https://colab.research.google.com/drive/16HVAJHdCkyj7W43Q3ZChnxZ7DOwx6K5i?usp=sharing

deephub翻譯組