This vignette visualizes classification results from discriminant analysis, using tools from the package.
library("classmap")
library("ggplot2")
## Warning: package 'ggplot2' was built under R version 4.1.3
library("gridExtra")
As a first small example, we consider the Iris data. We first load the data and inspect it.
data(iris)
X <- iris[, 1:4]
y <- iris[, 5]
is.factor(y)
## [1] TRUE
table(y)
## y
##     setosa versicolor  virginica 
##         50         50         50
pairs(X, col = as.numeric(y) + 1, pch = 19)
Now we carry out quadratic discriminant analysis and inspect the output. Note that we can also do linear discriminant analysis by choosing rule = “LDA”.
vcr.train <- vcr.da.train(X, y, rule = "QDA")
names(vcr.train)
##  [1] "yint"      "y"         "levels"    "predint"   "pred"      "altint"   
##  [7] "altlab"    "PAC"       "figparams" "fig"       "farness"   "ofarness" 
## [13] "classMS"   "lCurrent"  "lPred"     "lAlt"
We now inspect the output in detail. First look at the prediction as integer, the prediction as label, the alternative label as integer and the alternative label:
vcr.train$predint 
##   [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
##  [38] 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 2
##  [75] 2 2 2 2 2 2 2 2 2 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3
## [112] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3
## [149] 3 3
vcr.train$pred[c(1:10, 51:60, 101:110)]
##  [1] "setosa"     "setosa"     "setosa"     "setosa"     "setosa"    
##  [6] "setosa"     "setosa"     "setosa"     "setosa"     "setosa"    
## [11] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [16] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [21] "virginica"  "virginica"  "virginica"  "virginica"  "virginica" 
## [26] "virginica"  "virginica"  "virginica"  "virginica"  "virginica"
vcr.train$altint  
##   [1] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
##  [38] 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
##  [75] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2
## [112] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## [149] 2 2
vcr.train$altlab[c(1:10, 51:60, 101:110)]
##  [1] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
##  [6] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [11] "virginica"  "virginica"  "virginica"  "virginica"  "virginica" 
## [16] "virginica"  "virginica"  "virginica"  "virginica"  "virginica" 
## [21] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [26] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
The Probability of Alternative Class (PAC) of each object is found in the $PAC element of the output:
vcr.train$PAC[1:3] 
## [1] 4.918517e-26 7.655808e-19 1.552279e-21
summary(vcr.train$PAC)
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## 0.0000000 0.0000000 0.0000081 0.0237098 0.0010938 0.8456517
The $fig element of the output contains the distance from case i to class g. Let’s look at it for the first 5 objects:
vcr.train$fig[1:5, ]
##            [,1] [,2] [,3]
## [1,] 0.02675535    1    1
## [2,] 0.33639794    1    1
## [3,] 0.16134074    1    1
## [4,] 0.25293196    1    1
## [5,] 0.06600114    1    1
From the fig, the farness of each object can be computed. The farness of an object i is the f(i, g) to its own class:
vcr.train$farness[1:5]
## [1] 0.02675535 0.33639794 0.16134074 0.25293196 0.06600114
summary(vcr.train$farness)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##  0.0153  0.2396  0.5159  0.4996  0.7617  0.9862
The “overall farness” of an object is defined as the lowest f(i, g) it has to any class g (including its own):
summary(vcr.train$ofarness)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##  0.0153  0.2396  0.5145  0.4957  0.7543  0.9862
Objects with ofarness > cutoff are flagged as “outliers”. These can be included in a separate column in the confusion matrix. This confusion matrix can be computed using confmat.vcr, which also returns the accuracy.
To illustrate this we choose a rather low cutoff:
confmat.vcr(vcr.train, cutoff = 0.98)
## 
## Confusion matrix:
##             predicted
## given        setosa versicolor virginica outl
##   setosa         48          0         0    2
##   versicolor      0         48         2    0
##   virginica       0          1        48    1
## 
## The accuracy is 98%.
With the default cutoff = 0.99 no objects are flagged in this example:
confmat.vcr(vcr.train)
## 
## Confusion matrix:
##             predicted
## given        setosa versicolor virginica
##   setosa         50          0         0
##   versicolor      0         48         2
##   virginica       0          1        49
## 
## The accuracy is 98%.
Note that the accuracy is computed before any objects are flagged, so it does not depend on the cutoff.
The confusion matrix can also be constructed showing class numbers instead of labels. This option can be useful for long level names.
confmat.vcr(vcr.train, showClassNumbers = TRUE)
## 
## Confusion matrix:
##      predicted
## given  1  2  3
##     1 50  0  0
##     2  0 48  2
##     3  0  1 49
## 
## The accuracy is 98%.
A stacked mosaic plot made with the stackedplot() function can be used to visualize the confusion matrix. The outliers, if there are any, appear as grey areas on top.
cols <- c("red", "darkgreen", "blue")
stackedplot(vcr.train, classCols = cols, separSize = 1.5,
            minSize = 1, showLegend = TRUE)
