Deep learning classifiers have a known tendency to misclassify with high confidence. For this reason, detecting samples at test time whose class is not in the training set is a crucial step for model deployment. This can be considered as a binary classification problem, which is often tackled with metric learning: after training a distance in which samples from known classes are lumped together, a threshold is chosen which sets the boundary for samples from unseen classes. But this idea relies on two assumptions: that classes are indeed lumped into convex (and simply connected) regions, and that these regions are far apart from each other.
Because both are typically violated in practice [Din22V], presented this week at UAI2022, proposes a method which tries to enforce them. In the new metric space, samples from the same class follow distinct, and well separated, isotropic Gaussian distributions, as seen in the figure.
Crucially, this does not require OOD data to train, yet achieves SOTA even when compared to some techniques which do (however, the authors do not compare with most methods using OOD data for training since this is a different regime).
This is achieved with a composite loss including: a triplet loss to bring similar samples together and different ones apart, a KL term pursuing class-wise Gaussianity in the representation space, and a distancing term to improve inter-class separation beyond that provided by the triplet loss.
After training, a forward pass through all the data is done and the activations at 3 different layers of the network plus the final ones are collected for all samples in each class. Then one anisotropic Gaussian per class is fitted to these data. At test time, the negative log likelihood (NLL) of a sample is computed wrt. each Gaussian and a threshold is used to determine whether a sample is OOD or not. Training an additional VAE and using the distances from its intermediate representations to those of the original network to weight the likelihoods further decreases the false positive rate by reducing in-class NLL.
This is a promising method for applications in which OOD data is not readily available during model development.