神戸のデータ活用塾!KDL Data Blog

KDLが誇るデータ活用のプロフェッショナル達が書き連ねるブログです。

SAM(Segment Anything Model)の解説と実践(後編)

データインテリジェンスチームの黒臺(くろだい)です。 様々なものを対象としてセグメンテーションができるSegment Anything Model(以降SAMと表記)を本記事ではGoogle Colaboratoryで実装をします。SAMモデルの概要について、詳しい説明はこちらの前編記事を参照して下さい。
kdl-di.hatenablog.com


目次

前編のまとめ


前編の記事では、SAMのアルゴリズムはViT、CLIP、アテンション機構をベースとした応用がされていることを紹介しました。さらに、SAMモデルの学習用データに動物や風景などの画像が使用されていることも紹介しました。試しに町や建物などの画像を読み込ませると、下記の結果になります。

¹左図:元画像,右図:セグメンテーション後

弊社の公式マスコットキャラクターの「デジごん」。学習用データにおそらく含まれていないであろうデジごんをセグメンテーションすると、どうでしょうか。結果は下記のようになりました。
左図:元画像,右図:セグメンテーション後

セグメンテーションができていますね。
本記事では、3種類のセグメンテーション方法についてPythonを使って実装をしながら紹介をします。まず最初に、自動でセグメンテーションを行う方法について説明します。次に、プロンプトを使ってセグメンテーションを行う方法を2種類、紹介します。実装では、Google Colaboratoryを利用します。

実践編①自動でセグメンテーションを行う場合


今回は著者が撮影した画像を使います。1つの画像内に高周波成分を多く含んでいる、曲線・直線で構成された複雑な対象物が映っていることから、試験用の画像として選択しました。〈なお、1200×832ピクセルのサイズで試しました。〉

使用画像

それでは、画像内の様々なものをセグメンテーションできるか実践してみましょう。
Google Colaboratory上で、画像データを読み込みします。

import cv2

# 元の画像を読み込む
original_image = cv2.imread(r'/content/kobe.jpg')
 


次はSAMを利用できるようにします。

#SAMと依存関係のツールをダウンロードする
!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'

!pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision

#SAMの重みづけデータをダウンロードする
!mkdir -p {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights

import os
CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))
import torch
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
 
#SAMモデルで利用したいモデルをインポートする
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
 
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)


以上でSAMモデルをGoogle Colaboratory上で利用できるようになりました。次に画像を指定します。

import cv2
import supervision as sv

image_bgr = cv2.imread(IMAGE_NAME)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)

sam_result = mask_generator.generate(image_rgb)


マスク画像の結果をsam_resultの変数に格納しました。次は画像の出力方法を指定し、結果を確認します。うまくできるでしょうか?

mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
 
detections = sv.Detections.from_sam(sam_result=sam_result)
 
annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)
 
sv.plot_images_grid(
    images=[image_bgr, annotated_image],
    grid_size=(1, 2),
    titles=['source image', 'segmented image']
)


今回利用した画像の出力結果は以下の通りになりました。左図は入力画像、右図は出力結果です。

セグメンテーション前・後の画像

右図の出力結果を見たところ、ほとんどはうまくセグメンテーションができていましたが、左側の画面奥側にある建物だけはセグメンテーションすることができませんでした。
セグメンテーション前・後の画像 白破線内マスク画像無し


上図の白点線で囲った部分は、セグメンテーションができていないエリアを指しています。これを解消するためには、特定の物体に注力してセグメンテーションを実行する必要があります。プロンプトを使う方法が有効です。SAMモデルのプロンプトを使ったセグメンテーションは2種類あります。「ポイントプロンプト」と「ボックスプロンプト」です。ポイントプロンプトの方が使いやすいため、ポイントプロンプトを中心に紹介します。

実践編②ポイントプロンプトを使いセグメンテーションを行う場合

ポイントプロンプトは、画像内の特定の位置を示し、SAMにセグメンテーションを指示する方法です。画像上に1点もしくは複数の点を指定すると、SAMはそのポイントを基に対象物のセグメンテーションマスク画像を生成します。正確な境界が必要な場合に有効です。手動でポイントプロンプト用の座標位置を決めても良いですが、今回はポイントプロプト用に中心点を自動で計算してみましょう。画像の対象物の中心点を計算する方法は多数考えられます。今回は画像の輪郭を抽出し、ボロノイ図を作成して中心点を計算しました。ボロノイ図とは、複数の母点に対して、空間を領域に分ける方法です。それぞれの領域内の点は、その領域内の母点に最も近い点になります。

#計算用のツールをダウンロードする
from scipy.spatial import Voronoi, voronoi_plot_2d
import matplotlib.pyplot as plt

# グレースケールに変換
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

# エッジ検出
edges = cv2.Canny(gray_image, 100, 200)

# 輪郭を検出
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

