FlaskとONNXモデルを活用したリアルタイム物体検出Webアプリケーションの開発

本記事では、ディープラーニングモデル(例: RT-DETR)とPythonのWebフレームワークであるFlaskを組み合わせて、簡易的な物体検出Webアプリケーションを構築する方法を解説します。ユーザーが画像をアップロードし、検出の信頼度を指定すると、アプリケーションがONNX形式のモデルを使用して物体を検出し、その結果を表示します。

物体検出モデルのONNX形式への変換と推論

モデルを本番環境にデプロイする際、ONNX(Open Neural Network Exchange)形式は、モデルのフレームワーク非依存性と推論効率の向上に貢献します。PyTorchなどの学習フレームワークで作成されたモデルをONNX形式に変換し、onnxruntimeを用いて推論を実行する手順を説明します。

PyTorchモデルのONNXエクスポート

通常、PyTorchで学習されたモデルは.pthファイルとして保存されます。これをONNX形式に変換することで、異なるプラットフォームやデバイスでの高速な推論が可能になります。以下は、概念的な物体検出モデルをONNX形式にエクスポートする例です。実際のRT-DETRモデルも同様の手順で変換できます。

import torch
import torch.nn as nn
import onnx

class SimpleDetectionModel(nn.Module):
    def __init__(self, num_classes=80): # 例: COCOデータセットのクラス数
        super().__init__()
        # 簡略化されたバックボーンとヘッド
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # ダミーの検出ヘッド出力 (ONNXエクスポートの構造を示すため)
        # 実際のモデルでは、推論結果から直接これらの形式に変換されます
        self.dummy_conv_reg = nn.Conv2d(64, 4, kernel_size=1) # バウンディングボックス座標
        self.dummy_conv_cls = nn.Conv2d(64, num_classes, kernel_size=1) # クラススコア

    def forward(self, input_images, original_image_sizes):
        # input_images: (N, C, H, W)
        # original_image_sizes: (N, 2) 例: [[height, width]]
        features = self.feature_extractor(input_images)
        
        # ONNXエクスポートのためのダミー出力
        # 実際の検出モデルでは、これらの値はモデルの推論ロジックから生成されます
        batch_size = input_images.shape[0]
        num_detections = 100 # 例として1画像あたり最大100個の検出
        
        dummy_detection_boxes = torch.rand(batch_size, num_detections, 4) # [x1, y1, x2, y2]
        dummy_detection_scores = torch.rand(batch_size, num_detections)
        dummy_detection_labels = torch.randint(0, self.dummy_conv_cls.out_channels, (batch_size, num_detections))

        return dummy_detection_labels, dummy_detection_boxes, dummy_detection_scores

def export_pytorch_to_onnx(model_instance, output_path="detection_model.onnx", img_input_size=(640, 640)):
    dummy_input_images = torch.randn(1, 3, *img_input_size)
    dummy_input_sizes = torch.tensor([[img_input_size[0], img_input_size[1]]], dtype=torch.float32)

    torch.onnx.export(
        model_instance,
        (dummy_input_images, dummy_input_sizes),
        output_path,
        input_names=['input_images', 'original_image_sizes'],
        output_names=['detection_labels', 'detection_boxes', 'detection_scores'],
        dynamic_axes={
            'input_images': {0: 'batch_size'},
            'original_image_sizes': {0: 'batch_size'},
            'detection_labels': {0: 'batch_size'},
            'detection_boxes': {0: 'batch_size'},
            'detection_scores': {0: 'batch_size'}
        },
        opset_version=16,
        verbose=False
    )
    print(f"モデルをONNX形式で '{output_path}' にエクスポートしました。")

    # オプション: ONNXモデルの整合性チェック
    onnx_model = onnx.load(output_path)
    onnx.checker.check_model(onnx_model)
    print('ONNXモデルのチェックが成功しました。')

# (注: 上記コードは、実際のモデルの構造に合わせて調整する必要があります。)

onnxruntimeを使用した前向き推論

ONNXモデルの推論には、Microsoftが提供するonnxruntimeが推奨されます。これは、高性能なクロスプラットフォーム推論エンジンです。

import onnxruntime as ort
from PIL import Image, ImageDraw, ImageFont
from torchvision.transforms import ToTensor
import numpy as np
import time
import os

