NumPy高階索引


NumPy - 高階索引

如果一個ndarray是非元組序列,資料型別為整數或布林值的ndarray,或者至少一個元素為序列物件的元組,我們就能夠用它來索引ndarray。高階索引始終返回資料的副本。 與此相反,切片只提供了一個檢視。

有兩種型別的高階索引:整數和布林值。

整數索引

這種機制有助於基於 N 維索引來獲取陣列中任意元素。 每個整數陣列表示該維度的下標值。 當索引的元素個數就是目標ndarray的維度時,會變得相當直接。

以下範例獲取了ndarray物件中每一行指定列的一個元素。 因此,行索引包含所有行號,列索引指定要選擇的元素。

範例 1

import numpy as np 

x = np.array([[1,  2],  [3,  4],  [5,  6]]) 
y = x[[0,1,2],  [0,1,0]]  
print y

輸出如下:

[1  4  5]

該結果包括陣列中(0,0)(1,1)(2,0)位置處的元素。

下面的範例獲取了 4X3 陣列中的每個角處的元素。 行索引是[0,0][3,3],而列索引是[0,2][0,2]

範例 2

import numpy as np 
x = np.array([[  0,  1,  2],[  3,  4,  5],[  6,  7,  8],[  9,  10,  11]])  
print  '我們的陣列是:'  
print x 
print  '\n' 
rows = np.array([[0,0],[3,3]]) 
cols = np.array([[0,2],[0,2]]) 
y = x[rows,cols]  
print  '這個陣列的每個角處的元素是:'  
print y

輸出如下:

我們的陣列是:                                                                 
[[ 0  1  2]                                                                   
 [ 3  4  5]                                                                   
 [ 6  7  8]                                                                   
 [ 9 10 11]]

這個陣列的每個角處的元素是:                                      
[[ 0  2]                                                                      
 [ 9 11]]

返回的結果是包含每個角元素的ndarray物件。

高階和基本索引可以通過使用切片:或省略號...與索引陣列組合。 以下範例使用slice作為列索引和高階索引。 當切片用於兩者時,結果是相同的。 但高階索引會導致複製,並且可能有不同的記憶體布局。

範例 3

import numpy as np 
x = np.array([[  0,  1,  2],[  3,  4,  5],[  6,  7,  8],[  9,  10,  11]])  
print  '我們的陣列是:'  
print x 
print  '\n'  
# 切片
z = x[1:4,1:3]  
print  '切片之後,我們的陣列變為:'  
print z 
print  '\n'  
# 對列使用高階索引 
y = x[1:4,[1,2]] 
print  '對列使用高階索引來切片:'  
print y

輸出如下:

我們的陣列是:
[[ 0  1  2] 
 [ 3  4  5] 
 [ 6  7  8]
 [ 9 10 11]]

切片之後,我們的陣列變為:
[[ 4  5]
 [ 7  8]
 [10 11]]

對列使用高階索引來切片:
[[ 4  5]
 [ 7  8]
 [10 11]]

布林索引

當結果物件是布林運算(例如比較運算子)的結果時,將使用此型別的高階索引。

範例 1

這個例子中,大於 5 的元素會作為布林索引的結果返回。

import numpy as np 
x = np.array([[  0,  1,  2],[  3,  4,  5],[  6,  7,  8],[  9,  10,  11]])  
print  '我們的陣列是:'  
print x 
print  '\n'  
# 現在我們會列印出大於 5 的元素  
print  '大於 5 的元素是:'  
print x[x >  5]

輸出如下:

我們的陣列是:
[[ 0  1  2] 
 [ 3  4  5] 
 [ 6  7  8] 
 [ 9 10 11]] 

大於 5 的元素是:
[ 6  7  8  9 10 11]

範例 2

這個例子使用了~(取補運算子)來過濾NaN

import numpy as np 
a = np.array([np.nan,  1,2,np.nan,3,4,5])  
print a[~np.isnan(a)]

輸出如下:

[ 1.   2.   3.   4.   5.]

範例 3

以下範例顯示如何從陣列中過濾掉非複數元素。

import numpy as np 
a = np.array([1,  2+6j,  5,  3.5+5j])  
print a[np.iscomplex(a)]

輸出如下:

[2.0+6.j  3.5+5.j]