stackedplot(vcr.train, classCols = cols, separSize = 1.5,
            minSize = 1, showLegend = TRUE, cutoff = 0.98)
The default stacked mosaic plot has no legend:
stplot <- stackedplot(vcr.train, classCols = cols, 
                     separSize = 1.5, minSize = 1,
                     main = "QDA on iris data")
stplot
We also make the silhouette plot using the silplot() function:
# pdf("Iris_QDA_silhouettes.pdf", width=5.0, height=4.6)
silplot(vcr.train, classCols = cols, 
        main = "Silhouette plot of QDA on iris data")      
##  classNumber classLabel classSize classAveSi
##            1     setosa        50       1.00
##            2 versicolor        50       0.91
##            3  virginica        50       0.95
# dev.off()
We now make the class maps based on the vcr object. This can be done using the classmap() function. We make a separate class map for each of the three classes. We see that class 1 is a very tight class (low PAC, no high farness). Class 2 is not so tight, and has two points which are predicted as virginica. Class 3 has one point predicted as versicolor.
classmap(vcr.train, 1, classCols = cols)
classmap(vcr.train, 2, classCols = cols)
classmap(vcr.train, 3, classCols = cols) # With the default cutoff no farness values stand out:
# With a lower cutoff:
classmap(vcr.train, 3, classCols = cols, cutoff = 0.98)
# Now one point is to the right of the vertical line.
# It also has a black border, meaning that it is flagged
# as an outlier, in the sense that its farness to _all_
# classes is above 0.98.
To illustrate the use of new data we create a fake dataset which is a subset of the training data, where not all classes occur, and ynew has NA’s.
Xnew <- X[c(1:50, 101:150), ]
ynew <- y[c(1:50, 101:150)]
ynew[c(1:10, 51:60)] <- NA
pairs(X, col = as.numeric(y) + 1, pch = 19) # 3 colors
pairs(Xnew, col = as.numeric(ynew) + 1, pch = 19) # only red and blue
Now we build the vcr object on the training data.
vcr.test <- vcr.da.newdata(Xnew, ynew, vcr.train)
Inspect some of the output to confirm that it corresponds with what we would expect:
plot(vcr.test$predint, vcr.train$predint[c(1:50, 101:150)]); abline(0, 1)
plot(vcr.test$altint, vcr.train$altint[c(1:50, 101:150)]); abline(0, 1)
plot(vcr.test$PAC, vcr.train$PAC[c(1:50, 101:150)]); abline(0, 1)
vcr.test$farness 
##   [1]         NA         NA         NA         NA         NA         NA
##   [7]         NA         NA         NA         NA 0.29421328 0.32178116
##  [13] 0.51150351 0.89366298 0.96650511 0.91516067 0.82782724 0.04831270
##  [19] 0.78801603 0.23207165 0.80057706 0.46966655 0.97498614 0.90086633
##  [25] 0.96031572 0.63996116 0.43078605 0.07648892 0.16940167 0.35680444
##  [31] 0.31726541 0.76311424 0.91656430 0.79289643 0.15775639 0.57139387
##  [37] 0.82646629 0.53574220 0.56635148 0.04234576 0.24816964 0.98405396
##  [43] 0.69347810 0.98395599 0.93996829 0.36120046 0.47605905 0.20490701
##  [49] 0.15482827 0.03145084         NA         NA         NA         NA
##  [55]         NA         NA         NA         NA         NA         NA
##  [61] 0.44870413 0.07632847 0.07073258 0.58646621 0.79167886 0.26259281
##  [67] 0.11731858 0.94491559 0.98620090 0.84267322 0.14000705 0.41260605
##  [73] 0.85509019 0.51571738 0.14623030 0.51605016 0.53499547 0.40202972
##  [79] 0.10552473 0.74488824 0.53885697 0.96493110 0.23632013 0.66887310
##  [85] 0.93026914 0.83628519 0.68987941 0.32345397 0.49557670 0.35256862
##  [91] 0.32403877 0.91811245 0.26632690 0.15321938 0.50795409 0.69766299
##  [97] 0.63534800 0.11247730 0.62691061 0.42737442
plot(vcr.test$farness, vcr.train$farness[c(1:50, 101:150)]); abline(0, 1)
plot(vcr.test$fig, vcr.train$fig[c(1:50, 101:150), ]); abline(0, 1)
vcr.test$ofarness 
##   [1] 0.02675535 0.33639794 0.16134074 0.25293196 0.06600114 0.63210603
##   [7] 0.59041424 0.01732745 0.52024594 0.55494759 0.29421328 0.32178116
##  [13] 0.51150351 0.89366298 0.96650511 0.91516067 0.82782724 0.04831270
##  [19] 0.78801603 0.23207165 0.80057706 0.46966655 0.97498614 0.90086633
##  [25] 0.96031572 0.63996116 0.43078605 0.07648892 0.16940167 0.35680444
##  [31] 0.31726541 0.76311424 0.91656430 0.79289643 0.15775639 0.57139387
##  [37] 0.82646629 0.53574220 0.56635148 0.04234576 0.24816964 0.98405396
##  [43] 0.69347810 0.98395599 0.93996829 0.36120046 0.47605905 0.20490701
##  [49] 0.15482827 0.03145084 0.93068594 0.26632690 0.06831897 0.31105631
##  [55] 0.20258388 0.65479346 0.90867003 0.58975324 0.56295693 0.68728265
##  [61] 0.44870413 0.07632847 0.07073258 0.58646621 0.79167886 0.26259281
##  [67] 0.11731858 0.94491559 0.98620090 0.84267322 0.14000705 0.41260605
##  [73] 0.85509019 0.51571738 0.14623030 0.51605016 0.53499547 0.40202972
##  [79] 0.10552473 0.74488824 0.53885697 0.96493110 0.23632013 0.66887310
##  [85] 0.93026914 0.83628519 0.68987941 0.32345397 0.49557670 0.35256862
##  [91] 0.32403877 0.91811245 0.26632690 0.15321938 0.50795409 0.69766299
##  [97] 0.63534800 0.11247730 0.62691061 0.42737442
plot(vcr.test$ofarness, vcr.train$ofarness[c(1:50, 101:150)]); abline(0, 1)
The confusion matrix for the test data, as for the training data, can be constructed by the confmat.vcr() function. A cutoff of 0.98 flags three outliers in this example.
confmat.vcr(vcr.test)
## 
## Confusion matrix:
##            predicted
## given       setosa versicolor virginica
##   setosa        40          0         0
##   virginica      0          1        39
## 
## The accuracy is 98.75%.
confmat.vcr(vcr.test, cutoff = 0.98)
## 
## Confusion matrix:
##            predicted
## given       setosa versicolor virginica outl
##   setosa        38          0         0    2
##   virginica      0          1        38    1
## 
## The accuracy is 98.75%.
Also the stacked mosaic plot can be constructed on the test data:
stplot # to compare with:
stackedplot(vcr.test, classCols = cols, separSize = 1.5, minSize = 1)
## 
## Not all classes occur in these data. The classes to plot are:
## [1] 1 3
We now make the silhouette plot on the test data:
#pdf("Iris_test_QDA_silhouettes.pdf", width=5.0, height=4.3)
silplot(vcr.test, classCols = cols, 
        main = "Silhouette plot of QDA on iris subset") 
