SciPyの optimize.minimizeを使って関数が最小値を取る点を探す

最近よく使っている、scipyの最適化関数の一つであるminimizeについて、まだ記事を書いてなかったので紹介します。

公式ドキュメントはこちらです。
参考: minimize — SciPy v1.14.1 Manual

これはタイトルの通りで、数値を返す関数を渡すとその関数が最小値をとる引数を探してくれるものです。ちなみに、最大値になる引数を探すメソッドはないので最大値を探したかったら、その関数に-1をかけて符号を反転させた関数を用意してください。

サクッと一つやってみましょう。正解がわかりやすいよう、2次関数でも例にとりましょうか。次の関数を使います。

$$x^2-6x+5 = (x-3)^2 -4 $$

平方完成から分かる通り、$x=3$で最小ですね。

import numpy as np
from scipy.optimize import minimize

# 目的関数の定義
def sample_function(x):
    return x**2 - 6*x + 5

# 初期値
x0 = 0

# 最適化の実行
result = minimize(sample_function, x0, method='L-BFGS-B')

# 結果の表示
print("最適化の結果:", result)
print("最小値をとる点 x:", result.x)
print("最小値 f(x):", result.fun)
"""
最適化の結果:   message: CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
  success: True
   status: 0
      fun: -3.9999999999999982
        x: [ 3.000e+00]
      nit: 2
      jac: [ 1.776e-07]
     nfev: 6
     njev: 3
 hess_inv: <1x1 LbfgsInvHessProduct with dtype=float64>
最小値をとる点 x: [3.00000004]
最小値 f(x): -3.9999999999999982
"""

上のサンプルコードのように、最初の引数で最適化したい関数、2個目の引数で初期値、そしてオプションで利用するアルゴリズムなどの各種設定を行います。

ここで注意しないといけないのは、渡した関数の第1引数だけが最適化の対象ということです。なので、元の関数が複数の引数をとる多変数関数の場合、最適化したい引数等を一個の配列にまとめて渡す関数でラップしてあげる必要があります。また、最適化対象外の値を固定する引数については、argsで固定します。

例えば、$f(x, y, a, b) = x^2 + ax + y^2 + by$ みたいな関数があって、a, bは固定したとき、これを最小にする $x, y$を求めたい場合次のようにします。

# 元の関数(オリジナルの目的関数)
def original_function(x, y, a, b):
    return x**2 + a*x + y**2 + b*y


# 最適化用にラップする関数
# 元の関数のx, y を xという配列で渡してx[0], x[1]として内部で使っている
def optimization_target_function(x, *params):
    return original_function(x[0], x[1], *params)


# 最適化の実行
x0 = [0, 0]
result = minimize(optimization_target_function, x0, args=(4, -6), method='L-BFGS-B')

# 結果の表示
print("最適化の結果:", result)
print("最小値をとる点 x:", result.x[0])
print("最小値をとる点 y:", result.x[1])
print("最小値 f(x, y):", result.fun)

"""
最適化の結果:   message: CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
  success: True
   status: 0
      fun: -12.999999999999925
        x: [-2.000e+00  3.000e+00]
      nit: 2
      jac: [-1.776e-07 -7.105e-07]
     nfev: 9
     njev: 3
 hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>
最小値をとる点 x: -2.0000000804663682
最小値をとる点 y: 2.9999997404723753
最小値 f(x, y): -12.999999999999925
"""

最適化したい引数が何変数なのか?ってのを一見どこでも指定してないように見えて不安になるのですが、これは初期値の配列の要素数から自動的に判断してくれています。

結果は result.x に入っているのでここから自分で取得します。

本当に基本的な使い方は以上になります。

少し発展的な内容なのですが、最適化のアルゴリズムは何種類もある中から選べますが、中には導関数を利用するものもあります。このminimizeに導関数を渡さない場合は、数値微分でいい感じにやってくれるのですが、明示的に導関数を指定することも可能です。

やり方は簡単で、元の関数と同じ形で引数を受け取るように導関数を定義して、jac引数に渡すだけです。1変数の場合は簡単すぎるので2変数の例を出しますが、2変数の場合は第1引数による微分と第2引数による微分の2つがあるので、その結果を配列で返す関数を定義してそれを渡します。

例えば、$f(x,y)=x^2+y^2+4x+6y+13$ みたいなのを考えてみましょう。$x, y$での微分はそれぞれ、$2x+4$, $2y+6$ですね。

