神戸のデータ活用塾!KDL Data Blog

KDLが誇るデータ活用のプロフェッショナル達が書き連ねるブログです。

【やってみた】YOLOv5+ByteTrackでオブジェクトトラッキング!

f:id:kdl-di:20211202151138p:plain

こんにちは、株式会社神戸デジタル・ラボ DataIntelligenceチーム(以降DIチーム)の原口です。

今回は、流行のByteTrack*1とYOLOv5*2を組み合わせてオブジェクトトラッキングをやってみました!

実際のトラッキング動画と合わせてご紹介していきます!

f:id:kdl-di:20211203112352g:plain

本稿は、以下の条件に当てはまる方を想定しています。

  • PyTorchが利用できる方
  • 新しい技術に興味がある方
  • オブジェクトトラッキングに興味がある方

オブジェクトトラッキングとは?

近年様々なところで利用される物体検出AI。

物体検出AIで検出された物体を追跡することをオブジェクトトラッキング(物体追跡)と言います。

このオブジェクトトラッキング、どういった場面で使われているのでしょうか?

例えば防犯カメラ。一般的な物体検出では、その瞬間人が何人写っているかは分かりますが、ある人物がどこからどの方向に移動しているかは認識できません。

この問題にオブジェクトトラッキングを利用すると、ある人物がどこからどの方向に移動したのか、どういう経路で移動したのかを追跡できるようになります。

またスポーツでは、テニスの球やバドミントンのシャトルを追跡することができます。

これにより、「自身の打った球がどういったコースを通って相手に渡ったか、その際にどう打ち返してきたか」などをデータとして取得することが可能になります。

オブジェクトトラッキングの流れ

多様な場面で利用できるオブジェクトトラッキング。オブジェクトトラッキングはいったいどのようにして物体を追跡しているのでしょうか?

まずは一般的なオブジェクトトラッキングの流れについて、動画を用いて物体追跡をする場合を例に説明します。

f:id:kdl-di:20211209181959p:plain

オブジェクトトラッキングでは、動画の各フレーム*3に対して二段階の処理を行います。

  • Step1:物体を検出する
  • Step2:検出した物体を追跡する

Step1:物体を検出する

Step1では、物体検出AIが動画の各フレームに存在する物体を検出します。

物体検出AIは物体を検出することができますが、各フレームで見つけた物体の関係性は理解できないため、同一物体を連続的に追跡することができません。 f:id:kdl-di:20211209182216p:plain そこで物体追跡を行うアルゴリズムを利用します!

Step2:検出した物体を追跡する。

物体追跡アルゴリズムは、以下の特徴があります

  • 新しく検出したものにID(識別番号)を割り振る

  • すでにIDが割り振られている物体を再度検出した場合は同じIDを割り振り、追跡する

例えば1フレーム目の検出では、すべての検出結果が新しく検出したものになるので、検出した順番にIDが割り振られていきます。 f:id:kdl-di:20211209184415p:plain

続いて2フレーム目。2フレーム目では物体検出AIの検出結果と1フレーム目の追跡結果を用いてIDの割り振りを行います。

物体追跡アルゴリズムは1フレーム目の検出結果と2フレーム目の検出結果の関係性を理解し、同一物体に同一のIDを割り振り、物体追跡を行います。 f:id:kdl-di:20211209182432p:plain

このような処理を動画全体に行うことで、物体追跡が可能になります!

実装

今回は物体検出AIにYOLOv5、物体追跡アルゴリズムにByteTrackを用いて実装を行います。

実験はすべてGoogle Colaboratory*4上で行っています。 実験時の環境は次のようになっています。

  • Python : 3.7.12
  • PyTorch : 1.9.0
  • Cuda*5 : 11.1

まずは環境構築

まずは物体検出AIと物体追跡アルゴリズムを準備しましょう。GitHub*6のリポジトリからByteTrackをクローンしましょう!

git clone https://github.com/ifzhang/ByteTrack

クローンが出来たら、requirements.txtに従って環境を構築します。

requirements.txtには、GitHubからクローンしてきたプログラムを動かすために必要なモジュールが書かれており、それらを一括でインストールする際に利用します。

順番にインストールしていきましょう。

pip install -qr https://raw.githubusercontent.com/ultralytics/yolov5/master/requirements.txt 
cd ./ByteTrack
pip install -r requirements.txt

インストールが出来ましたら、そのままByteTrackで必要なモジュールをインストールしましょう!

python setup.py develop
pip install cython
pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
pip install cython_bbox

Trackerの準備

ByteTrackを利用するには、実際に物体追跡を行うTrackerを呼び出す必要があります。

まずはTrackerを呼び出す準備をしましょう!

import argparse

def make_parser():
    parser = argparse.ArgumentParser("YOLOX Eval")
    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
    parser.add_argument("-n", "--name", type=str, default=None, help="model name")

    # distributed
    parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
    parser.add_argument("--dist-url",default=None,type=str,help="url used to set up distributed training",)
    parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
    parser.add_argument("-d", "--devices", default=None, type=int, help="device for training")
    parser.add_argument("--local_rank", default=0, type=int, help="local rank for dist training")
    parser.add_argument( "--num_machines", default=1, type=int, help="num of node for training")
    parser.add_argument("--machine_rank", default=0, type=int, help="node rank for multi-node training")
    parser.add_argument("-f","--exp_file",default=None,type=str,help="pls input your expriment description file",)
    
    # tracking args
    parser.add_argument("--track_thresh", type=float, default=0.6, help="tracking confidence threshold")
    parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
    parser.add_argument("--match_thresh", type=float, default=0.9, help="matching threshold for tracking")
    parser.add_argument("--min-box-area", type=float, default=100, help='filter out tiny boxes')
    parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
    return parser

