株式会社神戸デジタル・ラボ DataIntelligenceチームの原口です。
今回はYOLOXを用いてリアルタイム推論にチャレンジします。
YOLOXとは
YOLOXとは2021年に発表されたYOLO系の物体検出モデルです。
YOLOv5・v8と比較するとモデルの学習準備が大変なところもありますが、ライセンスもApache-2.0 licenseと商用利用しやすい点が今でも人気の理由の1つです。
環境構築
YOLOXを使うための環境構築を実施します。まずはGitHubからYOLOXをクローンしましょう。
git clone https://github.com/Megvii-BaseDetection/YOLOX
クローンができれば、作成された「YOLOX」フォルダに移動します。
インストール実行前にリアルタイム推論に余分なライブラリをrequirements.txtからコメントアウトしましょう。(ONNX関連のライブラリはインストール時にエラーが発生することが多いです。今回はONNXを利用しないのてコメントアウトしました。pycocotoolsは環境によってこのままではインストールできない場合があるのでコメントアウトしました)
# TODO: Update with exact module version numpy torch>=1.7 opencv_python loguru tqdm torchvision thop ninja tabulate psutil tensorboard # 下3つをコメントアウト # verified versions # pycocotools corresponds to https://github.com/ppwwyyxx/cocoapi # pycocotools>=2.0.2 # onnx>=1.13.0 # onnx-simplifier==0.4.10
上記対応ができればインストールを実行しましょう。
python setup.py develop
Finished processing dependencies for yolox==0.3.0
と表示されればインストールは完了です。
最後にCOCOデータセットのラベルを利用するためにpycocotoolsをインストールしましょう。
python setup.py develop
内ではインストールが正しく実行されない場合があるため、別途インストールします。
pip install pycocotools
学習済みモデルを準備
続いて検出に利用する学習済みモデルを準備します。今回は最も軽量でCPUで高速に動作可能なYOLOX-Nanoを利用します。
学習済みモデルのダウンロードはYOLOXのGitHub、README内にあるモデル紹介の表(こちら)の右側にある「github」リンクをクリックすることでダウンロードできます。
デモを確認
インストール・学習済みモデルの準備が完了したのでデモコードを確認してみましょう・・・。
なんとデモコード内に「webcam」の表記が。独自にコードを書かなくてもリアルタイム推論できそうです。
まずはデモコードのリアルタイム推論を試してみましょう。
python .\tools\demo.py webcam -n yolox-nano -c .\yolox_nano.pth
ここで
- webcam:デモモードを「image・video・webcam」の3タイプから切り替えられます。
- -n:モデル名を選択します。
- -c:学習済みモデルのパスを指定します。
おお!1コマンドでリアルタイム推論がこんなに簡単に出来てしまう!
本記事は以上・・・、と行きたいところですがこのデモでは検出結果の座標データが得られないため、「検出した物体を切り出す・検出結果を用いて新たにアプリケーションを開発する」ということが出来ません。
検出デモを動かすという意味では「demo.py」は非常に優秀ですが、ソリューションの一部で使うには少し物足りない感じがします。
よって座標データを取れるようにYOLOXのリアルタイム推論を改造しましょう!(ついでにコードの軽量化もしましょう)
コンソールでCtrl + C
でデモを止めてから次の章に移りましょう!
リアルタイム推論の実装
Step1:demo.pyを眺める
demo.pyでリアルタイム推論が実行されていたので、このコードの中から必要なものだけを切り出せば最小限のリアルタイム推論コードを構築出来そうです。
そのためにもまずはdemo.pyを眺めましょう。
if __name__ == "__main__": args = make_parser().parse_args() exp = get_exp(args.exp_file, args.name) main(exp, args)
この部分はデモコードのが実行された際に最初に実行される部分です。
make_parser().parse_args()
は引数を管理しています。
get_exp
は何でしょうか?少し中身を確認してみましょう。
print(exp) ╒═══════════════════╤═══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╕ │ keys │ values │ ╞═══════════════════╪═══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╡ │ seed │ None │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ output_dir │ './YOLOX_outputs' │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ print_interval │ 10 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ eval_interval │ 10 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ dataset │ None │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ num_classes │ 80 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ depth │ 0.33 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ width │ 0.25 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ act │ 'silu' │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ data_num_workers │ 4 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ input_size │ (416, 416) │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ multiscale_range │ 5 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ data_dir │ None │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ train_ann │ 'instances_train2017.json' │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ val_ann │ 'instances_val2017.json' │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ test_ann │ 'instances_test2017.json' │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ mosaic_prob │ 0.5 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ mixup_prob │ 1.0 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ hsv_prob │ 1.0 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ flip_prob │ 0.5 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ degrees │ 10.0 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ translate │ 0.1 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ mosaic_scale │ (0.5, 1.5) │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ enable_mixup │ False │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ mixup_scale │ (0.5, 1.5) │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ shear │ 2.0 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ warmup_epochs │ 5 │ ├───────────────────┼──────────────────��────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ max_epoch │ 300 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ warmup_lr │ 0 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ min_lr_ratio │ 0.05 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ basic_lr_per_img │ 0.00015625 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ scheduler │ 'yoloxwarmcos' │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ no_aug_epochs │ 15 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ ema │ True │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ weight_decay │ 0.0005 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ momentum │ 0.9 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ save_history_ckpt │ True │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ exp_name │ 'yolox_nano' │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ test_size │ (416, 416) │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ test_conf │ 0.01 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ nmsthre │ 0.65 │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ random_size │ (10, 20) │ ├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤ │ model │ YOLOX( │ │ │ (backbone): YOLOPAFPN( │ │ │ (backbone): CSPDarknet( │ │ │ (stem): Focus( │ │ │ (conv): BaseConv( │ │ │ (conv): Conv2d(12, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) │ │ ╘═══════════════════╧═══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╛
どうやらモデルの設定などが格納されているようです。
引き続きmain()
の中を確認しましょう。
model = exp.get_model() model.eval() ckpt = torch.load(ckpt_file, map_location="cpu") model.load_state_dict(ckpt["model"])
お、これはモデルを呼び出していそうです。メモしておきましょう。
predictor = Predictor( model, exp, COCO_CLASSES, trt_file, decoder, args.device, args.fp16, args.legacy, )
モデルと実験ファイルを利用してPredictor
と呼ばれるものを準備しています。この中身をもう少し調査しましょう。
def inference(self, img): img_info = {"id": 0} if isinstance(img, str): img_info["file_name"] = os.path.basename(img) img = cv2.imread(img) else: img_info["file_name"] = None height, width = img.shape[:2] img_info["height"] = height img_info["width"] = width img_info["raw_img"] = img ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1]) img_info["ratio"] = ratio img, _ = self.preproc(img, None, self.test_size) img = torch.from_numpy(img).unsqueeze(0) img = img.float() if self.device == "gpu": img = img.cuda() if self.fp16: img = img.half() # to FP16 with torch.no_grad(): t0 = time.time() outputs = self.model(img) if self.decoder is not None: outputs = self.decoder(outputs, dtype=outputs.type()) outputs = postprocess( outputs, self.num_classes, self.confthre, self.nmsthre, class_agnostic=True ) logger.info("Infer time: {:.4f}s".format(time.time() - t0)) return outputs, img_info
Predictor
内に推論関数を見つけました。画像の前処理はself.preproc()
で行われています。
preproc()
自体はValTransform
で定義されているのでこちらも確認しましょう。
img, _ = preproc(img, input_size, self.swap) if self.legacy: img = img[::-1, :, :].copy() img /= 255.0 img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1) img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) return img, np.zeros((1, 5))
画像をpreprocを利用して変換しているようです。(最新モデルではself.legacy
はFalseになります。)
preproc自体は
def preproc(img, input_size, swap=(2, 0, 1)): if len(img.shape) == 3: padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 else: padded_img = np.ones(input_size, dtype=np.uint8) * 114 r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) resized_img = cv2.resize( img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LINEAR, ).astype(np.uint8) padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img padded_img = padded_img.transpose(swap) padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) return padded_img, r
で定義されていますね。どうやらリサイズ後にアスペクト比を維持できるよう足りない部分を埋めているようですね。
最後に描画部分を確認します。先ほどのPredictor
内でvisual
関数なるものがありました。これがキーになっているかも・・、データサイエンティストの勘がさえわたる瞬間です。
def visual(self, output, img_info, cls_conf=0.35): ratio = img_info["ratio"] img = img_info["raw_img"] if output is None: return img output = output.cpu() bboxes = output[:, 0:4] # preprocessing: resize bboxes /= ratio cls = output[:, 6] scores = output[:, 4] * output[:, 5] vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names) return vis_res
vis
関数が実際にボックスを描画してそうです。
この関数内では検出結果の生データを扱っていますね。どうやらoutput[:, 0:4]
が検出したboundingBoxを表しており、output[:, 4]*output[:, 5]
でscoreを計算、output[:, 6]
が検出したクラスを表しているようです。
ここまでで必要な関数が全て出そろいました。次はこれらを組み合わせて手作りリアルタイム推論を行いましょう!
Step2:実装する
Step1の調査で処理の流れは以下のようになっていることが分かりました。
ではこれに合わせて実装しましょう!
import cv2 import numpy as np import torch from yolox.data.data_augment import preproc from yolox.data.datasets import COCO_CLASSES from yolox.exp import get_exp from yolox.utils import postprocess, vis exp = get_exp(None, "yolox-nano") model = exp.get_model() model.eval() ckpt = torch.load("./yolox_nano.pth", map_location="cpu") model.load_state_dict(ckpt["model"]) cap = cv2.VideoCapture(0) while True: ret, frame = cap.read() img, r = preproc(frame, (416, 416)) img = torch.from_numpy(img).unsqueeze(0) img = img.float() with torch.no_grad(): outputs = model(img) outputs = postprocess(outputs, exp.num_classes, 0.1, exp.nmsthre, class_agnostic=True ) if outputs[0] is None: cv2.imshow("test", frame) cv2.waitKey(1) continue bboxes = outputs[0][:, 0:4]/r cls = outputs[0][:, 6] scores = outputs[0][:, 4] * outputs[0][:, 5] vis_res = vis(frame, bboxes, scores, cls, 0.1, COCO_CLASSES) cv2.imshow("test", vis_res) cv2.waitKey(1)
必要なものだけを切り出してきたのでかなりすっきりしたコードになりました。では実行してみましょう。
python realtime.py
正しく検出できました。検出結果のデータは以下のコード内に格納されているので、必要に応じてこの値を参照すれば様々なソリューションに組み込めそうです。
まとめ
今回はYOLOXを用いてリアルタイム推論にチャレンジしました。
元からデモコードがある場合もありますが、細かな調整や検出結果をデータとして取得したい場合は今回のように独自にコードを組む必要があります。
そのような場合はデモコードを切り貼りすることで対応しましょう!