The Normal Equation for Linear Regression in Matrix Form

In this tutorial I will go through an simple example implementing the normal equation for linear regression in matrix form. The iPython notebook I used to generate this post can be found on Github.

\text{Normal Equation: } \theta = (X^T X)^{-1}X^T \vec{y}

The primary focus of this post is to illustrate how to implement the normal equation without getting bogged down with a complex data set. To that end I have chosen a simple, albeit contrived, dataset.

Prologue

Before we get our hands dirty we must first import numpy, pandas, seaborn, and matplotlib.

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

Setting and Prepping our Data

First we will create a pandas DataFrame and populate it with our data. This sample data will have only 3 examples, each consisting of one feature x0, and a corresponding target y. Note: Pandas support IO for a wide collection of file types so you are free to read any data set you’d like from an external source.

trainingData = pd.DataFrame(data=[[1,1], [2,2], [4,4]], columns=['x1', 'y'])
x1 y
0 1 1
1 2 2
2 4 4

Next we will append the column of 1’s (x0) to matrix using the numpy ones function, and the pandas column insert.

 trainingData.insert(0, 'x0', np.ones(3))
x0 x1 y
0 1 1 1
1 1 2 2
2 1 4 4

Exploratory Plot

Now that we have the data lets plot (using seaborn) it to get some intuition for what hypothesis function might be.

with sns.axes_style("darkgrid"):
    g = sns.lmplot('x1', 'y', trainingData[['x1', 'y']], markers='o', fit_reg=False)
    g.set(ylim=(0, None))
    g.set(xlim=(0, None))

p1

From the above plot it is fairly easy to see that the hypothesis should be linear, and of the form:
h_\theta(x) = \theta_0 + \theta_1x_1

Furthermore its easy to see that the hypothesis is actually:
h_\theta(x) = x

And our thetas are:
\theta_0 = 0, \theta_1 = 1

Now let’s use the normal equation to confirm our belief. To begin we constructing the design matrix X and the target vector y.

X = trainingData[['x0', 'x1']]
y = trainingData[['y']]

Next we transpose X, using a shorthand (T) for the pandas transpose method. Since we are transposing a 3×2 matrix we can expect to end up with a 2×3 matrix as a result.

X^T

X.T
0 1 2
x0 1 1 1
x1 1 2 4

Applying the Normal Equation

Next we calculate X transpose multiplied by X. Since we are doing matrix multiplication, as opposed to scalar, we will need to use the pandas DataFrame.dot() function function.

(X^TX)

xTx = X.T.dot(X)
x0 x1
x0 3 7
x1 7 21

We then take the inverse of our product using the numpy inverse function.
(X^TX)^{-1}

XtX = np.linalg.inv(xTx)
array([[ 1.5       , -0.5       ],
       [-0.5       ,  0.21428571]])

 

We multiply the inverse by the transpose of x, which we previously calculated.
(X^TX)^{-1} X^T

XtX_xT = XtX.dot(X.T)
array([[ 1.        ,  0.5       , -0.5       ],
       [-0.28571429, -0.07142857,  0.35714286]])

 

Finally we multiply the previous result by our target vector.
(X^TX)^{-1} X^T\vec{y}

theta = XtX_xT.dot(y)
array([[ 0.],
       [ 1.]])

\theta = (0, 1)

Final Confirmation Plot

The normal equation had confirmed our initial guess that the function was h_{\theta}(x) = \theta_0 + \theta x = x. Finally we visualize the hypothesis with a confirmation plot.

# generate the y axis for the hypothesis function
hypothesis = [(x, theta[0] + x*theta[1]) for x in range(6)]

with sns.axes_style("darkgrid"):
    fig, ax = plt.subplots()
    ax.set_title('Linear Regression with the Normal Equation')
    ax.plot(trainingData['x1'], trainingData['y'], 'o', label = 'data')
    ax.plot([x for x in range(6)], hypothesis, 'k-', label = 'hypothesis')
    ax.legend(['Data', 'Hypothesis'], loc='best')
    ax.set(ylim=(0, 5))
    ax.set(xlim=(0, 5))

 

p2

And we’re done!

You may also like...

1 Response

  1. Satvik Tiwari says:

    A shorter version:

    from numpy.linalg import inv
    def NormalEquation(theta,X,y):
    theta=inv(X.T@X)@X.T@y
    return theta

Leave a Reply

Your email address will not be published. Required fields are marked *