Streamlitの変数がリセットされないようにsession単位で保存する

今回もStreamlitのお話です。前回、各種ウィジェットの使い方を紹介しましたが、ウィジェットを何かしら操作するとコードが一通り再実行されます。
ウィジェットの値自体は保存されているのですが、それ以外のPythonコード中の変数は全てリセットされてしまいます。

例えば、次のように、スライダーで値を選択してボタンを押したらそれが合計されていく、みたいなことをしようとした(しかしできてない)ソースを書いて見ます。

import streamlit as st


sum_ = 0
value = st.slider(label="スライダー", min_value=0, max_value=10)
if st.button("合計"):
    sum_ += value
    st.write(sum_)

これ、スライダーで値を選んで、「合計」ボタンを押したらどんどん値が足されてその結果が表示されそうなコードに見えませんか?

しかし実際はそのような挙動にならず、スライダーの値を変更したり、ボタンをクリックしたりすると一番最初のsum_=0の部分も再実行されるので値は積み重なっていかず、逐一リセットされます。

また、スライダー動かしただけ、の場合、st.write() の部分で書き出した数字まで消えます。(ボタン押したときだけ数字が表示される。)

このような場合に使うのが、st.session_stateです。
参考: Session State – Streamlit Docs

これはほとんど辞書のように使えます。
一番最初はst.session_stateが空っぽなのでif文で判定を入れて初期化し、それ以降は辞書の値を読み書きするのと同じイメージで実装していきます。先ほどのコードは次のようになります。

import streamlit as st


# session_state に値がない場合初期化する。
if "sum_" not in st.session_state:
    st.session_state.sum_ = 0

value = st.slider(label="スライダー", min_value=0, max_value=10)
if st.button("合計"):
    # session_stateのkeyに対応した値を更新する
    st.session_state.sum_ += value

# 表示
st.write(st.session_state.sum_)

スライダーを動かした時に消えないようにst.writeはif文の外に出しました。これで、合計ボタンを押すたびにスライダーで指定している値分だけ数字が足され、合計が表示されます。

これは単純に変数の値を保存しておくだけでなく、DBやWebなどの外部リソースからのデータ取得を伴う処理の効率化、要するにキャッシュにも使えます。

df = {DBからデータを取ってくる処理}
みたいなのをそのまま実装してしまうと、ウィジェットを何か動かすたびにSQLを発行してDBに負担をかけてしまいます。ここで、一度データを取得したらst.session_state.df などに値を保存しておき、値があったら再取得しないような実装にすると、DBなどの外部リソースへの負荷を減らし、レスポンス速度も改善できます。

ちなみに、値を消したかったら、対象のkeyに対してdelをやれば良いです。
del st.session_state.{key}

また、GUI上でも、右上の3点リーダーにClear cache というメニューがあるのでこれで消せます。

Streamlitの入力ウィジェット

今回もStreamlitの話です。

Streamlitではいろんな操作をインタラクティブに行えるようにさまざまなウィジェットが用意されています。ドキュメントのウィジェットのページはこちら。
参考: Input widgets – Streamlit Docs

テキストエリアやチェックボックス、ラジオボタンにドロップダウン、日時のセレクターなど主要なものは一通り揃っていると思います。

一個一個説明するよりも、ざっと使用例をお見せしたほうがいいと思うのでサンプルコードを最初に紹介します。基本的に、 st.{ウィジェット名}(ラベルなどの引数) で使って、戻り値を変数で受け取ってのちの処理を実装します。今回のサンプルコードでは全部書き出すようにしました。

import streamlit as st

# ボタン。ラベルを指定する。
if st.button("ボタン"):
    st.write("ボタンが押されました!")
else:
    st.write("ボタンがまだ押されていません。")

# チェックボックス。ラベルを指定する。
check = st.checkbox("チェックボックス")
if check:
    st.write("チェックされました")
else:
    st.write("チェックされていません")

# トグル。ラベルを指定する。
toggle = st.toggle("トグル")
if toggle:
    st.write("ONです")
else:
    st.write("OFFです")

# ラジオボタン。ラベルと選択肢を渡す。
color = st.radio(
    "色を選んでください",
    ("赤", "青", "緑")
)
st.write(f"選択された色: {color}")

