boto3のAPIエラー発生時のリトライ回数の設定を変更する

先日、boto3でAmazon Comprehendを使っていたら、大量のデータを処理する中で次のエラーが頻発する様になってしまいました。

ClientError: An error occurred (ThrottlingException) when calling the BatchDetectSentiment operation (reached max retries: 4): Rate exceeded

短時間に実行しすぎてレートの上限に引っかかり、リトライも規定回数失敗してしまった様です。

Comprehend の使い方自体は過去の記事で書いてます。
参考: Amazon Comprehend でテキストのセンチメント分析

軽く紹介しておくとこんな感じですね。

import boto3


comprehend = boto3.client("comprehend")
comprehend_result = comprehend.batch_detect_sentiment(
        TextList=text_list,  # ここで判定したいテキストを渡す。
        LanguageCode="ja"
    )

応急処置的な対応としては、エラーをキャッチして自分でリトライを実装するという手もあります。
参考: 失敗しやすい処理にリトライをスクラッチで実装する

ただ、boto3はどうやら標準機能でリトライ設定の回数を設定できる様なのです。batch_detect_sentiment メソッドのドキュメントをいくら読んでもそれらしい引数がないのでできないと思ってました。
この設定は、その一歩手前、クライアントインスタンスを生成する時点で設定しておく必要があります。

ドキュメントはこちら。
参考: Retries – Boto3 1.26.142 documentation

ドキュメントでは ec2 のAPIに適用していますが、comprehendでも同じ様に使えます。以下の様にするとリトライ回数の上限を10回に上げられる様です。

import boto3
from botocore.config import Config

config = Config(
   retries = {
      'max_attempts': 10,
      'mode': 'standard'
   }
)

comprehend = boto3.client('comprehend', config=config)

mode は、 legacy, standard, adaptive の3種類があります。とりあえず、standardを選んだら良いのではないでしょうか。対応しているエラーがlegacyより多く、以下のエラーに対してリトライしてくれます。

# Transient errors/exceptions
RequestTimeout
RequestTimeoutException
PriorRequestNotComplete
ConnectionError
HTTPClientError

# Service-side throttling/limit errors and exceptions
Throttling
ThrottlingException
ThrottledException
RequestThrottledException
TooManyRequestsException
ProvisionedThroughputExceededException
TransactionInProgressException
RequestLimitExceeded
BandwidthLimitExceeded
LimitExceededException
RequestThrottled
SlowDown
EC2ThrottledException

max_attempts の方が、リトライ回数の設定ですが、正直、何回に設定したら成功する様になるのかは状況による部分が多く、僕もまだ良い策定方法がわかっていません。(というより、これまでこんなに失敗することがなかった)。一応、responseのメタデータの中に、RetryAttemptsって項目があって、何回目の再実行で成功したかとか見れるのですが、ほとんど成功するので大体これ0なんですよね。ここを監視してたら参考指標になるのかもしれませんが、これを監視できる実装にするのは面倒です。基本的には成功するものですから。

もし、上に列挙したエラーのどれかが発生してboto3の実行が止まってしまうよ、という状況が発生したら、リトライ回数の設定変更を検討に入れてみてください。

ただ、リトライというのはあくまでも同じ処理を繰り返すだけなので、百発百中で失敗する様な処理をリトライしても無意味です。ほとんど確実に成功するのに超低確率で失敗する事象が起きてしまう場合に、その失敗確率をもっと下げるという用途でのみ有効です。

Louvain法のmodularityの式について検証してみる

前回に引き続き、Louvain法の話です。前回はPythonで使う方法を紹介しましたが、今回はこの手法が利用しているモジュラリティ(modularity)という指標について検証していきます。そもそもなぜこの式でグラフのノードのクラスタを評価できるのかってのを確認しましょう。
(ちなみに、モジュラリティを最大化するアルゴリズムについては今回の記事では取り上げません。あくまでもモジュラリティの定義について見ていきます。)

前回の記事でも書いてるのですが今回の記事では主役なので定義式を再掲します。

$$Q=\frac{1}{2m}\sum_{ij}\left[A_{ij}\ -\ \frac{k_ik_j}{2m} \right]\delta(c_i, c_j).$$

$A_{ij}$: ノード$i$, ノード$j$間のエッジの重み。
$k_i, k_j$: ノード$i$, ノード$j$それぞれに接続されたエッジの重みの合計。
$m$: グラフの全てのエッジの重みの合計。
$c_i, c_j$: ノード$i$, ノード$j$それぞれが所属しているコミュニティ。
$\delta$: クロネッカーのデルタ。二つの引数が等しければ1でそれ以外は0を返す関数。

一見しただけだと、これでクラスタを評価できるってわかりにくいですよね。

順番に見ていきましょう。

まず、$\delta(c_i, c_j)$の部分です。これ、$c_i$と$c_j$が別のクラスタだったら値が$0$なので、$\sum$のそれらの項は消滅し、さらに同じクラスタだったら値が$1$になるので、それらの項はそのままの値で残ります。

なのでこの式はこうなります。

$$Q=\frac{1}{2m}\sum_{c_iとc_jが同じクラスタとなるi,j}\left[A_{ij}\ -\ \frac{k_ik_j}{2m} \right].$$

そして、$A_{ij}$はエッジの重みなので、和の中身の前半部分だけに着目すると、次の様に言えます。

$$\frac{1}{2m}\sum_{c_iとc_jが同じクラスタとなるi,j}A_{ij} = \frac{1}{2m}\{クラスタ内部のエッジの重みの総和\}$$

$m$はクラスタに関係なく、グラフによって定まってる定数なので無視して良いですから、この部分は確かに同じクラスタの中にたくさんエッジがあると値が大きくなり評価指標としての意味を持っている様に見えます。

ただし、この項しかないと、全部のノードを1個のクラスタとした場合に$Q$が最大となっしまうので、それに対応する必要があります。そこで登場するのが残りの項です。

もし、ノード$i$とノード$j$の間にエッジがなかったりあってもウェイトが非常に小さかったりすると、負の数を足すことによってモジュラリティが下がる様になってるのです。たとえば、エッジがない場合、$A_{ij}$が0ですから、$A_{ij}\ -\ \frac{k_ik_j}{2m} = -\frac{k_ik_j}{2m} $です。

そして、$k_i$、$k_j$はそれぞれのノードに接続されたエッジの重みの和ですから、重みの大きなエッジをたくさん持っているノード間についてはこの値の絶対値は大きくなります。

つまり、エッジがたくさんあるノード同士にも関わらずそれらのノード間にエッジがなかったら強くペナルティを課すよ、ってのがこの項の意味なのです。
分母の$2m$とかにもちゃんと意味があるのですが細かい話になるので割愛します。これらの細かい工夫により、-0.5から1の値を取る指標となっています。計算量とのバランスもとりながら非常によく考えられた指標だと思いました。

最後に、前回の記事でも少しだけ触れているのですが、この$\sum_{ij}$が取る和の$ij$部分について確認したのでその話書いておきます。

