24  Logistic Regression in Tidymodels

“In God we trust, all others bring data.” - W Edwards Deming

24.1 Multiple Logistic Regression

In the last chapter, we introduced logistic regression using a simple (one predictor) model. The simple logistic regression model is easily extended to more than one predictor variable.

In fact, several predictor variables are usually required with logistic regression to obtain adequate description and useful predictions.

In matrix notation, the logistic response function becomes \[ E\{Y\} = \frac{\exp{\left(\textbf{X}^\prime\boldsymbol{\beta}\right)}}{1+\exp{\left(\textbf{X}^\prime\boldsymbol{\beta}\right)}} \]

Like the simple logistic response function, the multiple logistic-response function is monotonic and sigmoidal in shape with respect to \(\textbf{X}^\prime\boldsymbol{\beta}\) and is almost linear when \(\pi\) is between .2 and .8.

The \(X\) variables may be different predictor variables, or some may represent curvature and/or interaction effects.

Also, the predictor variables may be quantitative, or they may be qualitative and represented by indicator variables.

This flexibility makes the multiple logistic regression model very useful.

24.2 Example Using Tidymodels

In a health study to investigate an epidemic outbreak of a disease that is spread by mosquitoes, individuals were randomly sampled within two sectors in a city to determine if the person had recently contracted the disease under study.

This was ascertained by the interviewer, who asked pertinent questions to assess whether certain specific symptoms associated with the disease were present during the specified period.

The response variable \(y\) was coded 1 if this disease was determined to have been present, and 0 if not.

Three predictor variables were included in the study, representing known or potential risk factors.

They are age, socioeconomic status of household, and sector within city.

Age (\(x_1\)) is a quantitative variable. Socioeconomic status is a categorical variable with three levels. It is represented by two indicator variables (\(x_2\) and \(x_3\)), as follows: \[ \begin{align*} Class &\quad x_2 & x_3\\ Upper & \quad 0 & 0\\ Middle & \quad 1 & 0\\ Lower & \quad 0 & 1 \end{align*} \]

City sector is also a categorical variable. Since there were only two sectors in the study, one indicator variable (\(x_4\)) was used, defined so that \(x_4 = 0\) for sector 1 and \(x_4 = 1\) for sector 2.

We will first convert the response variable to a factor and then look at the scatterplot matrix.

library(tidyverse)
library(tidymodels)
library(GGally)

dat = read.table("http://users.stat.ufl.edu/~rrandles/sta4210/Rclassnotes/data/textdatasets/KutnerData/Chapter%2014%20Data%20Sets/CH14TA03.txt")

names(dat) = c("id", "X1", "X2", "X3", "X4", "Y")

dat = dat |> mutate(Y = relevel(as.factor(Y), ref="1"))

ggpairs(dat[,-1])

We can now setup the recipe, model, and workflow to fit the model.

dat_recipe = recipe(Y~X1+X2+X3+X4, data = dat) 

model = logistic_reg() |> 
  set_engine("glm")

wf = workflow() |> 
  add_recipe(dat_recipe) |> 
  add_model(model)

fit = wf |> fit(data=dat)

fit |> tidy()
# A tibble: 5 × 5
  term        estimate std.error statistic  p.value
  <chr>          <dbl>     <dbl>     <dbl>    <dbl>
1 (Intercept)   2.31      0.643      3.60  0.000319
2 X1           -0.0298    0.0135    -2.20  0.0276  
3 X2           -0.409     0.599     -0.682 0.495   
4 X3            0.305     0.604      0.505 0.613   
5 X4           -1.57      0.502     -3.14  0.00169 

The tidy() function returns the estimated values of the coefficients. Recall that we want to see the exponentials of the coefficients.

fit |> tidy(exponentiate=T)
# A tibble: 5 × 5
  term        estimate std.error statistic  p.value
  <chr>          <dbl>     <dbl>     <dbl>    <dbl>
1 (Intercept)   10.1      0.643      3.60  0.000319
2 X1             0.971    0.0135    -2.20  0.0276  
3 X2             0.664    0.599     -0.682 0.495   
4 X3             1.36     0.604      0.505 0.613   
5 X4             0.207    0.502     -3.14  0.00169 

When we increase X1 (age) by one unit, we expect a 3% decrease in the odds of the disease being present, given the other predictors are held constant.

For X2 being a 1 (middle class), we expect a 34% decrease in the odds of the disease being present, given the other predictors are held constant.

For X3 being a 1 (lower class), we expect a 36% increase in the odds of the disease being present, given the other predictors are held constant.

For X4 being a 1 (sector 2), we expect 79% decrease in the odds of the disease being present, given the other predictors are held constant.

With the predict function, we can obtain the predicted probabilities of the fitted model.

24.3 Predictions

pred_prob = predict(fit,
                    new_data = dat,
                    type = "prob")

pred_prob
# A tibble: 98 × 2
   .pred_1 .pred_0
     <dbl>   <dbl>
 1  0.209    0.791
 2  0.219    0.781
 3  0.106    0.894
 4  0.371    0.629
 5  0.111    0.889
 6  0.136    0.864
 7  0.0802   0.920
 8  0.273    0.727
 9  0.244    0.756
10  0.309    0.691
# ℹ 88 more rows

If we classify each observation based on which in the most likely prediction, we can obtain the predicted outcomes.

pred_class = predict(fit,
                    new_data = dat,
                    type = "class")

pred_class
# A tibble: 98 × 1
   .pred_class
   <fct>      
 1 0          
 2 0          
 3 0          
 4 0          
 5 0          
 6 0          
 7 0          
 8 0          
 9 0          
