Pythonのリストをn個に分割する

めったに使わないのですが、前回の記事がリストをn個ずつに分割するだったので今回はリストをn個のリストに分割する方法を紹介します。
ちなみに、目的が機械学習のクロスバリデーションであれば、scikit-learnに専用のメソッドがあるのでそちらを使いましょう。
今回の記事はそれ以外の用途で、何かしらの事情があってリストをn分割する必要が発生した時に使います。

さて、まず簡単に思いつくのは前回の記事同様にリストのスライスを使う方法です。
元のデータのサイズをnで割って区切り位置を決め、その位置で区切ります。
コードにすると次のようになりますね。
例として、サイズが23のデータを5分割しています。
途中、スライスする位置をintで整数に丸めているのは、単にスライスの表記が整数しか受け付けないからです。

# サンプルのデータ生成
data = list(range(23))
print(data)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
print(len(data))
# 23

# 5分割する
n = 5
size = len(data)

for i in range(n):
    start = int((size*i/n))
    end = int((size*(i+1)/n))
    print(data[start:end])
"""
[0, 1, 2, 3]
[4, 5, 6, 7, 8]
[9, 10, 11, 12]
[13, 14, 15, 16, 17]
[18, 19, 20, 21, 22]
"""

特に何も変哲のないコードですし、無事にリストが5分割されました。

ただ、一点気持悪いというか少なくとも僕の好みには合わない点があります。
それが分割結果の各リストのサイズです。数えてみると、4個、5個、4個、5個、5個、となっています。23が5で割り切れないので、数が不揃いになるのは仕方ないのですが、個人的には、4/4/5/5/5 か、 5/5/5/4/4 のどちらかで切りたいです。

しかし、これを実装するのはそこそこ手間がかかります。元のデータ長を分けたいグループ数で整数除算し、商とを余を求めて分割後の各グループに属する要素数を求め、その要素数から区切り位置を決め、その位置で切る手順をコードに起こす必要があるからです。
やってみたのが次のコードです。(確認用のprint文や説明のコメントのせいで余計に面倒なコードに見えてしまっていますね。)

import numpy as np


# data は上のコード例と同じものを使う。
data = list(range(23))
n = 5
size = len(data)

# データの件数を分けたいグループ数で割って商と余りを求める
quotient, remainder = divmod(size, n)
print("商:", quotient)
# 商: 4
print("余り:", remainder)
# 余り: 3

# [0] に続けて各グループの要素数を指定するリストを作る
section_sizes = ([0] + remainder * [quotient+1] + (n-remainder) * [quotient])
print(section_sizes)
# [0, 5, 5, 5, 4, 4]

# 累積和をとって、スライスする点のリストにする
slice_points = list(np.cumsum(section_sizes))
print(slice_points)
# [0, 5, 10, 15, 19, 23]

# 作成したスライス位置を使ってリストを切る
for i in range(n):
    start = slice_points[i]
    end = slice_points[i+1]
    print(data[start:end])

"""
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[10, 11, 12, 13, 14]
[15, 16, 17, 18]
[19, 20, 21, 22]
"""

はい、これで、5個、5個、5個、4個、4個、に区切れましたね。

途中累積和を取るためにnumpyをインポートしてcumsumまで使っています。
ただ、どうせnumpyを使うことになるのであれば、実はnumpyに専用のメソッドが用意されているので断然そちらがお勧めです。

参考: numpy.array_split

numpyのarray用に実装されたメソッドだと思いますが、ただのlistに対しても動作してくれます。これを使うと、たったこれだけのコードになります。

data = list(range(23))
n = 5
print(np.array_split(data, n))
"""
[array([0, 1, 2, 3, 4]),
 array([5, 6, 7, 8, 9]),
 array([10, 11, 12, 13, 14]),
 array([15, 16, 17, 18]),
 array([19, 20, 21, 22])]
"""

めっちゃ簡単ですね。メソッドの戻り値はn分割した各グループのリストになります。
分割された各グループは array 型に変換されるのでその点だけ注意してください。
元のデータがarray型でなくても結果はarray型になります。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です