神戸のデータ活用塾!KDL Data Blog

KDLが誇るデータ活用のプロフェッショナル達が書き連ねるブログです。

【やってみた】YOLOXでリアルタイム推論

株式会社神戸デジタル・ラボ DataIntelligenceチームの原口です。

今回はYOLOXを用いてリアルタイム推論にチャレンジします。

YOLOXとは

https://github.com/Megvii-BaseDetection/YOLOX/raw/main/assets/logo.png

YOLOXとは2021年に発表されたYOLO系の物体検出モデルです。

YOLOv5・v8と比較するとモデルの学習準備が大変なところもありますが、ライセンスもApache-2.0 licenseと商用利用しやすい点が今でも人気の理由の1つです。

環境構築

YOLOXを使うための環境構築を実施します。まずはGitHubからYOLOXをクローンしましょう。

git clone https://github.com/Megvii-BaseDetection/YOLOX

クローンができれば、作成された「YOLOX」フォルダに移動します。

インストール実行前にリアルタイム推論に余分なライブラリをrequirements.txtからコメントアウトしましょう。(ONNX関連のライブラリはインストール時にエラーが発生することが多いです。今回はONNXを利用しないのてコメントアウトしました。pycocotoolsは環境によってこのままではインストールできない場合があるのでコメントアウトしました)

# TODO: Update with exact module version
numpy
torch>=1.7
opencv_python
loguru
tqdm
torchvision
thop
ninja
tabulate
psutil
tensorboard


# 下3つをコメントアウト
# verified versions
# pycocotools corresponds to https://github.com/ppwwyyxx/cocoapi
# pycocotools>=2.0.2
# onnx>=1.13.0
# onnx-simplifier==0.4.10

上記対応ができればインストールを実行しましょう。

python setup.py develop

Finished processing dependencies for yolox==0.3.0と表示されればインストールは完了です。

最後にCOCOデータセットのラベルを利用するためにpycocotoolsをインストールしましょう。

python setup.py develop内ではインストールが正しく実行されない場合があるため、別途インストールします。

pip install pycocotools

学習済みモデルを準備

続いて検出に利用する学習済みモデルを準備します。今回は最も軽量でCPUで高速に動作可能なYOLOX-Nanoを利用します。

学習済みモデルのダウンロードはYOLOXのGitHub、README内にあるモデル紹介の表(こちら)の右側にある「github」リンクをクリックすることでダウンロードできます。

学習済みモデル一覧

デモを確認

インストール・学習済みモデルの準備が完了したのでデモコードを確認してみましょう・・・。

webカメラの表記。専用コードを作らなくても良い???

なんとデモコード内に「webcam」の表記が。独自にコードを書かなくてもリアルタイム推論できそうです。

まずはデモコードのリアルタイム推論を試してみましょう。

python .\tools\demo.py webcam -n yolox-nano -c .\yolox_nano.pth

ここで

  • webcam:デモモードを「image・video・webcam」の3タイプから切り替えられます。
  • -n:モデル名を選択します。
  • -c:学習済みモデルのパスを指定します。

リアルタイム推論・・・。動いた・・・!

おお!1コマンドでリアルタイム推論がこんなに簡単に出来てしまう!

本記事は以上・・・、と行きたいところですがこのデモでは検出結果の座標データが得られないため、「検出した物体を切り出す・検出結果を用いて新たにアプリケーションを開発する」ということが出来ません

検出デモを動かすという意味では「demo.py」は非常に優秀ですが、ソリューションの一部で使うには少し物足りない感じがします。

よって座標データを取れるようにYOLOXのリアルタイム推論を改造しましょう!(ついでにコードの軽量化もしましょう)

コンソールでCtrl + Cでデモを止めてから次の章に移りましょう!

リアルタイム推論の実装

Step1:demo.pyを眺める

demo.pyでリアルタイム推論が実行されていたので、このコードの中から必要なものだけを切り出せば最小限のリアルタイム推論コードを構築出来そうです。

そのためにもまずはdemo.pyを眺めましょう。

if __name__ == "__main__":
    args = make_parser().parse_args()
    exp = get_exp(args.exp_file, args.name)

    main(exp, args)

この部分はデモコードのが実行された際に最初に実行される部分です。

make_parser().parse_args()は引数を管理しています。

get_expは何でしょうか?少し中身を確認してみましょう。

