""" @Description : 構建Dataset類,不同的任務,dataset自行編寫,如基於csv,文字等載入標籤,均可從cfg組態檔中讀取後,自行擴充套件編寫 編寫自定義Dataset類時,初始化引數需定義為source_img, cfg。否則資料載入通用模組,data_load_service.py模組會報錯。 source_img :傳入的影象地址資訊 cfg:傳入的設定類資訊,針對不同的任務,可能生成的label模式不同,可基於設定類指定label的載入模式,最終為訓練的影象初始化label (使用者自定義實現) 本例為驗證碼載入類:基於檔名稱生成標籤(如驗證碼:0AaW_54463.png,標籤值為:0AaW,返回one-hot編碼) """import torch from torch.utils.data.dataset import Dataset import torchvision.transforms as transforms import cv2 from universe.data_load.normalize_adapter import NormalizeAdapter from PIL import Image from universe.utils.utils import one_hot classTrainDataset(Dataset): """ 構建一個 載入原始圖片的dataSet物件 此函數可載入 訓練集資料,基於路徑識別驗證碼真實的label,label在轉換為one-hot編碼 若 驗證集邏輯與訓練集邏輯一樣,驗證集可使用TrainDataset,不同,則需自定義一個,參考如下EvalDataset """def__init__(self, source_img, cfg): self.source_img = source_img self.cfg = cfg self.transform = createTransform(cfg, TrainImgDeal) def__getitem__(self, index): img = cv2.imread(self.source_img[index]) if self.transform isnotNone: img = self.transform(img) # ../ data / train\Qigj_73075.png label = self.source_img[index].split("_")[0][-4:] target = torch.Tensor(one_hot(label)) return img, target, self.source_img[index] def__len__(self): returnlen(self.source_img) classEvalDataset(Dataset): """ 構建一個 載入原始圖片的dataSet物件 此函數可載入 驗證集資料,基於路徑識別驗證碼真實的label,label在轉換為one-hot編碼 """def__init__(self, source_img, cfg): self.source_img = source_img self.cfg = cfg # 若驗證集圖片處理邏輯(增強,調整)與 訓練集不同,可自定義一個EvalImgDeal self.transform = createTransform(cfg, TrainImgDeal) def__getitem__(self, index): img = cv2.imread(self.source_img[index]) if self.transform isnotNone: img = self.transform(img) # ../ data / train\Qigj_73075.png label = self.source_img[index].split("_")[0][-4:] target = torch.Tensor(one_hot(label)) return img, target, self.source_img[index] def__len__(self): returnlen(self.source_img) classPredictDataset(Dataset): """ 構建一個 載入預測圖片的dataSet物件 此函數可載入 測試集資料,應用集資料(返回影象資訊) """def__init__(self, source_img,cfg): self.source_img = source_img # 若預測集圖片處理邏輯(增強,調整)與 訓練集不同,可自定義一個PredictImgDeal self.transform = createTransform(cfg, TrainImgDeal) def__getitem__(self, index): img = cv2.imread(self.source_img[index]) if self.transform isnotNone: img = self.transform(img) # 用於記錄實際的label值(因為應用資料也是指令碼生成的,所以可以知道正確的驗證碼) real_label = self.source_img[index].split("_")[0][-4:] return img, real_label, self.source_img[index] def__len__(self): returnlen(self.source_img) classTrainImgDeal: def__init__(self, cfg): img_size = cfg['target_img_size'] self.h = img_size[0] self.w = img_size[1] def__call__(self, img): img = cv2.resize(img, (self.h, self.w)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = Image.fromarray(img) return img defcreateTransform(cfg, img_deal): my_normalize = NormalizeAdapter.getNormalize(cfg['model_name']) transform = transforms.Compose([ img_deal(cfg), transforms.ToTensor(), my_normalize, ]) return transform
1/100 [9600/10000 (96%)] - ETA: 0:00:19, loss: 0.0003, acc: 0.9234 LR: 0.001000 [VAL] loss: 0.00012, acc: 81.060% 2/100 [9600/10000 (96%)] - ETA: 0:00:23, loss: 0.0000, acc: 0.9977 LR: 0.001000 [VAL] loss: 0.00011, acc: 82.000% 3/100 [9600/10000 (96%)] - ETA: 0:00:20, loss: 0.0000, acc: 0.9978 LR: 0.001000 [VAL] loss: 0.00012, acc: 79.520% 4/100 [9600/10000 (96%)] - ETA: 0:00:18, loss: 0.0000, acc: 0.9888 LR: 0.001000 [VAL] loss: 0.00013, acc: 78.020% 5/100 [9600/10000 (96%)] - ETA: 0:00:18, loss: 0.0000, acc: 0.9824 LR: 0.001000 [VAL] loss: 0.00012, acc: 80.260% 6/100 [9600/10000 (96%)] - ETA: 0:00:19, loss: 0.0000, acc: 0.9903 LR: 0.001000 [VAL] loss: 0.00013, acc: 80.040% 7/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 0.9923 LR: 0.000100 [VAL] loss: 0.00010, acc: 83.900% 8/100 [9600/10000 (96%)] - ETA: 0:00:20, loss: 0.0000, acc: 0.9977 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.280% 9/100 [9600/10000 (96%)] - ETA: 0:00:18, loss: 0.0000, acc: 0.9987 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.400% 10/100 [9600/10000 (96%)] - ETA: 0:00:20, loss: 0.0000, acc: 0.9992 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.600% 11/100 [9600/10000 (96%)] - ETA: 0:00:19, loss: 0.0000, acc: 0.9993 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.460% 12/100 [9600/10000 (96%)] - ETA: 0:00:19, loss: 0.0000, acc: 0.9995 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.600% 13/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 0.9998 LR: 0.000100 [VAL] loss: 0.00009, acc: 85.100% 14/100 [9600/10000 (96%)] - ETA: 0:00:19, loss: 0.0000, acc: 0.9996 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.720% 15/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 0.9998 LR: 0.000100 [VAL] loss: 0.00009, acc: 85.140% 16/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 0.9998 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.720% 17/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 0.9999 LR: 0.000100 [VAL] loss: 0.00009, acc: 85.220% 18/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 0.9999 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.900% 19/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 0.9999 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.980% 20/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 1.0000 LR: 0.000100 [VAL] loss: 0.00009, acc: 85.280% 21/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 0.9999 LR: 0.000100 [VAL] loss: 0.00009, acc: 85.140% 22/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 1.0000 LR: 0.000100 [VAL] loss: 0.00009, acc: 85.140% 23/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 0.9998 LR: 0.000100 [VAL] loss: 0.00009, acc: 84.880% 24/100 [9600/10000 (96%)] - ETA: 0:00:20, loss: 0.0000, acc: 1.0000 LR: 0.000100 [VAL] loss: 0.00010, acc: 85.120% 25/100 [9600/10000 (96%)] - ETA: 0:00:20, loss: 0.0000, acc: 1.0000 LR: 0.000010 [VAL] loss: 0.00009, acc: 85.160% 26/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 1.0000 LR: 0.000010 [VAL] loss: 0.00009, acc: 85.180% 27/100 [9600/10000 (96%)] - ETA: 0:00:21, loss: 0.0000, acc: 1.0000 LR: 0.000010 [VAL] loss: 0.00009, acc: 85.220% [INFO] Early Stop with patient 7 , best is Epoch - 20 :0.852800 -------------------------------------------------- {'model_name': 'mobilenetv3', 'GPU_ID': '', 'class_number': 248, 'random_seed': 42, 'cfg_verbose': True, 'num_workers': 8, 'train_path': 'data/train', 'val_path': 'data/val', 'test_path': 'data/test', 'label_type': 'DIR', 'label_path': '', 'pretrained': 'output/mobilenetv3_e21_0.84700.pth', 'try_to_train_items': 10000, 'save_best_only': True, 'save_one_only': True, 'save_dir': 'output/', 'metrics': ['acc'], 'loss': 'CE', 'show_heatmap': False, 'show_data': False, 'target_img_size': [224, 224], 'learning_rate': 0.001, 'batch_size': 64, 'epochs': 100, 'optimizer': 'Adam', 'scheduler': 'default-0.1-3', 'warmup_epoch': 0, 'weight_decay': 0, 'k_flod': 5, 'start_fold': 0, 'early_stop_patient': 7, 'use_distill': 0, 'label_smooth': 0, 'class_weight': None, 'clip_gradient': 0, 'freeze_nonlinear_epoch': 0, 'dropout': 0.5, 'mixup': False, 'cutmix': False, 'sample_weights': None, 'model_path': '../../config/weight/mobilenet/mobilenetv3_e22_1.00000.pth', 'TTA': False, 'merge': False, 'test_batch_size': 1} -------------------------------------------------- Process finished with exit code 0
def predict(cfg): initConfig(cfg) model = ModelService(cfg) data = DataLoadService(cfg) test_loader = data.getPredictDataloader(PredictDataset) runner = RunnerCaptchaService(cfg, model) modelLoad(cfg['model_path']) res_dict = runner.predict(test_loader) print(len(res_dict)) # to csv res_df = pd.DataFrame.from_dict(res_dict, orient='index', columns=['label']) res_df = res_df.reset_index().rename(columns={'index': 'image_id'}) res_df.to_csv(os.path.join(cfg['save_dir'], 'pre.csv'), index=False, header=True) if __name__ == '__main__': predict(cfg)
微信讚賞
支付寶讚賞