按照pix2pix的要求劃分訓練-測試-驗證資料夾

2020-10-02 12:00:53

本文實現以下功能:
先用 這篇部落格 中的方法生成了檔案後,按照 pix2pix 準備資料的要求 進行資料集的準備前,需要劃分資料類別。


import os
import random
import shutil
from shutil import copy2


def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.1, test_scale=0.1):
    '''
    讀取源資料資料夾,生成劃分好的資料夾,分為trian、val、test三個資料夾進行
    :param src_data_folder: 原始檔夾 /Result
    :param target_data_folder: 目標資料夾 /to/data/
    :param train_scale: 訓練集比例
    :param val_scale: 驗證集比例
    :param test_scale: 測試集比例
    :return:
    '''
    print("開始資料集劃分")
    class_names = os.listdir(src_data_folder)
    split_names = ['train', 'val', 'test']

    # 在目標目錄下建立類別資料夾
    for class_name in class_names:
        class_split_path = os.path.join(target_data_folder, class_name)
        if os.path.exists(class_split_path):
            pass
            # shutil.rmtree(class_split_path)
        else:
            os.mkdir(class_split_path)
        # 然後在 類別資料夾下建立 'train'/'val'/'test'資料夾
        for split_name in split_names:
            split_path = os.path.join(class_split_path, split_name)
            if os.path.exists(split_path):
                # pass
                # 如果該資料夾本來存在,則刪除該資料夾下所有檔案
                shutil.rmtree(split_path)
            os.mkdir(split_path)


    # 按照比例劃分資料集,並進行資料圖片的複製
    # 首先對A進行分類遍歷,同時相應的將B的原始檔夾中的檔案放入B的目標檔案中
    A_class_data_path = os.path.join(src_data_folder, 'A')
    B_class_data_path = os.path.join(src_data_folder, 'B')
    A_all_data = os.listdir(A_class_data_path)
    A_data_length = len(A_all_data)
    A_data_index_list = list(range(A_data_length))
    random.shuffle(A_data_index_list)

    A_train_folder = os.path.join(os.path.join(
        target_data_folder, 'A'), 'train')
    A_val_folder = os.path.join(os.path.join(
        target_data_folder, 'A'), 'val')
    A_test_folder = os.path.join(os.path.join(
        target_data_folder, 'A'), 'test')

    B_train_folder = os.path.join(os.path.join(
        target_data_folder, 'B'), 'train')
    B_val_folder = os.path.join(os.path.join(
        target_data_folder, 'B'), 'val')
    B_test_folder = os.path.join(os.path.join(
        target_data_folder, 'B'), 'test')

    train_stop_flag = A_data_length * train_scale
    val_stop_flag = A_data_length * (train_scale + val_scale)
    current_idx = 0
    train_num = 0
    val_num = 0
    test_num = 0
    for i in A_data_index_list:
        A_src_img_path = os.path.join(
            A_class_data_path, A_all_data[i])
        B_src_img_path = os.path.join(
            B_class_data_path, A_all_data[i])

        if current_idx <= train_stop_flag:
            copy2(A_src_img_path, A_train_folder)
            copy2(B_src_img_path, B_train_folder)
            # print("{}複製到了{}".format(src_img_path, train_folder))
            train_num = train_num + 1
        elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
            copy2(A_src_img_path, A_val_folder)
            copy2(B_src_img_path, B_val_folder)
            # print("{}複製到了{}".format(src_img_path, val_folder))
            val_num = val_num + 1
        else:
            copy2(A_src_img_path, A_test_folder)
            copy2(B_src_img_path, B_test_folder)
            # print("{}複製到了{}".format(src_img_path, test_folder))
            test_num = test_num + 1

        current_idx = current_idx + 1

    print("A類按照{}:{}:{}的比例劃分完成,一共{}張圖片".format(
        train_scale, val_scale, test_scale, A_data_length))
    print("訓練集{}:{}張".format(A_train_folder, train_num))
    print("驗證集{}:{}張".format(A_val_folder, val_num))
    print("測試集{}:{}張".format(A_test_folder, test_num))
    print("B 類的訓練集、驗證集、測試集完全按照 A 類的檔名稱對應分類!")


if __name__ == '__main__':
    src_data_folder = "Result"
    target_data_folder = "to/data"

    # 如果目標目錄不存在,則建立該目錄。
    if os.path.exists(target_data_folder):
        pass
    else:
        os.makedirs(target_data_folder)
    data_set_split(src_data_folder, target_data_folder)