Image classification with Tensorflow transfer learning
To deepen my understanding of neural networks, I created my first image classifier using Tensorflow , which is an open source ML framework with several tools and datasets that can help you train models.
The goal of this exercise was to train a model to classify an image into one of the categories from the oxford flowers dataset which contains sample images of 102 flower species that are commonly occurring in the UK.
This blog will provide a summary of the steps taken to create my first neural network and highlight some of the new concepts I learnt along the way.
Firstly, to load the dataset, I had to import tensorflow_datasets and then instead of creating my own test, train and validation splits, I opted to use the splits that already exist in the tensorflow data as follows:
I loaded the dataset_info by adding with_info =True above, so that I could easily access information about the dataset through out the process as shown below;
The dataset information shows us the number of samples in the test, train and validation splits as well as the num_classes which corresponds to the number of outputs we will need to retrieve from the output layer of the neural network.
To access a some images from any of the splits, we use the take() method as shown below:
The dataset is ready for training a model as it doesn’t have any missing values, so I moved on swiftly to the next step which was selecting a model for the classification task and building a pipeline for the model training.
One of the benefits of using Tensorflow is that you can save a model and reuse it as a starting point in building a model for similar tasks, a practice commonly known as transfer learning.
As paraphrased from the Tensorflow site,
“The intuition behind transfer learning for image classification is that if a model is trained on a large and general enough dataset, this model will effectively serve as a generic model of the visual world. You can then take advantage of these learned feature maps without having to start from scratch by training a large model on a large dataset.”
The Tensorflow hub has a variety of pre-trained models that have already been designed to maximize accuracy whilst also being efficient to run, such as the MobileNet model which I opted to use for this exercise.
I loaded the model as follows, by specifying the input_shape with an image size of 224 which is required by the MobileNet model. The feature extraction step is telling the model to take in an input and use previously learned representations of the visual world to extract meaningful features from the sample, and trainable is set to False because I don’t want the model to update the weights and biases that were previously learned from more superior training exercises.
The last layer in the model object is the output layer which specifies the number of categories that the model should output which in our case matches the 102 flower species in the dataset.
The next step is to compile the model by specifying an optimizer which is used to improve speed and performance while training a model , a loss function which is how the model computes the deviation between true labels and predicted labels as well as which metric the model should maximize. The definitions for all the options are available on the Tensorflow sites linked throughout this article.
To start training, I call model.fit() with the training and validation batches, a number of epochs which refer to training iterations, as well as callbacks which signal to the model when it should stop training. In this example, I created a parameter for early stopping, which tells the model to monitor the val_loss and stop training when the val_loss increases for the 5th time, which is the value I assigned to patience. The val_loss which is a measure of how much the model is penalized for inaccurate predictions using the validation sets.
After training the model, I plotted the loss and accuracy values as shown below:
I then tested the model on the test set to see if it can generalize well on unseen data using model.evaluate()
The model can be saved and retrieved for performing inference on new data as follows:
Prediction function
Sanity check with one example image