np.whereで効率的に値を出し分ける

今回もnumpyのテクニックの紹介です。np.whereというメソッドを解説します。

参考: numpy.where — NumPy v2.0 Manual

これは何かというと、第1引数にTrue/Falseで評価できるデータの配列を渡すとその評価に応じてTrueなら第2引数、Falseなら第3引数の値を返す、というものです。

第2, 第3引数に渡すのは第1引数に渡した配列と同じ長さ(多次元なら全て同じ)でも良いし、定数であったり、ブロードキャストすれば同じ形にできるものなら何でも良いです。

一番シンプルな例としては、条件を満たすかどうかでそれぞれ異なる定数を返すようなものでしょうか。

import numpy as np


scores = np.array([45, 85, 72, 50, 90])
results = np.where(scores >= 60, '合格', '不合格')
print(results)
# ['不合格' '合格' '合格' '不合格' '合格']

説明いらないと思いますが、60点以上なら合格、と判定するメソッドですね。

上記の例のように、事前にTrue/False の配列を作っておくのではなく、何かしらの条件式を代1引数に渡すような使い方になると思います。条件に応じて何かしらの演算を行いたい場合は、第2, 第3引数に計算式を入れて結果を渡すような形になります。例えば、偶数なら1/2, 奇数なら 3倍して1を足す、みたいな処理をするならこうです。

np.where(scores%2 == 0, scores/2, 3*scores+1)
# array([136., 256.,  36.,  25.,  45.])

1次元配列の場合は、内包表記でもほぼ同じことができるのでありがたみが薄いですが、np.whereは多次元配列で便利なことがあります。(単純に、内包表記の方が不便になるだけという見方もできますが。)

自分が最近使った例としては、欠損値がある行列Aと別の行列Bがあった時に、欠損値以外は元の行列Aの値、欠損してる部分はBの値、で埋めたいというものでした。

これが次のようにして簡単に行えます。

A = np.array(
        [[1, 2, 3,], [np.nan, 5, 6], [7, np.nan, 9]]
    )
B = np.array(
        [[11, 12, 13,], [14, 15, 16], [17, 18, 19]]
    )

print(np.where(np.isnan(A), B, A))
"""
[[ 1.  2.  3.]
 [14.  5.  6.]
 [ 7. 18.  9.]]
"""

コードがすっきり書けること以外にもベクトル処理が行えることによるパフォーマンス面のメリットなど、利点があるので機会があれば使ってみてください。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です