遷移學習(ADDA)《Adversarial Discriminative Domain Adaptation》【已復現遷移】

2023-01-29 06:01:10

論文資訊

論文標題:Adversarial Discriminative Domain Adaptation
論文作者:Eric Tzeng, Judy Hoffman, Kate Saenko, Trevor Darrell
論文來源:CVPR 2017
論文地址:download 
論文程式碼:download
參照次數:3257

1 簡介

  本文主要探討的是:源域和目標域特徵提取器共用引數的必要性。

  源域和目標域特徵提取器共用引數的代表——DANN。

2 對抗域適應

  標準監督損失訓練源資料:

    $\underset{M_{s}, C}{\text{min}} \quad \mathcal{L}_{\mathrm{cls}}\left(\mathbf{X}_{s}, Y_{t}\right)=  \mathbb{E}_{\left(\mathbf{x}_{s}, y_{s}\right) \sim\left(\mathbf{X}_{s}, Y_{t}\right)}-\sum\limits _{k=1}^{K} \mathbb{1}_{\left[k=y_{s}\right]} \log C\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\quad\quad(1)$

  域對抗:首先使得域鑑別器分類準確,即最小化交叉熵損失 $\mathcal{L}_{\operatorname{adv}_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right)$:

    $\begin{array}{l}\mathcal{L}_{\text {adv }_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right)= -\mathbb{E}_{\mathbf{x}_{s} \sim \mathbf{X}_{s}}\left[\log D\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\right] -\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log \left(1-D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right)\right]\end{array} \quad\quad(2)$

  其次,源對映和目標對映根據一個受約束的對抗性目標進行優化(使得域鑑別器損失最大)。

  域對抗技術的通用公式如下:

    $\begin{array}{l}\underset{D}{\text{min}}  & \mathcal{L}_{\mathrm{adv}_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right) \\\underset{M_{s}, M_{t}}{\text{min}}  & \mathcal{L}_{\mathrm{adv}_{M}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, D\right) \\\text { s.t. } & \psi\left(M_{s}, M_{t}\right)\end{array}\quad\quad(3)$

2.1 源域和目標域對映

  

  歸結為三個問題:

    • 選擇生成式模型還是判別式模型?
    • 針對源域與目標域的對映是否共用引數?
    • 損失函數如何定義?

2.2 Adversarial losses

  回顧DANN 的訓練方式:DANN 的梯度反轉層優化對映,使鑑別器損失最大化

    $\mathcal{L}_{\text {adv }_{M}}=-\mathcal{L}_{\mathrm{adv}_{D}}\quad\quad(6)$

  這個目標可能有問題,因為在訓練的早期,鑑別器快速收斂,導致梯度消失。

  當訓練 GANs 時,而不是直接使用 minimax,通常是用帶有倒置標籤[10]的標準損失函數來訓

  回顧 GAN :GAN將優化分為兩個獨立的目標,一個用於生成器,另一個用於鑑別器。訓練生成器的時候,其中 $\mathcal{L}_{\mathrm{adv}_{D}}$ 保持不變,但 $\mathcal{L}_{\mathrm{adv}_{M}}$ 變成:

    $\mathcal{L}_{\mathrm{adv}_{M}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, D\right)=-\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right] \quad\quad(7)$

  Note:$\mathbf{x}_{t}$ 代表噪聲資料,這裡是使得噪聲資料儘可能迷惑鑑別器。

adversarial_loss = torch.nn.BCELoss()  # 損失函數(二分類交叉熵損失)
generator = Generator()           #生成器
discriminator = Discriminator()   #鑑別器

optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))  # 生成器優化器
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))   # 鑑別器優化器

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  #torch.Size([64, 1])
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)   #torch.Size([64, 1])
        real_imgs = Variable(imgs.type(Tensor))     #torch.Size([64, 1, 28, 28])   真實資料

        # ----------------------> 訓練生成器  [生成器使用噪聲資料,使得其儘可能為真,迷惑鑑別器]
        optimizer_G.zero_grad()
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))    #torch.Size([64, 100])
        gen_imgs = generator(z)        #torch.Size([64, 1, 28, 28])
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # ----------------------> 訓練鑑別器  [ 儘可能將真實資料和噪聲資料區分開]
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
GAN code

  本文采用的方法類似於  GAN 。

