mplfinanceの株価チャートに指標を追加する

前回紹介した、mplfinanceの使い方の続編です。前回はただ単純に四本値でローソクチャートを書きましたが、今回はそれに各種テクニカル指標等を追加する方法を紹介します。
参考: mplfinanceで株価チャートを描く

免責事項
本記事では株式や為替などの金融商品の価格に関するデータをサンプルとして利用しますが、効果や正当性を保証するものではありません。本ブログを利用して損失を被った場合でも一切の責任を負いません。そもそも、この記事ではライブラリを使って図に線や点を追加する方法を紹介しているだけであり、著者自身がこの記事で登場している指標の投資における有用性を検証していませんし、投資にも利用していません。あくまでもこういうコードを書いたらこう動くという例の紹介です。

おきまりの文章が終わったのでやってきましょう。前回同様、こんな感じのデータがあるとします。

print("データ件数:", len(price_df))
# データ件数: 182

print(price_df.head(5))
"""
              open    high     low   close   volume
date                                               
2022-01-04  3330.0  3340.0  3295.0  3335.0  75300.0
2022-01-05  3355.0  3370.0  3320.0  3360.0  62200.0
2022-01-06  3345.0  3370.0  3305.0  3305.0  67200.0
2022-01-07  3315.0  3330.0  3265.0  3295.0  84500.0
2022-01-11  3300.0  3300.0  3215.0  3220.0  87400.0
"""

前回はこれをただそのまま表示しましたが、今回はそれに追加して以下の情報を追加で書いていこうと思います。追記する場所として、四本値のローソク足のパネル、出来高のグラフのパネル、新しいパネルを用意してそこに書き込む、の3種類ができるということ、何かしらの指標を線で表示することや特定の点にマーカーをプロットできる、ということを例示するために以下のような例を考えました。

  • 20日間の高値、安値の線 (四本値のパネル)
  • 高値、安値を更新した点のプロット(四本値のパネル)
  • 過去3日間の出来高の合計(出来高のパネル)
  • 前日の終値と当日の終値の差分(新規のパネル)

まずプロットするデータを作ります。データは四本値のデータと同じ長さで、インデックスの日付が共通のDataFrameかSeriesである必要があります。高値線や安値線などのテクニカル指標の場合は大丈夫雨だと思うのですが、1~2ヶ所点をプロットしたいだけ、といった場合であっても、点をプロットない場所は全部NaN値を入れる形で、同じ長さのデータを作らなければなりません。それだけ気をつければ特に詰まるところはないと思います。
まぁ、元のデータのDataFrameに列を追加する形で作っていけば確実でしょう。

高値安値とその更新した点は次のように作りました。更新点はちょっと上下にずれた位置にプロットしたかったので、それぞれ1.02倍/0.98倍しています。

import pandas as pd
import numpy as np


# 高値線、安値線
price_df["h_line"] = price_df.rolling(20, min_periods=1).max()["high"]
price_df["l_line"] = price_df.rolling(20, min_periods=1).min()["low"]

# 高値、安値を更新した日付を算出
h_breakout_flg = price_df["h_line"].shift(1) < price_df["high"]
l_breakout_flg = price_df["l_line"].shift(1) > price_df["low"]

# 更新日にプロットする値を用意する。
price_df["h_breakout"] = np.nan
price_df["l_breakout"] = np.nan
price_df.loc[h_breakout_flg, "h_breakout"] = price_df.loc[h_breakout_flg, "high"] * 1.02
price_df.loc[l_breakout_flg, "l_breakout"] = price_df.loc[l_breakout_flg, "low"] * 0.98

出来高の3日の和と終値の前日との差分はそれぞれ次のように作れます。

price_df["volume_sum"] = price_df.rolling(3)["volume"].sum()
price_df["close_diff"] = price_df["close"].diff()

これで、サンプルデータが出揃ったので可視化していきます。ドキュメントは前回同様Githubのサンプルコードを参照します。今回見るのはこれです。
参考: mplfinance/addplot.ipynb at master · matplotlib/mplfinance

ドキュメントは頼りないので必要に応じてソースコードもみましょう。

チャートに指標を追加するには、plotメソッドを呼び出すときに、addplot 引数に追加したい指標の情報をdictで渡します。複数追加したい時は追加したい指標の数だけのdictをlistにまとめて渡します。(今回やるのはこちら)。

addplot に渡す辞書というのは以下の例のような結構大掛かりな辞書です。

{'data': date
 2022-01-04       NaN
 2022-01-05    3437.4
 2022-01-06       NaN
 2022-01-07       NaN
 2022-01-11       NaN
                ...  
 2022-09-26       NaN
 2022-09-27       NaN
 2022-09-28       NaN
 2022-09-29       NaN
 2022-09-30       NaN
 Name: h_breakout, Length: 182, dtype: float64,
 'scatter': False,
 'type': 'scatter',
 'mav': None,
 'panel': 0,
 'marker': '^',
 'markersize': 18,
 'color': None,
 'linestyle': None,
 'linewidths': None,
 'edgecolors': None,
 'width': None,
 'bottom': 0,
 'alpha': 1,
 'secondary_y': 'auto',
 'y_on_right': None,
 'ylabel': None,
 'ylim': None,
 'title': None,
 'ax': None,
 'yscale': None,
 'stepwhere': 'pre',
 'marketcolors': None,
 'fill_between': None}

このdictデータを自分で作るのは大変です。そこで、mplfinance が専用のメソッド、make_addplot というのを持っているのでこれを使います。これを使って、追加する指標のデータと、どのグラフに書き込むのか(panel)可視化の方法(type, ‘line’, ‘bar’, ‘scatter’, ‘step’から選択)、マーカーや線のスタイル、大きや色、などの情報と合わせて渡すことでデータを作ってくれます。 (メインの四本値のデータよりそこに加筆する指標を先にライブラリに渡すのって微妙に直感的で無くて使いにくいですね。)

例えば、以下のようにすることで加筆データを生成できます。make_addplotに指定できる引数は上のサンプル辞書のkeyを見るのが早いと思います。だいたいイメージ通り動きます。

import mplfinance as mpf


adp = [
    # 高値安値線。 panel=0 と指定し、四本値と同じパネルを指定
    mpf.make_addplot(price_df[["h_line", "l_line"]], type='step',panel=0),
    # 高値更新位置。 scatter を指定し、markerで上向三角形を指定
    mpf.make_addplot(price_df["h_breakout"], type='scatter',panel=0, marker="^"),
    # 安値更新位置。 scatter を指定し、markerで下向三角形を指定
    mpf.make_addplot(price_df["l_breakout"], type='scatter',panel=0, marker="v"),
    # 出来高の和。 panel=1 とすると 出来高のパネルを指定したことになる。 line を指定して折れ線グラフに。
    mpf.make_addplot(price_df["volume_sum"], type='line',panel=1, linestyle="--", color="g"),
    # 終値の前日差分。棒グラフのサンプルも欲しかったのでbarを指定。panel=2とすると新規のパネルが追加される
    mpf.make_addplot(price_df.close.diff(), type='bar',panel=2),
]

