Starting with TensorFlow Datasets -part 3; An end to end example for building a Flower Classifier

Virajdatt Kohir
5 min readJan 17, 2022

--

Photo by John-Mark Smith on Unsplash

After discussing tf.data pipelines in the part1 of the series and discussing and introducing TensorFlow datasets in part2 of the series, here in this article we will take a look at an end-to-end example of building an image classifier using the tf_flowers dataset.

In this article we will be building an image classifier using TensorFlow and Keras to classify images of flowers, we will be using:

1. TensorFlow Datasets to load the images,
2. Keras and tf.datapipeline methods to preprocess and augment the images and
3. Finally build a couple of keras models for the image classification task.

We also experiment and check how data augmentation improves the generalization ability of the deep learning model

Table of Contents

  1. Load the data using TensorFlow Datasets
  2. Take a look at the metadata of the dataset using TensorFlow Datasets
  3. Preprocess the images and visualize
  4. Augment the images and visualize
  5. Build 2 models to experiment with augmented data
  6. Results of our experiment
  7. Load the data using TensorFlow Datasets

Let us start by loading the flowers dataset, we will be splitting the data into train, validation, and test sets. As we have learned previously in part 2 that TensorFlow Datasets make the data splitting easy, we will use the API explained in the previous post to split our data (the API should be intuitive even if you have not seen the previous post). We also load metadata about the dataset for later use.

2. Take a look at the metadata of the dataset using TensorFlow Datasets

The metadata which we captured comes since we set with_info=True, whose purpose is to load the metadata about the dataset.

The metadata object so returned has quite a few interesting methods, in this article, we will be taking a look at the FeaturesDict that TensorFlow Datasets implement. We can access this object by using the features method. This method helps in the following ways for our task:-

1. It lets us take a quick peek at the raw data and the labels.
2. Gives access to the label names (string) and their numeric encodings.

The following are some important methods we will be using.

# Look at the string labels
metadata.features['label'].names
# Converts the string labels to their corresponding int encoding
metadata.features['label'].str2int(<label in string>)
# Converts the int encoding to their corresponding string label encodingmetadata.features['label'].int2str(<encoded-integre>)

Here is the code snippet that demonstrates what the above methods do:-

From this, we can conclude that we have 5 types of images on our hands.

Next, let's take a look at the shapes of the images and the number of images in each set.

So we see that we have
- 92 images for training data
- 12 images for validation data
- 12 images for testing data

We also observe that the size of images is not standard. We will take care of this in the next step which is preprocessing our data.

3. Preprocess the images and visualize them.

In this step, we are going to preprocess our images so that they are of the same size. Convolutional Neural Network needs that all the images passed to it during training/testing be of the same size. And also neural networks work best when the input data is normalized. So here we also will be rescaling our data. As we can see that the images are not in any standard size, we will use TensorFlow and Keras image preprocessing to:-

1. Resize each image.
2. Rescale each image (pixel values between 0–1)
3. Additionally we also will be batching our dataset.
4. We can also shuffle the data if required.

Finally, we have a preprocessed dataset, lets us visualize how do our images look. If you look at the code below we will be using the metadata object we looked at in the previous section to label our image.

plt.figure(figsize=(5, 5))
for images, labels in train_ds.take(1):
plt.imshow(images[0])
plt.title(metadata.features['label'].int2str(labels[0]))
plt.axis("off")

4. Augment the images and visualize

Augmentation of data in Deep Learning has time and again proved to be a powerful tool to improve the model performance and generalization. So we will be using simple data augmentation layers that Keras offers for our task.
Here for our work we are randomly horizontally flipping the images/ randomly zooming into the image.

# here we are randomly horizontally flipping the image/ randomly zooming into the imagedata_aug = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal"),tf.keras.layers.experimental.preprocessing.RandomZoom(0.5),
])

Let's take a look at a few augmentations, applied to an image from the augmentation defined above.

You can see that the image has been flipped and been zoomed in a few examples above. If you want to try more augmentation, you can find those here in the documentation and stack those up and run your experiments.

5. Build 2 models to experiment (one with the vanilla dataset and the other with an augmented dataset)

We will use a simple CNN model defined in Keras to run our experiments. We will build 2 models, one which will be trained using the vanilla dataset and the other model with the augmented dataset to see if the augmentation helps the model to generalize better.

6. Results of our experiment:-

We can see that the model that was trained with the augmented dataset performed better than the model trained on the vanilla dataset. We see that the accuracy jumps from 70% to 75%. One can improve the model further using more layers of augmentation, increasing the epochs.

We see how TensorFlow Dataset provides API for loading and managing data and along with tf.data API ’s we can accelerate our experiments and maintain clean code.

And it concludes this TensorFlow Dataset’s series. In this 3 part series, we started with tf.data api and looked at how it has methods that can help us define and chain data transformations together for our Deep Learning experiments. Next, in part 2 we looked at the TensorFlow Dataset that allows simple yet powerful API’s for downloading and extracting datasets for Deep Learning experiments. In this finale of the series we used all of the things that learned in the past 2 articles and built a deep learning pipeline to experiment on the flowers dataset. We saw how tf.data apis and tfds makes the coding part easier.

Click on the following links for the code presented:-

  1. Github

OR

2. Google-Colab

That’s it for this week, hope you guys had a good read. I will be back next week with another exciting article, until then keep learning, keep building and participate in #27DaysOfKeras.

Please clap in case the content was helpful for you. You can reach me and talk with me on the following platform in case you have any questions or (mention your questions down in the comments below).

Linkedin:- https://www.linkedin.com/in/virajdatt-kohir/
Twitter:- https://twitter.com/kvirajdatt
GitHub:- https://github.com/Virajdatt
GoodReads:- https://www.goodreads.com/user/show/114768501-virajdatt-kohir

--

--

Virajdatt Kohir
Virajdatt Kohir

Written by Virajdatt Kohir

AI in health care with research focused on Deep Learning and LLM. I also love to talk about Machine Learning Engineering. A student for life.

No responses yet