Numpyだけで重回帰分析

興味本位でNumPyの多項式回帰(polyfit)のソースコードを読んでいたのですが、
その中でNumPyにも重回帰分析のメソッドが用意されていて使われているのを見つけました。
てっきり重回帰分析は、scikit-learnかstatsmodelsを使うか、もしくはNumpyでやるならスクラッチ実装しないといけないと思い込んでいたので意外でした。
使ってみるとかなり手軽に使えたのでこの記事で紹介します。

ちなみに多項式回帰については既に記事を書いているのでご参照ください。
参考記事: NumPyで多項式回帰

紹介する関数はこちらです。
numpy.linalg.lstsq
(正直この名前はドキュメントでは探しにくい。おそらく、least-squaresの略語です。)

とりあえず、使ってみましょう。
ダミーデータとして、
$$
y = 3x_0 -2 x_1 + 5x_2 + \varepsilon
$$
のデータを作っておきます。$\varepsilon$はノイズです。


import numpy as np
X = np.random.randn(100, 3)
y = X@np.array([[3], [-2], [5]]) + np.random.randn(100).reshape(100, 1)

さて、早速lstsqを使ってみます。
使い方は簡単で、先ほどのXとyを渡してあげて、あと、rcondという引数を指定するだけです。
rcondは指定しないとwarningが出ますが、指定しなくても動きます。
小さい特異値を切り捨てる割合を指定する方法で、Noneか-1か指定しておけば良さそうです。


print(np.linalg.lstsq(X, y, rcond=None))
"""
(array([[ 2.97422787],
       [-2.01082975],
       [ 4.81883873]]),
       array([112.27261964]),
       3,
       array([10.54157065,  9.62381814,  8.55906245]))
"""

さて、ご覧の通り、結構いろいろな値がタプルで戻ってきました。
最初のArrayの
array([[ 2.97422787],
[-2.01082975],
[ 4.81883873]]),
の部分が推定した係数です。正解の 3, -2, 5 に近い値になっているのがわかります。
次の、[112.27261964]の値は残差の平方和です。
そして、 3 は Xのrank、次の array([10.54157065, 9.62381814, 8.55906245]) は Xの特異値です。

残渣平方和が 112.27261964 になるのは計算してみておきましょう。


coef, rss, rank, s = np.linalg.lstsq(X, y, rcond=None)
# 予測値を計算
p = X@coef
# 残差の平方和を計算
print(((y-p)**2).sum())
# 112.2726196447206

バッチリですね。

ここまでの流れで、気付いた方もいらっしゃると思いますが、lstsqで重回帰分析すると、定数項が出てきません。
定数項を含めて重回帰分析するには、Xに値が全部1になる列を追加して、それを渡す必要があります。
Numpyだけでもできますが、 statsmodels の add_constant あたりを使ってもいいでしょう。

例えば、
$$
y = 3x_0 -2 x_1 + 5x_2 + 4 + \varepsilon
$$
をダミーデータを作って、回帰分析するとこまでやると次のようになります。


import numpy as np
import statsmodels.api as sm

# ダミーデータ生成
X = np.random.randn(100, 3)
y = X@np.array([[3], [-2], [5]]) + 4 + np.random.randn(100).reshape(100, 1)

# Xに定数項を追加したデータを生成
X_add_const = sm.add_constant(X)

# 回帰分析
coef, rss, rank, s = np.linalg.lstsq(X_add_const, y, rcond=None)

# 推定した係数を表示
print(coef)
"""
[[ 4.01847308]
 [ 3.02763718]
 [-1.98363017]
 [ 4.99985177]]
"""

最初の 4.018…が定数項で残りが係数です。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です