さて、これで加筆データができました。 注意する点としてはpanel ですね。panel=0は四本値のグラフで固定ですが、panel=1は次のチャート描写のメソッドで、出来高の表示をするかどうか指定するvolume引数の値で挙動が変わります。Trueなら、出来高のグラフがpanel=1で、panel=2以降が新規のグラフです。一方Falseなら、panel=1から新規のグラフです。番号を飛ばすことができず、panel=1がないとpanel=2は指定できないので注意してください。

では、前回のグラフに追加して、上記のadpの値も渡してみます。今回、パネルが3個になりますが、panel_ratios でそれぞれの幅の調整ができます。メインの四本値のパネルを大きめにしておきました。

mpf.plot(
    price_df,
    volume=True,  # 出来高も表示
    mav=[10, 20],  # 移動平均線
    addplot=adp,  # 追加指標
    figratio=(3, 2),  # 図全体の縦横比
    panel_ratios=(2, 1, 1),  # パネルの縦幅の比率
)

出来上がった図形がこちらです。

いい感じですね。細かい話ですが、高値安値線はtypeをstepにしてカクカクした線にしていて、高値の方に引いた線はtypeをlineにして斜めにつながる線にしています。好みの問題ですがこういう微調整ができるのが良いですね。

最初はデータの準備等に戸惑ったり、思うような微調整に苦戦したりするかもしれませんが、一回作ってしまうと、あとは銘柄を入れ替えたり期間を変えたりしながらパラパラといろんな検証ができます。scatterで自分の仕掛けと手仕舞いのポイント等を入れて検証したりってことも可能ですね。

mplfinanceで株価チャートを描く

以前、Pythonでローソク足のチャートを描く方法を紹介しました。
参考: pythonでローソク足を描く

この時は、mpl_financeという既にメンテナンスされていないライブラリを使っていました。まぁ、個人的にちょっと使う分には問題ないのですがやはりちょっと不安になりますよね。

その一方で、実は株価チャートを描く別のライブラリがあることがわかりました。それが、
mplfinance です。 名前が非常に似ていますが、これはアンダーバー(_)がありません。
さっき見たらGithubに15時間前のコミットがあり、しっかりメンテナンス継続中のようです。

ドキュメントとなるのは以下の二つでしょうか。Githubのチュートリアルや、サンプルコードが比較的充実しているのでこちらが使いやすそうですね。
参考1 (Github): matplotlib/mplfinance: Financial Markets Data Visualization using Matplotlib
参考2(PyPI): mplfinance · PyPI

今回はこれのチュートリアルを見ながら一番基本的なローソク足と、出来高、あとついでに移動平均線でも書いてみましょう。

まずはデータの準備です。例えば次のような4本値データがあったとしましょう。データ件数とサンプルとして最初の10行表示しています。チャートを描くので、日付と、始値、高値、安値、終値が必須で、出来高がオプションです。

print("データ件数:", len(price_df))
# データ件数: 81
print(price_df.head(10))
"""
         date    open    high     low   close    volume
0  2022-04-01  3710.0  3725.0  3670.0  3715.0  113600.0
1  2022-04-04  3685.0  3730.0  3675.0  3725.0   94200.0
2  2022-04-05  3705.0  3725.0  3675.0  3675.0  113900.0
3  2022-04-06  3650.0  3680.0  3625.0  3635.0  163800.0
4  2022-04-07  3620.0  3660.0  3615.0  3630.0   98900.0
5  2022-04-08  3700.0  3730.0  3660.0  3700.0  135100.0
6  2022-04-11  3840.0  4060.0  3810.0  4045.0  494900.0
7  2022-04-12  4030.0  4210.0  4015.0  4120.0  403000.0
8  2022-04-13  4115.0  4165.0  4080.0  4155.0  184600.0
9  2022-04-14  4085.0  4120.0  4020.0  4065.0  240900.0
"""

これをライブラリが要求する形にする必要があります。まず、日付(date)がデータフレームのインデックスに指定されていないといけません。また、データ型も文字列ではダメで、DatetimeIndexである必要があります。その変形をします。

price_df["date"] = pd.to_datetime(price_df["date"])
price_df.set_index("date", inplace=True)

print(price_df.head(5))
"""
              open    high     low   close    volume
date                                                
2022-04-01  3710.0  3725.0  3670.0  3715.0  113600.0
2022-04-04  3685.0  3730.0  3675.0  3725.0   94200.0
2022-04-05  3705.0  3725.0  3675.0  3675.0  113900.0
2022-04-06  3650.0  3680.0  3625.0  3635.0  163800.0
2022-04-07  3620.0  3660.0  3615.0  3630.0   98900.0
"""

また、データの各列の名前は[“Open”, “High”, “Low”, “Close”, “Volume”] (“Volume”は無くても可)にするよう指定されています。日本語で”始値”とか入っている場合は変換しましょう。
先頭を大文字にしないといけないのかな?と思ったのですが、全部小文字でも動くことが確認できているので、今回はこのまま使います。

なんか、ソースコードを見ると多くの人が要求したから小文字にも対応してくれたそうです。
ここ

    # We will not be fully case-insensitive (since Pandas columns as NOT case-insensitive)
    # but because so many people have requested it, for the default column names we will
    # try both Capitalized and lower case:
    columns = config['columns']
    if columns is None:
        columns =  ('Open', 'High', 'Low', 'Close', 'Volume')
        if all([c.lower() in data for c in columns[0:4]]):
            columns =  ('open', 'high', 'low', 'close', 'volume')

さて、これでデータが揃いました。

ここからの使い方はとても簡単で、こだわりがなければライブラリをimportして、plotってメソッドに渡すだけです。

import mplfinance as mpf

mpf.plot(price_df)

いわゆるOHLCチャートが出力されました。

ここから少しカスタマイズしていきます。

まず、チャートをローソク足にするのは、type=”candle”です。”candle”の他にはデフォルトの”ohlc”や、”line”, “renko”, “pnf” が指定できます。それぞれイメージ通りの出力が得られるので興味がある方は試してみてください。

出来高の追加は volume=True を指定します。

また、移動平均線は、mav 引数で指定します。整数を一つ指定すればその日数の移動平均が1本、配列等で複数の整数を渡して数本引くこともできます。

チャート全体をもう少し横長にしたいなど、縦横比率を変えたい場合は、figratioで指定します。デフォルトは(8.00,5.75) です。どうも名前の通り、数値の比だけが重要のようです。(4, 2)と(2, 1)の結果が一緒でした。

以上、やってみます。

mpf.plot(
    price_df,
    type="candle",
    volume=True,
    mav=[5, 12],
    figratio=(2, 1),
)

出力がこちらです。

いい感じですね。

