NumPyの多次元配列から値が大きいn個のインデックスを取得する

NumPyの多次元配列(といってもここで扱うのは2次元の行列ですが)の要素の中から、値が大きい順に数個取り出して、そのインデックスが欲しいことがありました。
これが多次元ではなく、1次元の配列だったら、argsort使えば一発です。
参考: numpyのarrayを並び替えた結果をインデックスで取得する

多次元になるとargsordだけではうまく動きません。
まともに実装すると、やや手間だったのですが、unravel_index というメソッドを使うとうまく書けたので紹介します。

とりあえず参考に使うデータを用意しておきます。実際はもっと巨大なデータでやったのですが、この記事では結果を確認しやすいように 5*5の行列の25個の要素から上位5要素のインデックスを取得することを目指します。
まずデータを作っておきます。

import numpy as np


np.random.seed(4)
ary = np.random.randint(10, 100, size=(5, 5))
print(ary)
"""
[[56 65 79 11 97]
 [82 60 19 68 65]
 [65 67 46 60 54]
 [48 62 13 10 65]
 [31 31 83 48 66]]
"""

このデータから値が大きい順に5このインデックスを取得することを目指します。もし、欲しいのがインデックスではなく値であればこれは簡単です。ravel() か flatten()、reshape()あたりで1次元に変形してソートするだけです。こんな感じに。

np.sort(ary.ravel())[-5:]
# array([68, 79, 82, 83, 97])

今回の用件では欲しいのは、これらの値が入ってたインデックスです。つまり、
[[1, 3], [0, 2], [1, 0], [4, 2], [0, 4]] を求めたいのです。

単純に argsortすると、行ごとにその行内でソートして結果を返してくれます。
また、axis引数にNoneを指定すると行列全体でソートしてくれるのですが結果が、ravel等で1次元に直した後にargsortしたようなイメージで帰ってきます。

print(np.argsort(ary))
"""
[[3 0 1 2 4]
 [2 1 4 3 0]
 [2 4 3 0 1]
 [3 2 0 1 4]
 [0 1 3 4 2]]
"""

print(np.argsort(ary, axis=None))
"""
array([18,  3, 17,  7, 21, 20, 12, 15, 23, 14,  0, 13,  6, 16,  9, 19,  1,
       10, 24, 11,  8,  2,  5, 22,  4])
"""

axis=Noneの方の結果(小さい順なので最後の5つ、[8, 2, 5, 22, 4]が求める5つの数のインデックス)を元の(5*5)の配列に読み替えると欲しい結果が得られることになります。
1次元でインデックスが8ってことは9番目の要素で、5*5行列の9個目の要素は[1, 4]だ、同様に2は[0, 2]、5は[1, 0] と読み替えていくわけですね。割り算と余りを使って実装できそうです。やってみると次のようになります。

for i in np.argsort(ary, axis=None)[-5:]:
    print([i//5, i%5])

"""
[1, 3]
[0, 2]
[1, 0]
[4, 2]
[0, 4]
"""

上記の変換をあらかじめ用意されたメソッドで行うのが、この記事冒頭で名前を出した unravel_index です。
ドキュメント: numpy.unravel_index — NumPy v1.22 Manual

動かしてみます。

# 以下の引数を渡していることに注意して見てください。
# np.argsort(ary, axis=None)[-5:] = [ 8  2  5 22  4]
# ary.shape = (5, 5)

print(np.unravel_index(np.argsort(ary, axis=None)[-5:], ary.shape))
# (array([1, 0, 1, 4, 0]), array([3, 2, 0, 2, 4]))

かなり近いところまで来ましたね。最終的な結果が二つのarrayからなるタプルになっていますが、それぞれから1個ずつ値を取り出してペアにすれば欲しい結果になっています。

人間が見るには微妙に扱いにくい形に見えるのですが、次のように値を取り出すのに使えます。

index_list = np.unravel_index(np.argsort(ary, axis=None)[-5:], ary.shape)
print(ary[index_list])
# [68 79 82 83 97]

今回はインデックスの方が欲しかったので、これを人が見ても見やすい形にするため、vstackして、さらに転置(.T)しても良いでしょう。

print(np.vstack(index_list).T)
"""
[[1 3]
 [0 2]
 [1 0]
 [4 2]
 [0 4]]
"""
# 以下のように書けば1行で可能。
print(np.vstack(np.unravel_index(np.argsort(ary, axis=None)[-5:], ary.shape)).T)

以上で欲しかった結果が得られました。unravel_indexがちょっと馴染みが薄かったのでまだ慣れないのですが、個人的にはこれが一番いいのではないかと思います。

大き方から順ではなく、小さい方からとたい、って時は、argsortに対するスライスの部分を少し変えるだけで対応できます。([-5:]ではなく[:5]に。)

ちなみに、argsort/ unravel_index/ vstack など少々マイナーなメソッドをいくつも使うのが嫌な場合は次のような方法もあります。多次元の配列を、インデックスと値のペアのようなデータ構造に変換してソートするやり方です。

# インデックスと値を組にしたデータにする。
lil_ary = [(i, j, ary[i, j]) for i in range(ary.shape[0]) for j in range(ary.shape[1])]
print(lil_ary[:5])  # 中身を一部確認
# [(0, 0, 56), (0, 1, 65), (0, 2, 79), (0, 3, 11), (0, 4, 97)]

# 3番目の要素(indexは2)をキーにしてソート
lil_ary.sort(key=lambda x: x[2])
# 最後の5項目を取る
print(lil_ary[-5:])
# [(1, 3, 68), (0, 2, 79), (1, 0, 82), (4, 2, 83), (0, 4, 97)]

あとはこれの、タプルのそれぞれの先頭2要素を取り出せば欲しかったインデックスです。
何をやっているのかはこちらの方がわかりやすいかも、と思っていますがこのやり方はメモリ効率などの観点でデメリットもあるので巨大なデータへの適用はお勧めしません。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です