# 目的関数
def objective_function(vars):
    """
    2変数関数 f(x, y) = x^2 + y^2 + 4x + 6y + 13
    vars: [x, y]
    """
    x, y = vars
    return x**2 + y**2 + 4*x + 6*y + 13

# 勾配(偏微分値)
def gradient_function(vars):
    """
    勾配 ∇f(x, y) = [∂f/∂x, ∂f/∂y]
    vars: [x, y]
    """
    x, y = vars
    grad_x = 2 * x + 4
    grad_y = 2 * y + 6
    return np.array([grad_x, grad_y])  # 勾配ベクトルを返す

# 初期値
x0 = [0, 0]

# 最適化の実行 (勾配を指定)
result = minimize(objective_function, x0, jac=gradient_function, method='L-BFGS-B')

# 結果の表示
print("最適化の結果:", result)
print("最小値をとる点 [x, y]:", result.x)
print("最小値 f(x, y):", result.fun)

"""
最適化の結果:   message: CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
  success: True
   status: 0
      fun: 0.0
        x: [-2.000e+00 -3.000e+00]
      nit: 2
      jac: [ 0.000e+00  0.000e+00]
     nfev: 3
     njev: 3
 hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>
最小値をとる点 [x, y]: [-2. -3.]
最小値 f(x, y): 0.0
"""

関数が簡単で計算量も小さい場合は数値微分でも特に問題ないのですが、そうでない場合は明示的に導関数を渡すことでリソースを節約することもできます。

長くなってきたので今回の記事はここまでにしようと思います。境界条件とうの制約をつける方法とかを今後紹介したいですね。

最後にminimizeを使う時の注意点ですが、これ、極小値が複数あるような関数の場合、最小値ではなく極小値の一つを返してくることがあります。結果が初期値に依存してしまうのです。複雑な形の関数を最適化する場合は注意してください。

Pythonコードでimportに失敗するライブラリのバージョンを確認する

とある特殊な環境でPythonを書いていて、いくつかのライブラリがimportに失敗するという事態に遭遇しました。自分のローカルPC上であれば、pip freeze とかしてライブラリのバージョンを調べて原因を調査するのですが、その環境ではOSコマンドが打てず、同様の調査が不可能でした。importさえできれば {ライブラリ名}.__version__ みたいなプロパティから取得することもできたのですがimport自体が失敗するとあって調査に苦戦していました。

ところがどうやらimportを行わずにライブラリのメタデータにアクセスする方法がちゃんとあるようだったのでこの記事にまとめておきます。(正直、普通の環境であればpipで調べれれば済む話なので、ほとんどの人にとっては不要な知識だと思います。)

pkg_resources を使う方法

importせずにライブラリのバージョンを取得する方法の1個目は pkg_resources を使うものです。

次の例は、 arviz というライブラリのバージョンを調べたものです。

from pkg_resources import get_distribution


try:
    version = get_distribution("arviz").version
    print(f"version: {version}")
except pkg_resources.DistributionNotFound:
    print("ライブラリが見つかりません")

# version: 0.16.1

僕は上記の方法で一旦解決しました。

ただ、これはsetuptoolsに依存したライブラリなのですが、その公式ドキュメントを見ると、もう廃止されたからimportlib.metadataを使えと書いてあるのですよね。

ということで合わせてそちらを紹介します。

importlib を使う方法

importlib はPython3.8から標準ライブラリに含まれたライブラリです。

これはインストール済みのライブラリのメタデータにアクセスする機能を持っています。標準ライブラリになったわけなので、最近のPythonであればおそらくこちらを使うのが適切なんだと思います。

参考: importlib.metadata — パッケージメタデータへのアクセス — Python 3.13.0 ドキュメント

サンプルコードでもバージョンを取得していますね。それにならってやってみましょう。

import importlib.metadata


try:
    version = importlib.metadata.version("arviz")
    print(f"version: {version}")
except importlib.metadata.PackageNotFoundError:
    print("ライブラリが見つかりません")

# version: 0.16.1

同じような結果が得られました。

余談ですが、公式ドキュメントのサンプルコードでは、
from importlib.metadata import version
として versionという関数をインポートしています。僕も最初それに倣ってやったのですが、versionを結果の変数名で使いたいな、と思ったので上記のインポート方法にしました。
ただ、これはこれでイマイチな気もします。

