PyTorchで学習したVGG16をLibTorchで推論する方法

VGG16LibTorchで推論

PyTorchで学習したVGG16モデルをLibTorchで読み込み、推論を行う方法についてまとめました。
今回はtorchvisionに実装されているVGG16モデルをLibTorchで読み込める形式に変換して推論を実行し、PyTorchで推論した結果と一致するか確認してみます。

1. PyTorchのVGG16モデル

まずはPyTorchのVGG16を呼び出して推論を実行してみます。
今回はこちらのゴールデンレトリバーの画像を分類してみます。
※こちらの画像を使用させていただいております。

なお出力はクラスIDのみなので、クラス名とIDが記載された「imagenet_class_index.json」を読み込み結果を確認します。

imagenet_class_index.jsonはこちらからダウンロードできます。

1.1. 前処理用クラス実装

VGG16はImageNetで学習されているので、学習した画像の形式に合わせるために前処理を行うためのクラスを実装します。

from torchvision import transforms

class ClassTransform():
    def __init__(self):
        resize = 224
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

        self.base_transform = transforms.Compose([
            transforms.Resize(resize),
            transforms.CenterCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        # ここで[1 3 224 224]形式に変換して返す
        return self.base_transform(img).unsqueeze_(0)

1.2. 分類結果出力用クラス実装

次に出力結果とJSONファイルを照合して分類結果を出力するクラスを実装します。

import json
import numpy as np

class ClassPredictor():
    def __init__(self):
        self.class_index = json.load(open('imagenet_class_index.json','r'))

    def predict_max(self, out):
        maxid = np.argmax(out.detach().numpy())
        predicted_label_name = self.class_index[str(maxid)][1]
        return predicted_label_name,maxid

1.3. 推論処理実装

そして画像を読み込んで分類し、LibTorchで読み込めるTorchScript形式へ変換して保存する処理を実装します。

from PIL import Image
import torch
from Class_Transform import ClassTransform
from Class_Predictor import ClassPredictor
from torchvision.models.vgg import vgg16,VGG16_Weights

# VGG16読み込み
net = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

# 推論モード設定
net.eval()

# 画像読み込み
img_path = 'golden_retriever.jpg'
img = Image.open(img_path)

# 前処理実行
transform = ClassTransform()
inputs = transform(img)

# 推論実行
out = net(inputs)

# 分類結果出力
predictor = ClassPredictor()
result,id = predictor.predict_max(out)
print("分類結果 ",result," : ",id)

# LibTorch用モデルへ変換
script_model = torch.jit.script(net)
script_model.save("vgg.pt")

結果は以下の通り、ゴールデンレトリバーが分類されてvgg.ptファイルが出力されていると思います。

分類結果  golden_retriever  :  207

2. LibTorch用クラス分類ファイル作成

クラス分類用ファイル「imagenet_class_index.json」はJSONファイルですが、C++だと扱いづらい(標準でサポートされていない)ので 今回はCSVファイルに変換して読み込ませることにします。

import json
import csv

imagenet_class_index = json.load(open('imagenet_class_index.json','r'))

with open('imagenet_class_index.csv', 'w', newline="", encoding='utf-8') as f:
    writer = csv.writer(f, delimiter=",")
    for i in range(len(imagenet_class_index)):
        predicted_label_name = imagenet_class_index[str(i)][1]
        data = [str(i),imagenet_class_index[str(i)][1]]
        writer.writerow(data)

3. LibTorch推論処理実装

次にPyTorchで出力したvgg.ptとimagenet_class_index.csvを読み込み、LibTorchで画像を分類する処理を実装します。
画像の読み込みにはOpenCVを使用します。

OpenCVの環境構築はこちら

3.1. 前処理用関数実装

PyTorchで実装した前処理クラスに合わせて、こちらでも前処理用の関数を実装します。

torch::Tensor Preprocess(const cv::Mat& mat_img) {
    cv::Mat mat_img_f32;

    // 0.0~1.0へ変換
    mat_img.convertTo(mat_img_f32, CV_32F, 1.0 / 255);
    
    // 224×224へリサイズ
    cv::resize(mat_img_f32, mat_img_f32, cv::Size(224, 224));

    // BGR→RGBへ変換
    cv::cvtColor(mat_img_f32, mat_img_f32, cv::COLOR_BGR2RGB);

    // [H W C] → [C H W]へ変換
    auto img_tensor = torch::from_blob(mat_img_f32.data, { 224, 224, 3 }, torch::kFloat);
    img_tensor = img_tensor.permute({ 2, 0, 1 });

    // 平均と標準偏差を合わせる
    img_tensor[0] = img_tensor[0].sub_(0.485).div_(0.229);
    img_tensor[1] = img_tensor[1].sub_(0.456).div_(0.224);
    img_tensor[2] = img_tensor[2].sub_(0.406).div_(0.225);

    // [1 3 224 224]に変換して返す
    return img_tensor.unsqueeze(0).clone();
}

3.2. 分類結果出力用関数実装

ここでも同じように、PyTorchで実装した出力用クラスに合わせて、分類結果出力用の関数を実装します。

std::vector<std::string> Result_Judge(int predicted) {
    std::string line;
    std::vector<std::string> result;

    std::ifstream ifs_csv_file("imagenet_class_index.csv");

    while (getline(ifs_csv_file, line)) {
        std::istringstream i_stream(line);

        std::string str_buf;
        std::vector<std::string> v_str_buf;
        while (getline(i_stream, str_buf, ',')) {
            v_str_buf.push_back(str_buf);
        }

        if (atoi(v_str_buf[0].c_str()) == predicted) {
            result.push_back(v_str_buf[1]);
            result.push_back(v_str_buf[0]);
            break;
        }
    }

    return result;
}

3.3. 推論処理実装

そして画像を読み込んで推論を行う処理を実装します。

// Pytorchで出力したVGG16モデルの読み込み
torch::jit::script::Module module = torch::jit::load("vgg.pt");
// 推論モードに設定
module.eval();

// 画像読み込み(OpenCV)
cv::Mat mat_img = cv::imread("golden_retriever.jpg");
// 正規化処理
torch::Tensor input = Preprocess(mat_img);

std::vector<torch::jit::IValue> inputs;
inputs.push_back(input);

// 推論実行
auto output = module.forward(inputs).toTensor();
int predicted = output.argmax(1).item<int>();

std::vector<std::string> result = Result_Judge(predicted);

std::cout << "分類結果 " << result[0] << " : " << result[1] << std::endl;

return 0;

3.4. 推論処理全体と実行結果

プログラム全体は以下の通りです。

#include <iostream>
#include <torch/torch.h>
#include <torch/script.h>
#include <opencv2/opencv.hpp>
#include <string>
#include <fstream>
#include <sstream>

torch::Tensor Preprocess(const cv::Mat& mat_img) {
    cv::Mat mat_img_f32;

    // 0.0~1.0へ変換
    mat_img.convertTo(mat_img_f32, CV_32F, 1.0 / 255);
    
    // 224×224へリサイズ
    cv::resize(mat_img_f32, mat_img_f32, cv::Size(224, 224));

    // BGR→RGBへ変換
    cv::cvtColor(mat_img_f32, mat_img_f32, cv::COLOR_BGR2RGB);

    // [H W C] → [C H W]へ変換
    auto img_tensor = torch::from_blob(mat_img_f32.data, { 224, 224, 3 }, torch::kFloat);
    img_tensor = img_tensor.permute({ 2, 0, 1 });

    // 平均と標準偏差を合わせる
    img_tensor[0] = img_tensor[0].sub_(0.485).div_(0.229);
    img_tensor[1] = img_tensor[1].sub_(0.456).div_(0.224);
    img_tensor[2] = img_tensor[2].sub_(0.406).div_(0.225);

    // [1 3 224 224]に変換して返す
    return img_tensor.unsqueeze(0).clone();
}

std::vector<std::string> Result_Judge(int predicted) {
    std::string line;
    std::vector<std::string> result;

    std::ifstream ifs_csv_file("imagenet_class_index.csv");

    while (getline(ifs_csv_file, line)) {
        std::istringstream i_stream(line);

        std::string str_buf;
        std::vector<std::string> v_str_buf;
        while (getline(i_stream, str_buf, ',')) {
            v_str_buf.push_back(str_buf);
        }

        if (atoi(v_str_buf[0].c_str()) == predicted) {
            result.push_back(v_str_buf[1]);
            result.push_back(v_str_buf[0]);
            break;
        }
    }

    return result;
}

int main()
{
    // Pytorchで出力したVGG16モデルの読み込み
    torch::jit::script::Module module = torch::jit::load("vgg.pt");
    // 推論モードに設定
    module.eval();

    // 画像読み込み(OpenCV)
    cv::Mat mat_img = cv::imread("golden_retriever.jpg");
    // 正規化処理
    torch::Tensor input = Preprocess(mat_img);

    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(input);

    // 推論実行
    auto output = module.forward(inputs).toTensor();
    int predicted = output.argmax(1).item<int>();

    std::vector<std::string> result = Result_Judge(predicted);
    
    std::cout << "分類結果 " << result[0] << " : " << result[1] << std::endl;

    return 0;
}

実行結果は以下のようになりました。

分類結果 golden_retriever : 207

無事にPyTorchと結果が一致していることが確認できました。

今回は以上です。

4. 参考文献・参考サイト

書籍
・小川雄太郎,株式会社 マイナビ出版,「つくりながら学ぶ!PyTorchによる発展ディープラーニング」,2019.07.25

Webサイト
・LibTorchリファレンス
https://docs.pytorch.org/docs/stable/cpp_index.html

コメント

タイトルとURLをコピーしました