概要
CNNを用いて、pngのキャラクター画像から多項分類をしました。画像関連の勉強のため。 対象は、Palworld, Pokemon, Dragon Quest, Digital monster, Yokai-watch
経緯
ここ1年半ほどポケモンにはまっています。初代の赤緑と2代目の金銀を小学生の頃にプレイして以来、しばらく離れていましたが、最近娘と一緒にSVをプレイし始め、それをきっかけに(独りで)アルセウスをプレイしたり、ポケモン関連のYouTube動画を視聴する毎日を過ごしています。
そんな中2024年1月にPalworldが発売される前後から、キャラクターがポケモンに似ているとの指摘が増えていました。*1確かに比較してみると非常によく似ているものの、ポケモンが1000種類も存在するので、どうしても似てしまうことも避けられないだろうとも思っていました。
少し調べたところ、法律的には厳しいんじゃないかという意見が多く、特に、知的財産権に強い任天堂がどのような対応をとるのか、司法でどのような判断が下されるのか、非常に興味深いと思っています。 近年では、生成AIの進展により、元データの権利に関する問題が複雑化しており、そもそもこれから長い期間ほっとトピックになっていくんでしょう。
一方で、これって機械学習で解けるんじゃね?という。以前からCNNは、MNISTのようなチュートリアルや、否定形の時系列データのグラフを画像変換するなど、いくつか触ってみたことがあるものの、画像そのものには触れたことがなかったので、これを機に試してみることにしました。
最初は、ポケモンなどの画像からGANで学習すれば、生成と同時にdiscriminator側でポケモンの判定も可能になるため、面白いなと思いました。(ついでに最近は逆強化学習の一環としてGAILを調査しており、その下調べとしてもあり)。ただ結局、家のPCの環境ではGANは少し厳しいか、少なくとも時間がかかりそうだったため、まずは単なるCNNを使用した分類問題に挑戦しています。
画像収集
まずは、ポケモンやPalworldに似ているキャラクターでかつ、その数が多そうなので、デジタルモンスター、ドラゴンクエスト、妖怪ウォッチを考えました(ついでにアンパンマンも考えましたが、さすがに絵柄が違うので除外)。これらをそれぞれ収集。
キャラクターの背景を白にして、128x128のRGB画像に変換しました。 元の画像の数は、それぞれ
ChatGPTから画像の数は同じくらいがいいとアドバイスがあったので、数が多いポケモン以外を左右反対も加えて倍にして、計2099枚。
方法
全体の20%をtestデータにし、tensorflowを用いて、最終的に下記のようになりました。
def convertX(path): img = load_img(path, target_size=(128, 128)) img_array = img_to_array(img) img_array /= 255.0 # 画像データを[0, 1]に標準化 return img_array # 画像のパスとラベルの取得 image_paths = df['path'].values labels = df.iloc[:, 1:].values # 'path'列以外の列をラベルとして取得 # ラベルのエンコーディング label_encoder = LabelEncoder() encoded_labels = label_encoder.fit_transform(labels.flatten()) num_classes = len(label_encoder.classes_) # 画像データの準備 X = np.array([convertX(path) for path in image_paths]) y = to_categorical(encoded_labels, num_classes=num_classes) # データの分割 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y) # CNNモデルの構築 model = Sequential() model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3))) model.add(MaxPooling2D((2, 2))) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D((2, 2))) model.add(Conv2D(128, (3, 3), activation='relu')) model.add(MaxPooling2D((2, 2))) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dense(num_classes, activation='softmax')) # モデルのコンパイル model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) early_stopping = EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True) history = model.fit(X_train, y_train, epochs=15, batch_size=64, validation_data=(X_test, y_test), callbacks=[early_stopping])
AutoMLを使って探索もしましたが、まさかの初めChatGPTが提案したものから、batch sizeを少し大きくしたものからかわりませんでした。
全部で15epoch計算したところ、testデータのAccuracyが最大となった、epoch9のmodelを使用します。下は学習曲線
結果
まずは代表的な画像の分類結果。略称はそれぞれポケモン(PM)、Palworld(PW)、デジタルモンスター(DM)、ドラゴンクエスト(DQ)、妖怪ウォッチ(YW)。
画像内の数字は、[そのカテゴリーに分類された数 / 正確なカテゴリーのデータ数]。例えば、列がDQ,行がPMが1/823とは、823個のPMのうち1個がDQに分類された、ことを意味する。
対角に並んでるのは、正確に分類できるもの。例えばTrainデータ内の、クレベースはポケモンだが、ドラクエキャラに分類されていて、言われてみればそんな気もしてくる。同じくマホイップがPalword、マシマシラが妖怪ウォッチに分類されてる。
最終的な正解率は表のような感じで、全体でtest:84%。ポケモンが90%を超えるのに、Palworldは他よりもかなり悪くて、test:34%。狙った結果ではあるんだが、思いの他悪いものの、ポケモンのみではなく、デジモン以外は均等に間違えてる。
またデジモンも正解率65%と低く、キャラはポケモンと間違われているのが多い。感覚的にはデジモンのキャラはとんがっている部分が多くてドラクエと間違えるかと思っていた。
単純な正解率にしてますが、データ数に偏りがあるので、多項分類の評価方法をもっと考えたほうがよいかもですが、とりあえず。
Accuracy[%] | DM | DQ | PM | PW | YW | |
---|---|---|---|---|---|---|
Train | DM | 94 | 0 | 2 | 1 | 1 |
DQ | 0 | 99 | 0 | 0 | 0 | |
PM | 0 | 0 | 99 | 0 | 0 | |
PW | 0 | 2 | 1 | 95 | 0 | |
YW | 0 | 0 | 2 | 0 | 97 | |
Test | DM | 65 | 0 | 25 | 0 | 10 |
DQ | 0 | 96 | 1 | 0 | 1 | |
PM | 1 | 0 | 93 | 0 | 3 | |
PW | 2 | 9 | 29 | 34 | 25 | |
YW | 1 | 3 | 11 | 2 | 82 |
※合計100%にならない場合あります。
他の方にも指摘されているBushiが妖怪ウォッチと誤分類されてるから、割と納得感がある。
画像内数字は、[推定分類:その度合い]
分類がうまくいかない原因としてざっと考えられるものとしては、ゲームキャラクターのデザイナーが複数いる場合などが考えられるでしょうか?ポケモン内でも複数のデザイナーがいるはずなので、どうやって統一感を出してるのかなどがカギなのかもしれません。ポケモンのみを対象にして、デザイナー分類か、世代分類(例えば初代またはSV)できるかも問題としておもしろいかもしれませんね。
感想
- ざっとやったにもかかわらず、なかなかの精度なのはぱっと見すごいんだが、CNNは定型的な問題にむいてるからなんだろうか?
- やってる途中から気づいていたが、図形の時間軸を考えてないから、因果関係も当然わからない。複数の画像の類似度をかえすようなもののほうがよかったんだろうか?Cosine similarityなど。
- 余談だが、よくある問題ならChatGPT3.5でも十分使えた。マイナーそうな問題だとcode interpreterでやってもらわないと、うそばっかりなイメージだけど。
- 分類で99%とか言いながら間違えることが多すぎじゃね?Trainでは分類できてるから、そんな間違った推定度合いになるんだろうか?もっと自信なくてもいいのに。