この他にも見栄えを整えるオプションは多数用意さてれいるのですが、いくつかの組み合わせがstyleとして用意されています。style=”yahoo”なんてのもあります。
こちらのページにサンプルがまとまっているので、気に入ったのがあったら使ってみるのも良いと思います。(僕は一旦上の、style=”default”でいいかな。)

今回はデータを渡すだけでサクッとチャートを描いてくれるmplfinanceの基本的な使い方を紹介しました。近いうちの記事で、これに別のテクニカル指標を表示したり線や点を追加するなどのカスタマイズ方法を紹介していきたいと思います。

Jupyter notebookのファイルをコマンドラインで実行する

Jupyter notebookのファイル (.ipynbファイル)をそのまま実行したい、って場面は結構あります。notebookファイルから通常のPythonファイル(.pyファイル)に変換しておけばいいじゃないか、という意見もあると思いますし、それはそれでごもっともです。ただ、僕個人の事例で言うと、個人的に開発してるツールの中に土日に触る時はちょっとずつ編集して改良して実行し、平日はそのまま全セルを実行するだけってnotebookファイルなどもあります。そのようなファイルについて、逐一上から順番にnotebookのセルを実行していくのはやや面倒です。

と言うことで、.ipynbファイルをコマンドラインからバッチのように実行できると便利、ってことでその方法を紹介していきます。

Google等で検索するとよく出てくる方法と、もう一つ、ドキュメントを読んでいて見つけた方法があるのでそれぞれ紹介します。後者の方法の方が手軽なので、まずそちらを書きます。

jupyter execute コマンドを使う方法

一つ目に紹介する方法は、jupyter execute コマンドです。
ドキュメントはこちら。
参考: Executing notebooks — nbclient – Using a command-line interface

これはすごく簡単で、以下のコマンドで実行するだけです。

$ jupyter execute {ファイル名}.ipynb
# 以下出力
[NbClientApp] Executing {ファイル名}.ipynb
[NbClientApp] Executing notebook with kernel: python3

コマンド名は直感的でわかりやすくて記述量も少なくて僕は気に入っています。

ただし、注意点があってこの方法でnotebookを実行しても元のnotebookファイルは更新されません。つまりどう言うことかと言うと、notebook内の出力領域に表示されるはずの情報は残らないと言うことです。printしたテキストとか、matplotlib等で表示した画像などは見れず、ただプログラムが走るだけと言う状態になります。

そのため、この方法でnotebookを実行する場合は必要な出力はnotebookの外部に保存するように作っておく必要があります。必要な結果はファイルに書き出すとかDBに保存するような実装にしておきましょう。

次に紹介する方法(ググるとよく出てくる方法)では、実行結果の出力を残せるので、このexexuteコマンドでも何かオプションを指定したら実行結果を残せるだろうと思って探してたんですが、どうも今日時点ではそのような機能は実装されていなさそうです。今後に期待したいところです。

全体的にオプションも少なく、その中でも実際使えるものというと実質的に次の二つだけかなと思います。

# $ jupyter execute --help の出力結果から抜粋
--allow-errors
    Errors are ignored and execution is continued until the end of the notebook.
    Equivalent to: [--NbClientApp.allow_errors=True]
--timeout=<Int>
    The time to wait (in seconds) for output from executions. If a cell
    execution takes longer, a TimeoutError is raised. ``-1`` will disable the
    timeout.
    Default: None
    Equivalent to: [--NbClientApp.timeout]

–allow-errors をつけると、エラーが発生してもそれ以降のセルも実行されるようになります。これをつけてない場合は、エラーになったセルがあればそれ以降のセルは実行されません。
試してみたのですが、–allow-errorsをつけていると、エラーになったセルがあってもそのエラー文等は表示されないので、リスクを伴うオプションだと思います。エラーになったらその旨を外部のログに残す実装になっていないと自分で気づく手段がありません。なお、–allow-errorsをつけてない場合、エラーになるセルがあったらそこで標準エラー出力にエラーを表示して止まるので気付けます。

–timeout の方はデフォルトでタイムアウト無しになっているのであまり気にしなくても良いかと思うのですが、異常に長く時間がかかるリスクがある場合などは設定しても良いでしょう。

jupyter nbconvert コマンドを使う方法

次に紹介するのは、 jupyter nbconvert コマンドを使う方法です。jupyter notebookをコマンドライン(CUI)で使う方法として検索するとよく出てくるのはこちらの方法です。

nbconvert 自体は、notebookを実行するコマンドじゃなくて、別の形式に変換するコマンドなので、正直これをnotebookの実行に使うのって抵抗あるのですが、どういうわけかこちらの方がいろんなオプションが充実していて、実行専用と思われる先ほどの jupyter execute コマンドよりも柔軟な設定が必要です。詳細は不明ですが歴史的な経緯か何かによるものでしょうか。

ドキュメントはこちら
参考: Executing notebooks — nbconvert 7.1.0.dev0 documentation

基本的な使い方は次のようになります。–to でファイルの変換先のタイプを指定するのですが、そこでnotebookを指定して、さらに–execute をつけると実行されます。

$ jupyter nbconvert --to notebook --execute {ファイル名}.ipynb
# 以下出力
[NbConvertApp] Converting notebook {ファイル名}.ipynb to notebook
[NbConvertApp] Writing {ファイルサイズ} bytes to {ファイル名}.nbconvert.ipynb

上記の出力をみていただくと分かる通り、実行した結果を、{ファイル名}.nbconvert.ipynb という新しいファイルに書き出してくれています。これの内容がセルを(空のセルを飛ばしながら)上から順番に実行した結果になっていて、こちらの方法であればnotebookの出力領域にprintした文字列やmatplotlibの画像なども残すことができます。

細かいオプションについては、 jupyter nbconvert –help で確認可能ですが、 先ほども書きましたがexecuteよりもたくさんあります。

–allow-errors は同じように指定できますし、 –output {ファイル名} で、書き込み先のファイル名を変更することも可能です。
ちなみにデフォルトだと、上記の実行例の通り{ファイル名}.nbconvert.ipynbに書き込みますが、既に同名のファイルが存在した場合は上書きしてしまいます。そのため、毎回の実行履歴を残しておきたいならば出来上がったファイルを退避しておくか、–outputオプションで別の名前をつける必要があるでしょう。
–inplace をつけて、別ファイルに書き出すのではなくて、元のファイルを置き換えるなども可能です。この辺の細かい調整を行えるのがnbconvertの方を使える利点ですね。executeの方にも実装していただきたいものです。

まとめ

以上で、jupyter notebookファイルをコマンドラインで実行する方法を二つ紹介してきました。それぞれメリットデメリットあるので用途に応じて便利な方を使っていただけたらと思います。

Pythonで線形和割り当て問題を解く

昔、あるアルゴリズムを実装する中で使ったことがある、 linear_sum_assignment っていうscipyのメソッドを久々に使うと思ったら使い方を忘れていたのでその復習を兼ねた記事です。