# 1個の値を選択するドロップリスト。
fruits = st.selectbox(
    "フルーツを選んでください",
    ("りんご", "オレンジ", "バナナ")
)

st.write(f"選択されたフルーツ: {fruits}")
# 複数選択可能なドロップリスト。ラベルと選択肢を指定する。
colors = st.multiselect(
    "色を選んでください",
    ["赤", "青", "緑", "黄", "黒", "白"]
)
# 選択した値は配列で返される。
st.write(f"選択された色: {colors}")

# スライダー。ラベルと最小値, 最大値, 初期値, ステップ を指定する。(省略可)
# 下の例は0才から100才まで5才刻みで選択でき、初期値が20才
age = st.slider('年齢を選択してください', 0, 100, 20, 5)
st.write(f"選択された年齢: {age}")

# 日付入力。
date = st.date_input("日付")
st.write(f"選択された日付: {date}")

# 時間入力。
time = st.time_input("時刻")
st.write(f"選択された時刻: {time}")

# テキストボックス(1行)
text_single = st.text_input("1行のテキストを入力できます。")
st.write(f"入力されたテキスト: {text_single}")

# テキストボックス(複数行)
text_multi= st.text_area("複数行のテキストを入力できます。")
# 出力はマークダウンなので、改行したい場合は行末に半角スペース2個必要です。
st.write(f"入力されたテキスト: {text_multi}")

これを実行すると次のような画面になります。

コードからイメージした通りのものになっているのではないでしょうか。

画面内のウィジェットのうちどれか一つを操作すると、画面全体が再描写されます。他のウィジェットの現時点の値はリセットされないので安心です。(ただ、それ以外の変数の値などはリセットされます。)

先週のグラフの可視化の記事と組み合わせると、ウィジェットで描写するデータを絞り込んで対象のデータ分のグラフを書く、といった使い方ができますね。

Streamlitの標準機能によるグラフ描写

Streamlitの記事2本目です。今回はデータの可視化として、Streamlit標準のグラフ描写機能と少し他のデータ表示方法を紹介します。

前回の記事でも描きましたが、Streamlitにはmatplotlibのグラフを表示する機能があります。また、これ以外にもgraphvizとかplotlyとかのグラフを表示する機能もあって、それらを使えば良いからというのもあってか標準のグラフ作成機能は作成できるグラフの種類がかなり限られます。

具体的には、エリアグラフ/棒グラフ/折れ線グラフ/地図/散布図の5種類です。
ドキュメントはこちら

今回の記事ではこの5種類のグラフと、あと指標の数値をそのまま表示する機能、そして画像データ(行列データ)を表示するサンプルコードをそのまま紹介します。

サクッと動かせるのでコピペして試してみてください。

import streamlit as st
import pandas as pd
import numpy as np

# データの準備
data = pd.DataFrame(
    np.random.randn(20, 3),
    columns=['a', 'b', 'c']
)

# エリアグラフ
st.area_chart(data)
# 棒グラフ
st.bar_chart(data)
# 折れ線グラフ
st.line_chart(data)
# 散布図(x軸、y軸に利用したい列を指定して使う)
st.scatter_chart(data, x="a", y="b")

# 地図表記用の緯度経度データの作成
data_map = pd.DataFrame(
    {
        'lat': [37.76, 37.76],
        'lon': [-122.4, -122.41]
    }
)
# 地図
st.map(data_map)

# メトリックを表示。valueが値で、deltaで変化幅を表示可能
st.metric(label="為替", value="165円", delta="-2円")

# imageで画像データを表示可能。identityはサンプルとして用意した斜め線の図(単位行列)
st.image(1-np.identity(100), caption="斜線")

念のためですが、実行方法は $streamlit run {ファイル名} です。

Streamlit入門

以前から気になっていたのですが、Streamlitというライブラリを最近本格的に使い始めました。これは、簡単にWebアプリケーションを作成できるPythonライブラリなのですが、データの可視化や分析を行うアプリケーションの作成に使うことを念頭において開発されており、僕らの業務と大変相性の良いライブラリです。

インストールとサンプルの実行

インストールはPyPIからpipで行えます。
参考: streamlit · PyPI