print(exp)
╒═══════════════════╤═══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╕
│ keys              │ values                                                                                                                │
╞═══════════════════╪═══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╡
│ seed              │ None                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ output_dir        │ './YOLOX_outputs'                                                                                                     │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ print_interval    │ 10                                                                                                                    │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ eval_interval     │ 10                                                                                                                    │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ dataset           │ None                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ num_classes       │ 80                                                                                                                    │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ depth             │ 0.33                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ width             │ 0.25                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ act               │ 'silu'                                                                                                                │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ data_num_workers  │ 4                                                                                                                     │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ input_size        │ (416, 416)                                                                                                            │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ multiscale_range  │ 5                                                                                                                     │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ data_dir          │ None                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ train_ann         │ 'instances_train2017.json'                                                                                            │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ val_ann           │ 'instances_val2017.json'                                                                                              │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ test_ann          │ 'instances_test2017.json'                                                                                             │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ mosaic_prob       │ 0.5                                                                                                                   │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ mixup_prob        │ 1.0                                                                                                                   │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ hsv_prob          │ 1.0                                                                                                                   │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ flip_prob         │ 0.5                                                                                                                   │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ degrees           │ 10.0                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ translate         │ 0.1                                                                                                                   │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ mosaic_scale      │ (0.5, 1.5)                                                                                                            │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ enable_mixup      │ False                                                                                                                 │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ mixup_scale       │ (0.5, 1.5)                                                                                                            │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ shear             │ 2.0                                                                                                                   │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ warmup_epochs     │ 5                                                                                                                     │
├───────────────────┼──────────────────��────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ max_epoch         │ 300                                                                                                                   │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ warmup_lr         │ 0                                                                                                                     │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ min_lr_ratio      │ 0.05                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ basic_lr_per_img  │ 0.00015625                                                                                                            │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ scheduler         │ 'yoloxwarmcos'                                                                                                        │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ no_aug_epochs     │ 15                                                                                                                    │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ ema               │ True                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ weight_decay      │ 0.0005                                                                                                                │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ momentum          │ 0.9                                                                                                                   │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ save_history_ckpt │ True                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ exp_name          │ 'yolox_nano'                                                                                                          │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ test_size         │ (416, 416)                                                                                                            │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ test_conf         │ 0.01                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ nmsthre           │ 0.65                                                                                                                  │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ random_size       │ (10, 20)                                                                                                              │
├───────────────────┼───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ model             │ YOLOX(                                                                                                                │
│                   │   (backbone): YOLOPAFPN(                                                                                              │
│                   │     (backbone): CSPDarknet(                                                                                           │
│                   │       (stem): Focus(                                                                                                  │
│                   │         (conv): BaseConv(                                                                                             │
│                   │           (conv): Conv2d(12, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)                       │                                                                                                            │
╘═══════════════════╧═══════════════════════════════════════════════════════════════════════════════════════════════════════════════════════╛

どうやらモデルの設定などが格納されているようです。

引き続きmain()の中を確認しましょう。

model = exp.get_model()
model.eval()
ckpt = torch.load(ckpt_file, map_location="cpu")
model.load_state_dict(ckpt["model"])

お、これはモデルを呼び出していそうです。メモしておきましょう。

predictor = Predictor(
        model, exp, COCO_CLASSES, trt_file, decoder,
        args.device, args.fp16, args.legacy,
    )

モデルと実験ファイルを利用してPredictorと呼ばれるものを準備しています。この中身をもう少し調査しましょう。

def inference(self, img):
    img_info = {"id": 0}
    if isinstance(img, str):
        img_info["file_name"] = os.path.basename(img)
        img = cv2.imread(img)
    else:
        img_info["file_name"] = None

    height, width = img.shape[:2]
    img_info["height"] = height
    img_info["width"] = width
    img_info["raw_img"] = img

    ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
    img_info["ratio"] = ratio

    img, _ = self.preproc(img, None, self.test_size)
    img = torch.from_numpy(img).unsqueeze(0)
    img = img.float()
    if self.device == "gpu":
        img = img.cuda()
        if self.fp16:
            img = img.half()  # to FP16

    with torch.no_grad():
        t0 = time.time()
        outputs = self.model(img)
        if self.decoder is not None:
            outputs = self.decoder(outputs, dtype=outputs.type())
        outputs = postprocess(
            outputs, self.num_classes, self.confthre,
            self.nmsthre, class_agnostic=True
        )
        logger.info("Infer time: {:.4f}s".format(time.time() - t0))
    return outputs, img_info

Predictor内に推論関数を見つけました。画像の前処理はself.preproc()で行われています。

preproc()自体はValTransformで定義されているのでこちらも確認しましょう。

img, _ = preproc(img, input_size, self.swap)
if self.legacy:
    img = img[::-1, :, :].copy()
    img /= 255.0
    img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
    img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
return img, np.zeros((1, 5))

画像をpreprocを利用して変換しているようです。(最新モデルではself.legacyはFalseになります。)

preproc自体は

def preproc(img, input_size, swap=(2, 0, 1)):
    if len(img.shape) == 3:
        padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
    else:
        padded_img = np.ones(input_size, dtype=np.uint8) * 114

    r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
    resized_img = cv2.resize(
        img,
        (int(img.shape[1] * r), int(img.shape[0] * r)),
        interpolation=cv2.INTER_LINEAR,
    ).astype(np.uint8)
    padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img

    padded_img = padded_img.transpose(swap)
    padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
    return padded_img, r

で定義されていますね。どうやらリサイズ後にアスペクト比を維持できるよう足りない部分を埋めているようですね。

最後に描画部分を確認します。先ほどのPredictor内でvisual関数なるものがありました。これがキーになっているかも・・、データサイエンティストの勘がさえわたる瞬間です。

def visual(self, output, img_info, cls_conf=0.35):
    ratio = img_info["ratio"]
    img = img_info["raw_img"]
    if output is None:
        return img
    output = output.cpu()

    bboxes = output[:, 0:4]

    # preprocessing: resize
    bboxes /= ratio

    cls = output[:, 6]
    scores = output[:, 4] * output[:, 5]

    vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
    return vis_res

vis関数が実際にボックスを描画してそうです。

この関数内では検出結果の生データを扱っていますね。どうやらoutput[:, 0:4]が検出したboundingBoxを表しており、output[:, 4]*output[:, 5]でscoreを計算、output[:, 6]が検出したクラスを表しているようです。

ここまでで必要な関数が全て出そろいました。次はこれらを組み合わせて手作りリアルタイム推論を行いましょう!

Step2:実装する

Step1の調査で処理の流れは以下のようになっていることが分かりました。

処理の流れ

ではこれに合わせて実装しましょう!

import cv2
import numpy as np
import torch

from yolox.data.data_augment import preproc
from yolox.data.datasets import COCO_CLASSES
from yolox.exp import get_exp
from yolox.utils import postprocess, vis

exp = get_exp(None, "yolox-nano")
model = exp.get_model()
model.eval()
ckpt = torch.load("./yolox_nano.pth", map_location="cpu")
model.load_state_dict(ckpt["model"])

cap = cv2.VideoCapture(0)
while True:
    ret, frame = cap.read()
    img, r = preproc(frame, (416, 416))
    img = torch.from_numpy(img).unsqueeze(0)
    img = img.float()
    with torch.no_grad():
        outputs = model(img)
        outputs = postprocess(outputs,
                              exp.num_classes,
                              0.1,
                              exp.nmsthre,
                              class_agnostic=True
                              )
        if outputs[0] is None:
            cv2.imshow("test", frame)
            cv2.waitKey(1)
            continue
        bboxes = outputs[0][:, 0:4]/r
        cls = outputs[0][:, 6]
        scores = outputs[0][:, 4] * outputs[0][:, 5]
        vis_res = vis(frame, bboxes, scores, cls, 0.1, COCO_CLASSES)
        cv2.imshow("test", vis_res)
        cv2.waitKey(1)

必要なものだけを切り出してきたのでかなりすっきりしたコードになりました。では実行してみましょう。

python realtime.py

独自コードによる検出結果

正しく検出できました。検出結果のデータは以下のコード内に格納されているので、必要に応じてこの値を参照すれば様々なソリューションに組み込めそうです。

検出結果のデータ

まとめ

今回はYOLOXを用いてリアルタイム推論にチャレンジしました。

元からデモコードがある場合もありますが、細かな調整や検出結果をデータとして取得したい場合は今回のように独自にコードを組む必要があります。

そのような場合はデモコードを切り貼りすることで対応しましょう!