pandasのwhereとmask

前回の記事で、np.whereという関数の紹介をしたのですが、pandasにも同名のwhereっていうメソッドがあるので紹介します。また、非常に似た挙動のmaskというメソッドもありますので合わせて書きます。

pandasのwhereとmaskはDataFrameやSeriesが持っているメソッドです。
参考:
pandas.DataFrame.where — pandas 2.2.3 documentation
pandas.DataFrame.mask — pandas 2.2.3 documentation

挙動はnumpyのwhereと似ている部分があり、条件に応じて要素を置き換えます。ただし、使い方が少しだけ異なっており、np.where()のようにnumpy自体が持っていた関数ではなくDataFrameやSeriesなどのメソッドなので元の値が利用される分、np.whereより引数が一つ少なくなります。

1個目の引数に条件、2個目の引数に置き換える値(省略すればNoneになります。また関数を指定することもできます。)を入れて使用します。

そしてこの条件の扱いがwhereとmaskで異なります。
whereは条件がFalseの場合に値を置き換えmaskは条件がTrueの場合に値を置き換えます。

例えば、0~9の値を並べたデータフレームで、3の倍数かどうかという条件で置き換え対象を負の数にするような書き方をすると、whereは3の倍数以外の数がマイナスになり、maskの方は3の倍数がマイナスになります。

import pandas as pd
import numpy as np


df = pd.DataFrame(np.array(range(10)).reshape(5, 2))
print(df)
"""
   0  1
0  0  1
1  2  3
2  4  5
3  6  7
4  8  9
"""

# 3の倍数が条件を満たすのでそのまま残り、それ以外がマイナスになる。
print(df.where(df%3==0, -df))
"""
   0  1
0  0 -1
1 -2  3
2 -4 -5
3  6 -7
4 -8  9
"""

# 3の倍数が条件を満たすのでマイナスになる。
print(df.mask(df%3==0, -df))
"""
   0  1
0  0  1
1  2 -3
2  4  5
3 -6  7
4  8 -9
"""

上記の例は、2個目の引数に元のデータと同じ形のデータフレームが渡されていますが、2個目の引数は定数を渡すこともできるし、関数を渡すこともできます。

例えば奇数を定数-1に置き換えたり、奇数を3倍して1足すようなメソッドは次のようになるでしょう。

print(df.mask(df%2==0, -1))
"""
   0  1
0 -1  1
1 -1  3
2 -1  5
3 -1  7
4 -1  9
"""

print(df.mask(df%2==1, lambda x: 3*x+1))
"""
   0   1
0  0   4
1  2  10
2  4  16
3  6  22
4  8  28
"""

「こういう値に対してこうしたい」っていう日本語の説明に対して直感的に書けるのはmaskの方ですね。fillna()の汎用版みたいなイメージで使いやすいです。

whereの方は、「こういう条件を満たす値はそのままでいいんだ、そうでは無いのを置き換えたい」っていうイメージでしょうか。

np.whereで効率的に値を出し分ける

今回もnumpyのテクニックの紹介です。np.whereというメソッドを解説します。

参考: numpy.where — NumPy v2.0 Manual

これは何かというと、第1引数にTrue/Falseで評価できるデータの配列を渡すとその評価に応じてTrueなら第2引数、Falseなら第3引数の値を返す、というものです。

第2, 第3引数に渡すのは第1引数に渡した配列と同じ長さ(多次元なら全て同じ)でも良いし、定数であったり、ブロードキャストすれば同じ形にできるものなら何でも良いです。

一番シンプルな例としては、条件を満たすかどうかでそれぞれ異なる定数を返すようなものでしょうか。

import numpy as np


scores = np.array([45, 85, 72, 50, 90])
results = np.where(scores >= 60, '合格', '不合格')
print(results)
# ['不合格' '合格' '合格' '不合格' '合格']

説明いらないと思いますが、60点以上なら合格、と判定するメソッドですね。

上記の例のように、事前にTrue/False の配列を作っておくのではなく、何かしらの条件式を代1引数に渡すような使い方になると思います。条件に応じて何かしらの演算を行いたい場合は、第2, 第3引数に計算式を入れて結果を渡すような形になります。例えば、偶数なら1/2, 奇数なら 3倍して1を足す、みたいな処理をするならこうです。