インストールしたあと、PyPIのサイトに掲載されているコマンドでサンプルを起動できます。

$ pip install streamlit
$ streamlit hello

結構面白いサンプルなので、これからStreamlitを使っていこうというモチベーションを上げる意味でも一度試すことをお勧めします。

ここから超基礎的な使い方の説明に入ります。(上記のサンプルよりしょぼくてすいません。今後の記事でもう少し色々解説します。)

最もシンプルなアプリの実装

最初に、起動の確認としてただテキストを表示するだけのアプリを作ってみましょう。

app.py というファイルに以下のコードを書いて保存します。

import streamlit as st

st.title('初めてのStreamlitアプリケーション')
st.write('こんにちは、Streamlit!')

そして、次のコマンドで起動します。

$ streamlit run app.py

これで、タイトルとテキストを表示するだけのアプリが起動します。(localhostの8501番ポートです)

インタラクティブなウィジェットの利用

少しだけ動きを出してみます。ドキュメントのウィジェットのページ に色々紹介されているのですが、ここではスライダーを試します。選択した値を表示してみましょう。

import streamlit as st

st.title('スライダーの例')

value = st.slider('数値を選んでください', 0, 100, 50)
st.write('選択した数値:', value)

これを同じように実行すると、スライダーが表示され、0から100の整数(初期値が50)を一つ選べ、スライダーを動かすとその下のテキストボックスに選んだ数値が表示されます。

データフレームの表示

次は、データフレームの表示方法を紹介します。StreamlitはPandas (と他にもPyArrow, Snowpark,PySpark) のDataFrameを表示する専用のメソッドを持っているのです。

参考: st.dataframe – Streamlit Docs

import streamlit as st
import pandas as pd

st.title('データフレームの表示')

data = {
    '名前': ['Alice', 'Bob', 'Charlie'],
    '年齢': [24, 27, 22],
    '得点': [88, 92, 85]
}
df = pd.DataFrame(data)

st.write('データフレーム:')
st.dataframe(df)

これでデータフレームが表示されます。

matplotlibのグラフの表示

最後にmatplotlibのグラフを表示する方法を紹介します。一応、データフレームとグラフが表示できたら超最低限のダッシュボードは作れます。

matplotlibのグラフは、st.pyplotメソッドで表示します。とりあえずsinのグラフでも表示しておきましょう。

import streamlit as st
import matplotlib.pyplot as plt
import numpy as np

# タイトルを設定
st.title('matplotlibのグラフを表示する例')

# データを作成
x = np.linspace(0, 10, 100)
y = np.sin(x)

# matplotlibでグラフを作成
fig, ax = plt.subplots()
ax.plot(x, y, label='sin(x)')
ax.set_xlabel('X軸')
ax.set_ylabel('Y軸')
ax.set_title('Sine Wave')
ax.legend()

# Streamlitでグラフを表示
st.pyplot(fig)

matplotlibの使い方に慣れている人(ぼくもそうです)はとりあえずいつものノリでグラフを書いて渡すだけで表示できるので便利です。

これ以外にも、Streamlit自体の機能でのグラフ描写等も行えるので今後の記事で紹介していきたいと思います。

http.serverでCGIを動かす

今回までhttp.serverの話です。

前回の記事でカスタムハンドラーを作成する方法を紹介しましたのでPythonを使って動的にサイトを作ることもできるようになりましたが、これ以外にもCGIを使って動的なサイトを作ることもできます。
参考: Pythonのhttp.serverモジュールでカスタムハンドラーを実装する方法

要するに普通にPythonファイルをドキュメントルートに設置しておいて、それを動かせるわけですね。

一番シンプルな方法は、 http.serverをコマンドで起動する際に –cgi オプションをつけることです。デフォルトでは、ドキュメントルート直下の、 /cgi-bin と /htbin の 二つのディレクトリに配置されたファイルはCGIとして処理されるようになります。

この二つのディレクトリに固定されるのは、class http.server.CGIHTTPRequestHandler の、cgi_directories ってプロパティにそう指定されているからです。逆にこれ以外のディレクトリに.pyファイルを配置していてもhtmlファイルと同じようにただそのファイルの中身が返されます。

やってみましょう。