# 推論処理をカプセル化する関数
def execute_object_detection(image_path: str, onnx_model_path: str, conf_threshold: float = 0.6):
    target_img_dim = (640, 640) # モデルが期待する入力画像サイズ
    detection_classes = ['car', 'truck', 'bus', 'person', 'bicycle', 'motorbike', 'traffic light', 'stop sign', 'dog', 'cat'] # 例示用のクラスリスト

    # 画像の前処理
    original_img = Image.open(image_path).convert('RGB')
    resized_img = original_img.resize(target_img_dim)
    input_tensor = ToTensor()(resized_img).unsqueeze(0).numpy() # バッチ次元を追加

    # 元の画像サイズ (モデルによっては必要)
    image_size_tensor = np.array([[target_img_dim[0], target_img_dim[1]]], dtype=np.float32)

    # ONNX Runtimeセッションの初期化
    session_options = ort.SessionOptions()
    try:
        if ort.get_device() == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers():
            print("GPUでONNX推論を実行します。")
            inference_session = ort.InferenceSession(onnx_model_path, session_options, providers=['CUDAExecutionProvider'])
        else:
            print("CPUでONNX推論を実行します。")
            inference_session = ort.InferenceSession(onnx_model_path, session_options, providers=['CPUExecutionProvider'])
    except Exception as e:
        print(f"ONNX Runtimeセッションの初期化中にエラーが発生しました: {e}")
        print("CPUExecutionProviderで再試行します。")
        inference_session = ort.InferenceSession(onnx_model_path, session_options, providers=['CPUExecutionProvider'])


    input_feed_dict = {
        inference_session.get_inputs()[0].name: input_tensor,
        inference_session.get_inputs()[1].name: image_size_tensor
    }
    output_names = [out.name for out in inference_session.get_outputs()]

    # 推論実行
    start_time = time.time()
    outputs = inference_session.run(output_names=output_names, input_feed=input_feed_dict)
    end_time = time.time()
    inference_duration_ms = (end_time - start_time) * 1000
    print(f"推論時間: {inference_duration_ms:.2f} ms")

    # 推論結果の解析 (出力順序: detection_labels, detection_boxes, detection_scores)
    labels, boxes, scores = outputs

    result_img = original_img.copy()
    draw_context = ImageDraw.Draw(result_img)

    # 検出ボックスを元の画像サイズにスケールするための係数
    original_width, original_height = original_img.size
    scale_x = original_width / target_img_dim[0]
    scale_y = original_height / target_img_dim[1]

    detected_object_count = 0
    # バッチ処理を想定 (通常はバッチサイズ1)
    for i in range(len(labels)):
        current_scores = scores[i]
        current_labels = labels[i]
        current_boxes = boxes[i]

        # 信頼度閾値でフィルタリング
        high_conf_indices = np.where(current_scores > conf_threshold)[0]
        filtered_labels = current_labels[high_conf_indices]
        filtered_boxes = current_boxes[high_conf_indices]
        filtered_scores = current_scores[high_conf_indices]

        detected_object_count += len(filtered_labels)

        for label_idx, box_coords, score_val in zip(filtered_labels, filtered_boxes, filtered_scores):
            x1, y1, x2, y2 = box_coords
            # バウンディングボックス座標を元の画像サイズに変換
            scaled_x1, scaled_y1, scaled_x2, scaled_y2 = x1 * scale_x, y1 * scale_y, x2 * scale_x, y2 * scale_y

            class_name = detection_classes[label_idx] if label_idx < len(detection_classes) else f"Unknown({label_idx})"
            display_text = f"{class_name}: {score_val:.2f}"

            # 検出ボックスの描画
            draw_context.rectangle([(scaled_x1, scaled_y1), (scaled_x2, scaled_y2)], outline='red', width=2)
            
            # テキストの描画 (背景付きで視認性向上)
            try:
                # PIL < 9.0 (draw.textsizeは非推奨)
                text_width, text_height = draw_context.textsize(display_text, font=ImageFont.load_default())
            except AttributeError:
                # PIL >= 9.0
                text_bbox = draw_context.textbbox((0,0), display_text, font=ImageFont.load_default())
                text_width = text_bbox[2] - text_bbox[0]
                text_height = text_bbox[3] - text_bbox[1]

            text_bg_coords = [scaled_x1, scaled_y1 - text_height - 5, scaled_x1 + text_width + 5, scaled_y1]
            draw_context.rectangle(text_bg_coords, fill='red')
            draw_context.text((scaled_x1 + 2, scaled_y1 - text_height - 3), display_text, fill='white', font=ImageFont.load_default())

    fps = 1000 / inference_duration_ms if inference_duration_ms > 0 else 0
    return result_img, detected_object_count, fps

