なんとなくドキュメントを眺めていたら、groupby().transform()っていう便利そうな関数を見つけたのでその紹介です。
DataFrameのgroupbyといえば、指定した列をキーとしてグループごとの合計や平均、分散、個数などの集計を行うことができる関数です。
通常は、集計したキーの数=グループの数の行数のDataFrameを戻り値として返してきます。
import pandas as pd
df = pd.DataFrame(
{
"category": ["A", "A", "A", "B", "B"],
"amount": [100, 300, 100, 200, 200],
}
)
print(df)
"""
category amount
0 A 100
1 A 300
2 A 100
3 B 200
4 B 200
"""
print(df.groupby("category").sum())
"""
category
A 500
B 400
"""
ここで、この groupby して得られた集計値を、元のDataFrameの各業に展開したいことがあります。
そのような場合、僕はpd.mergeでデータフレームを結合するか、辞書形式に変換して結合することが多かったです。
例えば以下のようなコードになります。
# mergeで結合する場合
group_df = df.groupby("category").sum()
group_df.reset_index(inplace=True)
group_df.rename(columns={"amount": "category_amount"}, inplace=True)
print(pd.merge(df, group_df, on="category", how="left"))
"""
category amount category_amount
0 A 100 500
1 A 300 500
2 A 100 500
3 B 200 400
4 B 200 400
"""
# 辞書を作ってマッピングする場合
group_df = df.groupby("category").sum()
sum_dict = group_df.to_dict()["amount"]
print(sum_dict)
# {'A': 500, 'B': 400}
df["category_amount"] = df["category"].apply(sum_dict.get)
print(df)
"""
category amount category_amount
0 A 100 500
1 A 300 500
2 A 100 500
3 B 200 400
4 B 200 400
"""
書いてみるとこれらの手順を踏んでもそんなに複雑ではないのですが、やっぱり一発でできるともっと便利です。
そこで使えるのが、冒頭で紹介した、transformです。
参考: pandas.core.groupby.DataFrameGroupBy.transform
これは元のデータフレームと同じインデックスを持つデータフレームとして、GroupByの結果を返してくれます。ちょっとやってみます。
df = pd.DataFrame(
{
"category": ["A", "A", "B", "B", "B"],
"amount": [100, 300, 100, 200, 200],
}
)
# 元のDataFrameと同じ行数で、対応する行の"category"列の値が含まれるグループの合計を返す
print(df.groupby("category").transform("sum"))
"""
amount
0 400
1 400
2 500
3 500
4 500
"""
# 元のDataFrameに合計値を付与したい場合は次のようにできる
df["category_amount"] = df.groupby("category").transform("sum")["amount"]
print(df)
"""
category amount category_amount
0 A 100 400
1 A 300 400
2 B 100 500
3 B 200 500
4 B 200 500
"""
1行で済みましたね。
この新しく作った列を使えば、一定件数以下しか存在しないカテゴリの行を削除するとか、カテゴリごとにそれぞれの要素のカテゴリ内で占めてる割合を計算するとか、それぞれの要素のカテゴリごとの平均との差異を求めるとかそういった計算が非常に容易にできるようになります。
そしてさらに、このtransform とlambda関数を組み合わせて使うと、カテゴリの平均との差を一発で出す、といったこともできます。
df = pd.DataFrame(
{
"category": ["A", "A", "B", "B", "B"],
"amount": [100, 300, 100, 200, 200],
}
)
print(df.groupby("category").transform(lambda x: x-x.mean()))
"""
amount
0 -100.000000
1 100.000000
2 -66.666667
3 33.333333
4 33.333333
"""
lambda 関数に渡されている x はそれぞれの行の値のように振る舞ってくれるにもかかわらず、同時に x.mean() でグループごとの平均を出すこともでき、その差分を元のDataFrameとインデックスを揃えて返してくれています。
これは使いこなせば相当便利なメソッドになりそうです。