cgi-binというディレクトリを作成して、その直下に sample-cgi.py というファイル名で以下のスクリプトを書いておきます。

#!/usr/bin/env python

print("Content-Type: text/html; charset=utf-8\n")
print("<html><body>")
print("<h1>CGIスクリプト実行!</h1>")
print("</body></html>")

そして、このファイルにchmod 744 で実行権限をつけておきます。これ重要です。

そして、http.serverを起動します。

 % python -m http.server --cgi

こうすると localhost:8000/cgi-bin/sample-cgi.py にブラウザでアクセスすると、
CGIスクリプト実行! の文字が表示されます。

CGIで作成するとファイル名とURLがそのままシンプルに対応していくのでいくつも作る場合はシンプルで良いですね。

実は、class http.server.CGIHTTPRequestHandler というのを使うと、CGIディレクトリの指定とかがもっと柔軟に行えるのですが、http.server使ってそこまで凝ったことをすることもないんじゃないかなぁと思うので簡潔ですが今回の記事はここまでとします。

Pythonのhttp.serverモジュールでカスタムハンドラーを実装する方法

前回の記事に続いて、http.serverの話です。せっかくPythonを使ってWebサーバーを立てるわけですからファイルの内容を表示する静的サイトだけでなく、クエリパラメーターやフォームからPOSTされたデータを処理して表示する動的サイトの作り方を軽く紹介しておきます。

ただ、前回の記事でも書きました通り、http.server自体がプロダクション環境に適さない簡易的なものなのであくまでもちょっとした手元のツール等での利用に止めることを推奨します。

カスタムハンドラーの作成

http.serverのBaseHTTPRequestHandlerを継承してカスタムハンドラーを作成することで、特定の処理に対して独自の処理を行えます。

例えば、以下の内容で、sample1.py というファイルを作ってみましょう。

from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.parse

class MyHandler(BaseHTTPRequestHandler):
    def do_GET(self):
        if self.path == "/hello":
            self.send_response(200)
            self.send_header("Content-type", "text/html; charset=utf-8")
            self.end_headers()
            self.wfile.write("こんにちは!".encode("utf-8"))
        else:
            super().do_GET()

PORT = 8000
with HTTPServer(("", PORT), MyHandler) as httpd:
    print(f"Serving on port {PORT}")
    httpd.serve_forever()

そして、このファイルを実行します。

% python sample1.py

そうすると、`http://localhost:8000/hello` にアクセすると、こんにちは! のメッセージが表示されます。あとはプログラムで出力したい文字列を作成すれば任意のhtmlを返せますし、テンプレートファイルを読み込んでそれを表示すると言ったこともできます。

クエリパラメータの処理

次は、クエリパラメーターを受け取ってそれに応じた表示をするようにしてみましょう。

ファイル名は sample2.py等で作ります。実行方法は同じようにpythonコマンドにファイル名を渡すだけです。

from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.parse

class MyHandler(BaseHTTPRequestHandler):
    def do_GET(self):
        parsed_path = urllib.parse.urlparse(self.path)
        query_params = urllib.parse.parse_qs(parsed_path.query)
        self.send_response(200)
        self.send_header("Content-type", "text/html; charset=utf-8")
        self.end_headers()
        response = f"Path: {parsed_path.path}, Query parameters: {query_params}"
        self.wfile.write(response.encode("utf-8"))


PORT = 8000
with HTTPServer(("", PORT), MyHandler) as httpd:
    print(f"Serving on port {PORT}")
    httpd.serve_forever()

pathとパラメーターを受け取ってそれを表示するようにしてみました。先ほどの例と同様に、
% python sample2.py で起動して、 http://localhost:8000/hello?name=%E3%82%86%E3%81%86%E3%81%9F%E3%82%8D%E3%81%86 にアクセスすると、
Path: /hello, Query parameters: {‘name’: [‘ゆうたろう’]}
という表示が得られます。

上記のスクリプトでパスとパラメーターが取得できているのであとはそれを自由に活用するコードを書くだけです。

POSTリクエストの処理(フォームデータの処理)

最後にポストされたデータの処理方法を書いておきます。

これはポストするフォームも必要なのでそちらから用意します。
index.html という名前で次のファイルを作っておいてください。

