Pytorch基礎-張量基本操作

2023-01-05 15:01:01

授人以魚不如授人以漁,原汁原味的知識才更富有精華,本文只是對張量基本操作知識的理解和學習筆記,看完之後,想要更深入理解,建議去 pytorch 官方網站,查閱相關函數和操作,英文版在這裡,中文版在這裡。本文的程式碼是在 pytorch1.7 版本上測試的,其他版本一般也沒問題。

一,張量的基本操作

Pytorch 中,張量的操作分為結構操作和數學運算,其理解就如字面意思。結構操作就是改變張量本身的結構,數學運算就是對張量的元素值完成數學運算。

  • 常使用的張量結構操作:維度變換(tranposeview 等)、合併分割(splitchunk等)、索引切片(index_selectgather 等)。
  • 常使用的張量數學運算:標量運算、向量運算、矩陣運算。

二,維度變換

2.1,squeeze vs unsqueeze 維度增減

  • squeeze():對 tensor 進行維度的壓縮,去掉維數為 1 的維度。用法:torch.squeeze(a) 將 a 中所有為 1 的維度都刪除,或者 a.squeeze(1) 是去掉 a中指定的維數為 1 的維度。
  • unsqueeze():對資料維度進行擴充,給指定位置加上維數為 1 的維度。用法:torch.unsqueeze(a, N),或者 a.unsqueeze(N),在 a 中指定位置 N 加上一個維數為 1 的維度。

squeeze 用例程式如下:

a = torch.rand(1,1,3,3)
b = torch.squeeze(a)
c = a.squeeze(1)
print(b.shape)
print(c.shape)

程式輸出結果如下:

torch.Size([3, 3])
torch.Size([1, 3, 3])

unsqueeze 用例程式如下:

x = torch.rand(3,3)
y1 = torch.unsqueeze(x, 0)
y2 = x.unsqueeze(0)
print(y1.shape)
print(y2.shape)

程式輸出結果如下:

torch.Size([1, 3, 3])
torch.Size([1, 3, 3])

2.2,transpose vs permute 維度交換

torch.transpose() 只能交換兩個維度,而 .permute() 可以自由交換任意位置。函數定義如下:

transpose(dim0, dim1) → Tensor  # See torch.transpose()
permute(*dims) → Tensor  # dim(int). Returns a view of the original tensor with its dimensions permuted.

CNN 模型中,我們經常遇到交換維度的問題,舉例:四個維度表示的 tensor:[batch, channel, h, w]nchw),如果想把 channel 放到最後去,形成[batch, h, w, channel]nhwc),如果使用 torch.transpose() 方法,至少要交換兩次(先 1 3 交換再 1 2 交換),而使用 .permute() 方法只需一次操作,更加方便。例子程式如下:

import torch
input = torch.rand(1,3,28,32)                    # torch.Size([1, 3, 28, 32]
print(b.transpose(1, 3).shape)                   # torch.Size([1, 32, 28, 3])
print(b.transpose(1, 3).transpose(1, 2).shape)   # torch.Size([1, 28, 32, 3])
 
print(b.permute(0,2,3,1).shape)                  # torch.Size([1, 28, 28, 3]

三,索引切片

3.1,規則索引切片方式

張量的索引切片方式和 numpy、python 多維列表幾乎一致,都可以通過索引和切片對部分元素進行修改。切片時支援預設引數和省略號。範例程式碼如下:

>>> t = torch.randint(1,10,[3,3])
>>> t
tensor([[8, 2, 9],
        [2, 5, 9],
        [3, 9, 9]])
>>> t[0] # 第 1 行資料
tensor([8, 2, 9])
>>> t[2][2]
tensor(9)
>>> t[0:3,:]  # 第1至第3行,全部列
tensor([[8, 2, 9],
        [2, 5, 9],
        [3, 9, 9]])
>>> t[0:2,:]  # 第1行至第2行
tensor([[8, 2, 9],
        [2, 5, 9]])
>>> t[1:,-1]  # 第2行至最後行,最後一列
tensor([9, 9])
>>> t[1:,::2] # 第1行至最後行,第0列到最後一列每隔兩列取一列
tensor([[2, 9],
        [3, 9]])

以上切片方式相對規則,對於不規則的切片提取,可以使用 torch.index_select, torch.take, torch.gather, torch.masked_select