これは、2部グラフの最小重みマッチングとも呼ばれている問題で、要するに、二つのグループの要素からそれぞれ1個ずつ選んだペアにコストが定義されていて、どのように組み合わせてペアを選んでいったらコストの和を最小にできるかという問題です。

この説明はわかりにくいですね。もう少し具体的なのがいいと思うので、Scipyのドキュメントで使われている例を使いましょう。

Scipyのドキュメントではworker(作業者)とjob(仕事)を例に解説されています。
参考: scipy.optimize.linear_sum_assignment — SciPy v1.9.1 マニュアル

例えば、4人の作業者がいて4つの仕事があったとします。そして、その4人がそれぞれの仕事をした場合に、かかる時間(=コスト)が次のように与えられていたとします。行列形式ですが、i行j列の値が、作業者iが仕事jを実行した場合にかかるコストです。(例を乱数で作りました。)

import numpy as np


np.random.seed(0)
cost = np.random.randint(1, 10, size=(4, 4))
print(cost)

"""
[[6 1 4 4]
 [8 4 6 3]
 [5 8 7 9]
 [9 2 7 8]]
"""

cost[1, 2] = 6 ですが、これは作業者1が仕事2を行った場合のコストが6ということです。
(インデックスが0始まりであることに注意してください。cost[1, 2]は2行3列目の要素です。)

さて、上の図を見ての通り、作業者ごとに仕事の得手不得手があり、コストが違うようです。そこで、これらの仕事をそれぞれ誰が担当したらコストの総和を最小にできるでしょうか、というのが線形和割り当て問題です。

これが、先ほどの linear_sum_assignment を使うと一発で解けます。

ドキュメントにある通り、戻り値が行のインデックス、列のインデックスと帰ってくるので注意してください。

from scipy.optimize import linear_sum_assignment


row_ind, col_ind = linear_sum_assignment(cost)
print("行:", row_ind)
print("列:", col_ind)
"""
行: [0 1 2 3]
列: [2 3 0 1]
"""

二つのarray(プリントしてるのでlistに見えますがnumpyのArrayです)が戻ってきます。
これが、worker0がjob2を担当し、worker1がjob3を担当し、、、と読んでいきます。
これがコストを最小にする組み合わせです。簡単でしたね。

さて、値の戻ってき方がちょっと独特だったのでプログラムでこれを使うにはコツが要ります。こう使うと便利だよ、ってところまでドキュメントに書いてあると嬉しいのですが、書いてないので自分で考えないといけません。

インデックスとして返ってきているので、次のようにコスト行列のインデックスにこの値を入れると、最適化された組み合わせのコストが得られます。そして、sum()すると合計が得られます。以下の通り、14が最小ということがわかります。

print(cost[row_ind, col_ind].sum())
# 14

Scipyの実装を疑うわけではないのですが、念の為、本当にこの組み合わせが最適で14が最小なのか、全組み合わせ見ておきましょう。itertools.permutationsを使います。

from itertools import permutations


for perm in permutations(range(4)):
    print(list(perm), "=>", cost[range(4), perm].sum())
"""
[0, 1, 2, 3] => 25
[0, 1, 3, 2] => 26
[0, 2, 1, 3] => 28
[0, 2, 3, 1] => 23
[0, 3, 1, 2] => 24
[0, 3, 2, 1] => 18
[1, 0, 2, 3] => 24
[1, 0, 3, 2] => 25
[1, 2, 0, 3] => 20
[1, 2, 3, 0] => 25
[1, 3, 0, 2] => 16
[1, 3, 2, 0] => 20
[2, 0, 1, 3] => 28
[2, 0, 3, 1] => 23
[2, 1, 0, 3] => 21
[2, 1, 3, 0] => 26
[2, 3, 0, 1] => 14
[2, 3, 1, 0] => 24
[3, 0, 1, 2] => 27
[3, 0, 2, 1] => 21
[3, 1, 0, 2] => 20
[3, 1, 2, 0] => 24
[3, 2, 0, 1] => 17
[3, 2, 1, 0] => 27
"""

どうやらあってそうですね。

col_ind の方を使って、行列を並び替えることもできます。i行目のworkerがi列目のjobを担当する直感的に見やすい行列が次のようにして得られます。

print(cost[:, col_ind])
"""
[[4 4 6 1]
 [6 3 8 4]
 [7 9 5 8]
 [7 8 9 2]]
"""

また、解きたい問題や実装によっては、この行と列の対応を辞書にしたほうが使いやすいこともあるでしょう。そのような時はdictとzipで変換します。

print(dict(zip(row_ind, col_ind)))
# {0: 2, 1: 3, 2: 0, 3: 1}

ここまでの例では、与えられた行列はコストの行列でこれを最小化したい、という問題設定でやってきました。ただ、場合によっては利益やスコアの行列が与えられて、最大化する組み合わせを探したいという場合もあると思います。行列にマイナス掛けて同じことすればいいのですが、linear_sum_assignment自体にもそれに対応した引数があります。

それが、maximize で、 デフォルトはFalseですが、Trueにすると最大化を目指すようになります。同じ行列でやってみます。さっき全パターン列挙しているので正解はわかっていて、[0, 2, 1, 3]か[2, 0, 1, 3]のどちらかが得られるはずです。

print(linear_sum_assignment(cost, maximize=True))
# (array([0, 1, 2, 3]), array([0, 2, 1, 3]))

[0, 2, 1, 3]の方が出ててきましたね。

ここまで、正方行列を取り上げてきましたが、linear_sum_assignment は、一般行列についても実行できます。行と列の数が違う場合は、行と列のうち数が小さい方に揃えて、実行されます。

まず、行が多い(workerが多い)場合をやってみましょう。7行4列で、7人のworkerがいて、jobが4つあって、コストがそれぞれ定義されていた場合に、どの4人を選抜してそれぞれにどの4つのタスクをやってもらうのが最適か、という問題を解くのと対応します。

np.random.seed(0)
cost = np.random.randint(1, 10, size=(7, 4))
print(cost)
"""
[[6 1 4 4]
 [8 4 6 3]
 [5 8 7 9]
 [9 2 7 8]
 [8 9 2 6]
 [9 5 4 1]
 [4 6 1 3]]
"""

row_ind, col_ind = linear_sum_assignment(cost)
print("行:", row_ind)
print("列:", col_ind)
"""
行: [0 2 5 6]
列: [1 0 3 2]
"""

次に同様に横長の行列の場合です。例えば4人のworkerがいて7つのjobがあったときに、どの4つのjobを選んで実行したら利益を最大化できるか、って問題がこれに相当します。(最小化でいい例が思いつかなかったのでこれは最大化でやります。)

np.random.seed(0)
score = np.random.randint(1, 10, size=(4, 7))
print(score)
"""
[[6 1 4 4 8 4 6]
 [3 5 8 7 9 9 2]
 [7 8 8 9 2 6 9]
 [5 4 1 4 6 1 3]]
"""

