Essential TensorFlow and Keras Callbacks for your Neural Networks
In this article, we briefly go over the motivations for using callbacks in your Keras code, next we look at how to define and use some important built-in callbacks in your Tensorflow and Keras model training and finally we look at a simple example that demonstrates the benefits of using callbacks.
Some advantages of using these CallBacks:
- These callbacks can help you fight overfitting,
- Help you save frustration by saving your best models seen so far before collab or Kaggle gives up on training (trust me I have been there!!)
- Takes away the need to explicitly define the number of epochs for your training
Table of Contents
- Overview of the need for Callbacks.
- Important Keras Callbacks for your Deep Learning Project (with code).
- A demonstration of these callbacks with a simple example.
1. Overview of the need for Callbacks.
Training a Neural Network is a tricky business. There are a lot of things that one needs to take care of when training a neural network.
- Getting quality data and cleaning the data.
- Converting the data to a representation that a Deep Learning Model can consume (example: represent text data as Word Vectors for text-related tasks).
- Deciding the right type of neural network based on the task at hand (example: Transformers or RNN for NLP tasks, CNN for Image Classification)
- Deciding the number of neurons and Layers.
- Defining the right metrics and loss functions for effective learning on tasks and other details.
- Train the network for the right epochs to avoid overfitting, save the best model and look at the metrics you need.
However these days there are a lot of pre-trained models which are open-sourced and are present online for ready use on platforms like TFHub, Hugging Face, ModelZoo, Github, etc. These models can be readily used and fine-tuned for your specific tasks (example: using a resnet for your image classification task).
No matter if you are using a pre-trained network or building a Neural Architecture from scratch you need to save your best model, avoid them from overfitting your data and look at the metrics. Now we look at the important callbacks that help us in this direction.
2. Important Keras Callbacks for your Deep Learning Project
NOTE: all these callbacks are passed to your model.fit method. This step will be elaborated shortly.
a. ModelCheckpoint:
Keras comes with a callback named ModelCheckpoint which saves your best model at regular intervals based on the criteria you mention. You can use this callback to save your best model seen so far. I personally am benefitted by this callback whenever I run my model training on Colab or Kaggle which can hang midway during the training epochs.
Here is how you define ModelCheckpoint:
- The first argument here is the name of the file
- monitor keyword monitors the metric you define it can be validation accuracy or validation loss or any other metric. Keras understand you want to have minimum loss/ max accuracy so you don’t have to worry. Also the metric you monitor depends on how you evaluate your model, so use the one best for your task.
- save_best_only if set to True makes sure only the best model is saved. The “best model” loosely defined is the model which has given the best score on the monitor metric you passed in the previous line to monitor.
Also, there are other parameters for fine-grained control over the saving of your model can learn more about them at the official docs
This is one of my favorite Callbacks, this callback stops training the model based on the criteria you define and also restores weight from the best epoch. It means you need not worry about your model overfitting as the callback terminates the training early and also you need not worry about the number of epochs you want to define.
- monitor parameter is the same as for the above callback.
- patience parameter is the number of epochs the model has to continue training before terminating after it has seen no improvements in the score for the monitor parameter. In our example, the training will stop after 5 epochs when there have been 5 consecutive epochs with no improvement in validation loss (i.e validation loss didn’t go any below )
- restore_best_weights makes sure the weights from the best epoch so far are restored into the model after the training has stopped.
- min_delta means the minimum improvement in the monitor score between 2 epochs and below this threshold, the change in monitor score is not considered an improvement. Example: In our definition above if val_loss changes from 0.92 to 0.72 this means an improvement in validation loss but if the val_loss changes from 0.99 to 0.92 then this is not an improvement as per our min_delta definition for monitor score improvement. Please be careful setting this hyperparameter, you can leave this undefined which puts it to 0.
Note: There are times when the model training is stuck in a local minima and may need some more epochs to come out of it or better it may need a decay or increase in the learning rate. In such cases this can terminate the training early, to avoid it you can give a higher patience number but it means you increase the compute time.
Also, there are other parameters for fine-grained control over the stopping you can learn more about them at the official docs
c. TensorBoard
This Keras Callback will help you visualize your metrics on TensorBoard. Since I started using TensorBoard I have felt no need to explicitly write code to plot my metrics, Keras TensorBoard callback does this for you. If you are not aware of what TensorBoard is in short it is a Visualization Platform that TensorFlow provides. For more details please visit TensorBoard.
And yes that is all the code you need to write to visualize your metrics in your TensorBoard.
In the next section, we will see how we are going to use these callbacks in a simple classification task and also visualize the metrics in our tensorboard.
3. A demonstration of these callbacks with a simple example.
We will build a simple Convolutional Neural Network for classifying digits using the famous MNIST dataset and demonstrate how the callbacks can be used. We will use TensorFlow datasets to load the data and use TensorFlow datasets and Keras preprocessing methods to preprocess our image. If you are new to TensorFlow datasets please refer to my previous articles on the same but you won’t need to know them if you to start using these callbacks.
We define these callbacks and pass it to the keras fit method.
Yes, as seen in line 60 above, passing your callbacks is very easy.
Here are the results from our experiment:-
We can see that the EarlyStopping Callback stopped the training after 10 epochs since the improvement between Epoch 9 and Epoch 10 is less than 0.1 our min_delta.
We can next check our saved model and load tensorboard to visualize our metrics:-
You can access the code above either on google collab or GitHub here:
Thank you for reading and I hope this article helps you get started on your journey with the use of Callbacks in Keras. These callbacks can have a significant impact on how you work and my sincere hope is it also transforms your working style.
Linkedin:- https://www.linkedin.com/in/virajdatt-kohir/
Twitter:- https://twitter.com/kvirajdatt
GitHub:- https://github.com/Virajdatt
GoodReads:- https://www.goodreads.com/user/show/114768501-virajdatt-kohir