# (このスクリプトは 'detection_module.py' として保存され、Flaskアプリから呼び出されます)

FlaskによるWebサービスの構築

次に、Flaskを使用して、画像アップロード、推論実行、結果表示を行うWebサーバーを構築します。

# app.py
from flask import Flask, request, render_template, redirect, url_for
import os
import time
from werkzeug.utils import secure_filename
from PIL import Image

# 検出モジュールをインポート (上記のexecute_object_detection関数を含むファイル)
import detection_module

app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['RESULT_FOLDER'] = 'static/results'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 最大ファイルサイズ16MB

# アップロード・結果保存用ディレクトリの作成
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)

ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/', methods=['GET'])
def home_page():
    return render_template('upload_interface.html')

@app.route('/detect', methods=['POST'])
def handle_detection_request():
    if 'image_file' not in request.files:
        return redirect(url_for('home_page')) # ファイルが選択されていない場合はリダイレクト

    uploaded_file = request.files['image_file']
    if uploaded_file.filename == '':
        return redirect(url_for('home_page')) # ファイル名がない場合はリダイレクト

    if uploaded_file and allowed_file(uploaded_file.filename):
        # 安全なファイル名を生成し、アップロードパスを決定
        filename = secure_filename(uploaded_file.filename)
        timestamp_filename = f"{int(time.time())}_{filename}"
        upload_path = os.path.join(app.config['UPLOAD_FOLDER'], timestamp_filename)
        uploaded_file.save(upload_path)

        # 信頼度閾値を取得
        confidence_str = request.form.get("confidence_threshold", "0.6")
        confidence_value = float(confidence_str)

        # ONNXモデルパス (app.pyと同じディレクトリにあると仮定)
        onnx_model_file = "detection_model.onnx"
        if not os.path.exists(onnx_model_file):
            print("ONNXモデルが見つかりません。デモンストレーション用にダミーモデルを作成します。")
            dummy_model = detection_module.SimpleDetectionModel()
            detection_module.export_pytorch_to_onnx(dummy_model, onnx_model_file)

        # 物体検出の実行
        processed_image_pil, detected_objects_count, detection_fps = detection_module.execute_object_detection(
            upload_path, onnx_model_file, confidence_value
        )

        # 結果画像を保存
        result_image_filename = f"detected_{timestamp_filename}"
        result_image_save_path = os.path.join(app.config['RESULT_FOLDER'], result_image_filename)
        processed_image_pil.save(result_image_save_path)

        # 結果表示ページへリダイレクト
        return render_template(
            "detection_result.html",
            result_image_url=url_for('static', filename=f'results/{result_image_filename}'),
            object_count=detected_objects_count,
            inference_speed=f"{detection_fps:.2f}"
        )
    return redirect(url_for('home_page')) # 失敗時のフォールバック

if __name__ == '__main__':
    app.run(debug=True, port=5000)

フロントエンドの実装

Webアプリケーションは、画像をアップロードするフォームと、検出結果を表示するページの2つのHTMLテンプレートで構成されます。

画像アップロードページ (upload_interface.html)

