Fashion MNIST database¶
Objective: Train a basic Convolutional Neural Network (CNN) to classify fashion products.
The Fashion MNIST database is a dataset of 28x28 grayscale images of fashion products, designed to serve as a direct drop-in replacement for the original MNIST database for benchmarking machine learning algorithms. It consists of 70,000 images, with 60,000 images in the training set and 10,000 images in the test set. Each image is labeled with one of 10 categories:
- T-shirt/top: 0
- Trouser: 1
- Pullover: 2
- Dress: 3
- Coat: 4
- Sandal: 5
- Shirt: 6
- Sneaker: 7
- Bag: 8
- Ankle boot: 9
The Fashion MNIST database was created to provide a more challenging classification task than the simple MNIST database. It is freely available and commonly used in machine learning libraries.
Import libraries¶
InĀ [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
from sklearn.metrics import classification_report, ConfusionMatrixDisplay
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
Load the dataset¶
InĀ [2]:
# Download training and test data from open datasets
dataset_train = datasets.FashionMNIST(root='/tmp', download=True, transform=ToTensor())
dataset_right = datasets.FashionMNIST(root='/tmp', train=False, download=True, transform=ToTensor())
dataset_validation, dataset_test = random_split(dataset_right, [int(len(dataset_right) / 2)] * 2)
classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
# Create data loaders
dataloader_train = DataLoader(dataset_train, batch_size=32)
dataloader_validation = DataLoader(dataset_validation, batch_size=32)
dataloader_test = DataLoader(dataset_test, batch_size=len(dataset_test))
print("Number of training samples:", len(dataset_train))
print("Number of validation samples:", len(dataset_validation))
print("Number of test samples:", len(dataset_test))
print("Image size:", tuple(dataset_train[0][0].shape))
Number of training samples: 60000 Number of validation samples: 5000 Number of test samples: 5000 Image size: (1, 28, 28)
Visualize the dataset¶
InĀ [3]:
indexes = np.random.choice(range(0, len(dataset_train)), size=16, replace=False)
samples = [(dataset_train[index][0], dataset_train[index][1]) for index in indexes]
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
fig.suptitle('Random samples')
for ax, sample in zip(axs.flatten(), samples):
ax.imshow(sample[0][0], cmap="gray")
ax.set_title(classes[sample[1]])
ax.axis("off")
plt.tight_layout()
plt.show()
Visualize the class distribution¶
InĀ [4]:
_, y_train = zip(*dataset_train)
plt.figure()
sns.countplot(x=y_train, hue=y_train, palette="tab10", stat="percent", legend=False)
plt.title("Class")
plt.show()
Build a CNN¶
InĀ [5]:
# Get cpu, gpu or mps device for training
device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using {device} device!\n")
# Define model
class ConvolutionalNeuralNetwork(nn.Module):
def __init__(self, number_classes):
super().__init__()
self.linear_relu_stack = nn.Sequential(
# Channel width after this layer: 28
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
nn.ReLU(),
# Channel width after this layer: 14
nn.MaxPool2d(kernel_size=2),
# Channel width after this layer: 14
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, padding=2),
nn.ReLU(),
# Channel width after this layer: 7
nn.MaxPool2d(kernel_size=2),
# Flatten the output of the convolutional layers
nn.Flatten(),
# In total we have 32 channels which are each 7 * 7 in size
nn.Linear(32 * 7 * 7, number_classes))
def forward(self, x):
logits = self.linear_relu_stack(x)
return logits
number_classes = np.unique(y_train).size
model = ConvolutionalNeuralNetwork(number_classes).to(device)
print(model)
Using cuda device!
ConvolutionalNeuralNetwork(
(linear_relu_stack): Sequential(
(0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Flatten(start_dim=1, end_dim=-1)
(7): Linear(in_features=1568, out_features=10, bias=True)
)
)
Compile and train the CNN¶
InĀ [6]:
def train(dataloader, model, loss_function, optimizer):
size = len(dataloader.dataset)
number_batches = len(dataloader)
model.train()
train_loss, train_accuracy = 0, 0
for X, y in dataloader:
X, y = X.to(device), y.to(device)
# Compute prediction error
y_pred = model(X)
loss = loss_function(y_pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
train_loss += loss.item()
train_accuracy += (y_pred.argmax(1) == y).type(torch.float).sum().item()
train_loss /= number_batches
train_accuracy /= size
return train_accuracy, train_loss
def validation(dataloader, model, loss_function):
size = len(dataloader.dataset)
number_batches = len(dataloader)
model.eval()
validation_loss, validation_accuracy = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
y_pred = model(X)
loss = loss_function(y_pred, y)
validation_loss += loss.item()
validation_accuracy += (y_pred.argmax(1) == y).type(torch.float).sum().item()
validation_loss /= number_batches
validation_accuracy /= size
return validation_accuracy, validation_loss
InĀ [7]:
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
epochs = 10
history = {"accuracy": [], "loss": [], "val_accuracy": [], "val_loss": []}
for epoch in range(epochs):
start_time = datetime.now()
accuracy, loss = train(dataloader_train, model, loss_function, optimizer)
val_accuracy, val_loss = validation(dataloader_validation, model, loss_function)
stop_time = datetime.now()
history["accuracy"].append(accuracy)
history["loss"].append(loss)
history["val_accuracy"].append(val_accuracy)
history["val_loss"].append(val_loss)
print(f"Epoch {epoch + 1}/{epochs}")
print(f"\telapsed time: {(stop_time - start_time).total_seconds():.3f}s - accuracy: {accuracy:.4f} - loss: {loss:.4f} - val_accuracy: {val_accuracy:.4f} - val_loss: {val_loss:.4f}")
Epoch 1/10 elapsed time: 7.771s - accuracy: 0.8376 - loss: 0.4534 - val_accuracy: 0.8646 - val_loss: 0.3788 Epoch 2/10 elapsed time: 7.629s - accuracy: 0.8911 - loss: 0.3037 - val_accuracy: 0.8892 - val_loss: 0.3256 Epoch 3/10 elapsed time: 7.449s - accuracy: 0.9057 - loss: 0.2633 - val_accuracy: 0.8926 - val_loss: 0.3043 Epoch 4/10 elapsed time: 7.832s - accuracy: 0.9142 - loss: 0.2377 - val_accuracy: 0.8984 - val_loss: 0.2900 Epoch 5/10 elapsed time: 7.638s - accuracy: 0.9212 - loss: 0.2179 - val_accuracy: 0.9006 - val_loss: 0.2847 Epoch 6/10 elapsed time: 7.648s - accuracy: 0.9271 - loss: 0.2014 - val_accuracy: 0.9026 - val_loss: 0.2827 Epoch 7/10 elapsed time: 7.444s - accuracy: 0.9323 - loss: 0.1874 - val_accuracy: 0.9022 - val_loss: 0.2806 Epoch 8/10 elapsed time: 7.475s - accuracy: 0.9372 - loss: 0.1745 - val_accuracy: 0.9024 - val_loss: 0.2833 Epoch 9/10 elapsed time: 7.715s - accuracy: 0.9407 - loss: 0.1638 - val_accuracy: 0.9032 - val_loss: 0.2890 Epoch 10/10 elapsed time: 7.586s - accuracy: 0.9443 - loss: 0.1538 - val_accuracy: 0.9064 - val_loss: 0.2990
InĀ [8]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(history['accuracy'])
ax1.plot(history['val_accuracy'])
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Accuracy")
ax1.legend(["Training", "Validation"])
ax2.plot(history['loss'])
ax2.plot(history['val_loss'])
ax2.set_xlabel("Epochs")
ax2.set_ylabel("Loss")
ax2.legend(["Training", "Validation"])
plt.show()
Evaluate the CNN¶
InĀ [9]:
indexes = np.random.choice(range(0, len(dataset_test)), size=16, replace=False)
images = [dataset_test[index][0] for index in indexes]
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
fig.suptitle('Random samples')
with torch.no_grad():
for image, ax in zip(images, axs.flatten()):
X = torch.unsqueeze(image, 0).to(device)
prediction_logits = model(X)
ax.imshow(image[0], cmap="gray")
ax.set_title("Prediction: " + classes[prediction_logits.argmax(1).item()])
ax.axis("off")
plt.tight_layout()
plt.show()
InĀ [10]:
with torch.no_grad():
for X_test, y_test in dataloader_test:
X_test = X_test.to(device)
y_pred = model(X_test).argmax(1).cpu().numpy()
print(classification_report(y_test, y_pred, digits=4))
ConfusionMatrixDisplay.from_predictions(y_test, y_pred)
plt.grid(False)
plt.show()
precision recall f1-score support
0 0.8655 0.8690 0.8672 496
1 0.9941 0.9845 0.9893 517
2 0.8966 0.7924 0.8413 525
3 0.9094 0.9347 0.9219 505
4 0.7828 0.8773 0.8273 497
5 0.9918 0.9737 0.9826 494
6 0.7510 0.7312 0.7410 491
7 0.9433 0.9746 0.9587 512
8 0.9765 0.9807 0.9786 467
9 0.9693 0.9556 0.9624 496
accuracy 0.9070 5000
macro avg 0.9080 0.9074 0.9070 5000
weighted avg 0.9081 0.9070 0.9069 5000
Run in Google Colab