Fashion MNIST database¶

No description has been provided for this imageRun in Google Colab

Objective: Train a basic Convolutional Neural Network (CNN) to classify fashion products.

The Fashion MNIST database is a dataset of 28x28 grayscale images of fashion products, designed to serve as a direct drop-in replacement for the original MNIST database for benchmarking machine learning algorithms. It consists of 70,000 images, with 60,000 images in the training set and 10,000 images in the test set. Each image is labeled with one of 10 categories:

  • T-shirt/top: 0
  • Trouser: 1
  • Pullover: 2
  • Dress: 3
  • Coat: 4
  • Sandal: 5
  • Shirt: 6
  • Sneaker: 7
  • Bag: 8
  • Ankle boot: 9

The Fashion MNIST database was created to provide a more challenging classification task than the simple MNIST database. It is freely available and commonly used in machine learning libraries.

Import libraries¶

InĀ [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
from sklearn.metrics import classification_report, ConfusionMatrixDisplay
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')

Load the dataset¶

InĀ [2]:
# Download training and test data from open datasets
dataset_train = datasets.FashionMNIST(root='/tmp', download=True, transform=ToTensor())
dataset_right = datasets.FashionMNIST(root='/tmp', train=False, download=True, transform=ToTensor())
dataset_validation, dataset_test = random_split(dataset_right, [int(len(dataset_right) / 2)] * 2)
classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]

# Create data loaders
dataloader_train = DataLoader(dataset_train, batch_size=32)
dataloader_validation = DataLoader(dataset_validation, batch_size=32)
dataloader_test = DataLoader(dataset_test, batch_size=len(dataset_test))

print("Number of training samples:", len(dataset_train))
print("Number of validation samples:", len(dataset_validation))
print("Number of test samples:", len(dataset_test))
print("Image size:", tuple(dataset_train[0][0].shape))
Number of training samples: 60000
Number of validation samples: 5000
Number of test samples: 5000
Image size: (1, 28, 28)

Visualize the dataset¶

InĀ [3]:
indexes = np.random.choice(range(0, len(dataset_train)), size=16, replace=False)
samples = [(dataset_train[index][0], dataset_train[index][1]) for index in indexes]

fig, axs = plt.subplots(4, 4, figsize=(8, 8))
fig.suptitle('Random samples')

for ax, sample in zip(axs.flatten(), samples):
    ax.imshow(sample[0][0], cmap="gray")
    ax.set_title(classes[sample[1]])
    ax.axis("off")

plt.tight_layout()
plt.show()
No description has been provided for this image

Visualize the class distribution¶

InĀ [4]:
_, y_train = zip(*dataset_train)

plt.figure()
sns.countplot(x=y_train, hue=y_train, palette="tab10", stat="percent", legend=False)
plt.title("Class")
plt.show()
No description has been provided for this image

Build a CNN¶

InĀ [5]:
# Get cpu, gpu or mps device for training
device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using {device} device!\n")

# Define model
class ConvolutionalNeuralNetwork(nn.Module):
    def __init__(self, number_classes):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            # Channel width after this layer: 28
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, padding=2),
            nn.ReLU(),
            # Channel width after this layer: 14
            nn.MaxPool2d(kernel_size=2),
            # Channel width after this layer: 14
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, padding=2),
            nn.ReLU(),
            # Channel width after this layer: 7
            nn.MaxPool2d(kernel_size=2),
            # Flatten the output of the convolutional layers
            nn.Flatten(),
            # In total we have 32 channels which are each 7 * 7 in size
            nn.Linear(32 * 7 * 7, number_classes))

    def forward(self, x):
        logits = self.linear_relu_stack(x)

        return logits

number_classes = np.unique(y_train).size

model = ConvolutionalNeuralNetwork(number_classes).to(device)
print(model)
Using cuda device!

ConvolutionalNeuralNetwork(
  (linear_relu_stack): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1568, out_features=10, bias=True)
  )
)

Compile and train the CNN¶