<!DOCTYPE html>
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>画像物体検出サービス</title>
    <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet" crossorigin="anonymous">
    <style>
        body { background-color: #f8f9fa; }
        .detection-form-container {
            max-width: 500px;
            margin: 50px auto;
            padding: 30px;
            border-radius: 8px;
            box-shadow: 0 4px 12px rgba(0,0,0,0.1);
            background-color: #ffffff;
        }
        .form-title {
            margin-bottom: 30px;
            color: #343a40;
            text-align: center;
        }
        .form-group label {
            font-weight: bold;
            margin-bottom: 8px;
        }
        .btn-submit {
            width: 100%;
            padding: 10px;
            font-size: 1.1em;
            margin-top: 20px;
        }
    </style>
</head>
<body>
    <div class="container">
        <div class="detection-form-container">
            <h2 class="form-title">画像物体検出アプリケーション</h2>
            <form action="/detect" method="POST" enctype="multipart/form-data">
                <div class="mb-3">
                    <label for="imageFileInput" class="form-label">画像をアップロードしてください</label>
                    <input class="form-control" type="file" id="imageFileInput" name="image_file" accept="image/*" required>
                </div>
                <div class="mb-3">
                    <label for="confidenceThreshold" class="form-label">信頼度閾値を選択してください</label>
                    <select class="form-select" id="confidenceThreshold" name="confidence_threshold">
                        <option value="0.5">0.5</option>
                        <option value="0.6" selected>0.6 (推奨)</option>
                        <option value="0.7">0.7</option>
                        <option value="0.8">0.8</option>
                        <option value="0.9">0.9</option>
                    </select>
                </div>
                <button type="submit" class="btn btn-primary btn-submit">検出を開始</button>
            </form>
        </div>
    </div>
    <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js" crossorigin="anonymous"></script>
</body>
</html>

検出結果表示ページ (detection_result.html)

<!DOCTYPE html>
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>検出結果</title>
    <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet" crossorigin="anonymous">
    <style>
        body { background-color: #f8f9fa; }
        .result-card {
            max-width: 900px;
            margin: 50px auto;
            padding: 30px;
            border-radius: 8px;
            box-shadow: 0 4px 12px rgba(0,0,0,0.1);
            background-color: #ffffff;
        }
        .result-title {
            margin-bottom: 30px;
            color: #343a40;
            text-align: center;
        }
        .detected-image {
            max-width: 100%;
            height: auto;
            border-radius: 5px;
            box-shadow: 0 2px 8px rgba(0,0,0,0.05);
        }
        .info-item {
            margin-bottom: 10px;
            font-size: 1.1em;
        }
        .info-item strong {
            color: #007bff;
        }
        .btn-continue {
            margin-top: 20px;
            padding: 10px 25px;
            font-size: 1.1em;
        }
    </style>
</head>
<body>
    <div class="container">
        <div class="result-card">
            <h2 class="result-title">物体検出結果</h2>
            <div class="row align-items-center">
                <div class="col-md-7">
                    <img src="{{ result_image_url }}" class="detected-image img-fluid" alt="検出結果画像">
                </div>
                <div class="col-md-5">
                    <div class="card-body">
                        <div class="info-item"><strong>検出された物体数:</strong> {{ object_count }}</div>
                        <div class="info-item"><strong>推論速度:</strong> {{ inference_speed }} フレーム/秒</div>
                        <a href="/" class="btn btn-primary btn-continue">別の画像を検出</a>
                    </div>
                </div>
            </div>
        </div>
    </div>
    <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js" crossorigin="anonymous"></script>
</body>
</html>

推論環境の最適化とトラブルシューティング

onnxruntimeを使用する際には、実行環境の選択がパフォーマンスに大きく影響します。

ONNX Runtimeプロバイダーの指定

onnxruntime.InferenceSessionを初期化する際、利用可能な実行プロバイダー(例: CUDAExecutionProvider for GPU, CPUExecutionProvider for CPU)を明示的に指定することが重要です。これにより、適切なハードウェアアクセラレーションを活用できます。

# 例: detection_module.py 内
import onnxruntime as ort

# ...

try:
    if ort.get_device() == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers():
        print("GPUでONNX推論を実行します。")
        inference_session = ort.InferenceSession(onnx_model_path, options=session_options, providers=['CUDAExecutionProvider'])
    else:
        print("CPUでONNX推論を実行します。")
        inference_session = ort.InferenceSession(onnx_model_path, options=session_options, providers=['CPUExecutionProvider'])
except Exception as e:
    print(f"ONNX Runtimeセッションの初期化中にエラーが発生しました: {e}")
    print("CPUExecutionProviderで再試行します。")
    inference_session = ort.InferenceSession(onnx_model_path, options=session_options, providers=['CPUExecutionProvider'])

一般的なエラーとパフォーマンスの問題

  • ImportError: cannot import name 'create_and_register_allocator_v2': これは主にonnxruntime-gpuのバージョンとCUDA、CuDNNのバージョン不一致が原因で発生します。onnxruntimeの公式ドキュメントで、お使いのCUDA/CuDNNバージョンに対応するonnxruntime-gpuのバージョンを確認し、インストールし直す必要があります。 ONNX Runtime CUDA Execution Provider ドキュメント
  • GPU使用時のパフォーマンス低下: 稀に、onnxruntime-gpuを使用した場合にCPUよりも大幅にパフォーマンスが低下するケースがあります。これは、GPUドライバー、CUDA/CuDNNのインストール、またはモデルの最適化が不十分な場合に発生する可能性があります。特に、非常に小さなモデルやバッチサイズが1の場合、GPUへのデータ転送オーバーヘッドが純粋な計算時間よりも大きくなることがあります。GPUが正しく設定されているか、また最新のドライバーがインストールされているかを確認してください。

タグ: ONNX flask Python 物体検出 DeepLearning

6月20日 20:30 投稿