これは$i$と$j$がそれぞれ独立に全てのパターンを取ります。$(i, j)$と$(j, i)$は個別に足されるし、$(i, i)$も考慮されるってことですね。ライブラリの結果と、自分で実装した結果を並べてみて確認しました。とりあえず復習兼ねて前回のサンプルのグラフでやります。まずライブラリ利用。

import numpy as np
import networkx as nx
import community

# 前回の記事で生成したのと同じグラフを再現する。
# 30個のノード
node_list = list(range(30))

# 同じクラスタ内は0.5, 別クラスタは0.02の確率でエッジを生成する
edge_list = []
np.random.seed(1)  # シード固定
for i in range(30):
    for j in range(i+1, 30):
        if i // 10 == j // 10:
            if np.random.rand() < 0.5:
                edge_list.append((i, j))
        else:
            if np.random.rand() < 0.02:
                edge_list.append((i, j))

# グラフ生成
G = nx.Graph()
G.add_nodes_from(node_list)
G.add_edges_from(edge_list)

# コミュニティーの検出
partition = community.best_partition(G)

# ライブラリによるモジュラリティの計算
print(community.modularity(partition, G))
# 0.5316751700680272

同じ値を自分で計算して出してみましょう。エッジのウェイトが全部1なので、$m$はただのエッジの本数になるし、$k_i$, $k_j$はそれぞれのノードの次数に等しいことに注意してください。

Q = 0
m = G.size()
# iと jは全部の組み合わせをとる。
for i in range(len(G)):
    for j in range(len(G)):
        if partition[i] == partition[j]:
            # ノードi と ノードjの間にエッジがあったら1(ウェイト)を足す
            if (i, j) in G.edges:
                Q += 1 
            
            # Σの後半部分
            Q -= G.degree(i) * G.degree(j) / (2*m)
# 全体を2mで割る
Q/=2*m
print(Q)
# 0.5316751700680273

一致してますね。

ここから余談ですが、Qは -0.5から 1の値を取りますが、どんなグラフについてもQが-0.5を取ったり1を取ったりする分割が存在するわけではなさそうです。上記のサンプルでも0.53程度に留まっていますしね。

-0.5 の方は、完全2部グラフを構築して、さらにその2部をそれぞれコミュニティとすることで、同じクラスタ内部には1本もエッジがないのに別クラスタのノード間には必ずエッジが存在するという状態にすると発生します。

また、逆にQが大きくなるのは非連結な多数の完全グラフからなるグラフをそれぞれの完全グラフを一つのクラスタとすると、非連結なグラフが増えるにつれて1に近づいていくのを確認できています。ただ、近づくだけで1になることはないんじゃないでしょうか。(証明はできていませんが。)

便利なのでなんとなく使っていた手法でしたが、改めて指標を検証し複数のグラフや分割について自分で計算してみたことで理解を深めることができました。

モジュラリティは定義中に記号多いですしぱっと見難しそうに見えるのですが、落ち着いて見れば四則演算だけで構成されている非常にシンプルな指標なので、興味があれば皆さんもいろいろ試して遊んでみてください。

Louvain法を用いてグラフのコミュニティーを検出する

昔、グラフ関係の記事を書いてた頃に書いてたと思ってたら、まだ書いてなかったので、グラフのコミュニティーを検出する方法を書きます。

グラフのコミュニティ検出というのは、グラフのノードを複数のグループに分け、同じグループ内のノード同士ではエッジが密で、異なるグループのノード間ではエッジが疎になる様にすることを言います。人の集まりを仲良しグループでいくつかの組に分けるのに似ていますね。この記事の最後にも結果の例の画像をつけていますのでそれをみていただくとイメージがつきやすいと思います。

複数のアルゴリズムが存在するのですが、僕が一番気に入っていてよく使っているのはLouvain法による方法なのでこれを紹介していきます。
参考: Louvain method – Wikipedia

このLouvain法では、グラフとコミュニティに対してモジュラリティ(modularity)という指標を定義し、このモジュラリティが最大になる様なコミュニティの分け方を探すことでグラフのノードを分割します。モジュラリティの定義式はWikipediaにもありますが以下の通りです。

$$Q=\frac{1}{2m}\sum_{ij}\left[A_{ij}\ -\ \frac{k_ik_j}{2m} \right]\delta(c_i, c_j).$$

ここで、各記号の意味は以下の通りです。
$A_{ij}$: ノード$i$, ノード$j$間のエッジの重み。
$k_i, k_j$: ノード$i$, ノード$j$それぞれに接続されたエッジの重みの合計。
$m$: グラフの全てのエッジの重みの合計。
$c_i, c_j$: ノード$i$, ノード$j$それぞれが所属しているコミュニティ。
$\delta$: クロネッカーのデルタ。二つの引数が等しければ1でそれ以外は0を返す関数。

元論文を読めていないのですが、ライブラリの挙動を確認した結果によると、和の$ij$は$i$, $j$の全てのペアを渡る和で$i,j$の組み合わせとそれを入れ替えた$j, i$の組み合わせを両方取り、さらに$i=j$となる自己ループなエッジも考慮している様です。

さて、実際にやってみましょう。Pythonでは、communityというライブラリがあり、Louvain法が実装されています。(networkx自体にも他のコミュニティ検出のアルゴリズムが実装されているのですが、Louvain法は別ライブラリになっています。)

pipyでの登録名がpython-louvainなので、インストール時のコマンドが以下であることに注意してください。Pythonでimportするときと名前が違います。

$ pip install python-louvain

ドキュメントはこちら
参考: Community detection for NetworkX’s documentation — Community detection for NetworkX 2 documentation

早速使ってみましょう。例としてコミュニティがわかりやすいグラフを乱数で生成します。
0~29のノードを持ち、0~9, 10~19, 20〜29が同じコミュニティで、同コミュニティー内では50%の確率でエッジが貼られていて異なるミュにティだと2%の確率でしかエッジがないという設置です。

import numpy as np
import networkx as nx
import community

# 30個のノード
node_list = list(range(30))

# 同じクラスタ内は0.5, 別クラスタは0.02の確率でエッジを生成する
edge_list = []
np.random.seed(1)  # シード固定
for i in range(30):
    for j in range(i+1, 30):
        if i // 10 == j // 10:
            if np.random.rand() < 0.5:
                edge_list.append((i, j))
        else:
            if np.random.rand() < 0.02:
                edge_list.append((i, j))

# グラフ生成
G = nx.Graph()
G.add_nodes_from(node_list)
G.add_edges_from(edge_list)

さて、グラフができたのでコミュニティ検出をやっていきます。これものすごく簡単で、best_partitionというメソッドにnetworkxのグラフオブジェクトを渡すとノードと所属するグループの辞書として結果が帰ってきます。