# 中心点をリストに格納
points = []
for contour in contours:
    M = cv2.moments(contour)
    if M['m00'] != 0:
        cx = int(M['m10'] / M['m00'])
        cy = int(M['m01'] / M['m00'])
        points.append([cx, cy])
points = np.array(points)

# ボロノイ図を生成
vor = Voronoi(points)

# ボロノイ領域の中心に点を打つ
for region in vor.regions:
    if not -1 in region and len(region) > 0:
        polygon = [vor.vertices[i] for i in region]
        polygon = np.array(polygon)
        centroid = np.mean(polygon, axis=0)
        cv2.circle(image, (int(centroid[0]), int(centroid[1])), 5, (0, 255, 0), -1)

# 結果を表示
cv2_imshow(image)
cv2.waitKey(0)
cv2.destroyAllWindows()


中心点を計算した結果
画像内の緑色の点が、ボロノイ領域の中心点です。対象画像の枚数が多い場合は、OpenCVを使うことで自動で中心点を計算でき効率的です。中心点の結果を、ポイントプロンプトとして入力する方法は下記の通りです。

# 中心点の結果
centerpoints = [] #中心点の計算結果をここに格納する
# 中心点をnumpyの配列に変換する
centerpoints = np.array(centerpoints)

# SAMモデルの初期化
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

# SAMに画像をセット
predictor.set_image(image_rgb)

# 各中心点をSAMモデルに渡す
for i, centerpoints in enumerate(centerpoints[:3]):  # 最初の3点のみ処理
    input_point = np.array([centroid])
    input_label = np.array([1])  # 中心点が指している対象物をセグメンテーションするように指定
    
    # セグメンテーションマスク画像を生成
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
    )
   # 最高スコアのマスクを表示
    best_mask_idx = np.argmax(scores)

    plt.figure(figsize=(10, 8))
    plt.imshow(image_rgb)
    show_mask(masks[best_mask_idx], plt.gca())
    plt.axis('on')
    plt.show()


ここで作成したマスク画像を確認しましょう。

マスク画像の結果画像

画像内の緑星マークは、プロンプトで使った中心点の位置を表します。
ポイントプロンプト結果画像

ポイントプロンプトを使うことで、実践編①の方法でセグメンテーションができていなかったエリアに、セグメンテーションをすることが可能になります。

実践編③ボックスプロンプトを使いセグメンテーションを行う場合

ボックスプロンプトは補足としてご紹介します。ボックスプロンプトは、対象物を囲む矩形を[ x1, y1, x2, y2]の形式でSAMにセグメンテーションを指示する方法です。このプロンプトは、近接しているオブジェクトや部分的に隠れているオブジェクトを対象とするときに有用です。今回の画像では、ビル・街灯・屋根などをそれぞれ長方形のボックスで位置を指定するイメージです。例えばYoloモデルなどの物体検出タスクと、ボックスプロンプトを併用することで、より精度が高くセグメンテーションを実行できます。ボックスプロンプトをモデルに指定する方法は下記の通りです。

# ボックスの座標を定義
box = [x1, y1, x2, y2]

# モデルにボックスを入力
segmentation_result = model.predict(image, box=box)

# 結果を表示
display(segmentation_result)


ボックスプロンプトと画像処理を組み合わせて画像内の人工物のみセグメンテーションを行う、人のみセグメンテーションを行う、といった応用方法が考えられます。

まとめ

今回紹介をしたSAMモデルは、一般的な画像を対象としてセグメンテーションができます。本記事では、Google Colaboratoryを使い画像のSAMを利用したセグメンテーションをしました。さらにプロンプトを使いセグメンテーションを実装しました。プロンプトを使うセグメンテーションでは、特定の物体に注力してセグメンテーションを実行したい時に有用です。背景を効率よく切り取る場合にも有用です。他には、セグメンテーション用の学習データを作成する場合など、他の画像認識モデルを作成する準備でも応用可能です。
2024年7月に、Segment Anything Model 2(SAM2)が発表され、さらに動画内の物体検出をリアルタイムにできるようになりました。機会があれば別の記事でご紹介できればと思います。

最後に

画像以外のデータ活用についても、課題の発見から解決方法の提案まで幅広くご相談を承っております。ご興味のある方はぜひお問い合わせください。

参考資料

1.segment-anything Githubリポジトリ https://github.com/facebookresearch/segment-anything/tree/main/demo
2.how-to-use-segment-anything-model-sam https://blog.roboflow.com/how-to-use-segment-anything-model-sam/

出典

¹SIPI Image Database - Misc https://sipi.usc.edu/database/database.php?volume=misc/


profimg

黒臺万悠

データインテリジェンスチーム所属
元医療従事者、転職後組み込み系の開発を経験し、大学院卒業後KDLへ入社しました。データ分析のトピック全般をブログで紹介予定です。