kerasでモデルを構築したとき、構築したモデルが意図した構造になっているかどうか可視化して確認する方法です。
Sequentialモデルであれば、 .summary()で十分なことが多いのですが、functional APIを使って複雑なモデルを作る場合に重宝します。
kerasのドキュメントを見ると、そのままズバリな名前で 可視化 のページがあり、plot_modelという関数が説明されています。
可視化 – Keras Documentation
「graphvizを用いて」と書かれている通り、graphvizがインストールされている必要がありますが、
このほか pydot というライブラリも必要なのでpip等でインストールしておきましょう。
(他サイトなどでpydotは開発が止まっていて動かないからpydotplusを使う、といった趣旨の記事を見かけますが、
現在はpydotの開発が再開されているようでpydotで動きます。)
さて、graphvizとpydotが入ったら、早速ちょっとだけ複雑なモデルを作ってみて、可視化してみましょう。
一応 model.summary() の結果も表示してみました。
まずは可視化対象のモデル構築から。
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Add
i0 = Input(shape=(64, ))
i1 = Input(shape=(64, ))
x0 = Concatenate()([i0, i1])
x1 = Dense(32, activation="tanh")(x0)
x2 = Dense(32, activation="tanh")(x1)
x3 = Add()([x1, x2])
x4 = Dense(1, activation="sigmoid")(x3)
model = Model([i0, i1], x4)
print(model.summary())
# 以下出力結果
"""
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 64)] 0
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, 64)] 0
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 128) 0 input_1[0][0]
input_2[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 32) 4128 concatenate[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 32) 1056 dense[0][0]
__________________________________________________________________________________________________
add (Add) (None, 32) 0 dense[0][0]
dense_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 1) 33 add[0][0]
==================================================================================================
Total params: 5,217
Trainable params: 5,217
Non-trainable params: 0
__________________________________________________________________________________________________
None
"""
Connected to に複数レイヤー入っているとぱっと見わかりにくいですね。
次にplot_model使ってみます。
show_shapes オプションを使って、入出力の形も表示してみました。
from tensorflow.keras.utils import plot_model
plot_model(
model,
show_shapes=True,
)
出力されたのがこちら。
モデルの形をイメージしやすいですね。