注意点ですが、アルゴリズムの高速化のための工夫の弊害により絶対にベストな分割が得られるというわけではなく少々確率的な振る舞いをします。今回のサンプルコードでは非常にわかりやすい例を用意したので一発で成功していますが、通常の利用では何度か試して結果を比較してみるのが良いでしょう。

ではやっていきます。

partition = community.best_partition(G)

print(partition)
"""
{0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2, 8: 2, 9: 2,
10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, 16: 0, 17: 0, 18: 0, 19: 0,
20: 1, 21: 1, 22: 1, 23: 1, 24: 1, 25: 1, 26: 1, 27: 1, 28: 1, 29: 1}
"""

0〜9は2で、10〜19は0で、20〜29は1と、綺麗に分類できてますね。

画像に表示してみましょう。ドキュメントのサンプルコードだとlist(partition.values())と辞書の値をそのまま可視化メソッドに渡していますが、辞書型のデータなので順序は保証されておらずこれはリスキーな気がします。実はそれでもうまくいくのですが念のため、ノードと順番が揃う様に配列として取り出してそれを使う様にしました。

# ノードの並びと同じ順でグループのリストを配列として取得する。
partition_list = [partition[n] for n in G.nodes]
# 配置決め
pos = nx.spring_layout(G, seed=3)
# 描写
nx.draw_networkx(G, node_color=partition_list, cmap="tab10", pos=pos)

結果がこちらです。

綺麗に3グループに分かれていますね。

ちなみに、このグループ分けのモジュラリティを取得するメソッドもあります。グループ分けとグラフ本体のデータをそれぞれ渡すことで使えます。

print(community.modularity(partition, G))
# 0.5316751700680272

結構高いですね。

自分のChromeブラウザの履歴をPythonで集計する方法

以前、SQLite3の使い方を紹介しましたが、その応用です。

自分がどんなサイトにアクセスしているかを自分で集計したくなったのでその方法を調べました。(firefoxなどもそうらしいのですが) Chromeはアクセス履歴をSQLiteのDB(ファイル)に保存しています。そのため、該当ファイルに接続すればSQLで集計ができます。

Mac の場合、以下のパスに履歴ファイルがあります。拡張子がないので紛らわしいですが、Historyがファイル名で、Defaultまでがディレクトリパスです。
/Users/{ユーザー名}/Library/Application\ Support/Google/Chrome/Default/History

人によってはDefault 以外のディレクトリにあることもあるそうなので、もしそれらしいファイルがなかったら付近のディレクトリの中を探してみてください。例えば、Default/History ではなく、System Profile/Historyや、Guest Profile/Historyに存在する環境も見たことがあります。

ロックがかかったりするとよくないので、このファイルに直接繋ぐのではなく、どこかにコピーを取って、その中を確認しましょう。

# カレントディレクトリにコピーする
$ cp /Users/{ユーザー名}/Library/Application\ Support/Google/Chrome/Default/History .

早速接続してみましょう。とりあえずどんなテーブルがあるのか見てみます。

import sqlite3
import pandas as pd


# データベースに接続
con = sqlite3.connect('History')
con.row_factory = sqlite3.Row
cursor = con.cursor()

cursor.execute("SELECT name FROM sqlite_master")
pd.DataFrame(
    cursor.fetchall(),
    columns=[row[0] for row in cursor.description],
)

"""
name
0	meta
1	sqlite_autoindex_meta_1
2	urls
3	sqlite_sequence
4	visits
5	visit_source
6	visits_url_index
7	visits_from_index
8	visits_time_index
9	visits_originator_id_index
10	keyword_search_terms
11	keyword_search_terms_index1
12	keyword_search_terms_index2
13	keyword_search_terms_index3
14	downloads
15	downloads_url_chains
# 多いので以下略。自分の環境では合計35テーブルありました。
"""

目当てのアクセス履歴が保存されているテーブルは、visitsです。ただ、このテーブルはアクセスしたURLのパスそのものは含んでおらず、URLはurlsテーブルに入っていて、visitsテーブルにはurlsテーブルのidが入ってます。

urlsテーブルから1行だけ取り出してみます。(転置してprintしました。)

cursor.execute("""
    SELECT
        *
    FROM
        urls
    WHERE
        url = 'https://analytics-note.xyz/'
""")

print(pd.DataFrame(
    cursor.fetchall(),
    columns=[row[0] for row in cursor.description],
).T)
"""
                                           0
id                                        53
url              https://analytics-note.xyz/
title                                  分析ノート
visit_count                               43
typed_count                                0
last_visit_time            13327924159724978
hidden                                     0
"""

idが53らしいですね。URLとページタイトルが入っています。あと、visit_countとして訪問回数入ってます。(思ったより少ないです。他のPCやスマホで見てることが多いからでしょうか。)

次は、visitsテーブルも見てみます。urlという列にさっきのurlテーブルのidが入ってる(列名をurls_idとかにして欲しかった)ので、それで絞り込んでいます。

cursor.execute("""
    SELECT
        *
    FROM
        visits
    WHERE
        url = 53
    ORDER BY
        visit_time DESC
    LIMIT
        1
""")

print(pd.DataFrame(
    cursor.fetchall(),
    columns=[row[0] for row in cursor.description],
).T)
"""
                                                 0
id                                            8774
url                                             53
visit_time                       13327924159724978
from_visit                                       0
transition                               805306370
segment_id                                       9
visit_duration                                   0
incremented_omnibox_typed_score                  0
opener_visit                                     0
originator_cache_guid                             
originator_visit_id                              0
originator_from_visit                            0
originator_opener_visit                          0
is_known_to_sync                                 0
"""

visit_timeでアクセスした時刻がわかる様ですね。

ちなみに、UNIXタイムスタンプ等に比べて非常にでっかい数字が入っていますが、これはChromeの仕様で、1601年1月1日(UTC)からの経過時間をマイクロ秒(μ秒、100万分の1秒)単位で計測したものだそうです。(この情報は他サイトで入手したのですが、できればChromeの公式ドキュメントか何かで確認したいところです。概ね正しそうなのですが、自分の環境だと数分ずれていて違和感あります。)

さて、テーブルの構造がわかったので具体的なSQL構築していきます。たとえば、直近1000アクセス分の情報を見る場合は以下の様なクエリで良いのではないでしょうか。

sql = """
    SELECT
        visits.visit_time,
        urls.url,
        urls.title
    FROM
        visits
    LEFT OUTER JOIN
        urls
    ON
        visits.url = urls.id
    ORDER BY
        visits.visit_time DESC
    LIMIT 1000;
"""

cursor.execute(sql)

df = pd.DataFrame(
    cursor.fetchall(),
    columns=[row[0] for row in cursor.description],
)

仕上げに、visit_time がこのままだと使いにくすぎるので、これを人が読めるように変換します。手順は以下の通りです。

  1. 10^6で割ってマイクロ秒を秒に変換する。
  2. 1601年1月1日のタイムスタンプを足して現在時刻のタイムスタンプに変換する。
  3. 時刻型に戻す。
  4. 9時間足して日本時間にする。