InĀ [6]:
def train(dataloader, model, loss_function, optimizer):
    size = len(dataloader.dataset)
    number_batches = len(dataloader)
    model.train()
    train_loss, train_accuracy = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        y_pred = model(X)
        loss = loss_function(y_pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()
        train_accuracy += (y_pred.argmax(1) == y).type(torch.float).sum().item()
    
    train_loss /= number_batches
    train_accuracy /= size

    return train_accuracy, train_loss

def validation(dataloader, model, loss_function):
    size = len(dataloader.dataset)
    number_batches = len(dataloader)
    model.eval()
    validation_loss, validation_accuracy = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            loss = loss_function(y_pred, y)
            validation_loss += loss.item()
            validation_accuracy += (y_pred.argmax(1) == y).type(torch.float).sum().item()

    validation_loss /= number_batches
    validation_accuracy /= size

    return validation_accuracy, validation_loss
InĀ [7]:
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

epochs = 10
history = {"accuracy": [], "loss": [], "val_accuracy": [], "val_loss": []}

for epoch in range(epochs):
    start_time = datetime.now()
    accuracy, loss = train(dataloader_train, model, loss_function, optimizer)
    val_accuracy, val_loss = validation(dataloader_validation, model, loss_function)
    stop_time = datetime.now()

    history["accuracy"].append(accuracy)
    history["loss"].append(loss)
    history["val_accuracy"].append(val_accuracy)
    history["val_loss"].append(val_loss)

    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"\telapsed time: {(stop_time - start_time).total_seconds():.3f}s - accuracy: {accuracy:.4f} - loss: {loss:.4f} - val_accuracy: {val_accuracy:.4f} - val_loss: {val_loss:.4f}")
Epoch 1/10
	elapsed time: 7.771s - accuracy: 0.8376 - loss: 0.4534 - val_accuracy: 0.8646 - val_loss: 0.3788
Epoch 2/10
	elapsed time: 7.629s - accuracy: 0.8911 - loss: 0.3037 - val_accuracy: 0.8892 - val_loss: 0.3256
Epoch 3/10
	elapsed time: 7.449s - accuracy: 0.9057 - loss: 0.2633 - val_accuracy: 0.8926 - val_loss: 0.3043
Epoch 4/10
	elapsed time: 7.832s - accuracy: 0.9142 - loss: 0.2377 - val_accuracy: 0.8984 - val_loss: 0.2900
Epoch 5/10
	elapsed time: 7.638s - accuracy: 0.9212 - loss: 0.2179 - val_accuracy: 0.9006 - val_loss: 0.2847
Epoch 6/10
	elapsed time: 7.648s - accuracy: 0.9271 - loss: 0.2014 - val_accuracy: 0.9026 - val_loss: 0.2827
Epoch 7/10
	elapsed time: 7.444s - accuracy: 0.9323 - loss: 0.1874 - val_accuracy: 0.9022 - val_loss: 0.2806
Epoch 8/10
	elapsed time: 7.475s - accuracy: 0.9372 - loss: 0.1745 - val_accuracy: 0.9024 - val_loss: 0.2833
Epoch 9/10
	elapsed time: 7.715s - accuracy: 0.9407 - loss: 0.1638 - val_accuracy: 0.9032 - val_loss: 0.2890
Epoch 10/10
	elapsed time: 7.586s - accuracy: 0.9443 - loss: 0.1538 - val_accuracy: 0.9064 - val_loss: 0.2990
InĀ [8]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

ax1.plot(history['accuracy'])
ax1.plot(history['val_accuracy'])
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Accuracy")
ax1.legend(["Training", "Validation"])

ax2.plot(history['loss'])
ax2.plot(history['val_loss'])
ax2.set_xlabel("Epochs")
ax2.set_ylabel("Loss")
ax2.legend(["Training", "Validation"])

plt.show()
No description has been provided for this image

Evaluate the CNN¶

InĀ [9]:
indexes = np.random.choice(range(0, len(dataset_test)), size=16, replace=False)
images = [dataset_test[index][0] for index in indexes]

fig, axs = plt.subplots(4, 4, figsize=(8, 8))
fig.suptitle('Random samples')

with torch.no_grad():
    for image, ax in zip(images, axs.flatten()):
        X = torch.unsqueeze(image, 0).to(device)
        prediction_logits = model(X)
        ax.imshow(image[0], cmap="gray")
        ax.set_title("Prediction: " + classes[prediction_logits.argmax(1).item()])
        ax.axis("off")

plt.tight_layout()
plt.show()
No description has been provided for this image
InĀ [10]:
with torch.no_grad():
    for X_test, y_test in dataloader_test:
        X_test = X_test.to(device)
        y_pred = model(X_test).argmax(1).cpu().numpy()

print(classification_report(y_test, y_pred, digits=4))

ConfusionMatrixDisplay.from_predictions(y_test, y_pred)
plt.grid(False)
plt.show()
              precision    recall  f1-score   support

           0     0.8655    0.8690    0.8672       496
           1     0.9941    0.9845    0.9893       517
           2     0.8966    0.7924    0.8413       525
           3     0.9094    0.9347    0.9219       505
           4     0.7828    0.8773    0.8273       497
           5     0.9918    0.9737    0.9826       494
           6     0.7510    0.7312    0.7410       491
           7     0.9433    0.9746    0.9587       512
           8     0.9765    0.9807    0.9786       467
           9     0.9693    0.9556    0.9624       496

    accuracy                         0.9070      5000
   macro avg     0.9080    0.9074    0.9070      5000
weighted avg     0.9081    0.9070    0.9069      5000

No description has been provided for this image