風格遷移StyleTransfer和Pytorch實現

2020-10-14 11:00:20

風格遷移及Pytorch實現

風格遷移,就是利用演演算法學習一幅畫的風格,然後再把這種風格應用到另外一張圖片上。

本篇文章會介紹其原理,並使用Pytorch實現。

在這裡插入圖片描述

在折積中,淺層特徵越具體,深層特徵則越抽象);從風格角度來說,淺層特徵則記錄著顏色紋理等資訊,而深層特徵則會記錄更高階的資訊。

主要方式則是,隨機一張圖片,通過優化內容損失和風格損失,改變該圖,使其內容接近內容圖片,風格上接近風格圖片。

內容損失:直接計算特徵圖的歐式距離

風格損失:計算特徵圖的格拉姆矩陣的歐式距離

格拉姆矩陣的計算方式:

def get_gram_matrix(f_map):
    n, c, h, w = f_map.shape
    if n == 1:
        f_map = f_map.reshape(c, h * w)
        gram_matrix = torch.mm(f_map, f_map.t())
        return gram_matrix
    else:
        raise ValueError('批次應該為1,但是傳入的不為1')

將特徵圖reshape,將寬高的維度合在一起,然後計算其與自身轉置的矩陣乘法即可。

遷移出預先訓練好的VGG19的模型。並輸出五個不同維度的特徵圖。

from torchvision.models import vgg19
from torch import nn
from torchvision.utils import save_image
import torch
import cv2


class VGG19(nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        a = vgg19(True)
        a = a.features
        self.layer1 = a[:4]
        self.layer2 = a[4:9]
        self.layer3 = a[9:18]
        self.layer4 = a[18:27]
        self.layer5 = a[27:36]

    def forward(self, input_):
        out1 = self.layer1(input_)
        out2 = self.layer2(out1)
        out3 = self.layer3(out2)
        out4 = self.layer4(out3)
        out5 = self.layer5(out4)
        return out1, out2, out3, out4, out5

將圖片直接定義為網路引數,來訓練它。這裡直接從原始內容圖訓練,也可以使用白噪聲。

class GNet(nn.Module):
    def __init__(self, image):
        super(GNet, self).__init__()
        self.image_g = nn.Parameter(image.detach().clone())
        # self.image_g = nn.Parameter(torch.rand(image.shape))  # 也可以初始化一張白噪聲訓練 

    def forward(self):
        return self.image_g.clamp(0, 1)  # 為了限定數值範圍。

定義載入圖片函數:

def load_image(path):
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = torch.from_numpy(image).float() / 255
    image = image.permute(2, 0, 1).unsqueeze(0)
    return image

需要使用圖片需要保持形狀一致

首先載入內容圖片風格圖片,再範例化VGG19網路圖片,圖片直接從原內容圖開始訓練。

範例化優化器損失函數

image_content = load_image('c.jpg').cuda()
image_style = load_image('s.jpg').cuda()
net = VGG19().cuda()
g_net = GNet(image_content).cuda()
optimizer = torch.optim.Adam(g_net.parameters())
loss_func = nn.MSELoss().cuda()

計算風格圖片的輸入VGG19的輸出,並得到其格拉姆矩陣

s1, s2, s3, s4, s5 = net(image_style)
s1 = get_gram_matrix(s1).detach().clone()
s2 = get_gram_matrix(s2).detach().clone()
s3 = get_gram_matrix(s3).detach().clone()
s4 = get_gram_matrix(s4).detach().clone()
s5 = get_gram_matrix(s5).detach().clone()

計算內容圖片輸入VGG19的輸出

c1, c2, c3, c4, c5 = net(image_content)
c1 = c1.detach().clone()
c2 = c2.detach().clone()
c3 = c3.detach().clone()
c4 = c4.detach().clone()
c5 = c5.detach().clone()

訓練該圖片。

i = 0
while True:
    """生成圖片,計算損失"""
    image = g_net()
    out1, out2, out3, out4, out5 = net(image)

    """計算分格損失"""
    loss_s1 = loss_func(get_gram_matrix(out1), s1)
    loss_s2 = loss_func(get_gram_matrix(out2), s2)
    loss_s3 = loss_func(get_gram_matrix(out3), s3)
    loss_s4 = loss_func(get_gram_matrix(out4), s4)
    loss_s5 = loss_func(get_gram_matrix(out5), s5)
    loss_s = 0.1*loss_s1 + 0.1*loss_s2 + 0.6*loss_s3 + 0.1*loss_s4 + 0.1*loss_s5

    """計算內容損失"""
    loss_c1 = loss_func(out1, c1)
    loss_c2 = loss_func(out2, c2)
    loss_c3 = loss_func(out3, c3)
    loss_c4 = loss_func(out4, c4)
    loss_c5 = loss_func(out5, c5)
    loss_c = 0.05 * loss_c1 + 0.05 * loss_c2 + 0.15 * loss_c3 + 0.3 * loss_c4 + 0.45 * loss_c5

    """總損失"""
    loss = 0.5*loss_c + 0.5*loss_s

    """清空梯度、計算梯度、更新引數"""
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(i, loss.item(), loss_c.item(), loss_s.item())
    if i % 1000 == 0:
        save_image(image, f'{i}.jpg', padding=0, normalize=True, range=(0, 1))
    i += 1

分別計算風格損失和內容損失,然後求得總損失,優化該損失。

基本迭代一千次即可出效果。

內容圖片為:

在這裡插入圖片描述

幾個圖片的效果展示:

風格圖片生成圖片
在這裡插入圖片描述
/在這裡插入圖片描述
/在這裡插入圖片描述
/在這裡插入圖片描述
在這裡插入圖片描述在這裡插入圖片描述
/在這裡插入圖片描述
在這裡插入圖片描述 />在這裡插入圖片描述
在這裡插入圖片描述在這裡插入圖片描述
/在這裡插入圖片描述
在這裡插入圖片描述在這裡插入圖片描述
在這裡插入圖片描述在這裡插入圖片描述
在這裡插入圖片描述在這裡插入圖片描述
在這裡插入圖片描述在這裡插入圖片描述

調整各個損失不同的比例係數,能夠達到不同的效果。可酌情嘗試。