<!DOCTYPE html>
<html>
<head>
    <title>Form Submission</title>
</head>
<body>
    <form action="/submit" method="post">
        <label for="name">Name:</label>
        <input type="text" id="name" name="name"><br>
        <label for="age">Age:</label>
        <input type="text" id="age" name="age"><br>
        <input type="submit" value="Submit">
    </form>
</body>
</html>

そして作成するpythonファイル、 sample3.pyを用意します。

from http.server import BaseHTTPRequestHandler, HTTPServer
import urllib.parse

class MyHandler(BaseHTTPRequestHandler):
    def do_GET(self):
        if self.path == "/":
            self.send_response(200)
            self.send_header("Content-type", "text/html; charset=utf-8")
            self.end_headers()
            with open("index.html", "rb") as file:
                self.wfile.write(file.read())
        else:
            self.send_response(404)
            self.end_headers()

    def do_POST(self):
        if self.path == "/submit":
            content_length = int(self.headers["Content-Length"])
            post_data = self.rfile.read(content_length)
            data = urllib.parse.parse_qs(post_data.decode("utf-8"))
            self.send_response(200)
            self.send_header("Content-type", "text/html; charset=utf-8")
            self.end_headers()
            response = f"Received: {data}"
            self.wfile.write(response.encode("utf-8"))
        else:
            self.send_response(404)
            self.end_headers()

PORT = 8000
with HTTPServer(("", PORT), MyHandler) as httpd:
    print(f"Serving on port {PORT}")
    httpd.serve_forever()

これで、 localhost:8000 にアクセスすると、 do_GETメソッドが実行されてindex.htmlのファイルの中身(フォーム)が表示され、フォームにデータをPOSTすると、do_POSTメソッドが実行されてフォームで送信した内容が表示されます。

簡単ではありますが、以上がhttp.serverモジュールを利用したカスタムハンドラーの作成やクエリパラメーターの処理、POSTリクエストの処理方法でした。

Pythonの標準ライブラリで手軽にWebサーバーを立てる方法

掲題の通りの記事です。Macには標準でhttpdが入ってるのでどこまでニーズがあるのか不明ですが、Python環境があると標準ライブラリで簡易的なWebサーバーを立てることができます。

手元のメモ等をテキストファイルで保存している場合、そのディレクトリをドキュメントルートとしてサーバーを起動すると、ブラウザから参照できるようになるので便利なことがあります。

もちろん、Web開発をやる人であればローカルでWebサービスのテストを行うと言った普通のWebサービスとしての活用もできますね。

利用するのは http.server モジュールです。
参考: http.server — HTTP servers — Python 3.12.3 ドキュメント

ドキュメントを開くといきなり、
警告 http.server is not recommended for production. It only implements basic security checks.
と出てくる通り、本番環境での活用は推奨されていない、あくまでも簡易的なものです。

本当に一番シンプルな使い方は、 `python -m http.server` というコマンドを打つだけです。

