最近發現一個叫GroundingDINO的開集目標檢測演演算法,所謂開集目標檢測就是能檢測的目標類別不侷限於訓練的類別,這個演演算法可以通過輸入文字的prompt然後輸出對應的目標框。可以用來做預標註或者其他應用,比如我們要訓練某個細分場景的演演算法時,我們找不到足夠的已經標註的資料,就可以先用這個演演算法預打標, 與SAM結合,還能做根據text去分割出物體。
GroundingDINO:https://github.com/IDEA-Research/GroundingDINO
原始的專案是一個python指令碼,不適合單人使用,而不是和團隊一起使用。服務化之後,其他人可以通過http請求的方式來存取,而不需要每個人都搭建環境,也便於批次處理資料。
最簡單的是通過flask api把python指令碼包裝一層,這種方式實現簡單,但擴充套件性不夠,比如如果想要動態組batch,就需要自己寫這部分邏輯。更好的方式是使用成熟的模型推理服務,TorchServe就是其中的一種,比較適合pytorch模型(其實其他格式比如onnx也可以),使用TorchServe,我們只用寫好模型的預處理、推理和後處理邏輯,其他的比如範例擴充套件、動態batch、資源監控這些都不需要我們自己實現。我們有其他模型,也可以用同樣的方式服務起來,而不需要為每個模型都寫一個服務。因此本文選擇TorchServe來作為模型的推理服務。
克隆文末的專案後按順序執行下面步驟:
新建一個weights目錄,並把下面的模型放入:
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
新建一個bert-base-uncased 目錄,下載bert模型:
https://huggingface.co/bert-base-uncased/tree/main
config.json
pytorch_model.bin
tokenizer_config.json
tokenizer.json
vocab.txt
Dockerfile:
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
ARG DEBIAN_FRONTEND=noninteractive
#for Chinese User, uncomment this line
# COPY sources.list /etc/apt/sources.list
RUN apt update && \
apt install openjdk-17-jdk -y
RUN apt install git -y
#install python packages
COPY requirements.txt /root/
RUN pip install -r /root/requirements.txt --no-cache -i https://repo.huaweicloud.com/repository/pypi/simple/
docker build -t torchserve:groundingdino .
docker run --rm -it -v $(pwd):/data -w /data torchserve:groundingdino bash -c "torch-model-archiver --model-name groundingdino --version 1.0 --serialized-file weights/groundingdino_swint_ogc.pth --handler grounding_dino_handler.py --extra-files GroundingDINO_SwinT_OGC.py,bert-base-uncased/*"
執行完畢後,將得到一個groundingdino.mar檔案。
根據需要修改服務的設定
docker run -d --name groundingdino -v $(pwd)/model_store:/model_store -p 8080:8080 -p 8081:8081 -p 8082:8082 torchserve:groundingdino bash -c "torchserve --start --foreground --model-store /model_store --models groundingdino=groundingdino.mar"
import requests
import base64
import time
# URL for the web service
url = "http://ip:8080/predictions/groundingdino"
headers = {"Content-Type": "application/json"}
# Input data
with open("test.jpg", "rb") as f:
image = f.read()
data = {
"image": base64.b64encode(image).decode("utf-8"), # base64 encoded image or BytesIO
"caption": "steel pipe", # text prompt, split by "." for multiple phrases
"box_threshold": 0.25, # threshold for object detection
"caption_threshold": 0.25 # threshold for text similarity
}
# Make the request and display the response
resp = requests.post(url=url, headers=headers, json=data)
outputs = resp.json()
'''
the outputs will be like:
{
"boxes": [[0.0, 0.0, 1.0, 1.0]], # list of bounding boxes in xyxy format
"scores": [0.9999998807907104], # list of object detection scores
"phrases": ["steel pipe"] # list of text phrases
}
'''
本文來自部落格園,作者:haoliuhust,轉載請註明原文連結:https://www.cnblogs.com/haoliuhust/p/17435504.html