Nearest Centroid Classifier

Eric Bridgeford

2020-06-25

require(lolR)
require(ggplot2)
require(MASS)
n=400
d=2

Nearest Centroid Classifier

In this notebook, we will investigate how to use the nearest centroid classifier.

We simulate 400 examples of 30 dimensional points:

testdat <- lol.sims.mean_diff(n, d)
X <- testdat$X
Y <- testdat$Y

data <- data.frame(x1=X[,1], x2=X[,2], y=Y)
data$y <- factor(data$y)
ggplot(data, aes(x=x1, y=x2, color=y)) +
  geom_point() +
  xlab("x1") +
  ylab("x2") +
  ggtitle("Simulated Data") +
  xlim(-4, 6) +
  ylim(-4, 4)

We estimate the centers with the nearestCentroid classifier:

classifier <- lol.classify.nearestCentroid(X, Y)

data <- cbind(data, data.frame(size=1))
data <- rbind(data, data.frame(x1=classifier$centroids[,1], x2=classifier$centroids[,2], y="center", size=5))
ggplot(data, aes(x=x1, y=x2, color=y, size=size)) +
  geom_point() +
  xlab("x1") +
  ylab("x2") +
  ggtitle("Data with estimated Centers") +
  guides(size=FALSE) +
  xlim(-4, 6) +
  ylim(-4, 4)

Yhat <- predict(classifier, X)
data$y[1:(length(data$y) - 2)] <- Yhat
ggplot(data, aes(x=x1, y=x2, color=y, size=size)) +
  geom_point() +
  xlab("x1") +
  ylab("x2") +
  ggtitle("Data with Predictions") +
  guides(size=FALSE) +
  xlim(-4, 6) +
  ylim(-4, 4)