library(ISLR)
library(MASS)
library(tidyverse)
Supervised learning: regression in R
Introduction
In this practical, you will learn how to perform regression analysis in R, how to create predictions, how to plot with confidence and prediction intervals, how to calculate MSE, perform train-test splits, and write a function for cross validation.
Just like in the practical at the end of chapter 3 of the ISLR book, we will use the Boston
dataset, which is in the MASS
package that comes with R
.
Make sure to load MASS
before tidyverse
otherwise the function MASS::select()
will overwrite dplyr::select()
Regression in R
Regression is performed through the lm()
function. It requires two arguments: a formula
and data
. A formula
is a specific type of object that can be constructed like so:
<- outcome ~ predictor_1 + predictor_2 some_formula
You can read it as “the outcome variable is a function of predictors 1 and 2”. As with other objects, you can check its class and even convert it to other classes, such as a character vector:
class(some_formula)
as.character(some_formula)
You can estimate a linear model using lm()
by specifying the outcome variable and the predictors in a formula and by inputting the dataset these variables should be taken from.
You have now trained a regression model with medv
(housing value) as the outcome/dependent variable and lstat
(socio-economic status) as the predictor / independent variable.
Remember that a regression estimates \(\beta_0\) (the intercept) and \(\beta_1\) (the slope) in the following equation:
\[\boldsymbol{y} = \beta_0 + \beta_1\cdot \boldsymbol{x}_1 + \boldsymbol{\epsilon}\]
We now have a model object lm_ses
that represents the formula
\[\text{medv}_i = 34.55 - 0.95 * \text{lstat}_i + \epsilon_i\]
With this object, we can predict a new medv
value by inputting its lstat
value. The predict()
method enables us to do this for the lstat
values in the original dataset.
We can also generate predictions from new data using the newdata
argument in the predict()
method. For that, we need to prepare a data frame with new values for the original predictors.
Plotting lm() in ggplot
A good way of understanding your model is by visualizing it. We are going to walk through the construction of a plot with a fit line and prediction / confidence intervals from an lm
object.
Now we’re going to add a prediction line to this plot.
Mean squared error
You have now calculated the mean squared length of the dashed lines below.
You have calculated the mean squared length of the dashed lines in the plot below.
Train-validation-test split
Now we will use the sample()
function to randomly select observations from the Boston
dataset to go into a training, test, and validation set. The training set will be used to fit our model, the validation set will be used to calculate the out-of sample prediction error during model building (this can also be called hyperparameter tuning), and the test set will be used to estimate the true out-of-sample MSE.
We will set aside the boston_test
dataset for now.
This is the estimated out-of-sample mean squared error.
OPTIONAL Programming exercise: cross-validation
This is an advanced exercise. Some components we have seen before in this and previous practicals, but some things will be completely new. Try to complete it by yourself, but don’t worry if you get stuck. If you don’t know about for loops
in R
, read up on those before you start the exercise.
Use help in this order:
- R help files
- Internet search & stack exchange
- Your peers
- The answer, which shows one solution
You may also just read the answer and try to understand what happens in each step.