np.repeat()的座標軸問題

2020-10-12 13:00:29

numpy模組中的repeat函數,總是會出現設定axis座標軸的情況,這時的座標軸有時候就顯的十分混亂,每到此處就不知道該給axis什麼值。特寫一篇部落格來詳細說明這個問題。
程式碼如下:

import numpy as np


class Debug:
    def __init__(self):
        self.array1 = np.array([[1, 2], [3, 4]])

    def mainProgram(self):
        print("The value of array1 is: ")
        print(self.array1)
        print("The repeated array is: ")
        array2 = np.repeat(self.array1, repeats=1)
        print(array2)


if __name__ == '__main__':
    main = Debug()
    main.mainProgram()
"""
The value of array1 is: 
[[1 2]
 [3 4]]
The repeated array is: 
[1 2 3 4]
"""    

我們可以看到我們輸入的是一個二維陣列,當我們設定repeats值為1時,輸出結果變成了一個一維陣列,因此這時的np.repeats函數類似numpy.ndarray.flatten()函數的功能。
接下來我們研究一下關於axis座標軸地問題,二維情況,程式碼如下:

import numpy as np


class Debug:
    def __init__(self):
        self.array1 = np.array([[1, 2], [3, 4]])

    def mainProgram(self):
        print("The value of array1 is: ")
        print(self.array1)
        print("The array2 is: ")
        array2 = np.repeat(self.array1, repeats=2, axis=0)
        print(array2)
        print("The array3 is: ")
        array3 = np.repeat(self.array1, repeats=2, axis=1)
        print(array3)


if __name__ == '__main__':
    main = Debug()
    main.mainProgram()
"""
The value of array1 is: 
[[1 2]
 [3 4]]
The array2 is: 
[[1 2]
 [1 2]
 [3 4]
 [3 4]]
The array3 is: 
[[1 1 2 2]
 [3 3 4 4]]
"""

我們可以看到,axis=0時表示沿著y方向重複,axis=1時表示沿著x方向重複。我們可以對比numpy陣列的座標軸表示,二維時,座標軸為(y, x),從左向右第一個引數0便代表y軸,1代表x軸。
接下來我們研究一下三維情況,程式碼如下:

import numpy as np


class Debug:
    def __init__(self):
        self.array1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

    def mainProgram(self):
        print("The value of array1 is: ")
        print(self.array1)
        print("The array2 is: ")
        array2 = np.repeat(self.array1, repeats=2, axis=0)
        print(array2)
        print("The array3 is: ")
        array3 = np.repeat(self.array1, repeats=2, axis=1)
        print(array3)
        print("The array4 is: ")
        array4 = np.repeat(self.array1, repeats=2, axis=2)
        print(array4)


if __name__ == '__main__':
    main = Debug()
    main.mainProgram()
"""
The value of array1 is: 
[[[1 2]
  [3 4]]

 [[5 6]
  [7 8]]]
The array2 is: 
[[[1 2]
  [3 4]]

 [[1 2]
  [3 4]]

 [[5 6]
  [7 8]]

 [[5 6]
  [7 8]]]
The array3 is: 
[[[1 2]
  [1 2]
  [3 4]
  [3 4]]

 [[5 6]
  [5 6]
  [7 8]
  [7 8]]]
The array4 is: 
[[[1 1 2 2]
  [3 3 4 4]]

 [[5 5 6 6]
  [7 7 8 8]]]
"""

我們可以看到,axis=0對應與沿著z軸重複,axis=1對應沿著y軸重複,axis=2對應沿著x軸重複。對比numpy座標軸的表示,我們知道三維座標軸為(z, y, x),所以從左向右,0對應z軸,1對應y軸,2對應x軸。

如果大家覺得有用,請高擡貴手給一個贊讓我上推薦讓更多的人看到吧~