筋肉で解決しないために。

日々出会うモノに対する考察をしたり、主に以下のテーマに関して書いています。 データサイエンス/人工知能/AI/機械学習/DeepLearning/Python//数学/統計学/統計処理

多重回帰分析の実装、そして儚さ。

こんにちは、ワタルです。

はじめに

今回は、多重回帰分析について、理論的な仕組みを理解できるよう解説しつつ、実装してみようと思います。 さらに、その実装がseabornであれば1行で終わってしまう儚さを体験しつつ、seabornにおける多重回帰分析のメソッドを学びたいと思います。

数学的に厳密には正しくない説明になっている箇所もあるかと思います。説明に違和感を覚える方がいましたら是非勉強させていただきますので、コメントを頂けましたらと思います。

目的

今回は、以下のデータ群に対して、基底関数として3次項までの多項式基底を用い、推定モデルを算出することが目的になります。

f:id:watarumon:20180628202613p:plain

線形回帰とは

計算に入る前に簡単に線形回帰について解説してみたいと思います。

まず、回帰とは、与えられたデータに適した関数を求める手法のことで、線形回帰とは、回帰分析の中の1つの手法のことを指します。

線形回帰の中でも、基底関数に多項式基底を用いた手法を多重回帰分析と呼びます。

ちなみに、基底関数に1次元項までの場合、単回帰分析と呼びます。

多重回帰分析の実装

さあ実装していきます。

まずはおまじないです。

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

データ群を確認します。

X = np.array([0.02, 0.12, 0.19, 0.27, 0.42, 0.51, 0.64, 0.84, 0.88, 0.99])
t = np.array([0.05, 0.87, 0.94, 0.92, 0.54, -0.11, -0.78, -0.89, -0.79, -0.04])
plt.scatter(X,t)

f:id:watarumon:20180628202613p:plain

天下り的ですが、ここで、3次関数っぽいなという感覚から今回は、3次多項式基底を用いることを思いつきます。 簡単のため、正規化項や過学習については、他のwebページに譲ることとします。

思いついたモデルは、係数w_0,w_1,w_2,w_3が4×1の行列Wとすると、その解tの行列Tを用いて、このように記述することができます。

XW = T

ここで、Xの要素を求めるphi(x)を求めます。

def phi(x):
    return [1, x, x**2, x**3]
PHI = np.array([phi(x) for x in X])

次に、二乗誤差を算出するために、Xの転置行列で左から積を取り、 さらに左からその逆行列の積を取ることで、Wを求めます。

W = (X^{T}X)^{-1}X^{T}T

w = np.linalg.solve(np.dot(PHI.T, PHI), np.dot(PHI.T, t))

最後にプロットします。

def f(w, x):
    return np.dot(w, phi(x))
xlist = np.arange(0, 1, 0.01)
ylist = [f(w, x) for x in xlist]
plt.plot(xlist, ylist)
plt.plot(X, t, 'o')

f:id:watarumon:20180628202719p:plain

ここまでで、よくできた重回帰モデルができました。

seabornを用いて多重回帰分析の実装

次にseabornを用いて、ここまでの計算を1から行ってみます。

おまじないは同じなので省略します。

X = np.array([0.02, 0.12, 0.19, 0.27, 0.42, 0.51, 0.64, 0.84, 0.88, 0.99])
t = np.array([0.05, 0.87, 0.94, 0.92, 0.54, -0.11, -0.78, -0.89, -0.79, -0.04])
sns.regplot(X,t,order=3)

f:id:watarumon:20180628202733p:plain

終わってしまいました。。。

このように、regplotに引数(order=3)をしたことで、3次項までの重回帰分析を行うことができます。

終わりに

今回は、多重回帰分析について、理論的な仕組みを理解しつつ、seabornにおける多重回帰分析のメソッドを学びました。

この様に、seabornやscikit-learnを用いると、あっという間に計算出来てしまいますが、一度くらいはせこせこと実装してみると勉強になると思います。