Faster AutoAugment: Learning Augmentation Strategies using Backpropagation

Automatic augmentation is nowadays standard, but can be very slow. A new search strategy greatly speeds up the selection of augmentation policies.

If you work with image classifiers, chances are that you have augmented your data. Hand-crafted pipelines were the norm for a while until automated search over augmentation choices took the field by storm around 2018 (if you had the computational resources for it). [Cub19A], which is perhaps amongst the most popular ones, discretises the whole search space of image transformations, including their parameters, and uses policy gradient to select “augmentation policies” based on the validation score of the target neural network over an image batch.

Alas, policy gradient (and other black-box approaches like Bayesian optimization) can be extremely costly. Inspired by ideas from differentiable neural architecture search, [Hat20F] offers an alternative which can be significantly faster while achieving similar performance

Replacing search and sampling of categorical variables by differentiable operations during training, falling back to discrete operations at inference time. They achieve this by modelling the search as a differentiable process. First, sampling of a Bernoulli variable for deciding whether to apply a transformation is replaced by sampling from a relaxed continuous distribution, building on [Jan17C]. Second, they provide approximations to the gradients of several non-differentiable image transformations and re-implement differentiable transformations in pytorch (as opposed to using PIL like AutoAugment does). While during training the sampling of discrete quantities is replaced by these differentiable operations, at inference time the method falls back to discrete operations.

Finally, they do not directly optimize final model performance and instead minimise the distance of the augmented images to the original distribution to encourage “filling the holes in the dataset”. A classification loss is added as well to avoid flipping labels.

There are of course lots of details under the hood, like Wasserstein GANs, network architecture choices and DARTS. Follow the links in the bibliography to read the full paper or experiment with the code.