library(tidyverse)
library(tidymodels)
library(GGally)
= read.table("http://users.stat.ufl.edu/~rrandles/sta4210/Rclassnotes/data/textdatasets/KutnerData/Chapter%2014%20Data%20Sets/CH14TA03.txt")
dat
names(dat) = c("id", "X1", "X2", "X3", "X4", "Y")
= dat |> mutate(Y = relevel(as.factor(Y), ref="1"))
dat
ggpairs(dat[,-1])
24 Logistic Regression in Tidymodels
“In God we trust, all others bring data.” - W Edwards Deming
24.1 Multiple Logistic Regression
In the last chapter, we introduced logistic regression using a simple (one predictor) model. The simple logistic regression model is easily extended to more than one predictor variable.
In fact, several predictor variables are usually required with logistic regression to obtain adequate description and useful predictions.
In matrix notation, the logistic response function becomes \[ E\{Y\} = \frac{\exp{\left(\textbf{X}^\prime\boldsymbol{\beta}\right)}}{1+\exp{\left(\textbf{X}^\prime\boldsymbol{\beta}\right)}} \]
Like the simple logistic response function, the multiple logistic-response function is monotonic and sigmoidal in shape with respect to \(\textbf{X}^\prime\boldsymbol{\beta}\) and is almost linear when \(\pi\) is between .2 and .8.
The \(X\) variables may be different predictor variables, or some may represent curvature and/or interaction effects.
Also, the predictor variables may be quantitative, or they may be qualitative and represented by indicator variables.
This flexibility makes the multiple logistic regression model very useful.
24.2 Example Using Tidymodels
In a health study to investigate an epidemic outbreak of a disease that is spread by mosquitoes, individuals were randomly sampled within two sectors in a city to determine if the person had recently contracted the disease under study.
This was ascertained by the interviewer, who asked pertinent questions to assess whether certain specific symptoms associated with the disease were present during the specified period.
The response variable \(y\) was coded 1 if this disease was determined to have been present, and 0 if not.
Three predictor variables were included in the study, representing known or potential risk factors.
They are age, socioeconomic status of household, and sector within city.
Age (\(x_1\)) is a quantitative variable. Socioeconomic status is a categorical variable with three levels. It is represented by two indicator variables (\(x_2\) and \(x_3\)), as follows: \[ \begin{align*} Class &\quad x_2 & x_3\\ Upper & \quad 0 & 0\\ Middle & \quad 1 & 0\\ Lower & \quad 0 & 1 \end{align*} \]
City sector is also a categorical variable. Since there were only two sectors in the study, one indicator variable (\(x_4\)) was used, defined so that \(x_4 = 0\) for sector 1 and \(x_4 = 1\) for sector 2.
We will first convert the response variable to a factor and then look at the scatterplot matrix.
We can now setup the recipe, model, and workflow to fit the model.
= recipe(Y~X1+X2+X3+X4, data = dat)
dat_recipe
= logistic_reg() |>
model set_engine("glm")
= workflow() |>
wf add_recipe(dat_recipe) |>
add_model(model)
= wf |> fit(data=dat)
fit
|> tidy() fit
# A tibble: 5 × 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) 2.31 0.643 3.60 0.000319
2 X1 -0.0298 0.0135 -2.20 0.0276
3 X2 -0.409 0.599 -0.682 0.495
4 X3 0.305 0.604 0.505 0.613
5 X4 -1.57 0.502 -3.14 0.00169
The tidy()
function returns the estimated values of the coefficients. Recall that we want to see the exponentials of the coefficients.
|> tidy(exponentiate=T) fit
# A tibble: 5 × 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 (Intercept) 10.1 0.643 3.60 0.000319
2 X1 0.971 0.0135 -2.20 0.0276
3 X2 0.664 0.599 -0.682 0.495
4 X3 1.36 0.604 0.505 0.613
5 X4 0.207 0.502 -3.14 0.00169
When we increase X1 (age) by one unit, we expect a 3% decrease in the odds of the disease being present, given the other predictors are held constant.
For X2 being a 1 (middle class), we expect a 34% decrease in the odds of the disease being present, given the other predictors are held constant.
For X3 being a 1 (lower class), we expect a 36% increase in the odds of the disease being present, given the other predictors are held constant.
For X4 being a 1 (sector 2), we expect 79% decrease in the odds of the disease being present, given the other predictors are held constant.
With the predict function, we can obtain the predicted probabilities of the fitted model.
24.3 Predictions
= predict(fit,
pred_prob new_data = dat,
type = "prob")
pred_prob
# A tibble: 98 × 2
.pred_1 .pred_0
<dbl> <dbl>
1 0.209 0.791
2 0.219 0.781
3 0.106 0.894
4 0.371 0.629
5 0.111 0.889
6 0.136 0.864
7 0.0802 0.920
8 0.273 0.727
9 0.244 0.756
10 0.309 0.691
# ℹ 88 more rows
If we classify each observation based on which in the most likely prediction, we can obtain the predicted outcomes.
= predict(fit,
pred_class new_data = dat,
type = "class")
pred_class
# A tibble: 98 × 1
.pred_class
<fct>
1 0
2 0
3 0
4 0
5 0
6 0
7 0
8 0
9 0
10 0
# ℹ 88 more rows
The value of the linear model can be obtained from the predict
function. This value is what is obtained when the values of the predictor variables are substituted into the linear part of the model \[
\hat{\beta}_0+\hat{\beta}_1 x_1 + \hat{\beta}_2 x_2 + \hat{\beta}_3 x_3 + \hat{\beta}_4 x_4
\]
= predict(fit,
pred_raw new_data = dat,
type = "raw")
pred_raw
1 2 3 4 5 6
1.331181765 1.271681581 2.134434261 0.527929270 2.082687707 1.844686968
7 8 9 10 11 12
2.439688816 0.981891712 1.130642174 0.803391157 1.628682690 1.628682690
13 14 15 16 17 18
1.509682320 2.045183984 -0.362567837 0.083683549 -1.255070610 0.500184843
19 20 21 22 23 24
0.559685028 0.291934196 -0.295356597 -0.622607613 0.262184104 0.002144328
25 26 27 28 29 30
0.626938844 0.061644512 -0.206106319 2.558689186 0.803433733 2.023187522
31 32 33 34 35 36
2.142187892 1.636393745 0.862891342 2.193934446 -0.265563928 0.716189122
37 38 39 40 41 42
0.240145067 0.864939584 -0.176356227 1.013690046 -1.247359554 0.351434381
43 44 45 46 47 48
0.024183365 -1.344320887 0.567438660 -0.027605765 0.448438290 -0.890315869
49 50 51 52 53 54
-0.860608353 -0.801108168 -1.693610940 -1.463321256 0.321684289 -0.065066913
55 56 57 58 59 60
-0.184067282 0.202683919 -0.422068022 -0.749319038 -0.184067282 -1.076570055
61 62 63 64 65 66
2.082687707 2.469438909 2.558689186 2.142187892 0.862933918 1.963687338
67 68 69 70 71 72
1.598932597 1.420432043 0.944430564 1.479932227 1.509682320 1.509682320
73 74 75 76 77 78
1.479932227 0.765930009 2.290938354 1.725644022 1.249685119 1.309142728
79 80 81 82 83 84
2.223684538 1.368642913 1.160392266 2.439688816 0.684433363 1.100934657
85 86 87 88 89 90
0.743890972 2.074934076 1.666143837 2.201688077 2.439688816 2.439688816
91 92 93 94 95 96
2.409938724 2.499189001 2.380188631 1.636393745 1.666186413 2.052937615
97 98
2.290938354 1.576936136
These raw values can be used visualize the predicted probabilities and the actual values of the response variable.
= dat |>
results select(Y) |>
bind_cols(pred_raw, pred_prob, pred_class)
|>
results ggplot(aes(x = pred_raw, y = .pred_1, col = Y))+
geom_point()+
geom_hline(yintercept = 0.5)
A line at 0.5 represents the value for which we would classify the predicted probabilities. Values above 0.5 would be classified as a 1 (disease present) and values below 0.5 would be classified as a 0 (disease not present). The color of the dots are the actual values of the response variable. If the model was able to predict the response variable perfectly, then all of the values above 0.5 would be blue and all the values below 0.5 would be red. Since there are some red above the line and some blue below the line, then we have missclassification.
24.4 Confusion Matrix
To get a better idea of which is being missclassified, we can look at the confusion matrix.
conf_mat(results, truth = Y,
estimate = .pred_class)
Truth
Prediction 1 0
1 12 9
0 19 58
We see from the confusion matrix that there are 9 observations that are predicted to have the disease but really did not (false positives). There were also 19 observations that were predicted to not have the disease but really did (false negatives).
From this confusion matrix, we can calculate the probability of predicting a 1 given the observation is a 1 (a true positive). This is known as specificity or True Positive Rate (TPR).
Likewise, we can calculate the probability of prediction a 0 given the observation is a 0 (a true negative). This is known as specificity or True Negative Rate (TNR).
sens(results, truth = Y,
estimate = .pred_class)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 sens binary 0.387
spec(results, truth = Y,
estimate = .pred_class)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 spec binary 0.866
24.5 ROC Curves
A Receiver Operating Characteristic (ROC) curve is a graphical representation of the diagnostic ability of a binary classifier. Logistic regression is such a classifier.
The ROC curve plots the TPR versus the False Positive Rate (1-specificity).
The ROC curve is created by calculating TPR and FPR at various threshold values, ranging from 0 to 1.
The diagonal line represents a random classifier (e.g., flipping a coin). A good classifier has a curve that bows significantly above this diagonal.
The upper-left corner of the plot (TPR = 1, FPR = 0) represents a perfect classifier, achieving both 100% sensitivity and 100% specificity.
Area Under the Curve (AUC)
The Area Under the ROC Curve (AUC) provides a single number summary of the model’s performance:
AUC ranges from 0 to 1.
- AUC = 1: Perfect model.
- AUC = 0.5: Random guessing.
- AUC < 0.5: Worse than random guessing (usually indicates a problem with the model).
An ROC curve helps assess how well logistic regression separates the two classes, independent of the specific threshold chosen. It provides insights into the trade-offs between sensitivity and specificity as you adjust the threshold. Comparing ROC curves or AUC values of different models can guide model selection and threshold optimization.
To find the ROC curve:
|>
results roc_curve(truth = Y, .pred_1) |>
autoplot()
roc_auc(results, truth = Y, .pred_1)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.776