np.where(scores%2 == 0, scores/2, 3*scores+1)
# array([136., 256.,  36.,  25.,  45.])

1次元配列の場合は、内包表記でもほぼ同じことができるのでありがたみが薄いですが、np.whereは多次元配列で便利なことがあります。(単純に、内包表記の方が不便になるだけという見方もできますが。)

自分が最近使った例としては、欠損値がある行列Aと別の行列Bがあった時に、欠損値以外は元の行列Aの値、欠損してる部分はBの値、で埋めたいというものでした。

これが次のようにして簡単に行えます。

A = np.array(
        [[1, 2, 3,], [np.nan, 5, 6], [7, np.nan, 9]]
    )
B = np.array(
        [[11, 12, 13,], [14, 15, 16], [17, 18, 19]]
    )

print(np.where(np.isnan(A), B, A))
"""
[[ 1.  2.  3.]
 [14.  5.  6.]
 [ 7. 18.  9.]]
"""

コードがすっきり書けること以外にもベクトル処理が行えることによるパフォーマンス面のメリットなど、利点があるので機会があれば使ってみてください。

Nanを含むnumpy配列のデータを専用メソッドで手軽に集計する

numpyのちょっとしたテクニックの話です。僕は最近まで知らなかったのですが、numpyには np.nansum など nan + 集計関数名 という命名規則のメソッド群が用意されています。これの紹介をします。

前提として、 numpy配列の値を合計したり平均を取ったりする時、データ中にnanがあると結果もnanになります。pandasのSeriesの場合と挙動が違うのですね。例えば以下のような感じです。(Seriesと挙動違うよという話は以前どこかの記事で書いた覚えがあります)

import numpy as np
import pandas as pd


# nanを含むデータを作る
ary = np.array([1, 1, 2, 3, np.nan, 8])
print(ary)
# [ 1.  1.  2.  3. nan  8.]

# 合計するとnanになる
print(ary.sum())
# nan

# 平均も同様
print(ary.mean())
# nan

# Series はnanを無視してくれる
print(pd.Series(ary).sum())
# 15.0
print(pd.Series(ary).mean())
# 3.0

欠損値の存在に気づくきっかけになったりしてありがたいこともありますし、仕様としてどうあるべきかを考えたらnullの伝播が実装されているこの作りが正しいと思えるのですが、この挙動が不便なことが多いのも事実です。

僕はこういう時大体Seriesに変換してしまって集計していました。

ただ、実は numpyにもNanに対応したメソッドがちゃんとあり、それが冒頭に書いたnansumです。maxにはnanmax, stdにはnanstd のように多くのメソッドに対して実装されています。

dir()で探すと一覧額作成できます。

for m in dir(np):
    if m.startswith("nan"):  # メソッド名がnanで始まるか
        if m.replace("nan", "") in dir(np):  # nanの部分を除外した場合に同じ名前のメソッドがあるか
            print(m)
"""
nanargmax
nanargmin
nancumprod
nancumsum
nanmax
nanmean
nanmedian
nanmin
nanpercentile
nanprod
nanquantile
nanstd
nansum
nanvar
"""

これらを使うと、エラーが起きずにnanを無視して無視して残りの要素について集計してくれます。

print(np.nansum(ary))
# 15.0
print(np.nanmean(ary))
# 3

1次元配列の場合は内包表記での対応とか色々やり方もあるのですが多次元になってくると面倒だし集計のために補完するのも面倒なのでありがたいですね。使い方がnp.nansum(ary)であって、ary.nansum() では無いので注意してください。

もう一点、 np.nan ではなく、Noneを含めてるとこれは数値の欠損値では無いので相わらずエラーになります。ここも注意です。

ary2 = np.array([1, 1, 2, 3, None, 8])

try:
    np.nansum(ary2)

except Exception as e:
    print(e)
# unsupported operand type(s) for +: 'int' and 'NoneType'

Pythonでマルチプロセス処理

前回の記事がマルチスレッドだったので今回はマルチプロセスを紹介します。

Pythonにおけるマルチプロセスの1番のメリットはGILの制約を回避できることでしょうね。

ただ、先に書いておきますが、この記事で書いている方法はJupyter notebookのセルに直接書くと正常に動作せずエラーになることがあります。.pyファイルを作成してそこに記入して使うようにしましょう。