10 0          
# ℹ 88 more rows

The value of the linear model can be obtained from the predict function. This value is what is obtained when the values of the predictor variables are substituted into the linear part of the model \[ \hat{\beta}_0+\hat{\beta}_1 x_1 + \hat{\beta}_2 x_2 + \hat{\beta}_3 x_3 + \hat{\beta}_4 x_4 \]

pred_raw = predict(fit,
                    new_data = dat,
                    type = "raw")

pred_raw
           1            2            3            4            5            6 
 1.331181765  1.271681581  2.134434261  0.527929270  2.082687707  1.844686968 
           7            8            9           10           11           12 
 2.439688816  0.981891712  1.130642174  0.803391157  1.628682690  1.628682690 
          13           14           15           16           17           18 
 1.509682320  2.045183984 -0.362567837  0.083683549 -1.255070610  0.500184843 
          19           20           21           22           23           24 
 0.559685028  0.291934196 -0.295356597 -0.622607613  0.262184104  0.002144328 
          25           26           27           28           29           30 
 0.626938844  0.061644512 -0.206106319  2.558689186  0.803433733  2.023187522 
          31           32           33           34           35           36 
 2.142187892  1.636393745  0.862891342  2.193934446 -0.265563928  0.716189122 
          37           38           39           40           41           42 
 0.240145067  0.864939584 -0.176356227  1.013690046 -1.247359554  0.351434381 
          43           44           45           46           47           48 
 0.024183365 -1.344320887  0.567438660 -0.027605765  0.448438290 -0.890315869 
          49           50           51           52           53           54 
-0.860608353 -0.801108168 -1.693610940 -1.463321256  0.321684289 -0.065066913 
          55           56           57           58           59           60 
-0.184067282  0.202683919 -0.422068022 -0.749319038 -0.184067282 -1.076570055 
          61           62           63           64           65           66 
 2.082687707  2.469438909  2.558689186  2.142187892  0.862933918  1.963687338 
          67           68           69           70           71           72 
 1.598932597  1.420432043  0.944430564  1.479932227  1.509682320  1.509682320 
          73           74           75           76           77           78 
 1.479932227  0.765930009  2.290938354  1.725644022  1.249685119  1.309142728 
          79           80           81           82           83           84 
 2.223684538  1.368642913  1.160392266  2.439688816  0.684433363  1.100934657 
          85           86           87           88           89           90 
 0.743890972  2.074934076  1.666143837  2.201688077  2.439688816  2.439688816 
          91           92           93           94           95           96 
 2.409938724  2.499189001  2.380188631  1.636393745  1.666186413  2.052937615 
          97           98 
 2.290938354  1.576936136 

These raw values can be used visualize the predicted probabilities and the actual values of the response variable.

results = dat |> 
  select(Y) |> 
  bind_cols(pred_raw, pred_prob, pred_class)

results |> 
  ggplot(aes(x = pred_raw, y = .pred_1, col = Y))+
  geom_point()+
  geom_hline(yintercept = 0.5)

A line at 0.5 represents the value for which we would classify the predicted probabilities. Values above 0.5 would be classified as a 1 (disease present) and values below 0.5 would be classified as a 0 (disease not present). The color of the dots are the actual values of the response variable. If the model was able to predict the response variable perfectly, then all of the values above 0.5 would be blue and all the values below 0.5 would be red. Since there are some red above the line and some blue below the line, then we have missclassification.

24.4 Confusion Matrix

To get a better idea of which is being missclassified, we can look at the confusion matrix.

conf_mat(results, truth = Y,
         estimate =  .pred_class)
          Truth
Prediction  1  0
         1 12  9
         0 19 58

We see from the confusion matrix that there are 9 observations that are predicted to have the disease but really did not (false positives). There were also 19 observations that were predicted to not have the disease but really did (false negatives).

From this confusion matrix, we can calculate the probability of predicting a 1 given the observation is a 1 (a true positive). This is known as specificity or True Positive Rate (TPR).

Likewise, we can calculate the probability of prediction a 0 given the observation is a 0 (a true negative). This is known as specificity or True Negative Rate (TNR).

sens(results, truth = Y,
     estimate = .pred_class)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 sens    binary         0.387
spec(results, truth = Y,
     estimate = .pred_class)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 spec    binary         0.866

24.5 ROC Curves

A Receiver Operating Characteristic (ROC) curve is a graphical representation of the diagnostic ability of a binary classifier. Logistic regression is such a classifier.

The ROC curve plots the TPR versus the False Positive Rate (1-specificity).

The ROC curve is created by calculating TPR and FPR at various threshold values, ranging from 0 to 1.

The diagonal line represents a random classifier (e.g., flipping a coin). A good classifier has a curve that bows significantly above this diagonal.

The upper-left corner of the plot (TPR = 1, FPR = 0) represents a perfect classifier, achieving both 100% sensitivity and 100% specificity.

Area Under the Curve (AUC)

The Area Under the ROC Curve (AUC) provides a single number summary of the model’s performance:

AUC ranges from 0 to 1.

  • AUC = 1: Perfect model.
  • AUC = 0.5: Random guessing.
  • AUC < 0.5: Worse than random guessing (usually indicates a problem with the model).

An ROC curve helps assess how well logistic regression separates the two classes, independent of the specific threshold chosen. It provides insights into the trade-offs between sensitivity and specificity as you adjust the threshold. Comparing ROC curves or AUC values of different models can guide model selection and threshold optimization.

To find the ROC curve:

results |> 
  roc_curve(truth = Y, .pred_1) |> 
  autoplot()

roc_auc(results, truth = Y, .pred_1)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.776