風格遷移,就是利用演演算法學習一幅畫的風格,然後再把這種風格應用到另外一張圖片上。
本篇文章會介紹其原理,並使用Pytorch實現。
在折積中,淺層特徵越具體,深層特徵則越抽象);從風格角度來說,淺層特徵則記錄著顏色紋理等資訊,而深層特徵則會記錄更高階的資訊。
主要方式則是,隨機一張圖片,通過優化內容損失和風格損失,改變該圖,使其內容接近內容圖片,風格上接近風格圖片。
內容損失:直接計算特徵圖的歐式距離;
風格損失:計算特徵圖的格拉姆矩陣的歐式距離
格拉姆矩陣的計算方式:
def get_gram_matrix(f_map):
n, c, h, w = f_map.shape
if n == 1:
f_map = f_map.reshape(c, h * w)
gram_matrix = torch.mm(f_map, f_map.t())
return gram_matrix
else:
raise ValueError('批次應該為1,但是傳入的不為1')
將特徵圖reshape,將寬高的維度合在一起,然後計算其與自身轉置的矩陣乘法即可。
遷移出預先訓練好的VGG19的模型。並輸出五個不同維度的特徵圖。
from torchvision.models import vgg19
from torch import nn
from torchvision.utils import save_image
import torch
import cv2
class VGG19(nn.Module):
def __init__(self):
super(VGG19, self).__init__()
a = vgg19(True)
a = a.features
self.layer1 = a[:4]
self.layer2 = a[4:9]
self.layer3 = a[9:18]
self.layer4 = a[18:27]
self.layer5 = a[27:36]
def forward(self, input_):
out1 = self.layer1(input_)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out5 = self.layer5(out4)
return out1, out2, out3, out4, out5
將圖片直接定義為網路引數,來訓練它。這裡直接從原始內容圖訓練,也可以使用白噪聲。
class GNet(nn.Module):
def __init__(self, image):
super(GNet, self).__init__()
self.image_g = nn.Parameter(image.detach().clone())
# self.image_g = nn.Parameter(torch.rand(image.shape)) # 也可以初始化一張白噪聲訓練
def forward(self):
return self.image_g.clamp(0, 1) # 為了限定數值範圍。
定義載入圖片函數:
def load_image(path):
image = cv2.imread(path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = torch.from_numpy(image).float() / 255
image = image.permute(2, 0, 1).unsqueeze(0)
return image
需要使用圖片需要保持形狀一致
首先載入內容圖片和風格圖片,再範例化VGG19網路和圖片,圖片直接從原內容圖開始訓練。
範例化優化器和損失函數。
image_content = load_image('c.jpg').cuda()
image_style = load_image('s.jpg').cuda()
net = VGG19().cuda()
g_net = GNet(image_content).cuda()
optimizer = torch.optim.Adam(g_net.parameters())
loss_func = nn.MSELoss().cuda()
計算風格圖片的輸入VGG19的輸出,並得到其格拉姆矩陣。
s1, s2, s3, s4, s5 = net(image_style)
s1 = get_gram_matrix(s1).detach().clone()
s2 = get_gram_matrix(s2).detach().clone()
s3 = get_gram_matrix(s3).detach().clone()
s4 = get_gram_matrix(s4).detach().clone()
s5 = get_gram_matrix(s5).detach().clone()
計算內容圖片輸入VGG19的輸出
c1, c2, c3, c4, c5 = net(image_content)
c1 = c1.detach().clone()
c2 = c2.detach().clone()
c3 = c3.detach().clone()
c4 = c4.detach().clone()
c5 = c5.detach().clone()
訓練該圖片。
i = 0
while True:
"""生成圖片,計算損失"""
image = g_net()
out1, out2, out3, out4, out5 = net(image)
"""計算分格損失"""
loss_s1 = loss_func(get_gram_matrix(out1), s1)
loss_s2 = loss_func(get_gram_matrix(out2), s2)
loss_s3 = loss_func(get_gram_matrix(out3), s3)
loss_s4 = loss_func(get_gram_matrix(out4), s4)
loss_s5 = loss_func(get_gram_matrix(out5), s5)
loss_s = 0.1*loss_s1 + 0.1*loss_s2 + 0.6*loss_s3 + 0.1*loss_s4 + 0.1*loss_s5
"""計算內容損失"""
loss_c1 = loss_func(out1, c1)
loss_c2 = loss_func(out2, c2)
loss_c3 = loss_func(out3, c3)
loss_c4 = loss_func(out4, c4)
loss_c5 = loss_func(out5, c5)
loss_c = 0.05 * loss_c1 + 0.05 * loss_c2 + 0.15 * loss_c3 + 0.3 * loss_c4 + 0.45 * loss_c5
"""總損失"""
loss = 0.5*loss_c + 0.5*loss_s
"""清空梯度、計算梯度、更新引數"""
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(i, loss.item(), loss_c.item(), loss_s.item())
if i % 1000 == 0:
save_image(image, f'{i}.jpg', padding=0, normalize=True, range=(0, 1))
i += 1
分別計算風格損失和內容損失,然後求得總損失,優化該損失。
基本迭代一千次即可出效果。
內容圖片為:
幾個圖片的效果展示:
風格圖片 | 生成圖片 |
---|---|
/ | |
/ | |
/ | |
/ | |
/> | |
/ | |
調整各個損失不同的比例係數,能夠達到不同的效果。可酌情嘗試。