マルチプロセスを実装するには、最近はconcurrent.futuresのProcessPoolExecutorを使います。
参考: concurrent.futures — 並列タスク実行 — Python 3.12.6 ドキュメント

ドキュメントのサンプルコードを参考に動かしてみましょう!
例として取り上げられているのは素数判定ですね。Pythonで処理が完結するのですが、GIL制約のためマルチスレッドだと高速化の恩恵が受けられないものです。

from concurrent.futures import ProcessPoolExecutor
import math


PRIMES = [
    112272535095293,
    112582705942171,
    112272535095293,
    115280095190773,
    115797848077099,
    1099726899285419
    ]

def is_prime(n):
    print(f"整数 {n} を素数判定します")
    if n < 2:
        return False
    if n == 2:
        return True
    if n % 2 == 0:
        return False

    sqrt_n = int(math.floor(math.sqrt(n)))
    for i in range(3, sqrt_n + 1, 2):
        if n % i == 0:
            return False
    return True

def main():
    with ProcessPoolExecutor() as executor:
        for number, prime in zip(PRIMES, executor.map(is_prime, PRIMES)):
            print('%d は素数か: %s' % (number, prime))

if __name__ == '__main__':
    main()

# 以下実行結果
"""
整数 112272535095293 を素数判定します
整数 112582705942171 を素数判定します
整数 112272535095293 を素数判定します
整数 115280095190773 を素数判定します
整数 115797848077099 を素数判定します
整数 1099726899285419 を素数判定します
112272535095293 は素数か: True
112582705942171 は素数か: True
112272535095293 は素数か: True
115280095190773 は素数か: True
115797848077099 は素数か: True
1099726899285419 は素数か: False
"""

最初にそれぞれの値の素数判定が始まってる旨のメッセージが出てその後に結果が順番に出てきたので、並行して処理されているのが確認できました。

is_prime(n)が並行して実行している処理です。

ProcessPoolExecutor() でエクゼキューターを作成して、今回は submit()ではなく、mapで適用していますね。map()には第一引数で並列実行したい関数を渡し、次の引数でその関数に渡す引数のリストを渡します。

submit と map はどちらもProcessPoolExecutor や ThreadPoolExecutor の継承元の抽象クラスのExecutor に実装されているメソッドなので、実はマルチプロセスとマルチスレッドのどちらでも両方使うことができます。お好みの方で書いたらよさそうです。

細かい挙動は異なっていて、前回のsubmit()ではas_completed()を使って終わった順番に処理を取り出していましたが、map()を使う場合は、処理自体は並列して同時に行われて順不同で完了しますが、結果の取り出しは渡した引数の順番になります。

Pythonでマルチスレッド処理

とっくの昔に、threadingを使ったマルチスレッド処理について記事を書いていたつもりだったのに、まだ書いてないことに気づきました。(そして、マルチプロセスの処理についてもまだ書いてませんでした。)

それでは気づいたこのタイミングで記事にしようと思ったのですが、改めてドキュメントを見てみると、concurrent.futures というより高レベルなモジュールがあるとのことでしたので、こちらを利用したマルチスレッド処理について紹介します。

先に言っておきますが、PythonにはGIL (Global Interpreter Lock) という制約があって、マルチスレッドにしたとしても、Pythonインタープリタは一度に1つのスレッドしか実行できません。なので、Pythonで完結するプログラムはマルチスレッドしても高速化の恩恵はありません。では、いつマルチスレッドは使うのかというと、Python外部のリソース(ストレージとかOSの処理とかWebアクセスとか)の待ち時間が発生する場合になります。

前置きが長くなってきましたが、実際に、concurrent.futuresを使ったマルチスレッドの並列処理のサンプルコードを紹介します。concurrent.futures.ThreadPoolExecutor というのを使います。
参考: concurrent.futures.ThreadPoolExecutor

5つのサイトへのアクセスを並列でやってみましょう。

import concurrent.futures
import requests
import time


# 取得するURLのリスト
URLS = [
    'http://www.example.com',
    'http://www.python.org',
    'http://www.openai.com',
    'http://www.wikipedia.org',
    'http://www.github.com'
]


# URLからコンテンツを取得する関数
def fetch_url(url):
    print(f"実行開始: {url}")
    response = requests.get(url)
    print(f"実行完了: {url}")
    return url, response.status_code, len(response.content)


