pyMCの記事2記事目です。
前回非常にシンプルなモデルで推論をやりましたが、今回はその結果を可視化する便利なライブラリである、ArviZの紹介です。
ArviZはpyMCに限らず、ベイズモデルの分析や可視化、比較等を行うライブラリです。ベイズ推論の結果の分析に特化しているだけあって非常に多くの機能を持っています。
参考: ArviZ: Exploratory analysis of Bayesian models — ArviZ 0.18.0 documentation
今回は特に利用頻度が高いと思われるメソッドに絞って紹介していきます。
例としては前回の記事で作った単純なモデルを使います。前回のコードを走らせてサンプリングした結果が trace という変数に入ってるという前提で見ていってください。
推論結果のサマリーをまとめる
最初に紹介するのは推論結果を統計値で返してくれる、az.summary()です。これだけは可視化ではない(グラフ等での表示ではない)のですがよく使うのでこの記事で紹介します。
参考: arviz.summary — ArviZ 0.18.0 documentation
pm.summary()とほとんど同じ挙動なのですが、トレース結果の統計値をDataFrame形式で返してくれます。
import pymc as pm
import arviz as az
az.summary(trace)
# 以下結果
"""
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
mu 3.157 0.192 2.774 3.501 0.003 0.002 4182.0 2738.0 1.0
sigma 1.933 0.133 1.687 2.184 0.002 0.002 3934.0 3009.0 1.0
"""
各変数の平均値等が確認できるので便利ですね。
サンプルの系列(トレース)を可視化する
次に紹介するのは az.plot_trace()です。MCMCの可視化としては一番ポピュラーなやつではないでしょうか。
参考: arviz.plot_trace — ArviZ 0.18.0 documentation
今回用意している例では2変数を4系列で1000ステップサンプリングしていますので、その分布とトレースを一気に可視化してくれます。
手元で試したところ、ちょっとラベルが重なっていたので、matplotlibのメソッドを一つ呼び出して調整しています。
import matplotlib.pyplot as plt
az.plot_trace(trace)
plt.tight_layout()
plt.show()
出力結果がこちらです。
いい感じですね。
事後分布を可視化する
さっきのトレースの左半分にも表示されてはいるのですが、本当に欲しい結果は事後分布です。それを表示することに特化しているのが az.plot_posterior() です。
参考: arviz.plot_posterior — ArviZ 0.18.0 documentation
やってみます。
az.plot_posterior(trace)
plt.show()
出力がこちら。事後分布と変数名、期待値等やhdiなどを可視化してくれましたね。
フォレストプロットで可視化する
フォレストプロットとは何か?というのは実物を見ていただいた方が早いと思うのでやってみます。2変数なのでいまいちありが分かりにくいと思いますが、変数の数が多いとこれは非常に便利です。
参考: arviz.plot_forest — ArviZ 0.18.0 documentation
az.plot_forest(trace)
plt.show()
結果がこちら。
サンプリングの系列ごとに可視化してくれていますね。引数で、 combined=True を一緒に渡すと、系列をまとめて変数ごとに集計してくれますよ。
kind等の引数で見た目を変えていくこともできるのでドキュメントを参照していろいろ試してみてください。
分布を可視化する
最後は、ちょっと特殊です。AzviZにはnumpyの配列などを受け取って単純に分布を表示するメソッドなども用意されています。それが、az.plot_dist()です。
参考: arviz.plot_dist — ArviZ 0.18.0 documentation
これは、numpy配列(要するにarray)を受け取るので、先ほどまでの例のようにtraceをそのまま渡せません。pyMCの事後分布を可視化したいのであれば、traceからサンプリングした結果の部分を自分で取り出して渡す必要があります。
例えば、muの方であればこのようになります。
az.plot_dist(trace.posterior['mu'].values.ravel())
plt.show()
結果がこちら。
これはpyMCの結果以外でも汎用的に使えるやつなので一緒に紹介しました。
その他の補足
今回の記事では、例としたモデルが非常に単純だったので使いませんでしたが、大規模なものになると全変数表示すると潰れてしまって読み取れないということが起きます。
そのような場合、それぞれのメソッドが var_names という引数で出力する変数を絞り込めるようになっているで使ってみてください。
また、多くのメソッドは ax 引数などを受け取れるようになっているで、出力先のaxを指定してmatplotlibの機能で出力を加工することなどもできます。
それ以外にも各メソッドさまざまなオプションを持っているのでぜひドキュメントを参照しながら使いこなしてみてください。