row_ind, col_ind = linear_sum_assignment(score, maximize=True)
print("行:", row_ind)
print("列:", col_ind)
"""
行: [0 1 2 3]
列: [4 5 3 0]
"""

以上が linear_sum_assignment の基本的な使い方になります。

Pythonでファイルの更新時刻やファイルサイズの情報を取得する

パソコン(ここではMacを想定)内のファイルを整理していて、古いファイルなどをリストアップしようとしたときのメモです。
更新時刻を取得するのはBashコマンドでもできますしファインダーでも見れて、僕も普段はそうしているのですが、一旦気合入れて整理しようと思ったときにこれらの方法がやや使いにくかったのでPythonでやることを検討しました。

結論から言うと、Pythonのosモジュールを使うと実装できます。
os.stat ってのがファイルの情報を取得する関数で、結果はstat_result というオブジェクトで帰ってきます。

ドキュメントはこちら。
参考: os — 雑多なオペレーティングシステムインターフェース — Python 3.10.6 ドキュメント

サンプルとしてこんなファイルがあったとしましょう。

$ ls -la sample.txt
-rw-r--r--  1 {user} {group}  7  9  5 01:01 sample.txt

これの情報を取得するには次のようにします。

import os


file_path = "./sample.txt"
file_info = os.stat(file_path)
print(file_info)
"""
os.stat_result(st_mode=33188, st_ino=10433402, st_dev=16777220, st_nlink=1,
st_uid=501, st_gid=20, st_size=7,
st_atime=1662307286, st_mtime=1662307285, st_ctime=1662307285)
"""

st_atimeが最終アクセス時刻、st_mtimeが最終更新時刻です。
printすると出てきませんが、st_birthtimeなんてのもあってこれがファイルの作成時刻です。

これらの値は普通に属性なので、.(ドット)で繋いでアクセスできます。

注意しないといけないのは、実行しているOSによって取得できる値に違いがあり、取得できなかったり取得できるけど意味が違ったりするものがあることです。

詳しくはドキュメントに書いてあります。
class os.stat_result

st_ctime はUNIXではメタデータの最終更新時刻で、Windows では作成時刻、単位は秒など色々違いますね。
なんとなく使わずにきちんと動作を確認して使うことが重要でしょう。

また、元々の目的が更新時刻の取得だったのですが、ついでにst_size でファイルサイズも取得できています。
上の例で見ていただくと、 st_size=7 となっていて、その上のlsの結果と一致します。

さて、以上でファイルの更新時刻やサイズが取得できたのですが、更新時刻(を含む事故国関係の情報一式)はUNIX時間で得られます。
人間にとって使いにくいので、以前紹介した方法で変換しましょう。
参考: Pythonで時刻をUNIX時間に変換する方法やPandasのデータを使う時の注意点

from datetime import datetime


# ファイル作成時刻
print(datetime.fromtimestamp(file_info.st_birthtime))
# 2022-09-05 01:01:13.879805

# 最終内容更新時刻
print(datetime.fromtimestamp(file_info.st_mtime))
# 2022-09-05 01:01:25.663676

# 最終アクセス時刻
print(datetime.fromtimestamp(file_info.st_atime))
# 2022-09-05 01:01:26.286258

非常に簡単ですね。
あとは globか何かでファイルパスの一覧を作成してDataFrame化して、applyでさっと処理して仕舞えば少々ファイルが多くてもすぐリスト化できそうです。

Pythonのdataclassを使ってみた

Pythonの標準ライブラリにdataclassというのがあるの見つけたので使ってみました。
参考: dataclasses — データクラス — Python 3.10.6 ドキュメント

名前から、オリジナルのデータ型を定義するためのモジュールなのかなとも思ったのですが実際は少し違いそうです。もちろん、オリジナルのデータ型を定義するためにも使えるのですが、その実態は、クラスに対して__init__()__repr__()といった特殊メソッドを自動的に生成してくれるデコレーターという解釈が正確のようです。

お試しに、証券コードと会社名と説明を属性として持ったCompanyクラスを作ってみましょう。

import dataclasses


@dataclasses.dataclass
class Company:
    code: int
    name: str
    description: str = None


# __init__() メソッドが自動的に定義されているためこれでインスタンスを作成できる
toyota = Company(code=7203, name="トヨタ自動車", description="自動車メーカー")
suzuki = Company(code=7269, name="スズキ", description="自動車メーカー")


# __eq__() メソッドが自動的に定義されているため、比較ができる
toyota == suzuki
# False

これは属性がたった3個だけで、メソッドも持ってないようなクラスなのですが、__init__()が入らないってだけでものすごくシンプルに描けるようになりましたね。

また、比較用のメソッドを自分で作らなくても、各属性が全て一致しているかどうかを基準に一致不一致を判定してくれるのも便利です。属性が3個だけだとそこまでありがたみがないですが、もっと大規模なクラスで、全属性の一致を判定するのは無駄にコードが長くなりますから。

int とか str と書いて型ヒントをつけられたりするのも今風な感じがします。ただ、この型ヒントはどうやらただのアノテーションで、代入する値に対する強制力などはないようです。
codeを整数、nameを文字列としていますがそうでない値も入ります。

dummy_company = Company(code="文字列", name=1234)
print(dummy_company)
# Company(code='文字列', name=1234, description=None)

この記事の冒頭で書いていますが、このdataclassはclassに対するデコレーターなので、ただのオリジナルデータ型ではなく、普通にメソッド等を持っているクラスを作成することもできます。その場合__init__()などを自分で書かなくて良くなるので、特に凝った__init__()が不要な場合はバンバン使って良さそうです。例えば、二つの値を持ち、合計値を返せるクラスは次のようになります。

@dataclasses.dataclass
class two_number:
    a: int
    b: int

    def sum(self):
        return self.a + self.b


tn = two_number(5, 8)
print(tn.sum())
# 13

自分が実装したいメソッドの部分に専念できるのはいいですね。

このdataclassのデコレーターですが、デコレーター自体も引数を取って、色々設定することができます。詳細はドキュメントに譲りますが、例えばinit=Falseやrepr=False, eq=Falseを指定すると、デフォルトで生成されると言っていた__init__()や__repr__()、__eq__()などが生成されなくなります。自分で実装したいものがあったらそれだけ自分で実装するようにしましょう。

frozen (デフォルトはFalse) を Trueに指定すると、値への代入が禁止されます。これのメリットしては辞書のキーとして使えるようになることでしょうか。ちょっとやってみます。
1個目の例は上で作ったCompanyなので、frozenはFalseです。その次がfrozen=True。

# frozen = False だと属性に値を代入できる
toyota.description  = "日本の自動車メーカー"
print(toyota)


# frozen=Trueを指定してみる
@dataclasses.dataclass(frozen=True)
class Frozen_Company:
    code: int
    name: str
    description: str = None