# マルチスレッドでURLを並列取得する
start_time = time.time()

with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
    # 各URLに対してfetch_url関数を並列実行
    futures = {executor.submit(fetch_url, url): url for url in URLS}

    for future in concurrent.futures.as_completed(futures):
        url = futures[future]
        try:
            url, status, content_length = future.result()
            print(f"URL: {url}, Status: {status}, length: {content_length}")
        except Exception as e:
            print(f"{url}でエラーが発生しました: {e}")

print(f"処理時間: {time.time() - start_time}秒")

# 以下結果
"""
実行開始: http://www.example.com
実行開始: http://www.python.org
実行開始: http://www.openai.com
実行開始: http://www.wikipedia.org
実行開始: http://www.github.com
実行完了: http://www.python.org
URL: http://www.python.org, Status: 200, length: 50928
実行完了: http://www.github.com
URL: http://www.github.com, Status: 200, length: 254186
実行完了: http://www.openai.com
URL: http://www.openai.com, Status: 403, length: 14186
実行完了: http://www.example.com
URL: http://www.example.com, Status: 200, length: 1256
実行完了: http://www.wikipedia.org
URL: http://www.wikipedia.org, Status: 200, length: 78458
処理時間: 0.49734020233154297秒
"""

ドキュメントのコードをもとにしていますが、fetch_url()メソッドの最初と最後にprit文を差し込んで5つのURLについて同時に処理が進んでいるのが分かるようにしました。開始と終了が異なる順番で結果がprintされていて、並列で動いてた感がありますね。

さて、上記コードの fetch_url() がマルチスレッドで実行されていた関数本体ですが、 肝心のThreadPoolExecutorはかなり使い方にクセがあります。

oncurrent.futures.ThreadPoolExecutor(max_workers=5) でエグゼキューターを作って、submit()や、as_completed()というメソッドを使っていますね。

submit() は実行キューへタスクを送信するメソッドです。

そして、もう一つ、oncurrent.futures.as_completed() というのを使っています。
こちらは、送信された非同期タスクが完了した順にFutureオブジェクトを返すジェネレータ関数です。これを使うことで、並列で動いていたメソッドが完了した順に、後続の処理を行うことができます。
上の例では、future.result() でメソッドの戻り値を受け取って、順次printしています。

使い所は慎重に選ばないと高速化等の効果は得られないですし、書き方にクセがあるので、慣れないと少々戸惑うのですが、ハードウェアアクセスの待ち時間が長い時や外部リソースへのアクセスを伴う処理の高速化では非常に役に立つものなので機会があったら使ってみてください。

Streamlitでアニメーション

今回はStreamlitでアニメーションを作成します。

といっても、やることは以前紹介したプレースホルダーの中身を順次更新し続けるだけ、という実装です。
参考: Streamlitのコンテナを使って動的にページを表示する

アニメーションさせるためには一つの枠を連続的に書き換えて画像を表示するので、st.empty() を使います。

とりあえず一個やってみましょう。画像の描写はmatplotlibを使ってみました。お試しなのでアニメーションの内容は線分をぐるぐる回すだけです。(両端を三角関数で実装します。)

import streamlit as st
import matplotlib.pyplot as plt
import numpy as np


# 描画エリアを設定
fig = plt.figure()
ax = fig.add_subplot(111)

ax.set_xlim(-1.2, 1.2)
ax.set_ylim(-1.2, 1.2)

# アニメを描写するプレースホルダーを作成
placeholder = st.empty()

# Streamlitのアニメーション表示
for i in range(100):
    ax.clear()
    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    ax.plot(
        [np.cos(i*0.1), -np.cos(i*0.1)],
        [np.sin(i*0.1), -np.sin(i*0.1)]
    )
    
    # プレースホルダーを更新
    placeholder.pyplot(fig)

これで線がぐるぐる回るアニメーションが表示できます。

あれ、time.sleep(0.01)とかウェイトを入れておかないとこのfor文が一瞬で終わってしまうんじゃないの?と思われるかもしれませんが、実験してみたところちょうど良い感じにアニメーションになりました。

どうもstreamlitの仕様として一枚一枚の画像の表示(pyplot)にウェイトがかかっているようです。

これは結構大きなメリットで、あまり表示時間とか気にせずにいい感じのアニメーションが作れます。

