Symmetry Teleportation for Accelerated Optimization

A novel approach, symmetry teleportation, enhances convergence speed in gradient-based optimization by allowing parameters to traverse large distances on the loss level set by exploiting symmetries in the loss landscape.

Symmetry teleportation is a novel optimization technique that leverages symmetries in the loss landscape to accelerate convergence in gradient-based optimization algorithms. By allowing parameters to traverse large distances on the loss level set, symmetry teleportation enhances the convergence speed of optimization algorithms such as stochastic gradient descent (SGD) and AdaGrad. More precisely, if $G$ is a symmetry group of the loss function, which means

\begin{equation*} \mathcal{L} (g \cdot (w, X)) = \mathcal{L} (w, X), \quad \forall g \in G \end{equation*}

the teleportation operation allows the optimization algorithm to move along the orbit of the current parameter configuration $w_t$ under the action of $G$. Usually, teleportation is used to maximize the gradient norm within the orbit, i.e.

\begin{align*} g_t &\leftarrow \operatorname{argmax}_{g \in G} \| (\nabla \mathcal{L})|_{g \cdot w_t} \|^2, \\ w_t &\leftarrow g_t \cdot w_t \end{align*}

A simple example of a symmetry group is the positive scale invariance of fully connected ReLU networks. In this case, one can scale the incoming weights of a hidden layer by a positive factor and the outgoing weights by the inverse of that factor without changing the network output.

Figure 1: , Illustration of a teleport operation along the orbit of a symmetry group action. While the loss remains unchanged, the gradient norm of the loss function varies along the orbit. See Figure 1 of [Zha24I].

# Convergence Analysis

The key to understand the improved convergence, is the observation that for a quadratic loss function it holds that if $w$ is an element with maximal gradient norm in a level set of the loss function, the gradient becomes an eigenvector of the Hessian. Hence, if the symmetry group $G$ acts transitively on the level set, teleportation combined with a first-order optimization step is equivalent to a Newton step.

Moreover, in case of a quadratic function, the group $O(n)$, where $n$ is the number of network parameters, is a symmetry group of the loss function that acts transitively on the level sets. Since close to a minimum the loss function can be well approximated by a quadratic function, teleportation can be seen as a way to approximate a Newton step in the presence of symmetries. For more details on this, see section 5 of [Zha23S].

# Influence on Generalization

A potentially harmful effect of teleportation is that it can lead to early overfitting in non-convex settings. By teleporting the parameters to a point with a high gradient norm, the optimization algorithm might move the optimization trajectory to a sharper minimum. Since flat minima tend to generalize better, this can result in a decrease in generalization performance. On the other hand, using teleportation to minimize the gradient norm can lead to a flatter minimum and improve generalization, which is further studied in the follow-up paper [Zha24I].

# Experiment

The authors conduct a series of experiments to showcase the improved convergence properties of teleportation. For this, they use a multilayer feed-forward network with three layers and a regression loss. The corresponding symmetries used for teleportation are described in Prop. 4.3 [Zha23S].

The results for comparing the training loss and gradient norms of SGD and AdaGrad with and without teleportation are shown in Figure 2. One can observe a faster decay of the training loss for both algorithms, when adding teleportation. In addition, for the same loss value, the teleportation variants generate gradients with larger norms, explaining the improved convergence.

In a second step, the authors investigate the behavior of teleportation on the MNIST classification problem, using a three layer feed-forward neural network with LeakyReLU activation functions and cross-entropy loss. Contrary to the first experiment, the authors focus on the evolution of the validation loss, see Figure 3. While they observe an improved convergence for the training loss, this comes at the cost of a very early overfitting and a slightly larger validation loss.

# Discussion

Although the improvements in convergence might seem impressive, the slightly worse generalization behavior potentially outweighs the benefit of the speed-up. With this in mind, the authors invest further analysis into improving generalization with teleportation in the follow-up paper [Zha24I].