##  classNumber classLabel classSize classAveSi
##            1     setosa        40       1.00
##            3  virginica        40       0.94
#dev.off()
Finally, we construct the class maps for the test data. We compare the class map of the training data with that of the test data for each class.
classmap(vcr.train, 1, classCols = cols)
classmap(vcr.test, 1, classCols = cols) 
classmap(vcr.train, 2, classCols = cols)
classmap(vcr.test, 2, classCols = cols)
## Error in classmap(vcr.test, 2, classCols = cols): Class number 2 with label versicolor has no objects to visualize.
classmap(vcr.train, 3, classCols = cols)
classmap(vcr.test, 3, classCols = cols) 
We now analyze the floral buds data, which was also used as an illustration in the paper. First load and inspect the data.
data(data_floralbuds)
X <- as.matrix(data_floralbuds[, 1:6])
y <- data_floralbuds$y
dim(X) # 550  6
## [1] 550   6
length(y) # 550
## [1] 550
table(y)
## y
##  branch     bud  scales support 
##      49     363      94      44
# branch     bud  scales support 
#     49     363      94      44 
# Pairs plot
cols <- c("saddlebrown", "orange", "olivedrab4", "royalblue3")
pairs(X, gap = 0, col = cols[as.numeric(y)]) # hard to separate visually
Now we perform quadratic discriminant analysis:
vcr.obj <- vcr.da.train(X, y)
Construct the confusion matrix without and with outliers shown:
confmat <- confmat.vcr(vcr.obj, showOutliers = FALSE)
## 
## Confusion matrix:
##          predicted
## given     branch bud scales support
##   branch      45   1      1       2
##   bud          0 358      1       4
##   scales       2   0     90       2
##   support      6   3      0      35
## 
## The accuracy is 96%.
confmat.vcr(vcr.obj) 
## 
## Confusion matrix:
##          predicted
## given     branch bud scales support outl
##   branch      45   1      1       2    0
##   bud          0 353      1       4    5
##   scales       2   0     86       2    4
##   support      6   3      0      35    0
## 
## The accuracy is 96%.
Construct the stacked mosaic plot:
stackedplot(vcr.obj, classCols = cols, separSize = 0.6,
            minSize = 1.5,  main = "stacked plot of QDA on floral buds")