3.2,gather 和 torch.index_select 運算元

gather 運算元的用法比較難以理解,在翻閱了官方檔案和網上資料後,我有了一些自己的理解。

1,gather 是不規則的切片提取運算元(Gathers values along an axis specified by dim. 在指定維度上根據索引 index 來選取資料)。函數定義如下:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

引數解釋

  • input (Tensor) – the source tensor.
  • dim (int) – the axis along which to index.
  • index (LongTensor) – the indices of elements to gather.

gather 運算元的注意事項:

  • 輸入 input 和索引 index 具有相同數量的維度,即 input.shape = index.shape
  • 對於任意維數,只要 d != dim,index.size(d) <= input.size(d),即對於可以不用索引維數 d 上的全部資料。
  • 輸出 out 和 索引 index 具有相同的形狀。輸入和索引不會相互廣播。

對於 3D tensor,output 值的定義如下:
gather 的官方定義如下:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]   # if dim == 2

通過理解前面的一些定義,相信讀者對 gather 運算元的用法有了一個基本瞭解,下面再結合 2D 和 3D tensor 的用例來直觀理解運算元用法。
(1),對於 2D tensor 的例子:

>>> import torch
>>> a = torch.arange(0, 16).view(4,4)
>>> a
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
>>> index = torch.tensor([[0, 1, 2, 3]])  # 選取對角線元素
>>> torch.gather(a, 0, index)
tensor([[ 0,  5, 10, 15]])

output 值定義如下:

# 按照 index = tensor([[0, 1, 2, 3]])順序作用在行上索引依次為0,1,2,3
a[0][0] = 0
a[1][1] = 5
a[2][2] = 10
a[3][3] = 15

(2),索引更復雜的 2D tensor 例子:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> t
tensor([[1, 2],
        [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

output 值的計算如下:

output[i][j] = input[i][index[i][j]]  # if dim = 1
output[0][0] = input[0][index[0][0]] = input[0][0] = 1
output[0][1] = input[0][index[0][1]] = input[0][0] = 1
output[1][0] = input[1][index[1][0]] = input[1][1] = 4
output[1][1] = input[1][index[1][1]] = input[1][0] = 3

總結:可以看到 gather 是通過將索引在指定維度 dim 上的值替換為 index 的值,但是其他維度索引不變的情況下獲取 tensor 資料。直觀上可以理解為對矩陣進行重排,比如對每一行(dim=1)的元素進行變換,比如 torch.gather(a, 1, torch.tensor([[1,2,0], [1,2,0]])) 的作用就是對 矩陣 a 每一行的元素,進行 permtute(1,2,0) 操作。
2,理解了 gather 再看 index_select 就很簡單,函數作用是返回沿著輸入張量的指定維度的指定索引號進行索引的張量子集。函數定義如下:

torch.index_select(input, dim, index, *, out=None) → Tensor

函數返回一個新的張量,它使用資料型別為 LongTensorindex 中的條目沿維度 dim 索引輸入張量。返回的張量具有與原始張量(輸入)相同的維數。 維度尺寸與索引長度相同; 其他尺寸與原始張量中的尺寸相同。範例程式碼如下:

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

四,合併分割

4.1,torch.cat 和 torch.stack

可以用 torch.cat 方法和 torch.stack 方法將多個張量合併,也可以用 torch.split方法把一個張量分割成多個張量。torch.cattorch.stack 有略微的區別,torch.cat 是連線,不會增加維度,而 torch.stack 是堆疊,會增加一個維度。兩者函數定義如下:

# Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
torch.cat(tensors, dim=0, *, out=None) → Tensor
# Concatenates a sequence of tensors along **a new** dimension. All tensors need to be of the same size.
torch.stack(tensors, dim=0, *, out=None) → Tensor

torch.cattorch.stack 用法範例程式碼如下:

>>> a = torch.arange(0,9).view(3,3)
>>> b = torch.arange(10,19).view(3,3)
>>> c = torch.arange(20,29).view(3,3)
>>> cat_abc = torch.cat([a,b,c], dim=0)
>>> print(cat_abc.shape)
torch.Size([9, 3])
>>> print(cat_abc)
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [10, 11, 12],
        [13, 14, 15],
        [16, 17, 18],
        [20, 21, 22],
        [23, 24, 25],
        [26, 27, 28]])
>>> stack_abc = torch.stack([a,b,c], axis=0)  # torch中dim和axis引數名可以混用
>>> print(stack_abc.shape)
torch.Size([3, 3, 3])
>>> print(stack_abc)
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[10, 11, 12],
         [13, 14, 15],
         [16, 17, 18]],

        [[20, 21, 22],
         [23, 24, 25],
         [26, 27, 28]]])