3 對抗性域適應

  與之前方法不同: 

  

  本文方法:

  

  首先:Pretrain ,使用源域訓練一個分類器;[ 公式 9 第一個子公式]

  其次:Adversarial Adaption 

    • 使用源域和目標域資料,訓練一個域鑑別器 Discriminator ,是的鑑別器儘可能區分源域和目標域資料 ;[ 公式 9 第二個子公式]  
    • 使用目標域資料,訓練目標域特徵提取器,儘可能使得域鑑別器區分不出目標域樣本;[ 公式 9 第三個子公式]  

  最後:Testing,在目標域上做 Eval;

  ADDA對應於以下無約束優化:

    $\begin{array}{l}\underset{M_{s}, C}{\text{min}} \quad \mathcal{L}_{\mathrm{cls}}\left(\mathbf{X}_{s}, Y_{s}\right) &=&-\mathbb{E}_{\left(\mathbf{x}_{s}, y_{s}\right) \sim\left(\mathbf{X}_{s}, Y_{s}\right)} \sum_{k=1}^{K} \mathbb{1}_{\left[k=y_{s}\right]} \log C\left(M_{s}\left(\mathbf{x}_{s}\right)\right) \\\underset{D}{\text{min}}  \quad\mathcal{L}_{\text {adv }_{D}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, M_{s}, M_{t}\right)&=& -\mathbb{E}_{\mathbf{x}_{s} \sim \mathbf{X}_{s}}\left[\log D\left(M_{s}\left(\mathbf{x}_{s}\right)\right)\right] \text { - } \mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log \left(1-D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right)\right] \\\underset{M_{t}}{\text{min}}  \quad \mathcal{L}_{\operatorname{adv}_{M}}\left(\mathbf{X}_{s}, \mathbf{X}_{t}, D\right)&=& -\mathbb{E}_{\mathbf{x}_{t} \sim \mathbf{X}_{t}}\left[\log D\left(M_{t}\left(\mathbf{x}_{t}\right)\right)\right] \\\end{array} \quad\quad(9)$

    tgt_encoder.train()
    discriminator.train()

    # setup criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer_tgt = optim.Adam(tgt_encoder.parameters(),lr=params.c_learning_rate,betas=(params.beta1, params.beta2))
    optimizer_discriminator = optim.Adam(discriminator.parameters(),lr=params.d_learning_rate,betas=(params.beta1, params.beta2))
    len_data_loader = min(len(src_data_loader), len(tgt_data_loader))  #149

    for epoch in range(params.num_epochs):
        # zip source and target data pair
        data_zip = enumerate(zip(src_data_loader, tgt_data_loader))
        for step, ((images_src, _), (images_tgt, _)) in data_zip:
            # 2.1 訓練域鑑別器,使得域鑑別器儘可能的準確
            images_src = make_variable(images_src)
            images_tgt = make_variable(images_tgt)
            discriminator.zero_grad()
            feat_src,feat_tgt = src_encoder(images_src) ,tgt_encoder(images_tgt)   # 源域特徵提取  # 目標域特徵提取
            feat_concat = torch.cat((feat_src, feat_tgt), 0)
            pred_concat = discriminator(feat_concat.detach())    # 域分類結果

            label_src = make_variable(torch.ones(feat_src.size(0)).long())   #假設源域的標籤為 1
            label_tgt = make_variable(torch.zeros(feat_tgt.size(0)).long())  #假設目標域域的標籤為 0
            label_concat = torch.cat((label_src, label_tgt), 0)

            loss_critic = criterion(pred_concat, label_concat)
            loss_critic.backward()
            optimizer_discriminator.step()     # 域鑑別器優化

            pred_cls = torch.squeeze(pred_concat.max(1)[1])
            acc = (pred_cls == label_concat).float().mean()

            # 2.2 train target encoder # 使得目標域特徵生成器,儘可能使得域鑑別器區分不出源域和目標域樣本
            optimizer_discriminator.zero_grad()
            optimizer_tgt.zero_grad()
            feat_tgt = tgt_encoder(images_tgt)
            pred_tgt = discriminator(feat_tgt)
            label_tgt = make_variable(torch.ones(feat_tgt.size(0)).long())   #假設目標域域的標籤為 1(錯誤標籤),使得域鑑別器鑑別錯誤
            loss_tgt = criterion(pred_tgt, label_tgt)
            loss_tgt.backward()
            optimizer_tgt.step()  # 目標域 encoder 優化
ADDA Code