資料集選用CIFAR-10的資料集,Cifar-10 是由 Hinton 的學生 Alex Krizhevsky、Ilya Sutskever 收集的一個用於普適物體識別的計算機視覺資料集,它包含 60000 張 32 X 32 的 RGB 彩色圖片,總共 10 個分類。其中,包括 50000 張用於訓練集,10000 張用於測試集。
模型需要繼承nn.module
import torch
from torch import nn
class Lenet5(nn.Module):
"""
for cifar10 dataset.
"""
def __init__(self):
super(Lenet5,self).__init__()
self.conv_unit = nn.Sequential(
#input:[b,3,32,32] ===> output:[b,6,x,x]
#Conv2d(Input_channel:輸入的通道數,kernel_channels:折積核的數量,輸出的通道數,kernel_size:折積核的大小,stride:步長,padding:邊緣補足)
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
#池化
nn.MaxPool2d(kernel_size=2,stride=2,padding=0),
#折積層
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
#池化
nn.AvgPool2d(kernel_size=2,stride=2,padding=0)
#output:[b,16,5,5]
)
#flatten
#Linear層
self.fc_unit=nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
#測試折積輸出到全連線層的輸入
#tmp = torch.rand(2,3,32,32)
#out = self.conv_unit(tmp)
#print("conv_out:",out.shape)
#Loss評價 Cross Entropy Loss 分類 在其中包含一個softmax()操作
#self.criteon = nn.MSELoss() 迴歸
#self.criteon = nn.CrossEntropyLoss()
def forward(self,x):
"""
:param x:[b,3,32,32]
:return:
"""
batchsz = x.size(0)
#[b,3,32,32]=>[b,16,5,5]
x = self.conv_unit(x)
#[b,16,5,5]=>[b,16*5*5]
x = x.view(batchsz,16*5*5)
#[b,16*5*5]=>[b,10]
logits = self.fc_unit(x)
return logits
# [b,10]
# pred = F.softmax(logits,dim=1) 這步在CEL中包含了,所以不需要再寫一次
#loss = self.criteon(logits,y)
def main():
net = Lenet5()
tmp = torch.rand(2,3,32,32)
out = net(tmp)
print("lenet_out:",out.shape)
if __name__ == '__main__':
main()
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from lenet5 import Lenet5
import torch.nn.functional as F
from torch import nn,optim
def main():
batch_size = 32
epochs = 1000
learn_rate = 1e-3
#匯入圖片,一次只匯入一張
cifer_train = datasets.CIFAR10('cifar',train=True,transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]),download=True)
#載入圖
cifer_train = DataLoader(cifer_train,batch_size=batch_size,shuffle=True)
#匯入圖片,一次只匯入一張
cifer_test = datasets.CIFAR10('cifar',train=False,transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]),download=True)
#載入圖
cifer_test = DataLoader(cifer_test,batch_size=batch_size,shuffle=True)
#iter迭代器,__next__()方法可以獲得資料
x, label = iter(cifer_train).__next__()
print("x:",x.shape,"label:",label.shape)
#x: torch.Size([32, 3, 32, 32]) label: torch.Size([32])
device = torch.device('cuda')
model = Lenet5().to(device)
print(model)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(),lr=learn_rate)
for epoch in range(epochs):
model.train()
for batchidx,(x,label) in enumerate(cifer_train):
x,label = x.to(device),label.to(device)
logits = model(x)
#logits:[b,10]
loss = criteon(logits,label)
#backprop
optimizer.zero_grad() #梯度清零
loss.backward()
optimizer.step() #梯度更新
#
print(epoch,loss.item())
model.eval()
with torch.no_grad():
#test
total_correct = 0
total_num = 0
for x,label in cifer_test:
x,label = x.to(device),label.to(device)
#[b,10]
logits = model(x)
#[b]
pred =logits.argmax(dim=1)
#[b] vs [b] => scalar tensor
total_correct += torch.eq(pred,label).float().sum().item()
total_num += x.size(0)
acc = total_correct/total_num
print("epoch:",epoch,"acc:",acc)
if __name__ == '__main__':
main()