一方で、time.sleep(0.01) で0.01秒間隔の表示で1000フレーム使ってピッタリ10秒の動画を作ろう!みたいな調整は困難です。まぁ、これはstreamlitは動画作成を念頭に置いたものではないと思うので仕方ないですね。

ただし、デフォルトだと動作が早すぎるという場合はtime.sleep()を使ってウェイトを増やしましょう。

StreamlitでPyGWalkerを動かす

2連続のPyGWalker関係の話です。そして、相変わらずStreamlit関係の記事です。

発展著しい両ライブラリですが、StreamlitにPyGWalkerを埋め込んで動かすこともできます。そして、使用感としてはJupyterよりStreamlitに埋め込んだ方が使い勝手がいいですね。

使用するメソッドですが、streamlitがPyGWalkerを埋め込むメソッドを持っているわけではなく、PyGWalker側がStreamlit上で動作するメソッドを持っているので注意が必要です。Streamlit側のドキュメントを読み込んでも本機能についての記述は出てきません。(少なくとも今日時点では。)

こちらを読みます。
参考: PyGWalkerとStreamlitを使ったデータの探索と情報共有 – Kanaries

StreamlitRenderer というのを使えば良いのですね。そして、設定を保存するspec引数もあります。

1点、Jupyterで動かす場合との違いなのですがセキュリティ上の理由なのかわかりませんがデフォルトではspecで指定したjsonファイルへの書き込み、要するに保存ができません。すでにどこかで保存されたダッシュボードの読み込みだけが可能という挙動になります。

これは、spec_io_mode 引数がデフォルトで”r” (読み込みモード) になっているためです。Streamlit上で作ったビューをそのまま保存したい場合は、”rw” を指定する必要があります。

注意点はこれだけなので、早速やってみましょう。データは何でもいいのでまたワインです。

from pygwalker.api.streamlit import StreamlitRenderer
import pandas as pd
import streamlit as st
from sklearn.datasets import load_wine

st.set_page_config(layout="wide")

# データ読み込み
wine = load_wine()
# 特徴量名を列名としてDataFrame作成
df = pd.DataFrame(
    wine.data,
    columns=wine.feature_names,
)

# target列も作成する。
df["target"] = wine.target
df["class"] = df["target"].apply(lambda x: wine["target_names"][x])

pyg_app = StreamlitRenderer(
    df,
    spec="./st_config.json",
    spec_io_mode="rw"
)
pyg_app.explorer()

一番最後の、explorer()を忘れないように注意してくださいね。

これでStreamlit上でもTableau風のUIでグラフを描けるようになりました。

PyGWalkerのダッシュボード設定を保存する

Tableau public でローカルファイルセーブが実装されたのでやや存在感が薄れているのですが、TableauライクなダッシュボードをPythonで作れるPyGWalkerの記事2本目です。

前回書いたのがこのライブラリが登場した直後だったので、当時は今と比べるとまだ基本的な機能も揃っていなかったのですが、現時点では待望のダッシュボードの保存機能が実装されているのでその紹介です。

参考: PyGWalkerでデータフレームを可視化してみる

これ、使い方はすごく簡単で、walk メソッドで起動する時に、jsonファイルのパスをspec引数へ渡し、ダッシュボードを作ったら保存ボタンを押すだけです。(自動保存は今日時点ではサポートされていないらしい。

ReadMe にも記載がありますね。

前回の記事と同じようにワインのデータでやってみましょう。

import pandas as pd
from sklearn.datasets import load_wine
import pygwalker as pyg


# データ読み込み
wine = load_wine()
# 特徴量名を列名としてDataFrame作成
df = pd.DataFrame(
    wine.data,
    columns=wine.feature_names,
)

# target列も作成する。
df["target"] = wine.target
df["class"] = df["target"].apply(lambda x: wine["target_names"][x])

walker = pyg.walk(df, spec="./config.json"). # 設定の保存先をspecで指定 

こうすると、spec で指定したファイルが存在しなければ自動的に作成されます。そしてそこに設定が保存されます。起動時点で、specで指定したファイルが存在していたらそれが読み込まれて前回の続きから作業ができます。

繰り返しですが、「保存」そのものは自動ではやってくれないので、Saveのアイコンを確実に押しましょう。かなりわかりにくいですがこいつです。

Saveの文字はマウオーバーして出てきてくだけなので、その下の歯車付きテキストファイルのようなアイコンを探してください。