f_toyota = Frozen_Company(code=7203, name="トヨタ自動車", description="自動車メーカー")

# frozen = True だと属性に値を代入できないため、例外が上がる
try:
    f_toyota.description  = "日本の自動車メーカー"
except Exception as e:
    print(type(e), ":", e)
#  <class 'dataclasses.FrozenInstanceError'> : cannot assign to field 'description'

frozenにするメリットとしては、タプルと同様に辞書のキーにできる、という点があります。

# frozenではない、つまりハッシュ化不可能なので辞書のキーにできない
try:
    {toyota: 1}
except Exception as e:
    print(type(e), ":", e)
# <class 'TypeError'> : unhashable type: 'Company'

# frozen=Trueだとハッシュ化可能なので辞書のキーにできる
{f_toyota: 1}
# {Frozen_Company(code=7203, name='トヨタ自動車', description='自動車メーカー'): 1}

また、order という引数(デフォルトFalse)にTrueを渡すと、__le__()等々の不等号を実装する特殊メソッドたち4種も自動的に生成してくれるようになります。どうも要素を順番に比較して最初に上下がついたもので決まるようです。これも属性が多い時は便利なのではないでしょうか。

@dataclasses.dataclass(order=True)
class two_number:
    a: int
    b: int

tn1 = two_number(12, 5)
tn2 = two_number(6, 20)
tn1 > tn2
# True

さて、以上でdataclass自体の基本的な説明はおしまいです。

あとは偶然気づいた豆知識なのですが、Pandasとの連携について紹介します。
dataclassの配列は、簡単にPandasのDataFrameに変換できます。実質的にdictみたいに振る舞ってくれるようです。

import pandas as pd


df = pd.DataFrame([toyota, suzuki])
print(df)
"""
   code    name description
0  7203  トヨタ自動車  日本の自動車メーカー
1  7269     スズキ     自動車メーカー
"""

便利ですね。

逆に、DataFrameに入った値たちをdataclassで定義したクラスのインスタンスに変換したいな、と思って方法探しました。専用のメソッドなどは見つかってないのですが、lamda関数を使ってこのようにするのが良いでしょう。

df.apply(lambda row: Company(**row), axis=1)
"""
0    Company(code=7203, name='トヨタ自動車', description=...
1    Company(code=7269, name='スズキ', description='自動...
dtype: object
"""

事前にDataFrameの列名と、dataclassの属性名を揃えておく必要はあるのでそこは注意してください。

PythonでUUIDを生成する

ある作業をやっているときに、データの塊ごとにユニークなidを振りたいことがありました。
通常は0から順番に番号振っていけば十分なのですが、今回は分散処理していて個別の処理でバッティングしないようにidを振りたかったのです。これでも番号のレンジを分けておけば十分なのですが、世の中にはUUIDって仕組みがあるのでこれを試すことにしました。

UUIDの詳細はWikipediaをご参照ください。要するにユニークなIDを発行する仕組みです。
参考: UUID – Wikipedia

UUIDってバージョン1〜5があって仕組みが違っていたんですね。自分はMACアドレスと時刻を使ってIDを生成する方法だと認識していたのですが、それはバージョン1のことだったようです。

PythonでUUIDを利用したい場合は、uuidという標準ライブラリが使えます。
参考: uuid — UUID objects according to RFC 4122

バージョン1, 3, 4, 5 の4種類のUUIDが実装されているようです。
とりあえず動かしてみますか。バージョン3と5は名前空間と名前に対してIDを振り分けるので、とりあえずこのブログのドメイン名を渡しています。

import uuid


uuid.uuid1()
# UUID('9f15c3a4-2173-11ed-bb9d-dca9048ad673')

uuid.uuid3(uuid.NAMESPACE_DNS, "analytics-note.xyz")
# UUID('30b2104c-d522-3f77-9fe8-6863bd4a6cda')

uuid.uuid4()
# UUID('8e489abb-a18a-4f0c-80d1-30ae0ea83d81')

uuid.uuid5(uuid.NAMESPACE_DNS, "analytics-note.xyz")
# UUID('feeaf7d1-3987-5451-9a28-afd72801a03b')

それぞれ、UUIDが生成できましたね。ハイフンで区切られた3つめの塊の1文字目の数字がバージョン番号なのですが、ちゃんと1,3,4,5となっています。

uuid3とuuid5 は渡した名前に対してIDを割り当てているので、引数が同じなら結果はずっと同じです。
uuid1は時刻を、uuid4は乱数を用いているので実行するたびに生成されるIDが変わります。

この記事冒頭の目的にではuuid4を使えば良いでしょう。

上記の結果を見ても分かる通り、結果はUUIDというクラスのオブジェクトとして返ってきます。一応typeを見ておきましょう。

sample_id = uuid.uuid4()
print(type(sample_id))
# <class 'uuid.UUID'>

str で文字型にキャストすることもできますし、プロパティのhexやintで16進法表示や整数表示も得られます。(なぜstrはプロパティではないのか不思議です)

print(str(sample_id))
# 61d67c09-4a78-4d03-a6f0-8f1843673083

print(sample_id.hex)
# 61d67c094a784d03a6f08f1843673083

print(sample_id.int)
# 130048782873754933734408394572189610115

文字列からUUIDオブジェクトを作ることもできます。ただ、これはいつ使うものなのか不明です。

uuid.UUID("61d67c09-4a78-4d03-a6f0-8f1843673083")
# UUID('61d67c09-4a78-4d03-a6f0-8f1843673083')

あとは気になるのは処理時間ですね。ID作成に時間がかかるようだと困るので。
ただ、ちょっと実験したところそこまで大きな問題もなさそうでした。
100万回実行するのに3秒未満で完了しています。

%%time
for i in range(1000000):
    uuid.uuid4()

"""
CPU times: user 2.37 s, sys: 407 ms, total: 2.77 s
Wall time: 2.79 s
"""

ipysankeywidgetでサンキーダイアグラム

詳細はWikipediaに譲りますが、サンキーダイアグラムっていうデータの可視化手法があります。
参考: サンキー ダイアグラム – Wikipedia

Webサービスにおけるユーザーの動線可視化とか、コンバージョンに至るまでの途中の離脱状況とか分析するのに便利なのですが、適したツールがなかなか見つからずあまり利用することがありませんでした。Tableauでやるには非常に面倒な手順を踏む必要がありましたし、matplotlibで作ると見た目がカッコ悪いものになりがちです。

何か良いツールはないかと思っていて、JavaScriptのライブラリなども含めて調べていたのですが、jupyterで使える ipysankeywidget というのが良さそうだったので紹介します。

まず、導入方法はこちらです。pip だけではなく jupyter notebook / jupyter lab それぞれ対応したコマンドが必要なので注意してください。
参考: https://github.com/ricklupton/ipysankeywidget

# pip インストールの場合、以下のコマンドでイストールした後に以降の設定コマンド実行
$ pip install ipysankeywidget