>>> chunk_abc = torch.chunk(cat_abc, 3, dim=0)
>>> chunk_abc
(tensor([[0, 1, 2],
         [3, 4, 5],
         [6, 7, 8]]),
 tensor([[10, 11, 12],
         [13, 14, 15],
         [16, 17, 18]]),
 tensor([[20, 21, 22],
         [23, 24, 25],
         [26, 27, 28]]))

4.2,torch.split 和 torch.chunk

torch.split()torch.chunk() 可以看作是 torch.cat() 的逆運算。split() 作用是將張量拆分為多個塊,每個塊都是原始張量的檢視。split() 函數定義如下:

"""
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.
If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections.
"""
torch.split(tensor, split_size_or_sections, dim=0)

chunk() 作用是將 tensordim(行或列)分割成 chunkstensor 塊,返回的是一個元組。chunk() 函數定義如下:

torch.chunk(input, chunks, dim=0) → List of Tensors
"""
Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor.
Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks.
Parameters:
    input (Tensor) – the tensor to split
    chunks (int) – number of chunks to return
    dim (int) – dimension along which to split the tensor
"""

範例程式碼如下:

>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))
>>> torch.chunk(a, 2, dim=1)
(tensor([[0],
        [2],
        [4],
        [6],
        [8]]), 
tensor([[1],
        [3],
        [5],
        [7],
        [9]]))

五,折積相關運算元

5.1,上取樣方法總結

上取樣大致被總結成了三個類別:

  1. 基於線性插值的上取樣:最近鄰演演算法(nearest)、雙線性插值演演算法(bilinear)、雙三次插值演演算法(bicubic)等,這是傳統影象處理方法。
  2. 基於深度學習的上取樣(轉置折積,也叫反折積 Conv2dTranspose2d等)
  3. Unpooling 的方法(簡單的補零或者擴充操作)

計算效果:最近鄰插值演演算法 < 雙線性插值 < 雙三次插值。計算速度:最近鄰插值演演算法 > 雙線性插值 > 雙三次插值。

5.2,F.interpolate 取樣函數

Pytorch 老版本有 nn.Upsample 函數,新版本建議用 torch.nn.functional.interpolate,一個函數可實現客製化化需求的上取樣或者下取樣功能,。

F.interpolate() 函數全稱是 torch.nn.functional.interpolate(),函數定義如下:

def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None):  # noqa: F811
    # type: (Tensor, Optional[int], Optional[List[float]], str, Optional[bool], Optional[bool]) -> Tensor
    pass

引數解釋如下:

  • input(Tensor):輸入張量資料;
  • size: 輸出的尺寸,資料型別為 tuple: ([optional D_out], [optional H_out], W_out),和 scale_factor 二選一。
  • scale_factor:在高度、寬度和深度上面的放大倍數。資料型別既可以是 int——表明高度、寬度、深度都擴大同一倍數;也可是 tuple`——指定高度、寬度、深度等維度的擴大倍數。
  • mode: 上取樣的方法,包括最近鄰(nearest),線性插值(linear),雙線性插值(bilinear),三次線性插值(trilinear),預設是最近鄰(nearest)。
  • align_corners: 如果設為 True,輸入影象和輸出影象角點的畫素將會被對齊(aligned),這隻在 mode = linear, bilinear, or trilinear 才有效,預設為 False

例子程式如下:

import torch.nn.functional as F
x = torch.rand(1,3,224,224)
y = F.interpolate(x * 2, scale_factor=(2, 2), mode='bilinear').squeeze(0)
print(y.shape)   # torch.Size([3, 224, 224)

5.3,nn.ConvTranspose2d 反折積

轉置折積(有時候也稱為反折積,個人覺得這種叫法不是很規範),它是一種特殊的折積,先 padding 來擴大影象尺寸,緊接著跟正向折積一樣,旋轉折積核 180 度,再進行折積計算。

參考資料