timmのconvnextv2のトレーニング済みモデルを使って画像分析をしてみる
timmとはPyTorch向けのライブラリで、画像認識系の深層学習の実装を用意してくれている感じのライブラリになっています。
対応している手法を見てみると2024年現在ではvit, resnet, swin, convnextなどをサポートしているようです。
主にtimmは新たにモデルをトレーニングして作成する際に使うと思うのですが、すでにトレーニング済みのモデルが用意されていて試しにそれを使うこともできます。
個人的にtimmの学習済みモデルをfine-tuning(微調整)して見たいと思ったので、とりあえずモデルを動かしてみることにしました。
ConvNeXt V2
今回使っているConvNeXT V2という手法は、昔から存在するCNNという画像認識手法を活用した2022年に提案された新しい手法です。画像認識の分野では新しいTransformer系よりも精度が良いようです(2022年くらいでは)。
ConvNeXtはCNN系なので精度比でモデルサイズが小さくて済むようです。そのためメモリにも優しいと思われます。
手法の技術的な詳細はこれらのサイトに書いてありますが筆者はあまり理解していません。
https://arxiv.org/abs/2301.00808
https://arxiv.org/abs/2201.03545
画像認識をしてみる
まずPythonに必要なものをインストールします。
pip install torch timm pillow torchvision
学習済みモデルはImageNet-1kというデータセットを使っているようなので、ラベルの情報が必要です。こちらにImageNet-1kのラベルをまとめてくれている人がいるのでダウンロードしてlabels.pyという名前に変更します。
https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
適当なPythonソースコードを作成します。
import sys
from itertools import zip_longest
import timm
import torch
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
# 下記からダウンロードしたラベルをlabels.pyとして保存して使用する
# https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
from labels import labels
model = timm.create_model('convnextv2_tiny.fcmae', pretrained=True)
# model = timm.create_model("convnextv2_huge.fcmae_ft_in22k_in1k_512", pretrained=True)
model.eval()
# モデル向けに画像を変換してくれる関数を作成
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model), is_training=False)
# 変換する
image = Image.open(sys.argv[1]).convert('RGB')
image = transform(image).unsqueeze(0)
# 推論
out = model(image)
# 予想された確率の合計が1になるように変換
probabilities = out.softmax(dim=1)
# 上位5つの推論結果を取得
top5_probabilities, top5_class_indices = torch.topk(probabilities * 100, k=5)
for c, v in zip_longest(top5_class_indices[0], top5_probabilities[0]):
label_id = c.item() # ラベルのid
label_name = labels[label_id] # ラベルの英名
label_probability = v.item() # 確率
print(label_id, label_name, label_probability)
適当な画像で推論してみます。
python predict.py image.jpg
ImageNet-1kは小さいデータセットなので動物などわかりやすい画像で試してみるのが良さそうです。
めちゃめちゃな結果が返ってきたと思います。
試しに使用した学習済みモデル(convnextv2_tiny.fcmae)が非力だからです。
# model = timm.create_model('convnextv2_tiny.fcmae', pretrained=True)
model = timm.create_model("convnextv2_huge.fcmae_ft_in22k_in1k_512", pretrained=True)
金魚の画像を投入してみた結果です。
1 goldfish, Carassius auratus 85.73846435546875
392 rock beauty, Holocanthus tricolor 8.172449111938477
393 anemone fish 0.17695628106594086
973 coral reef 0.15994583070278168
うまく推論できていそうです。(金魚は学名だとCarassius auratusと言うのか…)
timmの詳細
timmの使い方はここに書いてあります。
https://huggingface.co/docs/hub/timm
timmの推論済みConvNeXTモデルの詳細は下記に書いてあります。
https://huggingface.co/timm/convnextv2_tiny.fcmae