深度學習(四)——torchvision中資料集的使用

2023-07-14 06:00:44

一、 科研資料集

下載連結:

https://pytorch.org/vision/stable/index.html

本文中我們使用的是\(CIFAR\)資料集

二、CIFAR10資料集詳解

具體網站:

CIFAR10 — Torchvision 0.15 documentation

1. 引數詳解

  • torchvision中每個資料集的引數都是大同小異的,這裡只介紹CIFAR10資料集

  • 該資料集的資料格式為PIL格式

class torchvision.datasets.CIFAR10(root:str,train:bool=True,transform:Optional[Callable]=None,target_transform:Optional[Callable]=None,download:bool=False)
  • 內建函數:

    • root(string):必須設定,輸入資料集下載後存放在電腦中的路徑

    • train(bool):True代表建立的一個訓練集(train);False代表建立一個測試集(test)。

    • transform:對資料集中的資料進行變換

    • target_transform:對標籤(target)資料進行變換

    • download(bool):True的時候會自動從網上下載這個資料集,False的時候則不會下載該資料集。

  • 程式碼範例:

    • 執行後直接下載資料集

    • 需要注意的是,如果下載速度過慢,則可以在執行後,把彈出的網址單拎出來,放到迅雷等軟體上進行下載

import torchvision

#設定訓練集
#root:設定為相對路徑,會在該.py檔案下設定一個名為dataset的檔案存放CIFAR10資料
#train: True,資料集為訓練集
#download: 下載該資料集
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)

#設定測試集;train=False
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
  • 資料標籤檢視:

    • 在執行上面的程式碼下載好資料集後,輸入print(test_set[0),並使用一下pycharm的dubug功能,不難發現:

    • 也就是說,資料標籤有'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'十類,分別用整數0~9來表示

    • 資料集包含的所有標籤也可以用下面的程式碼列印出來:

print(test_set.classes)
#[Run] [airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  • 某條資料的PIL Image、標籤的獲取方法:img,target=test_set[索引]
img,target=test_set[0]
print(img)
print(target,test_set.classes[target])

#[Run]
#<PIL.Image.Image image mode=RGB size=32x32 at 0x1DDF9FCD640>
#3  cat
  • 顯示圖片:
img.show()

三、使用transform處理多組影象資料

程式碼範例

  • 首先使用\(Compose\)去定義如何處理PIL影象資料

  • 然後代入\(torchvision.datasets.CIFAR10\)中,處理裡面的影象資料

#首先用Compose處理影象資料,可以先轉為tensor格式,然後再裁剪等,這裡只轉tensor格式
import torchvision
dataset_transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

#定義transform=dataset_transform,使得影象資料型別轉換為Compose中處理過後的
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
  • 對處理過後的影象進行視覺化操作
from torch.utils.tensorboard import SummaryWriter
writer=SummaryWriter("p10")
for i in range(10): #顯示test_set資料集中的前十張圖片
    img,target=test_set[i]
    writer.add_image("test_set",img,i)
writer.close()