Modeling Long Sequences with Structured State Spaces

Building on ideas of using state-spaces and linear ODEs for sequence modeling in optimal control, a new layer for sequence processing is introduced and optimized in a line of papers. The theory behind this layer has overlaps with LSTMs and convolutional layers, in a sense containing them as special cases. It is demonstrated that the new layers can handle very long-range dependencies in sequences with over 10.000 tokens, while being much faster than transformers.

In this paper pill we will describe a line of work by Albert Gu et al. about combining ideas of state-space dynamics (having origins in the classical theory of optimal control) and modelling sequences with deep learning à la LSTM or CNNs. The most recent development of these ideas resulted in a versatile algorithm highlighted in [Gu22E] which received an honorable mention at ICLR 2022.

The essence of the idea is the following: model a sequence-to-sequence task that maps $u_i$ to $y_i$ as a stack of discretizations of time-invariant linear state space processes with $u(t)$ as control function and $y(t)$ as output (the $u_i/y_i$ being regarded as $u(t)/y(t)$ evaluated at different timesteps). The state variable $x(t)$ takes a role similar to the hidden state in RNN architectures. Thus, the fundamental model of a single layer is described by the ODE $\dot{x}(t) = A x(t) + B u(t) \ , \ y(t) = C x(t)$, with learnable parameters being the matrices $A$, $B$ and $C$. In [Gu22C] it was shown that the state-space formulation can be viewed as an RNN as well as a convolutional architecture, and also, interestingly, that standard approximations used for non-linear ODEs are equivalent to stacking of recurrent layers and gating, as used in deep LSTMs. Thus, models made out of (infinitely) stacked LSSL (linear state space layers) are perfect function approximators. Note that the details of the proofs are omitted in the version submitted to NeurIPS, but they can be found in the arXiv version.

Naively discretizing this ODE and unrolling it over input-output sequences immediately results in two problems:

  1. the usual exploding/vanishing gradients problem from RNN is carried over to learning A and
  2. the final expression for the output series contains many expensive matrix multiplications

The first problem was addressed by developing a particular way of storing memory through the projection of functions to a low-dimensional space and computing the time evolution of these projections explicitly (instead of first computing the full evolution and recomputing projections). This procedure was dubbed HiPPO (see [Gu20H]) and served as inspiration for much of the follow-up work. In particular, it gave rise to a class of matrices for A that can be learned without running into vanishing gradients.

The second problem was not only a matter of computational efficiency but also of numerical stability. These two issues were addressed in the papers mentioned above, resulting in a stable and fast algorithm as presented at ICLR22. The main ideas underlying the solutions are rather technical, but in the end a suitable class of matrices A is found which can be learned through gradient descent and on which the relevant calculations can be performed efficiently with a bunch of mathematical and numerical tricks.

This final algorithm, apart from being an interesting new point of view for sequence modelling, is capable of capturing very long dependencies without relying on the transformer architecture (and being much faster than transformers in these cases). It improved SotA on multiple benchmarks, can handle missing data (because of the relations to ODE) and generalizes to continuous sequences. The exact results can be found in the experiments sections of the papers mentioned in this pill.

All algorithms from this line of work have been implemented in PyTorch by the authors in the open source state-spaces library. It seems to be well maintained and documented, we expect further extensions of the state-space approach to appear there in the future. Moreover, there is an annotated version of the paper explaining the ideas and implementation in detail (by the author of the famous annotated transformer blog). State-space modelling of sequences should be worth a look for everybody working with sequential data, especially if the sequences are long.