順番にやっていくとこんなコードになります。

from datetime import datetime
from datetime import timedelta


df["アクセス時刻"] = df["visit_time"].apply(
    lambda x: x/10**6
).apply(
    lambda x: x+datetime(1601, 1, 1).timestamp()
).apply(
    datetime.fromtimestamp
).apply(
    lambda x: x+timedelta(hours=9)
)

もっといいやり方はいくらでもある気がする(1601/1/1のタイムスタンプとか9時間分の秒数とか事前に定数として持っておいてマイクロを取った時に足してしまうなど)のですが、可読性重視でいけばこの書き方で良いと思います。

dfコマンドとduコマンドを用いたディスク容量使用状況の調査方法

はじめに

何度か必要になったことがあるのですがその度に使い方を忘れてしまっているdfコマンドとduコマンドの使い方のメモです。

それぞれ一言で言ってしまえば次の様になります。
– dfコマンドはディスクの使用済み容量と空き容量を確認できるコマンドです。
– duコマンドは指定したディレクトリの使用容量を調べるコマンドです。

どっちがどっちだったかよく忘れるのでメモっておこうという記事の目的がこれで終わってしまったのですが、あんまりなのでもう少し続けます。

このブログが動いてるサーバーもそうですし、投資や各種技術検証で使っているEC2インスタンスなども費用が自己負担なのでディスクサイズが最低限の状態で動いており、時々ディスクの空き具合を気にしてあげる必要があります。

また、職場で利用しているデータ分析関係のワークフローを構築しているサーバーも本番サービスのサーバーと違って最低限構成なので何度かディスクのパンクによる障害を経験しています。

そんな時に調査するのにdfコマンドとdfコマンドを使います。

dfコマンドについて

まず、dfコマンドからです。

先ほどざっと書いた通りで、このコマンドを打つとマウントされているディスクの一覧が表示されてそれぞれの容量と使用量、そして丁寧なことに使用率なども表示してくれます。

ディスク容量のパンクが疑われる場合、まずこのコマンドで状況を確認にすることになると思います。(本来は、パンクが疑われなくても時々監視しておくべきなのだと思いますが。)

自分が持ってるとあるEC2インスタンスで実行した結果が以下です。(EC2で実行したらヘッダーが日本語になりました。)

$ df
ファイルシス   1K-ブロック    使用  使用可 使用% マウント位置
devtmpfs            485340       0  485340    0% /dev
tmpfs               494336       0  494336    0% /dev/shm
tmpfs               494336     352  493984    1% /run
tmpfs               494336       0  494336    0% /sys/fs/cgroup
/dev/xvda1         8376300 3768544 4607756   45% /

-h オプション(“Human-readable”の頭文字らしい)をつけると、人間が読みやすいキロメガギガ単位で表示してくれます。

$ df -H
ファイルシス   サイズ  使用  残り 使用% マウント位置
devtmpfs         497M     0  497M    0% /dev
tmpfs            507M     0  507M    0% /dev/shm
tmpfs            507M  361k  506M    1% /run
tmpfs            507M     0  507M    0% /sys/fs/cgroup
/dev/xvda1       8.6G  3.9G  4.8G   45% /

一番下の /dev/xvda1がメインのファイルシステムですね。現時点で45%使っていることがわかります。まだ余裕。

Mac (zsh)だと出力の形がちょっと違います。

% df -h
Filesystem       Size   Used  Avail Capacity iused      ifree %iused  Mounted on
/dev/disk3s1s1  228Gi  8.5Gi  192Gi     5%  356093 2018439760    0%   /
devfs           199Ki  199Ki    0Bi   100%     690          0  100%   /dev
/dev/disk3s6    228Gi   20Ki  192Gi     1%       0 2018439760    0%   /System/Volumes/VM
/dev/disk3s2    228Gi  4.4Gi  192Gi     3%     844 2018439760    0%   /System/Volumes/Preboot
/dev/disk3s4    228Gi  4.8Mi  192Gi     1%      43 2018439760    0%   /System/Volumes/Update
/dev/disk1s2    500Mi  6.0Mi  482Mi     2%       1    4936000    0%   /System/Volumes/xarts
/dev/disk1s1    500Mi  6.2Mi  482Mi     2%      30    4936000    0%   /System/Volumes/iSCPreboot
/dev/disk1s3    500Mi  1.0Mi  482Mi     1%      62    4936000    0%   /System/Volumes/Hardware
/dev/disk3s5    228Gi   22Gi  192Gi    11%  460541 2018439760    0%   /System/Volumes/Data
map auto_home     0Bi    0Bi    0Bi   100%       0          0  100%   /System/Volumes/Data/home

iノードの利用状況も一緒に出してくれている様です。この端末にはそんな大きなファイルを溜めてないのでスカスカですね。

duコマンドについて

この記事用にdfコマンドを試したサーバーや端末は問題なかったのですが、もしディスク容量が逼迫していることが判明したら原因を探すことになります。

その時は当然ディレクトリごとの使用量を調べて、原因となっているディレクトリを探すのですが、その際に使うのがduコマンドです。

du -sh {対象ディレクトリ}
の構文で使うことが多いです。-hオプションはdfの時と同じで人間が読みやすい単位で表示するものです。-sは対象のディレクトリの合計だけを表示するものです。これを指定しないと配下のディレクトリを再起的に全部表示するので場所によっては出力が膨大になります。

対象ディレクトリを省略するとカレントディレクトリが対象になります。

$ du -sh
1.2G    .

上記の様に、1.2G使ってるって結果が得られます。もう少しだけ内訳みたい、て場合は、-sの代わりに、-d {深さ} オプションで何層下まで表示するかを決められます。 -s は -d 0 と同じです。

$ du -h -d 0
1.2G    .
$ du -h -d 1
4.0K    ./.ssh
973M    ./.pyenv
180M    ./.cache
3.0M    ./.local
9.0M    ./.ipython
104K    ./.jupyter
0       ./.config
1.2G    .

