ARTIFICIAL INTELLIGENCE

A concrete problem linked to batch normalization on reduced data sets

6 min

Embedding images and executable files

Embedding images into a lower-dimensional representation is a blooming research field in Deep Learning. With a small vector representation of each image, many new tasks can be easily done afterward such as zero-shot learning and image similarity.

For example, in the image similarity task, akin images will have a small distance between their embeddings of those that are similar are close, so a simple calculation of the distances between embeddings to find images similar to a specific one.


In this figure, we assume we have a set of images on which we ran the neural network to compute their embeddings. Then, when we have a new image and want to retrieve the most similar one in our set, we simply run the neural network on it to get its embedding and compute the distance between this vector representation and the embeddings previously computed!

We recently did a lot of research on embedding executable files with Convolutional Neural Networks (CNN) to identify similarities between executable files. During this work, we stumbled upon surprising gaps between predictions in train and inference.

To start off our research, we simplified as much as possible the task at hand and started with very few files of three different families, with a total of 5000 files. With such a small dataset, the training was done very quickly.

Abnormal predictions when testing the model

The predictions made on the test dataset in inference were, unsurprisingly, erroneous. This was expected because the model overfitted on the training dataset because of the small amount of data. When overfitting, the model has learned to identify the noise specific to the training dataset which makes it unable to generalize well on unseen data.

However, when running the inference on the files that have been viewed during the training, we also got strongly inaccurate predictions. Clearly, the predictions made by the model were very different when training the model and running the inference, on exactly the same files!

This is typically a sign that there is a difference in how the data are handled and preprocessed. If the data given to the model in training and in inference is handled differently, the model will output different predictions. However, we were sure the preprocessing pipeline was exactly the same.

After further investigating, we realized that batch normalization layers in CNN have different behavior in train and in inference in frameworks commonly used for deep learning (TensorFlow and Pytorch). This raises the following questions: Why is batch normalization behaving differently in train and in inference? Why is it used? How can we solve our issue?

Dive into Batch Normalization

What is Batch Normalization?

Batch Normalization is often used to improve the performance and stability of neural networks, and it’s a great asset to have in your toolbox. Let’s take a closer look at what batch normalization is, how it works, and why it’s so useful.

So, what is batch normalization? In short, it’s a technique that normalizes the input of each layer in a neural network. This is done by subtracting the mean and dividing by the standard deviation of the input (which is the output of the previous convolutional layer) for each batch.

One of the biggest benefits of batch normalization is that it enables the use of much higher learning rates during training and makes it possible to use more complex architectures. This is because normalizing the inputs prevents them from becoming too large or small. This directly helps to prevent exploding and vanishing gradient issues, often faced with high learning rates and complex architectures, as detailed by the original paper.

Another benefit of batch normalization is that it can make deep learning models more robust to variations in the data, which can lead to better generalization performance. Finally, batch normalization can also help reduce the number of parameters in the network, which can further suppress overfitting.

Overall, batch normalization is a great technique to have in your toolbox that can improve the performance and stability of neural networks, and is widely used in deep learning. While it has some drawbacks (training process less interpretable, computational overhead …), the benefits often outweigh the drawbacks and it’s definitely worth giving it a try.

Why does Batch Normalization behave differently in train and inference?

Inference is most of the time done on a single element and not a batch of data. In this context, computing the mean and variance doesn’t make any sense. To solve this issue, batch normalization has a distinct behavior from its training one. The mean and variance computed across the batches in training are saved and used in inference to infer a mean and variance when a single element is given. They are updated following a moving window with a predefined momentum (between 0 and 1) :

In inference, the moving mean and moving var are fixed once the training is done and are used to normalize the output of the layers.

Why did our issue arise?

So, why did our model, after quickly overfitting on very few samples, make completely different predictions on the same samples?

Since our dataset is very small and the model overfits very quickly, the total number of batches ran was very small. Hence, the moving mean and moving variance have not fully converged at the end of the training and are still very close to their initializing values (0 for the mean and 1 for the variance). This completely distorts the model’s predictions.

Conclusion

The easiest solution is to simply remove batch normalization. In our case, on the small dataset, this allows us to quickly iterate and improve our solution. Once the first prototype of the neural architecture and methodology is validated, we can scale it up and apply it to a very large dataset. In this case, many batches will be run and the batch normalization has enough iterations to converge, so we can add it back again. It shows its efficiency in this scenario:

Accuracy scores during training epochs. Classification is done on embedded images. Blue with batch normalization and pink without

As in many other applications, batch normalization displayed its efficiency in our use case as well. Hopefully, this article has helped you gain a better understanding of this Deep Learning method!

To delve deeper into batch normalization, we consider experimenting with a new batch normalization method, the ‘average’ batch normalization, instead of the classical ‘moving average’ method discussed here. Charles Gaillard and Rémy Brossard discuss this new method in their article. We’ll keep you posted on the results!