# notebookの場合は次の2つ
$ jupyter nbextension enable --py --sys-prefix ipysankeywidget
$ jupyter nbextension enable --py --sys-prefix widgetsnbextension

# lab の場合は次の1つ
$ jupyter labextension install jupyter-sankey-widget @jupyter-widgets/jupyterlab-manager

# condaの場合はインストールだけで設定まで完了する。
$ conda install -c conda-forge ipysankeywidget 

ドキュメントは、d3-sankey-diagram のドキュメントを見ろって書いてありますね。これはJavaScriptのライブラリです。 ipysankeywidget はそれをラップしたもののようです。

このライブラリ自体のドキュメントがかなり貧弱ですが、サンプルのnoteboookが4つ用意されています。こちらを一通り試しておけば細かい設定のを除いて問題なく使えるようになるでしょう。より詳しくはd3のドキュメントを見に行ったほうがよさそうです。
参考: ipysankeywidget/examples/README.md

基本的な使い方は、 どこから(source)、どこへ(target)、どのくらいの量の(value)流れを描写するかのdictの配列を用意し、それを渡すだけです。auto_save_png というメソッドを使うと画像に書き出すこともできます。

from ipysankeywidget import SankeyWidget


links = [
    {'source': 'A', 'target': 'B', 'value': 2},
    {'source': 'B', 'target': 'C', 'value': 2},
    {'source': 'D', 'target': 'B', 'value': 2},
    {'source': 'B', 'target': 'D', 'value': 2},
]

sanky = SankeyWidget(links=links)
sanky  # jupyter notebook上に表示される
# sanky.auto_save_png("simple-sanky.png")  # ファイルに保存する場合

出力がこちらです。

簡単でしたね。

あとは order 引数でnodeの順番を指定したり、groupでまとめたり、 linkの各dictにtype属性を付与して色を塗り分けたりと、細かい補正が色々できます。

結構面白いので是非試してみてください。

tqdmを使ってプログレスバーを表示する

以前、こんな記事を書きました。
参考: printでお手軽プログレスバー

そして、自分ではどこかでライブラリを使ってプログレスバーを表示する普通の方法も紹介済みだったつもりだったのですが、探すとまだ書いてなかったので書いておきます。

ipywidgets にも実はプログレスバー用のウィジェット(IntProgress/ FloatProgress)が存在し、これを使うこともできますが、プログレスバーには専用のライブラリも存在し、そちらの方が手軽に使えるので今回はそれを紹介します。

それがタイトルに書いてるtqdmです。
ドキュメント: tqdm documentation

使い方は簡単で、for文等で逐次処理するリストやイテレーターを tqdm.tqdm でラップするだけです。 tqdmを2回書くの嫌なので「from tqdm import tqdm」とインポートすると良いでしょう。
time.sleepをダミー処理としてやってみましょう。

from tqdm import tqdm
import time


for i in tqdm(range(1000)):
    time.sleep(0.01)

# 以下が出力されるプログレスバー。
100%|██████████| 1000/1000 [00:10<00:00, 97.60it/s]

内包表記でも使えます。

square_numbers = [i**2 for i in tqdm(range(100))]
# 以下出力
100%|██████████| 100/100 [00:00<00:00, 242445.32it/s]

単純にリストを周回させるだけでなく、enumerate や zip を使うこともあると思います。
この場合、これらの関数の外側にtqdmをつけると、一応進捗を数字で出してはくれるのですが、進捗のバーが出ません。

list_a = list("abcdefghij")
for i, s in tqdm(enumerate(list_a)):
    print(i, s)
​
# 以下出力
10it [00:00, 16757.11it/s]
0 a
1 b
2 c
3 d
4 e
5 f
6 g
7 h
8 i
9 j

この場合、少しコツがあって、enumerateやzipの内側にtqdmを使うと良いです。

for i, s in enumerate(tqdm(list_a)):
    print(i, s)
# 以下出力
100%|██████████| 10/10 [00:00<00:00, 20301.57it/s]
0 a
1 b
2 c
3 d
4 e
5 f
6 g
7 h
8 i
9 j

# zipの場合は内側のリストの一方を囲む。
for a, b in zip(tqdm(range(10)), range(10, 20)):
    print(a, b)
​
# 以下出力
100%|██████████| 10/10 [00:00<00:00, 8607.23it/s]
0 10
1 11
2 12
3 13
4 14
5 15
6 16
7 17
8 18
9 19

どうやら、enumerate や zipの外側にtqdm を置くと、事前にイテレーションの回数が取得できず、進捗率が計算できないのが原因みたいです。

僕はPandasのデータフレームを扱うことが多いので、進捗を表示したくなるのももっぱらPandasのデータを処理しているときです。データフレームの iterrows() メソッドもこの enumerate や zipと同様に事前にデータ件数を取得できないらしく、普通に使うとバーや進捗率が出ません。

import numpy as np
import pandas as pd


df = pd.DataFrame(
    {
        "col1": np.random.randint(10, size=10000),
        "col2": np.random.randn(10000),
    }
)

for i, row in tqdm(df.iterrows()):
    pass

# 以下出力
10000it [00:00, 28429.70it/s]

この場合は、total引数で明示的にデータ件数を渡すのが有効です。ちなみにこれはenumerateなどでも使えるテクニックです。

for i, row in tqdm(df.iterrows(), total=len(df)):
    pass

# 以下出力
100%|██████████| 10000/10000 [00:00<00:00, 26819.34it/s]

さらに、pandasの applyでもプログレスバーを使うことができます。
テキストの前処理など地味に時間のかかる処理をapplyでやることが多いのでこれはありがたいですね。使い方は少しトリッキーで、まず、 「tqdm.pandas()」を実行します。

すると、pandasのデータが、progress_apply や、 progress_applymap などのメソッドを持つようになるので、これを実行します。

tqdm.pandas()
df["col2"].progress_apply(np.sin)
# 以下出力
100%|██████████| 10000/10000 [00:00<00:00, 241031.18it/s]
# 結果略

groupbyに対応したprogress_aggregateもありますよ。

df.groupby("col1").progress_aggregate(sum)
# 以下出力
100%|██████████| 10/10 [00:00<00:00, 803.51it/s]
# 結果略

あとは滅多に使わないのですが、for文ではなく、事前にループ回数が決まっていないループで使う方法を書いておきます。
下のように、 with でインスタンスを作って、明示的にupdateとしていきます。あらかじめのループ回数を渡していないのでバーや進捗率は出ませんが現在の実行回数が観れるので一応進捗が確認できます。

i = 1
with tqdm() as pbar:
    while True:
        pbar.update(1)
        i += 1
        # 無限ループ防止
        if i>100:
            break

# 以下出力
100it [00:00, 226229.99it/s]

これ以外にも、 leave=False を指定して、処理が終わったらプログレスバーを消すとか、
desc/ postfix で前後に説明文を書くとか、数字の単位や進捗の更新頻度など細かい設定がたくさんできます。
必要に応じてドキュメントを参照して使ってみてください。

