前回に引き続き、多言語 Universal Sentence Encoder の話です。
テキストをベクトル化しただけで終わるとつまらないので、これを使って、先日のライブドアニュースコーパスの記事分類をやってみました。
最初、本文でやろうとしたのですが、文ベクトルを得るのに結構時間がかかったので、記事タイトルでカテゴリー分類をやってみます。
すごく適当ですが、512次元のベクトルに変換したデータに対してただのニューラルネットワークで学習してみました。
まずはデータの準備からです。
import pandas as pd
import tensorflow_hub as hub
# import numpy as np
import tensorflow_text
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
# ライブドアニュースコーパス データを読み込む
df = pd.read_csv("./livedoor_news_corpus.csv")
# 訓練データと評価データに分割する
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.category)
df_train = df_train.copy()
df_test = df_test.copy()
df_train.reset_index(inplace=True, drop=True)
df_test.reset_index(inplace=True, drop=True)
# USEモデルの読み込みと、テキストのベクトル化
url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
embed = hub.load(url)
X_train = embed(df_train.title).numpy()
X_test = embed(df_test.title).numpy()
# 正解ラベル(記事カテゴリ)を One-Hot 表現に変換
ohe = OneHotEncoder()
ohe.fit(df_train.category.values.reshape(-1, 1))
y_train = ohe.transform(df_train.category.values.reshape(-1, 1)).toarray()
y_test = ohe.transform(df_test.category.values.reshape(-1, 1)).toarray()
あとはモデルを作って学習していきます。
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
# モデルの作成
model = Sequential()
model.add(Input(shape=(512, )))
model.add(Dropout(0.3))
model.add(Dense(128, activation='tanh'))
model.add(Dropout(0.4))
model.add(Dense(32, activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(9, activation='softmax'))
print(model.summary())
"""
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dropout (Dropout) (None, 512) 0
_________________________________________________________________
dense (Dense) (None, 128) 65664
_________________________________________________________________
dropout_1 (Dropout) (None, 128) 0
_________________________________________________________________
dense_1 (Dense) (None, 32) 4128
_________________________________________________________________
dropout_2 (Dropout) (None, 32) 0
_________________________________________________________________
dense_2 (Dense) (None, 9) 297
=================================================================
Total params: 70,089
Trainable params: 70,089
Non-trainable params: 0
_________________________________________________________________
"""
model.compile(
loss="categorical_crossentropy",
optimizer="adam",
metrics=['acc']
)
# 学習
early_stopping = EarlyStopping(
monitor='val_loss',
min_delta=0.0,
patience=5,
)
history = model.fit(X_train, y_train,
batch_size=128,
epochs=100,
verbose=2,
validation_data=(X_test, y_test),
callbacks=[early_stopping],
)
"""
Train on 5893 samples, validate on 1474 samples
Epoch 1/100
5893/5893 - 3s - loss: 1.9038 - acc: 0.3655 - val_loss: 1.5009 - val_acc: 0.5611
Epoch 2/100
5893/5893 - 1s - loss: 1.3758 - acc: 0.5564 - val_loss: 1.1058 - val_acc: 0.6771
# -- 中略 --
Epoch 28/100
5893/5893 - 1s - loss: 0.7199 - acc: 0.7611 - val_loss: 0.5913 - val_acc: 0.8012
Epoch 29/100
5893/5893 - 1s - loss: 0.7099 - acc: 0.7663 - val_loss: 0.5932 - val_acc: 0.7985
Epoch 30/100
5893/5893 - 1s - loss: 0.7325 - acc: 0.7597 - val_loss: 0.5935 - val_acc: 0.8005
"""
かなり適当なモデルですが、それでもテストデータで80%くらい正解できたようですね。
classification_reportもみておきましょう。
print(classification_report(model.predict_classes(X_test),y_test.argmax(axis=1),target_names=ohe.categories_[0]))
"""
precision recall f1-score support
dokujo-tsushin 0.77 0.81 0.79 166
it-life-hack 0.80 0.82 0.81 169
kaden-channel 0.80 0.80 0.80 174
livedoor-homme 0.56 0.72 0.63 79
movie-enter 0.85 0.79 0.82 187
peachy 0.67 0.69 0.68 166
smax 0.90 0.89 0.89 175
sports-watch 0.88 0.89 0.89 177
topic-news 0.88 0.75 0.81 181
accuracy 0.80 1474
macro avg 0.79 0.80 0.79 1474
weighted avg 0.81 0.80 0.80 1474
"""
どのカテゴリを、どのカテゴリーに間違えたのかを確認したのが次の表です。
df_test["predict_category"] = model.predict_classes(X_test)
df_test["predict_category"] = df_test["predict_category"].apply(lambda x: ohe.categories_[0][x])
print(pd.crosstab(df_test.category, df_test.predict_category).to_html())
predict_category | dokujo-tsushin | it-life-hack | kaden-channel | livedoor-homme | movie-enter | peachy | smax | sports-watch | topic-news |
---|---|---|---|---|---|---|---|---|---|
category | |||||||||
dokujo-tsushin | 134 | 2 | 2 | 4 | 5 | 21 | 0 | 1 | 5 |
it-life-hack | 1 | 139 | 13 | 4 | 2 | 5 | 9 | 1 | 0 |
kaden-channel | 1 | 12 | 139 | 4 | 1 | 1 | 6 | 0 | 9 |
livedoor-homme | 6 | 3 | 5 | 57 | 9 | 14 | 1 | 2 | 5 |
movie-enter | 0 | 2 | 1 | 1 | 148 | 8 | 1 | 5 | 8 |
peachy | 21 | 1 | 4 | 7 | 14 | 114 | 1 | 2 | 5 |
smax | 0 | 8 | 9 | 0 | 0 | 1 | 156 | 0 | 0 |
sports-watch | 1 | 1 | 0 | 2 | 2 | 2 | 0 | 158 | 14 |
topic-news | 2 | 1 | 1 | 0 | 6 | 0 | 1 | 8 | 135 |
独女通信とPeachyとか、ITライフハックと家電チャンネルなど、記事タイトルだけだと間違えても仕方がないような誤判定があるくらいで概ね正しそうです。