% python -m http.server
Serving HTTP on :: port 8000 (http://[::]:8000/) ...

上記の通り、デフォルトでは8000番ポートを使ってWebサーバーが起動します。
ドキュメントルートはコマンドを実行したディレクトリなので、 index.html ファイルを置いておくと、ブラウザで、 http://localhost:8000/ にアクセスするとそのindex.htmlファイルの中身が表示されます。

index.html ファイルがないと、 そのディレクトリに置かれているファイルリストが、
Directory listing for /
として表示されます。

これは通常の世界に公開するWebサーバーではディレクトリ構造が明らかになってしまうセキュリティ上欠陥のある仕様ですが、手元のメモへのアクセスのために使っている僕にとっては大変便利な仕様です。

下記のように、数値を渡すことでポート番号を変更することもできます。

% python -m http.server 9000

また、ドキュメントルートをカレントディレクトリではなく指定した場所にしたい場合は、-d か –directory 引数を使って次のように書きます。相対パスと絶対パスのどちらもサポートされています。

% python -m http.server --directory /tmp/

(自分の慣れだけの問題なのですが)個人的にはhttpdのサービスの起動をするよりも手軽だと感じているので、ちょいとWebサーバー建てたいな、という場面があればぜひ試してみてください。

最高密度区間 HDI (Highest Density Interval) について

これはpyMCに限った話ではないのですが、pyMC + ArviZ を使っていると頻繁に目にする指標なのでほぼpyMC関連の記事6記事目と見ていただいて良いと思います。

pyMCの結果を見ているとhdiという指標を頻繁に目にします。僕はこれについてあまり馴染みがなかったので説明を残しておこうと思います。

各記事のpyMCのサンプリング結果のsummaryをDataFrameにしたときも登場していますし、可視化したらグラフ中にも登場しますね。

参考: ArviZを使ってpyMCの推論結果を可視化する

このHDIは、確率分布の区間を切り取ったものです。指定の確率、(一般的には95%だそうですが、) ArviZでは94%がデフォルトで指定されており、その区間に含まれる確率が94%となるような区間を示します。

似たような概念で信頼区間や予測区間というのもありますが、違いはその区間の切り取り方です。信頼区間や予測区間では大抵、区間を切り取るとき、確率分布の上側と下側から同じ確率を切り抜きます。94%区間を取りたいのであれば、下から3%と上から3%を外して残りを取り出します。

それに対して、最高密度区間(HDI)では、その名の通り、確率密度が高い部分から優先的に確保して、指定の確率、今の例であれば94%となるように区間を切り出します。

山が一つあるような形の確率分布$f(x)$であれば、次を満たす区間$[a, b]$だと言えるでしょうか。$f(x)$が連続とか滑らかとか適切な過程が他にも必要かもしれません。

$$\begin{align}
&\int_{a}^{b} f(x)dx = 0.94\\
&f(a) = f(b)\\
&a < c < b \Rightarrow f(a) = f(b) < f(c)
\end{align}$$

注意として、確率分布が複数の山を持つような形だった場合、HDIは単一の区間ではなく、ふくすの区間に分かれて定義されることがあります。pyMCやArviZがそのような例に対応しているかどうかはちょっと確認しきれていません。(まだそういう例に出くわしたことがないので)

HDIの良い性質として、その指定の確率を含む区間の切り取り方としては一番狭い幅の区間になるというものがあります。ベイズ推論に限った話ではなく、何か数値を予想するのであればできるだけ狭いレンジに絞り込んで推論したいのでこれはありがたいですね。

pyMCで変化点検出

pyMC5の記事の5記事目です。
今回はpyMCをつかって変更点の検出をやってみます。といっても、題材が変更点の検出というだけでメインで取り上げたいのはデータ列の途中で期待値が変わるモデルを扱いたいってのと、事後分布からのサンプリングをやりたいっていう点の2点です。

題材の紹介

題材としては、少し古い本なのですが「Pythonで体験するベイズ推論」という、キャメロン・デビッドソン=ピロンさんの本から持ってきます。これの、受信するメッセージ数の変化がテーマです。

なぜこんな古い本を題材にしたかというと、新しい本を題材にするとそのままコードを書き写す感じになるので勉強にならないからです。この本のコードはpyMC3なのでpyMC5に焼き直すとコードは大幅に変わります。

サンプルデータはこちらのリポジトリにあります。

ただ、ここからダウンロドしてくるのも手間ですし、ただの整数列なのでそのままこの記事にも書いておきます。ちなみに、全部で74日分です。

data = [13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11,
 57, 11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13,
 19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2, 15,
 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20, 12, 35,
 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37, 5, 14, 13, 22
]

結果は最後に図示しますが、これ途中から件数が増えているように見えるのですよ。
それが、何日目から増えていて、その前後は期待値何件だったのか、というのをpyMCで推論していきます。

pyMCによるモデリング

ではやっていきましょう。

データが整数値なので、メッセージの件数はポアソン分布に従うとし、前半と後半で期待値$\lambda$が違うものとします。それぞれの$\lambda$はガンマ分布から持ってきましょう。
そして、変化する日はtauという変数名で離散一様分布からサンプリングします。

tauより前の日と以降の日で期待値が変わる点については、pm.math.switchを使って実装しました。この辺は本のコードと違います。

import matplotlib.pyplot as plt
import pymc as pm
import arviz as az


model = pm.Model()

with model:
    # 変化前の期待値
    lambda_1 = pm.Gamma('lambda_1', alpha=2, beta=1)
    # 変化後の期待値
    lambda_2 = pm.Gamma('lambda_2', alpha=2, beta=1)
    # 変化日
    tau = pm.DiscreteUniform('tau', lower=0, upper=len(data)-1)

    # idxは各データ点のインデックス
    idx = np.arange(len(data))
    # lambda_の値がtauの値によって決定される
    lambda_ = pm.math.switch(tau >= idx, lambda_1, lambda_2)
    
    # データの尤度
    obs = pm.Poisson('obs', mu=lambda_, observed=data)
    # サンプリング
    trace = pm.sample(draws=10000, tune=10000, chains=2)

サンプリングは少し多めに行っています。
結果を見ておきましょう。

print(az.summary(trace))
"""
	mean	sd	hdi_3%	hdi_97%	mcse_mean	mcse_sd	ess_bulk	ess_tail	r_hat
tau	43.221	0.875	42.000	44.000	0.016	0.011	3054.0	4512.0	1.0
lambda_1	17.406	0.617	16.275	18.576	0.004	0.003	18857.0	14591.0	1.0
lambda_2	22.035	0.859	20.354	23.597	0.007	0.005	15301.0	14084.0	1.0
"""

az.plot_trace(trace)
plt.tight_layout()

トレースを可視化したものがこちらです。

一瞬怪しいところがありますがtau = 44 あたりで変更しているのがわかりますね。

そこを閾にlambdaが増えていそうです。(17.4くらいから22くらいに。)

0日目から40日目くらいはでは、期待値17.4くらいで、45日目以降は22くらいと見て良さそうです。さて、その途中の41〜44日目はどう考えましょう?となった時に便利なのが事後分布からのサンプリングです。tauの分布を考慮していちいち計算しなくても、さっとサンプリングしてしまうことでそれぞれの日のメッセージ受信数の期待値の概算がわかります。

事後分布からのサンプリングに使うのが、 pm.sample_posterior_predictive です。これはモデルではなく、traceの方を渡してサンプリングを行います。サンプリングされるサイズが、最初に推論した時のステップ数に依存してしまうので、実はさっきサンプリングする時にdrawを大きめの値にしていたのです。draws=10000, chains=2 だったので、20000サンプルが得られます。

参考: pymc.sample_posterior_predictive — PyMC dev documentation

さて、やってみます。

# 既存のモデルとトレースを使用して、事後予測サンプルを生成
with model:
    # traceから事後予測サンプルを生成する際の修正
    posterior_predictive = pm.sample_posterior_predictive(trace, var_names=['obs'])

# 'obs'キーを使用して事後予測サンプルのデータを取得
posterior_obs = posterior_predictive.posterior_predictive['obs'].values

# 事後予測サンプルの期待値(平均)を計算
expected_values = np.mean(posterior_obs, axis=(0, 1))

axis=(0, 1) としているのは元の posterior_obs.shape が (2, 10000, 74) だからです。

これで、obsのサンプリング結果の期待値が得られました。元のデータと合わせて可視化してみましょう。

fig = plt.figure(facecolor="w")
ax = fig.add_subplot(1, 1, 1)
ax.bar(range(len(data)),data)
ax.plot(expected_values, c="orange")
plt.show()

いい感じですね。

pyMCで線形回帰分析

pyMCの記事4記事目です。これまでの記事では観測値だけ与えてそれを生成する確率分布を考えてきましたが、今回は観測値だけでなく何か特徴量を持つデータを考えます。
その最も単純な例として1変数の線形回帰をやってみましょう。特徴量が増えて重回帰分析になってもほとんど同じように対応できるので汎用性は高いと思います。

データの準備

何のデータを使ってもいいのですが、今回はscikit-learnのiris使います。3種類のアヤメのうち、virginicaに絞って、petal length (cm)からsepal length (cm)を予測するモデルを考えてみましょう。(相関係数が0.86くらいあって予測が簡単なのです。)

次のようにしてデータを取得します。

%%pycodestyle
import pandas as pd
from sklearn.datasets import load_iris


# irisデータ取得
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df["label"] = iris.target
# virginica に絞る
df = df[df.label == 2].reset_index(drop=True)

# 回帰分析に使う特徴量xと目的変数yを取得
x = df["petal length (cm)"].values
y = df["sepal length (cm)"].values

モデルの実装

データが揃ったらpyMCでモデルを作っていきます。

回帰係数を a, 定数項をb、誤差をε とするとモデルの式はこのようになりますね。
$$y= ax + c + \epsilon$$

ベイズで行いますので、それぞれに事前分布が必要です。
a, c の事前分布は期待値が0、標準偏差が10の正規分布としましょう。
そして、誤差項εは、期待値が0で、標準偏差がσの正規分布に従うとし、このσは標準偏差が10の半正規分布に従うとします。

これを実装してきますが、今回新たに使うのは、特徴料等の定数を格納する pm.ConstantData と 数式を定義できるpm.Deterministic です。

参考:
pymc.ConstantData — PyMC dev documentation
pymc.Deterministic — PyMC v5.6.0 documentation

正確には、 ConstantData の方は使わなくてもいいのですが、明示的に書いておくとモデルを可視化した時に定数部分も表示されるので便利です。

実際にコードを見ていただくと使い方がわかると思うのでやっていきましょう。例によって、Graphvizで可視化してJpyterで表示しています。

import numpy as np
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az


model = pm.Model()

with model:
    x_data = pm.ConstantData("x_data", x)
    y_data = pm.ConstantData("y_data", y)

    # 回帰係数と定数項
    a = pm.Normal("a", mu=0, sigma=10)
    c = pm.Normal("c", mu=0, sigma=10)

    # yの期待値
    mu = pm.Deterministic("mu", a*x_data + c)
    # 誤差
    sigma = pm.HalfNormal("sigma", sigma=10)

    # 観測値
    obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=y_data)