sshtunnel を使って踏み台サーバー経由でDB接続

以前、PyMySQLを使って、Amazon RDSを利用する方法を記事にしました。
参考: PythonでAuroraを操作する(PyMySQLを使う方法)

DBに直接接続できる場合はこれで良いのですが、場合によっては踏み台となるサーバーを経由して接続しなければならないことがあります。

僕の場合は職場ではセキュリティ上の理由から分析用のDBの一つがローカルから直接接続できないようになっていますし、プライベートではAurora Serverless v1使っているので、これはAWS内のリソース経由でしか接続できません。

ということで、Pythonで踏み台経由してAWSに接続する方法を書いていきます。
実はこれまで人からもらったコードをそのまま使っていたのですが、この記事書くために改めてsshtunnel のドキュメントを読んで仕組みを理解しました。

参考: Welcome to sshtunnel’s documentation! — sshtunnel 0.4.0 documentation

さて、さっそくやっていきましょう。セキュリティ的に接続情報はブログに書くわけにいかないので、以下の変数に入ってるものとします。

あと、サンプルなので実行したいSQL文も sql って変数に入ってるものとします。

サーバーのネットワーク設定ですが、踏み台はSSHのポート(通常は22番)、RDSはDBの接続ポート(通常は3306番)を開けておいてください。以降のコードで出てくる9999番ポートは、ローカル端末のポートなので踏み台やDBのサーバーでは開けておかなくて良いです。

# DBの接続情報 (RDSを想定)
db_host = "{DBのエンドポイント}"  # xxxx.rds.amazon.com みたいなの
db_port = 3306  # DBのポート(デフォルトから変更している場合は要修正)
db_name = "{データベース名}"
db_user = "{DBに接続するユーザー名}"
db_pass = "{DBに接続するユーザーのパスワード}"

# 踏み台サーバーの接続情報 (EC2を想定)
ssh_ip = "{サーバーのIPアドレス}"
ssh_user = "{SSH接続するユーザー名}"  # EC2のデフォルトであれば ec2-user
ssh_port = 22  # SS接続するポート(デオフォルトから変更している場合は要修正)
ssh_pkey = "{秘密鍵ファイルの配置パス}"  # .pem ファイルのパス

sql = "{実行したいSQL文}"

さて、さっそく行ってみましょう。単発で1個だけSQLを打って結果を取得したい、という場合、以下のコードで実行できます。
ローカル(手元のPCやMac)の 9999 番ポート (これは他で使ってなければ何番でもいい)への通信が、踏み台サーバーを経由してRDSに届くようになります。

from sshtunnel import SSHTunnelForwarder
from pymysql import cursors
from pymysql import connect


with SSHTunnelForwarder(
    ssh_address_or_host=(ssh_ip, ssh_port),  # 踏み台にするサーバーのIP/SSHポート
    ssh_username=ssh_user,  # SSHでログインするユーザー
    ssh_pkey=ssh_pkey,  # SSHの認証に使う秘密鍵
    remote_bind_address=(db_host, db_port),  # 踏み台を経由して接続したいDBのホスト名とポート
    local_bind_address=("localhost", 9999),  # バインドするローカル端末のホスト名とポート
) as tunnel:
    with connect(
            host="localhost",  # DBのエンドポイントではなく、ローカルの端末を指定する
            port=9999,  # これもDBのポートでは無く、バインドしたポート番号を指定する
            user=db_user,  # これ以下は普通にDB接続する場合と同じ引数
            password=db_pass,
            database=db_name,
            charset="utf8mb4",
            autocommit=True,
            cursorclass=cursors.DictCursor,
    ).cursor() as cursor:
        cursor.execute(sql)  # これでSQL実行
        rows = cursor.fetchall()  # 結果の取り出し

これで、通常はローカルからはアクセスできないDBへSQLを発行し、結果を変数rowsに取得することができました。SELECT文を打ったのであればpandasのDataFrame等に変換して使いましょう。

with文で変数をたくさん呼び出すインスタンスを使うのはコードの見栄えが非常に悪くなりますが、以下のように変数を事前に辞書にまとめておくと少しマシになります。

ssh_args = {
    "ssh_address_or_host": (ssh_ip, ssh_port),
    "ssh_username": ssh_user,
    "ssh_pkey": ssh_pkey,
    "remote_bind_address": (db_host, db_port),
    "local_bind_address": ("localhost", 9999),
}

db_args = {
    "host": "localhost",
    "port":  9999,
    "user":  db_user,
    "password":  db_pass,
    "database":  db_name,
    "charset":  "utf8mb4",
    "autocommit":  True,
    "cursorclass":  cursors.DictCursor,
}

with SSHTunnelForwarder(**ssh_args) as tunnel:
    with connect(**db_args).cursor() as cursor:
        cursor.execute(sql)
        rows = cursor.fetchall()

以上で、一応やりたいことはできると思いますが、発行したいSQLが複数ある場合かつ途中に別の処理も含むような場合一回ごとにポートフォワードとDBの接続をやり直していたらリソースの無駄です。(といっても、最近のコンピューター環境ならこれがストレスになる程時間かかるってことはないと思いますが。)

DBヘ接続しっぱなしにしておく場合は、withを使わずに次のように書きます。引数は上のコード例で作った、ssh_args, db_args をそのまま使います。

server = SSHTunnelForwarder(**ssh_args)
server.start()  # ポートフォワード開始

connection = connect(**db_args)  # DB接続
# 以上の3行で DBに接続した状態になる。

# 以下のようにして接続を使ってSQLを実行する。
with connection.cursor() as cursor:
    cursor.execute(sql)
    rows = cursor.fetchall()

# サンプルコードなのでSQLを1回しかやってないけど、続けて複数実行できる。

# 終わったらDB接続とポートフォワードをそれぞれ閉じる
connection.close()
server.stop()

これで、一つの接続を使い回すこともできるようになりました。

ちなみにですが、このsshtunnelで作ったポートフォワードの設定は端末単位で有効です。どういうことかというと、複数のPythonプロセス(例えば別々のJupyter notebook)間で、共有することができます。というより、Pythonに限らず他のプログラムからも使えます。
コンソールで、以下のコマンド使ってポートの動きを見ながら試すとよくわかります。

# 9999番ポートの利用状況を確認する
$ sudo lsof -i:9999

普段は何も結果が返ってこないか、ここまでのプログラムを実行してたらいろんな情報と共に(CLOSED)が返ってくると思いますが、ポートフォワードしている最中はESTABLISHEDになっていて、pythonが使っていることが確認できます。

特にPythonでDB操作したいという場合に限って言えば、別々のnotebookで操作するメリットなんて無いのですが、全く別の用途でポートフォワードだけPythonでやっておきたい、ということはあるかもしれないので、覚えておくと使う機会があるかもしれません。