在處理大規模資料時,資料無法全部載入記憶體,我們通常用兩個選項
tfrecords
tf.data.Dataset.from_generator()
tfrecords的並行化使用前文已經有過介紹,這裡不再贅述。如果我們不想生成tfrecord中間檔案,那麼生成器就是你所需要的。
本文主要記錄針對 from_generator()
的並行化方法,在 tf.data
中,並行化主要通過 map
和 num_parallel_calls
實現,但是對一些場景,我們的generator()
中有一些處理邏輯,是無法直接並行化的,最簡單的方法就是將generator()
中的邏輯抽出來,使用map
實現。
對generator()
中的複雜邏輯,我們對其進行簡化,即僅在生成器中做一些下標取值的型別操作,將generator()
中處理部分使用py_function
包裹(wrapped) ,然後呼叫map處理。
def func(i):
i = i.numpy() # Decoding from the EagerTensor object
x, y = your_processing_function(training_set[i])
return x, y
z = list(range(len(training_set))) # The index generator
dataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)
dataset = dataset.map(lambda i: tf.py_function(func=func,
inp=[i],
Tout=[tf.uint8,
tf.float32]
),
num_parallel_calls=tf.data.AUTOTUNE)
由於隱式推斷的原因,有時tensor的輸出shape是未知的,需要額外處理
dataset = dataset.batch(8)
def _fixup_shape(x, y):
x.set_shape([None, None, None, nb_channels]) # n, h, w, c
y.set_shape([None, nb_classes]) # n, nb_classes
return x, y
dataset = dataset.map(_fixup_shape)
為什麼需要 tf.py_function
,先來看下tf.Tensor
與tf.EagerTensor
EagerTensor是實時的,可以在任何時候獲取到它的值,即通過numpy獲取
Tensor是非實時的,它是靜態圖中的元件,只有當喂入資料、運算完成才能獲得該Tensor的值,
map中對映的函數運算,而僅僅是告訴dataset,你每一次拿出來的樣本時要先進行一遍function運算之後才使用的,所以function的呼叫是在每次迭代dataset的時候才呼叫的,屬於靜態圖邏輯
tensorflow.python.framework.ops.EagerTensor
tensorflow.python.framework.ops.Tensor
tf.py_function
在這裡起了什麼作用?
Wraps a python function into a TensorFlow op that executes it eagerly.
剛才說到map資料靜態圖邏輯,預設引數都是Tensor。而 使用tf.py_function()
包裝後,引數就變成了EagerTensor。
【2】https://blog.csdn.net/qq_27825451/article/details/105247211
【3】https://www.tensorflow.org/guide/data_performance