# Version in paper:
# pdf("Floralbuds_QDA_stackplot_without_outliers.pdf",
#     width=5, height=4.3)
# stackedplot(vcr.obj, classCols = cols, separSize = 0.6,
#             minSize = 1.5, showOutliers = FALSE,
#             htitle = "given class", vtitle = "predicted class")
# dev.off()
Now make the silhouette plot:
#pdf("Floralbuds_QDA_silhouettes.pdf", width=5.0, height=4.3)
silplot(vcr.obj, classCols = cols,
        main = "Silhouette plot of QDA on floral bud data")      
##  classNumber classLabel classSize classAveSi
##            1     branch        49       0.75
##            2        bud       363       0.96
##            3     scales        94       0.93
##            4    support        44       0.57
#dev.off()
The quasi residual plot can be made with the qresplot() function. We illustate this below by making the quasi residual plot against the sum of the variables. A correlation test confirms that the images with higher sums are significantly easier to classify:
PAC <- vcr.obj$PAC
feat <- rowSums(X); xlab = "rowSums(X)"
# pdf("Floralbuds_QDA_quasi_residual_plot.pdf", width=5, height=4.8)
qresplot(PAC, feat, xlab = xlab, plotErrorBars = TRUE, fac = 2, 
         main = "Floral buds: quasi residual plot")
# dev.off()
cor.test(feat, PAC, method = "spearman") 
## 
## 	Spearman's rank correlation rho
## 
## data:  feat and PAC
## S = 39255896, p-value < 2.2e-16
## alternative hypothesis: true rho is not equal to 0
## sample estimates:
##        rho 
## -0.4156944
Construct the class maps, as shown in the paper:
labels <- c("branch", "bud", "scale", "support")
# classmap of class "bud"
#
# To identify the points that stand out:
# classmap(vcr.obj, 2, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf("Floralbuds_QDA_classmap_bud.pdf", width=7, height=7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.obj, 2, classCols = cols,
         main = "predictions of buds",
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
         cex.main = 1.5) 
