我已經訓練完成了yolov5檢測和resnet152分類的模型,下面開始對一張圖片進行檢測分類。
首先用yolo演演算法對貓和狗進行檢測,然後將檢測到的目標進行裁剪,然後用resnet152對裁剪的圖片進行分類。
首先我有以下這些訓練好的模型
貓狗檢測的,貓的分類,狗的分類
我的預測檔案my_detect.py
import os import sys from pathlib import Path from tools_detect import draw_box_and_save_img, dataLoad, predict_classify, detect_img_2_classify_img, get_time_uuid FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from models.common import DetectMultiBackend from utils.general import (non_max_suppression) from utils.plots import save_one_box import config as cfg conf_thres = cfg.conf_thres iou_thres = cfg.iou_thres detect_size = cfg.detect_img_size classify_size = cfg.classify_img_size def detect_img(img, device, detect_weights='', detect_class=[], save_dir=''): # 選擇計算裝置 # device = select_device(device) # 載入資料 imgsz = (detect_size, detect_size) im0s, im = dataLoad(img, imgsz, device) # print(im0) # print(im) # 載入模型 model = DetectMultiBackend(detect_weights, device=device) stride, names, pt = model.stride, model.names, model.pt # print((1, 3, *imgsz)) model.warmup(imgsz=(1, 3, *imgsz)) # warmup pred = model(im, augment=False, visualize=False) # print(pred) pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000) # print(pred) im0 = im0s.copy() # 畫框,儲存圖片 # ret_bytes= None ret_bytes = draw_box_and_save_img(pred, names, detect_class, save_dir, im0, im) ret_li = list() # print(pred) im0_arc = int(im0.shape[0]) * int(im0.shape[1]) count = 1 for det in reversed(pred[0]): # print(det) # print(det) # 目標太小跳過 xyxy_arc = (int(det[2]) - int(det[0])) * (int(det[3]) - int(det[1])) # print(xyxy_arc) if xyxy_arc / im0_arc < 0.01: continue # 裁剪圖片 xyxy = det[:4] im_crop = save_one_box(xyxy, im0, file=Path('im.jpg'), gain=1.1, pad=10, square=False, BGR=False, save=False) # 將裁剪的圖片轉為分類的大小及tensor型別 im_crop = detect_img_2_classify_img(im_crop, classify_size, device) d = dict() # print(det) c = int(det[-1]) label = detect_class[c] # 開始做具體分類 if label == detect_class[0]: classify_predict = predict_classify(cfg.cat_weight, im_crop, device) classify_label = cfg.cat_class[int(classify_predict)] else: classify_predict = predict_classify(cfg.dog_weight, im_crop, device) classify_label = cfg.dog_class[int(classify_predict)] # print(classify_label) d['details'] = classify_label conf = round(float(det[-2]), 2) d['label'] = label+str(count) d['conf'] = conf ret_li.append(d) count += 1 return ret_li, ret_bytes def start_predict(img, save_dir=''): weights = cfg.detect_weight detect_class = cfg.detect_class device = cfg.device ret_li, ret_bytes = detect_img(img, device, weights, detect_class, save_dir) # print(ret_li) return ret_li, ret_bytes if __name__ == '__main__': name = get_time_uuid() save_dir = f'./save/{name}.jpg' # path = r'./test_img/hashiqi20230312_00010.jpg' path = r'./test_img/hashiqi20230312_00116.jpg' # path = r'./test_img/kejiquan20230312_00046.jpg' f = open(path, 'rb') img = f.read() f.close() # print(img) # print(type(img)) img_ret_li, img_bytes = start_predict(img, save_dir=save_dir) print(img_ret_li)
我的tools_detect.py檔案
import datetime import os import random import sys import time from pathlib import Path import torch from PIL import Image from torch import nn from utils.augmentations import letterbox FILE = Path(__file__).resolve() ROOT = FILE.parents[0] # YOLOv5 root directory if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATH ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative from utils.general import (cv2, scale_boxes, xyxy2xywh) from utils.plots import Annotator, colors import numpy as np def bytes_to_ndarray(byte_img): """ 圖片二進位制轉numpy格式 """ image = np.asarray(bytearray(byte_img), dtype="uint8") image = cv2.imdecode(image, cv2.IMREAD_COLOR) return image def ndarray_to_bytes(ndarray_img): """ 圖片numpy格式轉二進位制 """ ret, buf = cv2.imencode(".jpg", ndarray_img) img_bin = Image.fromarray(np.uint8(buf)).tobytes() # print(type(img_bin)) return img_bin def get_time_uuid(): """ :return: 20220525140635467912 :PS :並行較高時尾部亂數增加 """ uid = str(datetime.datetime.fromtimestamp(time.time())).replace("-", "").replace(" ", "").replace(":","").replace(".", "") + str(random.randint(100, 999)) return uid def dataLoad(img, img_size, device, half=False): image = bytes_to_ndarray(img) # print(image.shape) im = letterbox(image, img_size)[0] # padded resize im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB im = np.ascontiguousarray(im) # contiguous im = torch.from_numpy(im).to(device) im = im.half() if half else im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim return image, im def draw_box_and_save_img(pred, names, class_names, save_dir, im0, im): save_path = save_dir fontpath = "./simsun.ttc" for i, det in enumerate(pred): annotator = Annotator(im0, line_width=3, example=str(names), font=fontpath, pil=True) if len(det): det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() count = 1 im0_arc = int(im0.shape[0]) * int(im0.shape[1]) gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] base_path = os.path.split(save_path)[0] file_name = os.path.split(save_path)[1].split('.')[0] txt_path = os.path.join(base_path, 'labels') if not os.path.exists(txt_path): os.mkdir(txt_path) txt_path = os.path.join(txt_path, file_name) for *xyxy, conf, cls in reversed(det): # 目標太小跳過 xyxy_arc = (int(xyxy[2]) - int(xyxy[0])) * (int(xyxy[3]) - int(xyxy[1])) # print(im0.shape, xyxy, xyxy_arc, im0_arc, xyxy_arc / im0_arc) if xyxy_arc / im0_arc < 0.01: continue # print(im0.shape, xyxy) c = int(cls) # integer class label = f"{class_names[c]}{count} {round(float(conf), 2)}" # .encode('utf-8') # print(xyxy) annotator.box_label(xyxy, label, color=colors(c, True)) im0 = annotator.result() count += 1 # print(im0) # print(type(im0)) # im0 為 numpy.ndarray型別 # Write to file # print('+++++++++++') xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh # print(xywh) line = (cls, *xywh) # label format with open(f'{txt_path}.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') cv2.imwrite(save_path, im0) ret_bytes = ndarray_to_bytes(im0) return ret_bytes def predict_classify(model_path, img, device): # im = torch.nn.functional.interpolate(img, (160, 160), mode='bilinear', align_corners=True) # print(device) if torch.cuda.is_available(): model = torch.load(model_path) else: model = torch.load(model_path, map_location='cpu') # print(help(model)) model.to(device) model.eval() predicts = model(img) _, preds = torch.max(predicts, 1) pred = torch.squeeze(preds) # print(pred) return pred def detect_img_2_classify_img(img, classify_size, device): im_crop1 = img.copy() im_crop1 = np.float32(im_crop1) image = cv2.resize(im_crop1, (classify_size, classify_size)) image = image.transpose((2, 0, 1)) im = torch.from_numpy(image).unsqueeze(0) im_crop = im.to(device) return im_crop
我的config.py檔案
import torch import os base_path = r'.\weights' detect_weight = os.path.join(base_path, r'cat_dog_detect/best.pt') detect_class = ['貓', '狗'] cat_weight = os.path.join(base_path, r'cat_predict/best.pt') cat_class = ['東方短毛貓', '亞洲豹貓', '加菲貓', '安哥拉貓', '布偶貓', '德文捲毛貓', '折耳貓', '無毛貓', '暹羅貓', '森林貓', '橘貓', '奶牛貓', '獰貓', '獅子貓', '狸花貓', '玳瑁貓', '白貓', '藍貓', '藍白貓', '藪貓', '金漸層貓', '阿比西尼亞貓', '黑貓'] dog_weight = os.path.join(base_path, r'dog_predict/best.pt') dog_class = ['中華田園犬', '博美犬', '吉娃娃', '哈士奇', '喜樂蒂', '巴哥犬', '德牧', '拉布拉多犬', '杜賓犬', '松獅犬', '柯基犬', '柴犬', '比格犬', '比熊', '法國鬥牛犬', '秋田犬', '約克夏', '羅威納犬', '臘腸犬', '薩摩耶', '西高地白梗犬', '貴賓犬', '邊境牧羊犬', '金毛犬', '阿拉斯加犬', '雪納瑞', '馬爾濟斯犬'] # device = 0 # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cpu') conf_thres = 0.5 iou_thres = 0.45 detect_img_size = 416 classify_img_size = 160
整體檔案結構
其中models和utils資料夾都是yolov5原始碼的檔案
執行my_detect.py的結果