# モデルの可視化
g = pm.model_to_graphviz(model)
display(g)

参考ですが、ConstantDataを使わない場合はこうなります。

model2 = pm.Model()

with model2:

    # 回帰係数と定数項
    a = pm.Normal("a", mu=0, sigma=10)
    c = pm.Normal("c", mu=0, sigma=10)

    # yの期待値
    mu = pm.Deterministic("mu", a*x + c)
    # 誤差
    sigma = pm.HalfNormal("sigma", sigma=10)

    # 観測値
    obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=y)

g = pm.model_to_graphviz(model2)
display(g)

モデルができたのでサンプリングして結果を見ていきましょう。

with model:
    trace = pm.sample(random_seed=42, chains=2)


display(az.summary(trace, var_names=["a", "c", "sigma"]))
"""
	mean	sd	hdi_3%	hdi_97%	mcse_mean	mcse_sd	ess_bulk	ess_tail	r_hat
a	0.995	0.084	0.828	1.143	0.003	0.002	641.0	636.0	1.0
c	1.066	0.470	0.216	1.988	0.019	0.013	648.0	679.0	1.0
sigma	0.331	0.035	0.259	0.391	0.001	0.001	554.0	349.0	1.0
"""

az.plot_trace(trace, var_names=["a", "c", "sigma"], compact=False)
plt.tight_layout()

いい感じに推定できていますね。

回帰直線の可視化

せっかく単回帰したので、回帰直線を可視化してみたいと思います。上記のsummaryのa,cで可視化してもいいのですがせっかくなのでサンプリングの各ステップの値で可視化してみましょう。


# a, cの各ステップの値を取得
a_list = trace.posterior.a.values.ravel().reshape(-1, 1)
c_list = trace.posterior.a.values.ravel().reshape(-1, 1)

# xの範囲
x_values = np.array([4.4, 7.0])

# a, c の各値からyの値を計算
y_preds = x_values * a_list + c_list

# 回帰直線と元のデータを可視化
for y_pred in y_preds:
    plt.plot(x_values, y_pred, lw=1, alpha=0.01, c="c")
plt.scatter(x, y)
plt.show()

なかなか妥当な結果が得られましたね。

まとめ

今回は線形回帰を題材として取り上げましたが、線形回帰に限らず特徴量を使うモデリングは同じようにして実装していくことができます。pm.Deterministicを使うと一気に実装の幅が広がりますのでぜひ試してみてください。