Neural networks are very “data-hungry” algorithms: in the quest to reach ever higher accuracies, engineers around the world must use exponentially larger amounts of data. For example, a 1% drop in error for a large language model can require 10 times more data (keeping model size constant). This type of relationship, where for each incremental increase in performance one needs to provide exponentially more resources, is known as power law scaling, and could seriously put a strain on our ambitions to develop ever larger and smarter models.
In the effort to mitigate this issue, over the past years researchers have found ways to efficiently prune their datasets to retain only the most valuable data and reduce training cost and time. However, much remains to be understood and so far data pruning has not seen widespread adoption.
One of the outstanding papers of NeurIPS 2022, [Sor22N], makes a substantial step forward in both the theoretical understanding and practical benchmarking of data pruning techniques.
In the first part, the paper uses a student-teacher model, a classical method from statistical mechanics, to study the theoretical scaling of a neural network’s dependency on data. After a few epochs of training with the whole dataset, data points are sorted by a “difficulty metric” (e.g. by the difference between real value and prediction), then training is resumed with only a fraction of the total data, sometimes keeping the ones with the lowest difficulty, sometimes the highest. In the second part of the paper, a similar approach is also pursued with ResNet both on CIFAR-10 and ImageNet datasets, and with different difficulty metrics.
The first surprising result is that the optimal pruning strategy changes depending on the total amount of data (see Figure 1). If the initial training set is big enough, it is more convenient to retain only hard samples, but if the dataset is small, keeping the easiest ones works best! The intuitive reason is that when overfitting is less of an issue, it is better to keep hard samples, since there is plenty of information on the simple (and typically most common) cases, whereas if data is scarce, one has no hope of learning the outliers, and it is better to just focus on the basic cases.
The second important result is that it is possible to break the power law scaling. Figure 2 shows how the accuracy of the model drops as a function of the total training samples (per parameter, but the model architecture is the same in each plot). The test error drops increasingly more rapidly as pruning increases. This means that efficient data pruning allows to have exponentially less training data for a fixed test error (compare the Pareto frontier with the 100% line).
Nevertheless, the authors also stress that the choice of the difficulty metric is of key importance. Figure 3 shows how the accuracy on ImageNet drops using different metrics. Notice that some metrics perform worse than random pruning! Overall, the authors state that the best metric in the experiments has been memorization, which we talked about in a recent paper pill.
What I find most interesting about this paper is that a relatively simple theoretical model has been able to give us so much insight into important, real world datasets. This could lead to massive savings both in time and energy during the training of large NN models, but before this sees widespread adoption a few practical issues need to be addressed, namely finding a stable difficulty metric and speeding up its calculation.