Alpaca recognitionĀ¶
Objective: Use transfer learning with a pre-trained MobileNetV2 to build an Alpaca/Not Alpaca classifier.
Import librariesĀ¶
InĀ [1]:
from keras import utils, applications, Sequential, layers, Input, Model
import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np
2024-06-25 13:47:20.681957: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Download the datasetĀ¶
InĀ [2]:
%%bash
mkdir data
gdown -q 13JG3cmbMT2bzCxP4mqs_2KwsqPRYyfpc -O ./data/
unzip -q ./data/Alpaca_NotAlpaca.zip -d ./data
rm ./data/Alpaca_NotAlpaca.zip
Load the datasetĀ¶
InĀ [3]:
train_ds, validation_ds = utils.image_dataset_from_directory(directory="./data/Alpaca_NotAlpaca",
batch_size=32,
image_size=(160, 160),
seed=0,
validation_split=0.2,
subset="both",
label_mode='binary')
classes = train_ds.class_names
Found 327 files belonging to 2 classes. Using 262 files for training. Using 65 files for validation.
2024-06-25 13:47:42.401838: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1985 MB memory: -> device: 0, name: NVIDIA GeForce GTX 1650, pci bus id: 0000:01:00.0, compute capability: 7.5
Visualize the datasetĀ¶
InĀ [4]:
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
for images, labels in train_ds.take(1):
for i, ax in enumerate(axs.flat):
ax.imshow(images[i] / 255)
ax.axis("off")
ax.set_title(classes[int(labels[i].numpy().squeeze())])
plt.tight_layout()
plt.show()
2024-06-25 13:47:43.914449: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
InĀ [5]:
alpaca = len([file for file in os.listdir("./data/Alpaca_NotAlpaca/alpaca") if ".jpg" in file])
not_alpaca = len([file for file in os.listdir("./data/Alpaca_NotAlpaca/not alpaca") if ".jpg" in file])
sizes = [alpaca, not_alpaca]
fig, ax = plt.subplots()
ax.pie(sizes, textprops={'color': "w", 'fontsize': '12'}, autopct=lambda pct: "{:.2f}%\n({:d})".format(pct, round(pct/100 * sum(sizes))))
ax.legend(classes)
plt.title("Class distribution")
plt.show()
Instantiate the MobileNetV2 architectureĀ¶
InĀ [6]:
MobileNetV2 = applications.MobileNetV2(include_top=False, input_shape=(160, 160, 3))
MobileNetV2.trainable = False
Create the modelĀ¶
InĀ [7]:
data_augmentation = Sequential([layers.RandomFlip(mode="horizontal"),
layers.RandomTranslation(height_factor=0.2, width_factor=0.2, fill_mode="nearest"),
layers.RandomRotation(factor=0.2, fill_mode="nearest"),
layers.RandomZoom(height_factor=0.2, width_factor=0.2, fill_mode="nearest")])
inputs = Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = applications.mobilenet_v2.preprocess_input(x)
x = MobileNetV2(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = Model(inputs, outputs)
model.summary()
Model: "functional_2"
āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā³āāāāāāāāāāāāāāāāāāāāāāāāā³āāāāāāāāāāāāāāāā ā Layer (type) ā Output Shape ā Param # ā ā”āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā© ā input_layer_1 (InputLayer) ā (None, 160, 160, 3) ā 0 ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāā¤ ā sequential (Sequential) ā (None, 160, 160, 3) ā 0 ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāā¤ ā true_divide (TrueDivide) ā (None, 160, 160, 3) ā 0 ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāā¤ ā subtract (Subtract) ā (None, 160, 160, 3) ā 0 ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāā¤ ā mobilenetv2_1.00_160 ā (None, 5, 5, 1280) ā 2,257,984 ā ā (Functional) ā ā ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāā¤ ā global_average_pooling2d ā (None, 1280) ā 0 ā ā (GlobalAveragePooling2D) ā ā ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāāāāāāāāāāā¼āāāāāāāāāāāāāāāā¤ ā dense (Dense) ā (None, 1) ā 1,281 ā āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā“āāāāāāāāāāāāāāāāāāāāāāāāā“āāāāāāāāāāāāāāāā
Total params: 2,259,265 (8.62 MB)
Trainable params: 1,281 (5.00 KB)
Non-trainable params: 2,257,984 (8.61 MB)
Compile the modelĀ¶
InĀ [8]:
model.compile(optimizer="adam", loss='binary_crossentropy', metrics=['accuracy'])
Fit the modelĀ¶
InĀ [9]:
history1 = model.fit(train_ds, epochs=200, verbose=0, validation_data=validation_ds)
2024-06-25 13:47:51.917566: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907
InĀ [10]:
sns.set_style("whitegrid")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(history1.history['accuracy'])
ax1.plot(history1.history['val_accuracy'])
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Accuracy")
ax1.legend(["Training", "Validation"])
ax2.plot(history1.history['loss'])
ax2.plot(history1.history['val_loss'])
ax2.set_xlabel("Epochs")
ax2.set_ylabel("Loss")
ax2.legend(["Training", "Validation"])
plt.show()
Evaluate the modelĀ¶
InĀ [11]:
data_augmentation = Sequential([layers.RandomFlip(mode="horizontal"),
layers.RandomTranslation(height_factor=0.2, width_factor=0.2, fill_mode="nearest"),
layers.RandomRotation(factor=0.2, fill_mode="nearest"),
layers.RandomZoom(height_factor=0.2, width_factor=0.2, fill_mode="nearest")])
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
for images, labels in validation_ds.take(np.random.randint(low=1, high=validation_ds.cardinality().numpy())).map(lambda x, y: (data_augmentation(x), y)):
for i, ax in enumerate(axs.flat):
prediction_proba = model.predict(np.expand_dims(images[i], axis=0), verbose=0)
ax.imshow(images[i] / 255)
ax.axis("off")
ax.set_title("Prediction: " + classes[int(prediction_proba.squeeze().round())])
plt.tight_layout()
plt.show()
2024-06-25 13:50:05.512182: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
InĀ [12]:
!rm -rf ./data