# For marking points:
indstomark <- c(294, 70, 69, 152, 204) # from identify = TRUE above
labs  <- letters[seq_len(5)]
xvals <- coords[indstomark, 1] +
  c(0, 0.10, 0.14, 0.10, 0.08) # visual finetuning
yvals <- coords[indstomark, 2] +
  c(0.04, 0.04, 0, -0.03, +0.04)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("topleft", fill = cols[1:4], legend = labels, 
       cex = 1, ncol = 1, bg = "white")
# dev.off()
par(oldpar)
All class maps:
#
# pdf(file = "Floralbuds_all_class_maps.pdf", width = 7, height = 7)
par(mfrow = c(2, 2))
par(mar = c(3.3, 3.2, 2.7, 1.0))
classmap(vcr.obj, 1, classCols = cols,
         main = "predictions of branches")
legend("topright", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
#
par(mar = c(3.3, 0.5, 2.7, 0.3))
classmap(vcr.obj, 2, classCols = cols,
         main = "predictions of buds")
labs  <- letters[seq_len(5)]
xvals <- coords[indstomark, 1] +
  c(0, 0.10, 0.14, 0.10, 0.08) # visual finetuning
yvals <- coords[indstomark, 2] +
  c(0.04, 0.04, 0, -0.03, 0.04)
# xvals <- c( 1.75, 1.68, 1.25, 3.25, 4.00)
# yvals <- c(0.045, 0.92, 0.54, 0.97, 0.045)
text(x = xvals, y = yvals, labels = labs, cex = 1.0)
legend("topleft", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
#
par(mar = c(3.3, 3.2, 2.7, 1.0))
classmap(vcr.obj, 3, classCols = cols,
         main = "predictions of scales")
legend("left", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
# 
par(mar = c(3.3, 0.5, 2.7, 0.3))
classmap(vcr.obj, 4, classCols = cols,
         main = "predictions of supports")
legend("topright", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
# dev.off()
par(oldpar)
We now analyze the MNIST data, originally from the website of Yann LeCun. As the link on his website is currently down, we use a different source. Note that downloading the data may take a minute or two, depending on the speed of the internet connection.
mnist_url <- "https://wis.kuleuven.be/statdatascience/robust/data/mnist-rdata"
url.exists <- suppressWarnings(try(open.connection(url(mnist_url), open = "rt", timeout = 2),  silent = TRUE)[1], classes = "warning")
if (is.null(url.exists)) {load(url(mnist_url))} else {
  print(paste("The data source ", mnist_url, "is not active at the moment. The example can nevertheless be reproduced by downloading the mnist data from another source, formatting the training data to dimensions 60000 x 28 x 28, and running the code below."))
}
close(url(mnist_url))
X_train <- mnist$train$x
y_train <- as.factor(mnist$train$y)
head(y_train)
## [1] 5 0 4 1 9 2
## Levels: 0 1 2 3 4 5 6 7 8 9
# Levels: 0 1 2 3 4 5 6 7 8 9
dim(X_train) # 60000    28    28
## [1] 60000    28    28
length(y_train) # 60000
## [1] 60000
We now inspect the data by plotting a few images
plotImage = function(tempImage) {
  tdm = reshape2::melt(apply((tempImage), 2, rev))
  p = ggplot(tdm, aes(x = Var2, y = Var1, fill = (value))) +
    geom_raster() +
    guides(color = "none", size = "none", fill = "none") +
    theme(axis.title.x = element_blank(),
          axis.title.y = element_blank(),
          axis.text.x = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks.x = element_blank(),
          axis.ticks.y = element_blank()) +
    scale_fill_gradient(low = "white", high = "black")
  p
}
plotImage(X_train[1, , ])
plotImage(X_train[2, , ])
plotImage(X_train[3, , ])
We now unfold the array containing the data to a matrix, and inspect some sample images as well as the average image per digit:
# Change the dimensions of X for the sequel:
dim(X_train) <- c(60000, 28 * 28)
dim(X_train) # 60000    784
## [1] 60000   784
# Sampled digit images:
set.seed(123)
sampledigits <- list()
for (i in 0:9) {
  digit <- i
  idx <- sample(which(y_train == digit), size = 1)
  tempImage <- matrix(unlist(X_train[idx, ]), 28, 28)
  sampledigits[[i + 1]] <- plotImage(tempImage) 
}
psampledigits <- grid.arrange(grobs = sampledigits, ncol = 5)
# ggsave("MNIST_sampled_images.pdf", plot = psampledigits,
#        width = 10, height = 1)
# Averaged digit images:
meanPlots <- list()
for (j in 0:9) {
  m.out <- colMeans(X_train[which(y_train == j), ])
  dim(m.out) <- c(28, 28)
  meanPlots[[j + 1]] <- plotImage(m.out) 
}
meanplot <- grid.arrange(grobs = meanPlots, ncol = 5)
# ggsave("MNIST_averaged_images.pdf", plot = meanplot,
#        width = 10, height = 1)
Before performing discriminant analysis, we reduce the dimension of the data by PCA.
library(svd)
## Warning: package 'svd' was built under R version 4.1.3
ptm <- proc.time()
svd.out <- svd::propack.svd(X_train, neig = 50)
(proc.time() - ptm)[3]
## elapsed 
##   22.12
loadings <- svd.out$v
rm(svd.out)
dataProj <- as.matrix(X_train %*% loadings)
dim(dataProj)
## [1] 60000    50
Now we perform discriminant analysis, which takes roughly 5 seconds.
vcr.train <- vcr.da.train(X = dataProj, y_train)
We compute the confusion matrix and make the stacked mosaic plot:
confmat.vcr(vcr.train, showOutliers = FALSE)
## 
## Confusion matrix:
##      predicted
## given    0    1    2    3    4    5    6    7    8    9
##     0 5833    0   22    6    1   14    2    0   42    3
##     1    0 6436  104   14   32    0    2   13  138    3
##     2   14    1 5807   26   14    0    9   12   69    6
##     3    3    1   88 5821    4   52    0   18  120   24
##     4    6    1   21    3 5704    1   12   14   30   50
##     5   14    0    4   71    2 5222   17    0   78   13
##     6   27    2    6    2    8  114 5703    0   56    0
##     7   13    8   94   14   34   14    0 5936   54   98
##     8   10   24   40   72    8   40    2    4 5625   26
##     9   17    2   23   65   59   14    1   77   93 5598
## 
## The accuracy is 96.14%.
cols <- c("red3", "darkorange", "gold2", "darkolivegreen3",
         "darkolivegreen4", "cadetblue3", "deepskyblue4", 
         "darkslateblue", "darkorchid3", "deeppink4")
# stacked plot in paper:
# pdf("MNIST_stackplot_with_outliers.pdf", width=5, height=4.3)
stackedplot(vcr.train, classCols = cols, separSize = 0.6,
            minSize = 1.5, htitle = "given class",
            main = "Stacked plot of QDA on MNIST training data", vtitle = "predicted class")
# dev.off()
The silhouette plot:
# pdf("MNIST_QDA_silhouettes.pdf", width=5.0, height=4.6)
silplot(vcr.train, classCols = cols,
        main = "Silhouette plot of QDA on MNIST training data")      
##  classNumber classLabel classSize classAveSi
##            1          0      5923       0.97
##            2          1      6742       0.91
##            3          2      5958       0.95
##            4          3      6131       0.90
##            5          4      5842       0.95
##            6          5      5421       0.92
##            7          6      5918       0.93
##            8          7      6265       0.89
##            9          8      5851       0.92
##           10          9      5949       0.88
# dev.off()
Now we make the class maps.
wnq <- function(string, qwrite=TRUE) { # auxiliary function
  # writes a line without quotes
  if (qwrite) write(noquote(string), file = "", ncolumns = 100)
}
showdigit <- function(digit=digit, i, plotIt = TRUE) {
  idx = which(y_train == digit)[i]
  # wnq(paste("Estimated digit: ", as.numeric(vcr.train$pred[idx]), sep=""))
  tempImage <- matrix(unlist(X_train[idx, ]), 28, 28)
  if (plotIt) {plot(plotImage(tempImage))}
  return(plotImage(tempImage))
}
Class map of digit 0, shown in paper:
digit <- 0
#
# To identify outliers:
# classmap(vcr.train, digit+1, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf(paste0("MNIST_classmap_digit", digit, ".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.train, digit + 1, classCols = cols,
         main = paste0("predictions of digit ",digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
indstomark <- c(4000, 3964, 5891, 2485, 822, 
               2280, 2504, 3906, 5869, 1034) # from identify = TRUE
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(-0.04, -0.01, 0, -0.11, 0.06,
    0.07, 0.06, 0.10, 0.06, 0.09)
yvals <- coords[indstomark, 2] +
  c(-0.03, -0.03, -0.03, 0.022, -0.025, 
    -0.025, -0.035, -0.025, 0.03, 0.03)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("left", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
pred <- vcr.train$pred # needed for discussion plots
tempPreds <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit(digit, idx, plotIt = FALSE)
  tempplot <- arrangeGrob(tempplot, 
    bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] = tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, ncol = 5)
# ggsave(paste0("MNIST_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))
Class map of digit 1, shown in paper:
digit <- 1
# pdf(paste0("MNIST_classmap_digit", digit, ".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
classmap(vcr.train, digit + 1, classCols = cols,
         main = paste0("predictions of digit ", digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
legend("left", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
# indices of the 1s predicted as 2 (takes a while):
#
indstomark <- which(vcr.train$predint[which(y_train == digit)] == 3)
length(indstomark) # 104
## [1] 104
labs  <- letters[1:length(indstomark)]
pred <- vcr.train$pred # needed for discussion plots
tempPreds    <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
        bottom = paste0("\"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 8)
# ggsave(paste0("MNIST_discussionplot_digit", digit, "predictedAs2b.pdf"),
#        plot = discussionPlot, width = 10,
#        height = (length(indstomark) %/% 10 +
#                    (length(indstomark) %% 10 > 0)))
                 
# The digits 1 predicted as a 2 are mostly ones written with
# a horizontal line at the bottom.
Class map of digit 2:
digit <- 2
# To identify outliers:
# classmap(vcr.train, digit + 1, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf(paste0("MNIST_classmap_digit", digit,".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.train, digit + 1, classCols = cols,
                  main = paste0("predictions of digit", digit), cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
                  cex.main = 1.5)
indstomark <- c(3164, 5434, 2319 , 4224, 3682, 
               2642, 4920, 1233, 3741, 3993) # from identify = TRUE
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(0, 0.08, 0, 0, 0, 0, 0, 0, 0, 0)
yvals <- coords[indstomark, 2] +
  c(-0.03, -0.03, -0.03, -0.03, -0.03, 
    -0.03, -0.03, -0.03, 0.03, 0.03)  
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("right", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
pred <- vcr.train$pred # needed for discussion plots
tempPreds    <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
        bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 5)
# ggsave(paste0("MNIST_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))
Now we analyze the MNIST test data. First load and inspect the data, and project it onto the PCA subspace extracted from the training data.
X_test <- mnist$test$x
y_test <- as.factor(mnist$test$y)
head(y_test)
## [1] 7 2 1 0 4 1
## Levels: 0 1 2 3 4 5 6 7 8 9
#
dim(X_test) # 10000    28    28
## [1] 10000    28    28
length(y_test) # 10000
## [1] 10000
plotImage(X_test[1, , ])
plotImage(X_test[2, , ])
plotImage(X_test[3, , ])
dim(X_test) <- c(10000, 28 * 28)
dim(X_test) # 10000  784
## [1] 10000   784
dataProj_test <- as.matrix(X_test %*% loadings)
Now prepare the VCR object:
vcr.test <- vcr.da.newdata(Xnew = dataProj_test,
                           ynew = y_test,
                           vcr.da.train.out = vcr.train)
Build the confusion matrix and plot a stacked mosaic plot of the classification performance on the test data:
confmat.vcr(vcr.test, showOutliers = FALSE, showClassNumbers = TRUE)
## 
## Confusion matrix:
##      predicted
## given    1    2    3    4    5    6    7    8    9   10
##    1   970    0    1    0    0    2    1    1    5    0
##    2     0 1097   11    3    2    1    1    0   20    0
##    3     2    0 1002    3    3    0    2    1   19    0
##    4     1    0    9  972    0    5    0    2   17    4
##    5     0    0    4    0  965    0    3    2    2    6
##    6     2    0    1   18    0  859    1    1   10    0
##    7     8    1    2    0    4   12  924    0    7    0
##    8     1    2   28    1    3    2    0  958   14   19
##    9     3    0    9   12    1    5    1    2  935    6
##    10    5    1   11    6   10    2    0    6   18  950
## 
## The accuracy is 96.32%.
# In supplementary material:
# pdf("MNISTtest_stackplot_with_outliers.pdf", width = 5, height = 4.3)
stackedplot(vcr.test, classCols = cols, separSize = 0.6,
            main = "Stacked plot of QDA on MNIST test data",
            minSize = 1.5)
# dev.off()
Silhouette plot:
#pdf("MNIST_test_QDA_silhouettes.pdf", width = 5.0, height = 4.6)
silplot(vcr.test, classCols = cols,
        main = "Silhouette plot of QDA on MNIST test data")      
##  classNumber classLabel classSize classAveSi
##            1          0       980       0.98
##            2          1      1135       0.93
##            3          2      1032       0.94
##            4          3      1010       0.92
##            5          4       982       0.96
##            6          5       892       0.92
##            7          6       958       0.93
##            8          7      1028       0.86
##            9          8       974       0.92
##           10          9      1009       0.88
#dev.off()
Now we can construct the class maps on the test data. First for digit 0:
showdigit_test <- function(digit = digit, i, plotIt = TRUE) {
  idx = which(y_test == digit)[i]
  # wnq(paste("Estimated digit: ", as.numeric(vcr.test$pred[idx]), sep = ""))
  tempImage <- matrix(unlist(X_test[idx, ]), 28, 28)
  if (plotIt) {plot(plotImage(tempImage))}
  return(plotImage(tempImage))
}
digit <- 0
# classmap(vcr.test, digit+1, classCols = cols, identify = TRUE)
# pdf(paste0("MNISTtest_classmap_digit", digit,".pdf"))
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.test, digit + 1, classCols = cols,
         main = paste0("predictions of digit ", digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
indstomark <- c(140, 630, 241, 967, 189,
               377, 78, 943, 64, 354)
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(0.08, 0.07, -0.07, 0.06, 0,
    0.04, 0.05, 0.09, -0.04, 0.09)
yvals <- coords[indstomark, 2] +
  c(-0.025, -0.03, -0.024, -0.025, -0.03, 
    -0.03, -0.03, 0.022, 0.035, 0.03)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("left", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
pred <- vcr.test$pred # needed for discussion plots
tempPreds <- (pred[which(y_test == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit_test(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
      bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 5)
# ggsave(paste0("MNISTtest_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))
Now for digit 3:
digit <- 3
# classmap(vcr.test, digit + 1, classCols = cols, identify = TRUE)
# pdf(paste0("MNISTtest_classmap_digit", digit, ".pdf"))
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.test, digit + 1, classCols = cols,
         main = paste0("predictions of digit ", digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
indstomark <- c(883, 659, 262, 60, 310,
               832, 223, 784, 835, 289)
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(-0.01, 0.08, -0.10, 0.06, 0.07, 
    0.06, 0.03, 0.11, 0.02, 0.06)
yvals <- coords[indstomark, 2] +
  c(0.035, 0.033, -0.017, -0.022, -0.025, 
    -0.025, -0.033, -0.022, 0.035, 0.038)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("right", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
# dev.off()
par(oldpar)
pred <- vcr.test$pred # needed for discussion plots
tempPreds    <- (pred[which(y_test == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit_test(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
    bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 5)
# ggsave(paste0("MNISTtest_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))