Streamlitの変数がリセットされないようにsession単位で保存する

今回もStreamlitのお話です。前回、各種ウィジェットの使い方を紹介しましたが、ウィジェットを何かしら操作するとコードが一通り再実行されます。
ウィジェットの値自体は保存されているのですが、それ以外のPythonコード中の変数は全てリセットされてしまいます。

例えば、次のように、スライダーで値を選択してボタンを押したらそれが合計されていく、みたいなことをしようとした(しかしできてない)ソースを書いて見ます。

import streamlit as st


sum_ = 0
value = st.slider(label="スライダー", min_value=0, max_value=10)
if st.button("合計"):
    sum_ += value
    st.write(sum_)

これ、スライダーで値を選んで、「合計」ボタンを押したらどんどん値が足されてその結果が表示されそうなコードに見えませんか?

しかし実際はそのような挙動にならず、スライダーの値を変更したり、ボタンをクリックしたりすると一番最初のsum_=0の部分も再実行されるので値は積み重なっていかず、逐一リセットされます。

また、スライダー動かしただけ、の場合、st.write() の部分で書き出した数字まで消えます。(ボタン押したときだけ数字が表示される。)

このような場合に使うのが、st.session_stateです。
参考: Session State – Streamlit Docs

これはほとんど辞書のように使えます。
一番最初はst.session_stateが空っぽなのでif文で判定を入れて初期化し、それ以降は辞書の値を読み書きするのと同じイメージで実装していきます。先ほどのコードは次のようになります。

import streamlit as st


# session_state に値がない場合初期化する。
if "sum_" not in st.session_state:
    st.session_state.sum_ = 0

value = st.slider(label="スライダー", min_value=0, max_value=10)
if st.button("合計"):
    # session_stateのkeyに対応した値を更新する
    st.session_state.sum_ += value

# 表示
st.write(st.session_state.sum_)

スライダーを動かした時に消えないようにst.writeはif文の外に出しました。これで、合計ボタンを押すたびにスライダーで指定している値分だけ数字が足され、合計が表示されます。

これは単純に変数の値を保存しておくだけでなく、DBやWebなどの外部リソースからのデータ取得を伴う処理の効率化、要するにキャッシュにも使えます。

df = {DBからデータを取ってくる処理}
みたいなのをそのまま実装してしまうと、ウィジェットを何か動かすたびにSQLを発行してDBに負担をかけてしまいます。ここで、一度データを取得したらst.session_state.df などに値を保存しておき、値があったら再取得しないような実装にすると、DBなどの外部リソースへの負荷を減らし、レスポンス速度も改善できます。

ちなみに、値を消したかったら、対象のkeyに対してdelをやれば良いです。
del st.session_state.{key}

また、GUI上でも、右上の3点リーダーにClear cache というメニューがあるのでこれで消せます。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です