library(MASS)
library(class)
library(ISLR)
library(tidyverse)
Supervised learning: classification
Introduction
In this practical, we will learn to use three different classification methods: K-nearest neighbours, logistic regression, and linear discriminant analysis.
One of the packages we are going to use is class
. For this, you will probably need to install.packages("class")
before running the library()
functions.
Make sure to load MASS
before tidyverse
otherwise the function MASS::select()
will overwrite dplyr::select()
Default dataset
The default dataset contains credit card loan data for 10 000 people. The goal is to classify credit card cases as yes
or no
based on whether they will default on their loan.
K-Nearest Neighbours
Now that we have explored the dataset, we can start on the task of classification. We can imagine a credit card company wanting to predict whether a customer will default on the loan so they can take steps to prevent this from happening.
The first method we will be using is k-nearest neighbours (KNN). It classifies datapoints based on a majority vote of the k points closest to it. In R
, the class
package contains a knn()
function to perform knn.
Confusion matrix
The confusion matrix is an insightful summary of the plots we have made and the correct and incorrect classifications therein. A confusion matrix can be made in R
with the table()
function by entering two factor
or character
vectors:
table(
observed = default_test$default,
predicted = knn_2_pred
)
predicted
observed No Yes
No 1899 31
Yes 55 15
Logistic regression
KNN directly predicts the class of a new observation using a majority vote of the existing observations closest to it. In contrast to this, logistic regression predicts the log-odds
of belonging to category 1. These log-odds can then be transformed to probabilities by performing an inverse logit transform:
\[ p = \frac{1}{1+e^{-\alpha}}\], where \(\alpha\) indicates log-odds for being in class 1 and \(p\) is the probability.
Therefore, logistic regression is a probabilistic
classifier as opposed to a direct
classifier such as KNN: indirectly, it outputs a probability which can then be used in conjunction with a cutoff (usually 0.5) to classify new observations.
Logistic regression in R
happens with the glm()
function, which stands for generalized linear model. Here we have to indicate that the residuals are modeled not as a gaussian (normal distribution), but as a binomial
distribution.
Now we have generated a model, we can use the predict()
method to output the estimated probabilities for each point in the training dataset. By default predict
outputs the log-odds, but we can transform it back using the inverse logit function of before or setting the argument type = "response"
within the predict function.
Another advantage of logistic regression is that we get coefficients we can interpret.
Visualising the effect of the balance variable
In two steps, we will visualise the effect balance
has on the predicted default probability.
Linear discriminant analysis
The last method we will use is LDA, using the lda()
function from the MASS
package.