あとは、 -d 1 はオプションでやらず、ディレクトリの指定を ./* みたいにワイルドカードで指定しても似たことができますね。(-s は付けてください)

注意ですが、 ルートディレクトリなどで(/) でこのコマンド打つとかなり重い処理になる可能性があります。注意して実行しましょう。

現実的には、いろんなメディアファイルが配置されそうなパスとかログファイルが出力されている先とかそういうパスに当たりをつけてその範囲で調査するコマンドと考えた方が安全です。

os.walk()メソッドでファイル走査

最近、ちょっとしたツールや簡易的なスクリプトの構築でChatGPTを使っているのですが、ファイルの整理系のPythonスクリプトを書いてもらうときにしばしos.walkというメソッドを使っていて、それを自分が知らなかったのでドキュメントを読んでみたのでそのまとめです。

基本的には、あるディレクトリ配下のファイルやディレクトリを一通りリストアップするメソッドになります。
参考: os — 雑多なオペレーティングシステムインターフェース — Python 3.11.3 ドキュメント

同様の目的だと僕はこれまでglobを使ってました。以下の記事にしています。
参考: globでサブフォルダを含めて再帰的にファイルを探索する

まだ、自分で書くときはglob使ってますがChatGPTの出力を見てるとwalkの方が使い勝手が良さそうな気がしています。

使い方試すためにこんな感じて適当にファイルやディレクトリ(いくつかは空ディレクトリ)を用意しました。

$ find targetdir | sort
targetdir
targetdir/001.txt
targetdir/002.txt
targetdir/subdir1
targetdir/subdir1/101.txt
targetdir/subdir1/102.txt
targetdir/subdir1/subdir1_1
targetdir/subdir1/subdir1_1/111.txt
targetdir/subdir1/subdir1_1/subdir1_1_1
targetdir/subdir2
targetdir/subdir2/201.txt
targetdir/subdir3

使いたですが、上にリンクを貼ったドキュメントの様に、走査したいトップディレクトリを指定して渡してあげると、その配下の各ディレクトリごとに、(dirpath, dirnames, filenames) というタプルを返すジェネレーターを作ってくれます。

dirpath に最初に指定したディレクトリを含む配下のディレクトリたちが順に格納され、difnamesにその直下のディレクトリのディレクトリ名のリストが入り、filenamesにdirpathのディレクトリ直下のファイル名がリストで入ります。

ジェネレーターなのでfor文で回して結果をprintしてみます。

for dirpath, dirnames, filenames in os.walk("targetdir"):
    print("dirpath:", dirpath)
    print("dirnames:", dirnames)
    print("filenames:", filenames)
    print("- "*10)

"""
dirpath: targetdir
dirnames: ['subdir3', 'subdir2', 'subdir1']
filenames: ['002.txt', '001.txt']
- - - - - - - - - -
dirpath: targetdir/subdir3
dirnames: []
filenames: []
- - - - - - - - - -
dirpath: targetdir/subdir2
dirnames: []
filenames: ['201.txt']
- - - - - - - - - -
dirpath: targetdir/subdir1
dirnames: ['subdir1_1']
filenames: ['101.txt', '102.txt']
- - - - - - - - - -
dirpath: targetdir/subdir1/subdir1_1
dirnames: ['subdir1_1_1']
filenames: ['111.txt']
- - - - - - - - - -
dirpath: targetdir/subdir1/subdir1_1/subdir1_1_1
dirnames: []
filenames: []
- - - - - - - - - -
"""

簡単ですね。

globでファイルを探索した時は最初からファイル名ではなくファイルパスの形で得られていましたが、os.walkではディレクトリパスとファイル名が個別に得られるのでファイルパスが欲しい場合は連結する必要があります。せっかくosライブラリをimportしてるので専用メソッド使ってやりましょう。次がos.walkを使って配下のファイル一覧を得る方法です。

for dirpath, dirnames, filenames in os.walk("targetdir"):
    for filename in filenames:
        print(os.path.join(dirpath, filename))
"""
targetdir/002.txt
targetdir/001.txt
targetdir/subdir2/201.txt
targetdir/subdir1/101.txt
targetdir/subdir1/102.txt
targetdir/subdir1/subdir1_1/111.txt
"""

簡単ですね。ディレクトリでも同じ様にできると思いますし、連結してしまえばただの文字列なので、拡張子(末尾の文字列)で絞り込むと言った応用も簡単にできると思います。

J-Quants API の基本的な使い方

J-Quants APIの紹介

個人的に株式投資を行なっており、その他にも時系列データ分析の学習などの用途でも株価データを頻繁に利用しています。元々、それらのデータをどうやって用意するかが課題で証券会社のページ等々から引っ張ってきていたのですが、この度、日本取引所グループが公式にJ-Quants API という非常に便利なAPIを提供してくれるようになりました。

参考: J-Quants

ベータ版が登場した時から即座に登録して使っていたのですが、この度正式にリリースされましたので使い方とか紹介していこうと思います。

会員登録しないと使えないので、興味のある人は上記のリンクから登録してください。プラン表にある通り、Freeプランもあるので使い勝手を見るには無料で試せます。(Freeプランは12週間遅延のデータなので運用には使いにくいですが、時系列データ分析のサンプルデータ取得等の用途では十分活用できます。)

ちなみに僕はライトプランを使っているでこの記事ではそれを前提とします。(将来的にスタンダード以上のプランに上げるかも。)

肝心の取得できる情報なのですが、東証各市場の上場銘柄の一覧や、株価の四本値、財務情報や投資部門別情報などの情報を取得することができます。物によっては他の手段での収集がかなり手間なものもあり、とても貴重な情報源になると思います。

基本的な使い方

APIは認証周り以外はかなりシンプルな設計になっていて、API仕様にそって、認証情報とパラメーターを指定して規定のURLを叩けばJSONで結果を得られます。
参考: API仕様 J-Quants API

認証だけちょっと珍しい形式をとっていて、以下の手順を踏みます。

  1. メールアドレスとパスワードを用いてWeb画面かAPIでリフレッシュトークンを取得する。
  2. リフレッシュトークンを用いて専用APIでIDトークンを取得する。
  3. IDトークンを用いて、各種情報を取得するAPIを利用する。

ベータ版の初期はリフレッシュトークンの取得がWebでログインしてコピーする以外になかったのでそこもAPI化されたことで大変便利になりました。

なお、リフレッシュトークンの有効期限は1週間、IDトークンの有効期限は24時間です。

Pythonでトークンの取得

ではここからやっていきましょう。実は専用クライアントライブラリ等も開発が進んでるのですが、requests でも十分使っていけるのでrequestsで説明します。

まず、リフレッシュトークンの取得です。
参考: リフレッシュトークン取得(/token/auth_user) – J-Quants API

指定のURLにメールアドレスとパスワードをPOSTし、戻ってきたJSONから取り出します。

import json
import requests
import pandas as pd


# POSTするデータを作る。
email = "{メールアドレス}"
password = "{パスワード}"
account_data = json.dumps({
        "mailaddress": email,
        "password": password,
    })

auth_user_url = "https://api.jquants.com/v1/token/auth_user"

auth_result = requests.post(auth_user_url, data=account_data)
refresh_token = auth_result.json()["refreshToken"]

これでリフレッシュトークンが取得できたので、次は IDトークンを取得します。
参考: IDトークン取得(/token/auth_refresh) – J-Quants API

これはクエリパラメーターでURL内にリフレッシュトークンを指定して使います。(ここは少し設計の意図が不明ですね。何もPOSTしてないのでGETが自然な気がしますが、GETだとエラーになります。)

auth_refresh_url=f"https://api.jquants.com/v1/token/auth_refresh?refreshtoken={refresh_token}"
refresh_result = requests.post(auth_refresh_url)
id_token = refresh_result.json()["idToken"]

これでIDトークンも取得できました。

ちなみに両トークンとも1000文字以上の結構長い文字列です。

各種情報の取得

IDトークンが手に入ったらいよいよこれを使って必要な情報を取得します。APIの仕様書はかなりわかりやすいので、あまり迷わずに使えると思います。
たとえば、株価4本値が必要だとしましょう。

参考: 株価四本値(/prices/daily_quotes) – J-Quants API

仕様書にある通り、証券コード(code) や日付(dateもしくはfrom/to)をURLで指定し、ヘッダーにさっきの認証情報をつけて、GETメソッドでAPIを叩きます。

たとえば、トヨタ自動車(7203)の昨年1年間の株価を取得してみましょう。必須ではありませんが、得られたデータはDataFrameに変換しやすい形で帰ってくるのでDataFrameにしました。

import pandas as pd


code = "7203" # 4桁のコードでも5桁のコード72030でもよい。
from_ = "2022-01-01"
to_ = "2022-12-31"

daily_quotes_url = f"https://api.jquants.com/v1/prices/daily_quotes?code={code}&from={from_}&to={to_}"
# idトークンはヘッダーにセットする
headers = {"Authorization": f"Bearer {id_token}"}
daily_quotes_result = requests.get(daily_quotes_url, headers=headers)
# DataFrameにすると使いやすい
daily_quotes_df = pd.DataFrame(daily_quotes_result.json()["daily_quotes"])

取れたデータを表示してみましょう。ブログレイアウトの都合ですが3レコードほどを縦横転置してprintしたのが以下です。

print(daily_quotes_df.head(3).T)
"""
                              0               1              2
Date                 2022-01-04      2022-01-05     2022-01-06
Code                      72030           72030          72030
Open                     2158.0          2330.0         2284.5
High                     2237.5          2341.0         2327.0
Low                      2154.5          2259.0         2277.0
Close                    2234.5          2292.0         2284.5
Volume               43072600.0      53536300.0     37579600.0
TurnoverValue     94795110250.0  123423876550.0  86533079050.0
AdjustmentFactor            1.0             1.0            1.0
AdjustmentOpen           2158.0          2330.0         2284.5
AdjustmentHigh           2237.5          2341.0         2327.0
AdjustmentLow            2154.5          2259.0         2277.0
AdjustmentClose          2234.5          2292.0         2284.5
AdjustmentVolume     43072600.0      53536300.0     37579600.0
"""

バッチリデータ取れてますね。

もう一つ、例として財務情報も取ってみましょう。

参考: 財務情報(/fins/statements) – J-Quants API

取れる情報の種類がめっちゃ多いので結果は一部絞ってprintしてます。

code = "7203"

statements_url = f"https://api.jquants.com/v1/fins/statements?code={code}"

headers = {"Authorization": f"Bearer {id_token}"}
statements_result = requests.get(statements_url, headers=headers)
statements_df = pd.DataFrame(statements_result.json()["statements"])

# 100列以上取れるので、1レコードを20列だけprint
print(statements_df.head(1).T.head(20))
"""
DisclosedDate                                          2018-05-09
DisclosedTime                                            13:25:00
LocalCode                                                   72030
DisclosureNumber                                   20180312488206
TypeOfDocument              FYFinancialStatements_Consolidated_US
TypeOfCurrentPeriod                                            FY
CurrentPeriodStartDate                                 2017-04-01
CurrentPeriodEndDate                                   2018-03-31
CurrentFiscalYearStartDate                             2017-04-01
CurrentFiscalYearEndDate                               2018-03-31
NextFiscalYearStartDate                                2018-04-01
NextFiscalYearEndDate                                  2019-03-31
NetSales                                           29379510000000
OperatingProfit                                     2399862000000
OrdinaryProfit                                                   
Profit                                              2493983000000
EarningsPerShare                                            842.0
DilutedEarningsPerShare                                    832.78
TotalAssets                                        50308249000000
Equity                                             19922076000000
"""

以上が、J-Quants APIの基本的な使い方になります。

クライアントライブラリについての紹介

ここまで、requestsライブラリを使ってAPIを直接叩いてましたが、実はもっと手軽に使える様にするための専用ライブラリの開発も進んでおりすでにリリースされています。

参考: jquants-api-client · PyPI

ライブラリも有志の方々が頑張って開発していただけていますが、なにせAPI本体もまだリリースされたばかりですし、まだバグフィックスなどもされている最中の様なので今時点では僕はこちらのライブラリは使っていません。

API本体のアップデートとかも今後多そうですしね。

ただ、将来的にはこちらのライブラリが非常に便利な物になると期待しているので、API/ライブラリ双方の開発が落ち着いてきたら移行したいなと思ってます。

PythonでSQLite3を使う

前回の記事に引き続き、SQLite3の話です。

今回はPythonのコードの中でSQLite3を使っていきます。Pythonでは標準ライブラリにSQLite3用のインターフェースがあり、特に何か追加でインストールしなくてもSQLite3を使えます。MySQL等と比較したときの大きなアドバンテージですね。
参考: sqlite3 — DB-API 2.0 interface for SQLite databases — Python 3.11.3 documentation

使い方はドキュメントにあるので早速使っていきましょう。

DB(SQLiter3のDBはただのファイル)は前回の記事で作ったのがあるのでそれに接続します。存在しなかったら勝手に生成されるのでここはあまり難しいところではありません。

前回の記事を少しおさらいしておくと、DBファイル名は、mydatabase.dbで、userってテーブルができています。これの全レコードをSELECTするコードは次の様になります。

import sqlite3


# データベースに接続
con = sqlite3.connect('mydatabase.db')
cursor = con.cursor()
# SQLの実行と結果の表示
cursor.execute("SELECT * FROM user")
row = cursor.fetchall()
print(row)
# 以下結果
# [(1, 'Alice', 20), (2, 'Bob', 30), (3, 'Charlie', 40)]

簡単ですね。

結果はタプルのリストで得られました。列名(SELECT文ではアスタリスク指定にしている)の一覧が欲しいと思われますが、それは、coursor.descriptionという属性に入ってます。ただ、入り方がちょっとクセがあって、DB API の仕様との整合性のために、不要なNone が入ってます。

# 列名 + 6個の None という値が入ってる
print(cursor.description)
"""
(
    ('id', None, None, None, None, None, None),
    ('name', None, None, None, None, None, None),
    ('age', None, None, None, None, None, None)
)
"""

# 列名の一覧を得るコードの例
print([row[0] for row in cursor.description])
# ['id', 'name', 'age']

これだとちょっと使いにくいということもあるともうのですが、row_factory というのを使うと、列名: 値 の辞書で結果を得ることができます。ドキュメントに「How to create and use row factories」ってしょうがあるのでそれを参考にやってみます。

con = sqlite3.connect('mydatabase.db')
# row_factoryを指定
con.row_factory = sqlite3.Row
cursor = con.cursor()
cursor.execute("SELECT * FROM user")
row = cursor.fetchall()
print(row)
# [<sqlite3.Row object at 0x109b4de70>, <sqlite3.Row object at 0x109b4dae0>, <sqlite3.Row object at 0x109b4fc70>]

# これらはそれぞれ辞書形でデータが入っているので次の様にして値を取り出せる。
print(row[0]["name"])
# Alice
print(row[0]["age"])
# 20


# 上記の状態だと、DataFrameへの変換が簡単
import pandas as pd


print(pd.DataFrame(row))
"""
   0        1   2
0  1    Alice  20
1  2      Bob  30
2  3  Charlie  40
"""

これでSELECTはできましたね。

次はINSERT等のデータの編集の話です。INSERTだけ取り上げますが、UPDATEもDELETEも同様です。

基本的にexecuteでクエリを実行できるのですが、端末でSQLite3を使う時と違って、明示的にコミットしないと、セッションを切って繋ぎ直したらデータが消えてしまいます。commit()を忘れない様にしましょう。

con = sqlite3.connect('mydatabase.db')
cursor = con.cursor()
# 1行挿入
cursor.execute("INSERT INTO user (name, age) VALUES (?, ?)", ["Daniel", 35])
# コミット
con.commit()
# レコードが1レコード増えてる
cursor.execute("SELECT * FROM user")
row = cursor.fetchall()
print(row)
# [(1, 'Alice', 20), (2, 'Bob', 30), (3, 'Charlie', 40), (4, 'Daniel', 35)]

この他 CREATE TABLE等々も実行できますので以上の知識で通常の使用はできると思います。

ただ、Pythonで使う場合と端末で直接使う場合で異なる点もあります。それが、.(ドット)で始まるSQLite3の固有コマンドたちが使えないってことです。

.table が使えないのでテーブルの一覧を得るには代わりに次の様にします。

# テーブル名の一覧を取得するSQL文を実行
cursor.execute("SELECT name FROM sqlite_master")
print(cursor.fetchall())
# [('user',)]

そして、特定のテーブルのスキーマを確認するクエリはこちらです。userテーブルしか作ってないので、userテーブルで実験します。

# スキーマを取得するSQL
cursor.execute("SELECT sql FROM sqlite_master WHERE name = 'user'")
print(cursor.fetchone()[0])
"""
CREATE TABLE user (
    id INTEGER PRIMARY KEY,
    name TEXT,
    age INTEGER
)
"""

どちらも sqlite_master ってテーブルからデータを引っ張ってきてることからわかる通り、メタデータ的なテーブルが存在してその中に情報があるということです。

ちなみに、この sqlite_master に含まれる列名の一覧は次の様になってます。 sqlite_master の中に sqlite_master の情報はなさそうなので、*で全列SELECTして列名の一覧を取得しています。

cursor.execute("SELECT * FROM sqlite_master WHERE name = 'user'")
print([row[0] for row in cursor.description])
# ['type', 'name', 'tbl_name', 'rootpage', 'sql']

以上で、PythonでSQLite3を使える様になると思います。

SQLite3入門

SQLite3の概要

SQLite3というのは軽量なRDBMSです。C言語で書かれてます。特徴としては、データベースをファイルとして扱い、サーバーが不要って点が挙げられます。

リレーショナルデータベースをちょっと使いたいけど、わざわざMySQLを導入するほどではないとか既存のDB環境は汚したくないってときに重宝します。

ただ、当然ですがMySQLではないので、SQL以外のコマンド群、MySQLで言えばSHOWで始まる各種コマンドなどが無く、その代わりにSQLite3専用のコマンドが用意されていたりして、最初は少し戸惑ったので記事としてまとめておきます。今回はとりあえず端末で直接使う場面を想定して書きます。次回の記事で、Pythonライブラリから使う方法を書く予定です。

インストール方法

macの場合。最初からインストールされています。
バージョンを上げたい場合は、Homebrewでもインストールできるようです。

$ brew install sqlite3

Linux (EC2のAmazon Linux2 を想定) の場合、こちらも最初から入っていますが、yumで入れることもできます。

$ sudo yum update
$ sudo yum install sqlite

ディストリビューションによってはyumでは無く、apt-get等別のコマンドのことがあるので確認して入れたほうが良いでしょう。

Windowsの場合、公式サイトのダウンロードページでバイナリファイルが配布されているのでダンロードして使います。

SQLite3の起動と終了

MySQLとかだと接続情報(ユーザー名やパスワード、外部であればホスト名など)を含む長いコマンドを打ち込んで接続しますが、SQLite3はDBファイル名を指定して実行するだけでそのファイルに接続してDBとして使えるようになります。1ファイルが1データベースで、その中にテーブルを作っていくイメージです。MySQLは一つのRDBMSに複数のデータベース(スキーマ)を作れるのでこの点が違いますね。

mydatabase.db というファイルで起動する場合のコマンドは下記です。既存ファイルを指定しても良いし、無ければ新規に作成されます。
sqlite3 を起動するとプロンプトが sqlite> に変わります。

$ sqlite3 mydatabase.db
SQLite version 3.39.5 2022-10-14 20:58:05
Enter ".help" for usage hints.
sqlite>

終了するには、 .exit か .quit を入力します。 .exitの方はリターンコードを一緒に指定することもできます。先頭に.(ドット)が必要なので注意してください。

sqlite> -- 終了コマンド2種類。どちらかで閉じる。
sqlite> .quit
sqlite> .exit

個人的なぼやきですが、sqlite3の入門記事を書く人はSELECT文とかより先に終了方法を書くべきだと思っています。MySQLのノリでquitとかexitとか入れると syntax error が起きます。vimの終了方法並みに重要な知識です。

SQLite3におけるSQLの実行

SQLの実行は特に変わったことがありません。MySQLとほぼ同じように実行できます。

テーブルを作って、データを入れて、それを選択してみましょう。

-- テーブルの作成
sqlite> CREATE TABLE user (
    id INTEGER PRIMARY KEY,
    name TEXT,
    age INTEGER
);

-- データの挿入
sqlite> INSERT INTO user (name, age) VALUES ('Alice', 20);
sqlite> INSERT INTO user (name, age) VALUES ('Bob', 30);
sqlite> INSERT INTO user (name, age) VALUES ('Charlie', 40);

-- データの検索
sqlite> SELECT * FROM user;
-- 以下結果
1|Alice|20
2|Bob|30
3|Charlie|40

データ挿入時に id は指定しませんでしたが勝手に連番入れてくれてます。これはsqlite3の特徴の一つで、すべてのテーブルにはrowidっていう一意の行idが振られていて、 INTEGER PRIMARY KEY の列を作るとその列はrowid列の別名となり、自動的に振られたidが格納されるようになるのです。

DELETE や UPDATEなども他のRDBMSと同様の構文で行えるのですがこの記事では省略します。

SQLite3のコマンド

冒頭に少し書いた通り、SQLite3では、SQL以外にも独自のコマンドがあります。その代わり他のRDBMS固有のコマンドは使えません。SHOWなんとか、と打ちたくなって動かないことに気づいたら思い出しましょう。

そのコマンド群は . (ドット) 始まりになっています。代表的なものをいつくか挙げておきます。

  • .help: コマンドの一覧と使用方法を表示
  • .tables: データベース内のテーブル一覧を表示
  • .schema: テーブルのスキーマを表示 (指定したテーブルのCREATE TABLE文が見れる)
  • .backup: データベースをバックアップする
  • .import: CSVファイルなどの外部ファイルからデータを読み込み、テーブルに挿入する
  • .quitまたは.exit: SQLite3コマンドラインツールを終了する
  • .headers on|off: Selectした結果のヘッダーとして列名を表示するかどうか切り替える。初期設定はoff

一番最初に書いた .help で コマンドを一覧表示できるので一度一通りみておくと良いでしょう。(helpの起動方法が独自コマンドってこれもなかなか初心者に不親切な設計ですよ。)

まとめ

以上の内容を把握しておけば、最低限の用途でSQLite3を利用できるではと思います。

そして、最低限の用途で済まなくなったらSQLite3について詳しくなるよりも他のRDBMSへの移行を検討していくのが良いかなと考えています。簡易版RDBMSですからね。

matplotlibでオリジナルのカラーマップを作る

matplotlibでオリジナルの配色を使いたかったのでその方法を調べました。基本的にグラフを書くときにそれぞれのメソッドの引数で逐一RGBの値を指定していっても独自カラーを使うことはできるのですが、カラーマップとして作成すると、既存の配色と同じような雰囲気で使うことができるので便利です。

既存のカラーマップでも利用されている、以下の二つのクラスを使って作成できます。
matplotlib.colors.ListedColormap (N色の色を用意して選択して使う)
matplotlib.colors.LinearSegmentedColormap (線形に連続的に変化する色を使う)

上記のリンクがそれぞれのクラスのドキュメントですが、別途カラーマップ作成のページもあります。
Creating Colormaps in Matplotlib — Matplotlib 3.7.1 documentation

まず、ListedColormapのほうから試してみましょう。これはもう非常に単純で、色のリストを渡してインスタンスを作れば完成です。また、name引数で名前をつけておけます。省略するとname=’from_list’となり、わかりにくいので何かつけましょう。名前をつけておくとregisterってメソッドでカラーマップの一覧に登録することができ、ライブラリの既存のカラーマップと同じように名前で呼び出せるようになります。

それでは適当に、赤、緑、青、黄色、金、銀の6色で作ってみます。色の指定はRGBコードの文字列、0~1の実数値3個のタプル、matplotlibで定義されている色であれば色の名前で定義できます。例として色々試すためにそれぞれやってみます。

from matplotlib.colors import ListedColormap
import matplotlib as mpl
import numpy as np

colors = [
    "#FF0000",  # 赤
    (0, 1, 0),  # 緑
    (0, 0, 1, 1),  # 青(不透明度付き)
    "yellow",
    "gold",
    "silver",
]
pokemon_cmap = ListedColormap(colors, name="pokemon")

# 配色リストに登録するとnameで使えるようになる。(任意)
mpl.colormaps.register(pokemon_cmap)

これで新しいカラーマップができました。試しに使ってみましょう。文字列でつけた名前を渡してもいいし、先ほど作ったカラーマップのインスタンスを渡しても良いです。

# データを作成
np.random.seed(0)
data = np.random.randn(10, 10)

# 作図
fig = plt.figure(facecolor="w")
ax = fig.add_subplot(1, 1, 1)
psm = ax.pcolormesh(data, cmap="pokemon")
# psm = ax.pcolormesh(data, cmap=pokemon_cmap)  # 名前でななくカラーマップインスタンスを渡す場合
fig.colorbar(psm, ax=ax)
plt.show()

このような図ができます。色が適当なので見やすくはありませんが、確かに先ほど指定した色たちが使われていますね。

次は連続的に値が変化するLinearSegmentedColormapです。これは色の指定がものすごくややこしいですね。segmentdataっていう、”red”, “green”, “blue”というRGBの3色のキーを持つ辞書データで、それぞのチャネルの値をどのように変化させていくのかを指定することになります。サンプルで作られているデータがこれ。

cdict = {'red':   [(0.0,  0.0, 0.0),
                   (0.5,  1.0, 1.0),
                   (1.0,  1.0, 1.0)],

         'green': [(0.0,  0.0, 0.0),
                   (0.25, 0.0, 0.0),
                   (0.75, 1.0, 1.0),
                   (1.0,  1.0, 1.0)],

         'blue':  [(0.0,  0.0, 0.0),
                   (0.5,  0.0, 0.0),
                   (1.0,  1.0, 1.0)]}

各キーの値ですが3値のタプルのリストになっています。

リスト内のi番目のタプルの値を(x[i], y0[i], y1[i]) とすると、x[i]は0から1までの単調に増加する数列になっている必要があります。上の例で言えば、x[0]=0でx[1]=0.5でx[2]=1なので、0~0.5と0.5~1の範囲のredチャンネルの値の変化がy0,y1で定義されています。

値がx[i]からx[i+1]に変化するに当たって、そのチャンネルの値がy1[i]からy0[i+1]へと変化します。上記の”red”の例で言えば値が0〜0.5に推移するに合わせて0から1(255段階で言えば
255)へと変化し、値が0.5~1と推移する間はずっと1です。

y0[i]とy1[i]に異なる値をセットするとx[i]のところで不連続な変化も起こせるみたいですね。

あと、連続値とは言え実装上はN段階の色になる(そのNが大きく、デフォルトで256段階なので滑らかに変化してるように見える)ので、そのNも指定できます。

とりあえず上のcdictデータでやってみましょう。

newcmp = LinearSegmentedColormap('testCmap', segmentdata=cdict, N=256)

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)

pms = ax.pcolormesh(data, cmap=newcmp)
fig.colorbar(pms, ax=ax)
plt.show()

結果がこちら。

値が小さい時はRGBが全部0なので黒くて、徐々に赤みを増し、途中から緑の値も増え始めるので黄色に向かって、さらに途中から青も増え始めるので最大値付近では白になっていってます。わかりやすいとは言い難いですがなんとなく挙動が理解できてきました。

流石にこのやり方で細かい挙動を作っていくのは大変なので、LinearSegmentedColormap.from_listっていうstaticメソッドがあり、こちらを使うとただ色のリスト渡すだけで等間隔にカラーマップ作ってくれるようです。

また、その時zipを使って区間の調整もできます。使用例の図は省略しますが次のように書きます。(ドキュメントからそのまま引用。)

colors = ["darkorange", "gold", "lawngreen", "lightseagreen"]
cmap1 = LinearSegmentedColormap.from_list("mycmap", colors)

nodes = [0.0, 0.4, 0.8, 1.0]
cmap2 = LinearSegmentedColormap.from_list("mycmap", list(zip(nodes, colors)))

これのほうがずっと楽ですね。