前半はByteTrackで利用されていたYOLOXの設定です。今回はYOLOXの代わりにYOLOv5を利用します。

後半はTrackerの設定です。Trackerが物体を追跡する際の閾値や、追跡対象を見失った際に何フレームまで追跡情報を保持するかなどの設定を行うことができます。

ここまで準備が出来ましたら、いよいよByteTrackを呼び出しましょう!

from yolox.tracker.byte_tracker import BYTETracker
args = make_parser().parse_args()
tracker = BYTETracker(args)

これでトラッキングの準備が完了しました!

YOLOv5の準備

YOLOv5の準備はとっても簡単です!

呼び出しは(モジュールimport抜きで)たった一行!

こんなにもすごい技術をたった一行で使えるなんて、なんていい世の中なのだ・・・!

import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5x') 

実際にトラッキングしてみる!

いよいよ実際にトラッキングをしていきます!

import cv2
import sys
from yolox.utils.visualize import plot_tracking
from yolox.tracking_utils.timer import Timer
import os

save_path = "/path/to/dir"
video_path = "/path/to/video"

os.makedirs(save_path,exist_ok = True) 

# 動画ファイル読込
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
  print("何らかの理由で読み込めません")
fps = cap.get(cv2.CAP_PROP_FPS)

n= 0
timer = Timer()
results = []
frame_id = 0
while True:
    ret,frame_img = cap.read()
    if ret:
        ########YOLOv5による物体検出########
        detection = model(frame_img)
        pred = detection.pred[0][:,:5].detach().cpu()
        ##################################

        ##################ByteTrackによる物体追跡#################
        online_targets = tracker.update(pred,frame_img.shape,frame_img.shape)
        online_tlwhs = []
        online_ids = []
        online_scores = []
        for t in online_targets:
            tlwh = t.tlwh
            tid = t.track_id
            vertical = tlwh[2] / tlwh[3] > 1.6
            if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
                online_tlwhs.append(tlwh)
                online_ids.append(tid)
                online_scores.append(t.score)
        results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
        timer.toc()
        online_im = plot_tracking(frame_img, online_tlwhs, online_ids, frame_id=frame_id + 1,fps=1. / timer.average_time)
        #####################################################
        cv2.imwrite('{}/{:05}.jpg'.format(save_path,frame_id),online_im)
    else:
        break
    n += 1
    frame_id += 1
cap.release()

追跡結果の保存先と、利用する動画までのパスは上記コード内の

save_path = "/path/to/dir"
video_path = "/path/to/video"

を変更してください!

上記コードを実行すると、save_path内に追跡結果が保存されます。

動画として保存

最後に追跡結果を動画として保存します!

import glob
import os
from tqdm import tqdm

detect_dir = "/path/to/dir"
save_dir = "/path/to/dir"
img_list = sorted(glob.glob(os.path.join(detect_dir,'*')))

sample = cv2.imread(img_list[0])

img_h , img_w , _ = sample.shape

fourcc = cv2.VideoWriter_fourcc('m','p','4', 'v')
video  = cv2.VideoWriter(f'{save_dir}/ImgVideo.mp4', fourcc, fps, (img_w, img_h))

for img_file in tqdm(img_list):
    img = cv2.imread(img_file)
    video.write(img)

video.release()

detect_dirには検出結果を保存したディレクトリを、save_dirには作成した動画を保存する場所を記述してください。

確認

実際に作成した動画を確認してみましょう!

おおー、主観的にかなりいい精度!前列にいる方々は動画の最初から最後まで同じIDで追跡が行えています。

一方、後ろで動いているオレンジ色の方は、前方で動かれている方と重なってしまった際に追跡ができなくなり、たびたび追跡が途切れてしまっています。

ただ、部分的に隠れた場合はしっかり追跡できているので、完全に隠れるような場合を除くと高精度に追跡が可能なことを確認しました。

そしてこのデバイスと背景・・・・。どこかで見たような・・・?

・・・・・ハッ!これは、2021年11月14日に神戸商工会議所 神戸スポーツ産業懇話会が主催するバーチャルスポーツの可能性を探る「バーチャルスポーツHADO体験会 in KOBE」に参加した際の動画だ!!

ということで、良ければ以下の記事もご覧ください!

「バーチャルスポーツHADO体験会 in KOBE」開催、KDLの有志チームが3位に入賞しました! | Kobe Digital Labo 神戸デジタル・ラボ

まとめ

今回はYOLOv5とByteTrackを利用したオブジェクトトラッキングを実装しました。

オブジェクトトラッキング技術は、重なりがない場合はほぼ正確に追跡ができることを確認しました!

また重なりが存在する場合も、完全に重なるまでは追跡ができていることを確認しました!

これからも、新しい技術・面白い技術をどんどん実装していきたいと思います!

原口俊樹

データインテリジェンスチーム所属
データエンジニアを担当しています。画像認識を得意としており、画像認識・ニューラルネットワーク系の技術記事を発信していきます

*1:バイトトラック:オブジェクトトラッキング手法の一種

*2:ヨロバージョン5:物体検出AIの一種

*3:動画を構成する一枚一枚の画像のこと

*4:略称: Colab。ブラウザから Python を記述、実行できるサービス

*5:クーダ:グラフィック処理ユニット (GPU) 用に、NVIDIA が開発した並列コンピューティング プラットフォームおよびプログラミング モデル

*6:ギットハブ:プログラムのバージョン管理や公開を行うサイト