Universal Sentence Encoder を使ってニュース記事分類

前回に引き続き、多言語 Universal Sentence Encoder の話です。
テキストをベクトル化しただけで終わるとつまらないので、これを使って、先日のライブドアニュースコーパスの記事分類をやってみました。
最初、本文でやろうとしたのですが、文ベクトルを得るのに結構時間がかかったので、記事タイトルでカテゴリー分類をやってみます。

すごく適当ですが、512次元のベクトルに変換したデータに対してただのニューラルネットワークで学習してみました。

まずはデータの準備からです。


import pandas as pd
import tensorflow_hub as hub
# import numpy as np
import tensorflow_text
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

# ライブドアニュースコーパス データを読み込む
df = pd.read_csv("./livedoor_news_corpus.csv")
# 訓練データと評価データに分割する
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.category)
df_train = df_train.copy()
df_test = df_test.copy()
df_train.reset_index(inplace=True, drop=True)
df_test.reset_index(inplace=True, drop=True)

# USEモデルの読み込みと、テキストのベクトル化
url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
embed = hub.load(url)

X_train = embed(df_train.title).numpy()
X_test = embed(df_test.title).numpy()

# 正解ラベル(記事カテゴリ)を One-Hot 表現に変換

ohe = OneHotEncoder()
ohe.fit(df_train.category.values.reshape(-1, 1))
y_train = ohe.transform(df_train.category.values.reshape(-1, 1)).toarray()
y_test = ohe.transform(df_test.category.values.reshape(-1, 1)).toarray()

あとはモデルを作って学習していきます。


from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

# モデルの作成
model = Sequential()
model.add(Input(shape=(512, )))
model.add(Dropout(0.3))
model.add(Dense(128, activation='tanh'))
model.add(Dropout(0.4))
model.add(Dense(32, activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(9, activation='softmax'))
print(model.summary())
"""
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               65664     
_________________________________________________________________
dropout_1 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 32)                4128      
_________________________________________________________________
dropout_2 (Dropout)          (None, 32)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 9)                 297       
=================================================================
Total params: 70,089
Trainable params: 70,089
Non-trainable params: 0
_________________________________________________________________
"""

model.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=['acc']
)

# 学習
early_stopping = EarlyStopping(
                        monitor='val_loss',
                        min_delta=0.0,
                        patience=5,
                )

history = model.fit(X_train, y_train,
                    batch_size=128,
                    epochs=100,
                    verbose=2,
                    validation_data=(X_test, y_test),
                    callbacks=[early_stopping],
                    )

"""
Train on 5893 samples, validate on 1474 samples
Epoch 1/100
5893/5893 - 3s - loss: 1.9038 - acc: 0.3655 - val_loss: 1.5009 - val_acc: 0.5611
Epoch 2/100
5893/5893 - 1s - loss: 1.3758 - acc: 0.5564 - val_loss: 1.1058 - val_acc: 0.6771

# -- 中略 --

Epoch 28/100
5893/5893 - 1s - loss: 0.7199 - acc: 0.7611 - val_loss: 0.5913 - val_acc: 0.8012
Epoch 29/100
5893/5893 - 1s - loss: 0.7099 - acc: 0.7663 - val_loss: 0.5932 - val_acc: 0.7985
Epoch 30/100
5893/5893 - 1s - loss: 0.7325 - acc: 0.7597 - val_loss: 0.5935 - val_acc: 0.8005
"""

かなり適当なモデルですが、それでもテストデータで80%くらい正解できたようですね。
classification_reportもみておきましょう。


print(classification_report(model.predict_classes(X_test),y_test.argmax(axis=1),target_names=ohe.categories_[0]))
"""
                precision    recall  f1-score   support

dokujo-tsushin       0.77      0.81      0.79       166
  it-life-hack       0.80      0.82      0.81       169
 kaden-channel       0.80      0.80      0.80       174
livedoor-homme       0.56      0.72      0.63        79
   movie-enter       0.85      0.79      0.82       187
        peachy       0.67      0.69      0.68       166
          smax       0.90      0.89      0.89       175
  sports-watch       0.88      0.89      0.89       177
    topic-news       0.88      0.75      0.81       181

      accuracy                           0.80      1474
     macro avg       0.79      0.80      0.79      1474
  weighted avg       0.81      0.80      0.80      1474
"""

どのカテゴリを、どのカテゴリーに間違えたのかを確認したのが次の表です。


df_test["predict_category"] = model.predict_classes(X_test)
df_test["predict_category"] = df_test["predict_category"].apply(lambda x: ohe.categories_[0][x])

print(pd.crosstab(df_test.category, df_test.predict_category).to_html())
predict_category dokujo-tsushin it-life-hack kaden-channel livedoor-homme movie-enter peachy smax sports-watch topic-news
category
dokujo-tsushin 134 2 2 4 5 21 0 1 5
it-life-hack 1 139 13 4 2 5 9 1 0
kaden-channel 1 12 139 4 1 1 6 0 9
livedoor-homme 6 3 5 57 9 14 1 2 5
movie-enter 0 2 1 1 148 8 1 5 8
peachy 21 1 4 7 14 114 1 2 5
smax 0 8 9 0 0 1 156 0 0
sports-watch 1 1 0 2 2 2 0 158 14
topic-news 2 1 1 0 6 0 1 8 135

独女通信とPeachyとか、ITライフハックと家電チャンネルなど、記事タイトルだけだと間違えても仕方がないような誤判定があるくらいで概ね正しそうです。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です