Introduction to State Space Models (SSM)

Community Article Published April 8, 2024
Une version en français est disponible sur mon blog.

Foreword

I'd like to extend my warmest thanks to Boris ALBAR, Pierre BEDU and Nicolas PREVOT for agreeing to set up a working group on the subject of SSMs and thus accompanying me in my discovery of this type of model. A special thanks to the former for taking the time to proofread this blog post.


Introduction

The States Spaces Models are traditionally used in control theory to model a dynamic system via state variables.
In the context of deep learning, when we speak of SSMs, we are referring to a subset of existing representations, namely linear invariant (or stationary) systems.
These models showed impressive performance as early as October 2021 with the paper Efficiently Modeling Long Sequences with Structured State Spaces by Albert GU et al., to the point of positioning themselves as an alternative to transformers.
In this article, we will define the basics of a deep learning SSM based on S4. Like the paper Attention is all you need by Ashish VASWANI et al. (2017) for transformers, the S4 is the foundation of a new type of neural network architecture that needs to be known, but it is not a model that is used as such in practice (other SSMs with better performance or easier to implement now being available). Released a week earlier than S4, LSSL, by the same authors, is also an important source of information on the subject. We'll take a look at the various developments arising from S4 in a future blog post. Before that, let's delve into the basics of SSM.


Definition of an SSM in deep learning

Let's use the image below to define an SSM:

image/png
Figure 1: View of a continuous, time-invariant SSM (Source: https://en.wikipedia.org/wiki/State-space_representation)

It can be seen that an SSM is based on three variables that depend on time tt :

  • x(t)Cnx(t) \in \mathbb {C}^{n} represents the nn state variables,
  • u(t)Cmu(t) \in \mathbb {C}^{m} represents the mm state inputs,
  • y(t)Cpy(t) \in \mathbb {C}^{p} represents the pp outputs,

We can also see that it's made up of four learnable matrices: A,B,C\mathbf A, \mathbf B, \mathbf C and D\mathbf D.

  • ACm×n\mathbf A \in \mathbb {C}^{m \times n} is the state matrix (controlling the latent state xx),
  • BCn×m\mathbf B \in \mathbb {C}^{n \times m} is the control matrix,
  • CCp×n\mathbf C \in \mathbb {C}^{p \times n} is the output matrix,
  • DCp×m\mathbf D \in \mathbb {C}^{p \times m} is the command matrix,

The above picture can be reduced to the following system of equations:

x(t)=Ax(t)+Bu(t)y(t)=Cx(t)+Du(t) \begin{aligned} x'(t) &= \mathbf{A}x(t) + \mathbf{B}u(t) \\ y(t) &= \mathbf{C}x(t) + \mathbf{D}u(t) \end{aligned}

Note: here we use the notation xx' to designate the derivative of xx. It's not out of the question to encounter the notation x˙ in the literature instead.

Similarly, since it is implicit that the variables depend on time, the preceding equation is generally written in the following form for the sake of simplicity:

x=Ax+Buy=Cx+Du \begin{aligned} x' &= \mathbf{A}x + \mathbf{B}u \\ y &= \mathbf{C}x + \mathbf{D}u \end{aligned}

This system can be made even lighter, because in deep learning SSMs, Du=0\mathbf{D}u = 0 is seen as an easily computable skip connection.

x=Ax+Buy=Cx \begin{aligned} x' &= \mathbf{A}x + \mathbf{B}u \\ y &= \mathbf{C}x \end{aligned}

This system is continuous. It must therefore first be discretized before it can be supplied to a computer.


Discretization

Discretization is one of, if not the most important point in SSM. All the efficiency of this architecture lies in this step, since it enables us to pass from the continuous view of the SSM to its two other views: the recursive view and the convolutive view.
If there's one thing to remember from this article, it's this.

image
Figure 2: Image from blog post « Structured State Spaces: Combining Continuous-Time, Recurrent, and Convolutional Models » by Albert GU et al. (2022)

We'll see in later articles that there are several possible discretizations. This is one of the main differences between the various existing SSM architectures.
For this first article, let's apply the "original" discretization proposed in S4 to illustrate the two additional views of an SSM.


Recursive view of an SSM

To discretize the continuous case, let's use the trapezoid method where the principle is to assimilate the region under the representative curve of a function ff defined on a segment [tn,tn+1][t_n , t_{n+1}] to a trapezoid and calculate its area TT : T=(tn+1tn)f(tn)+f(tn+1)2T=(t_{n+1} - t_n){\frac {f(t_n)+f(t_{n+1})}{2}}.

We then have: xn+1xn=12Δ(f(tn)+f(tn+1))x_{n+1} - x_n = \frac{1}{2}\Delta(f(t_n) + f(t_{n+1})) with Δ=tn+1tn\Delta = t_{n+1} - t_n.
If xn=Axn+Bunx'_n = \mathbf{A}x_n + \mathbf{B} u_n (first line of the SSM equation), corresponds to ff, so:

xn+1=xn+Δ2(Axn+Bun+Axn+1+Bun+1)xn+1Δ2Axn+1=xn+Δ2Axn+Δ2B(un+1+un)()(IΔ2A)xn+1=(I+Δ2A)xn+ΔBun+1xn+1=(IΔ2A)1(I+Δ2A)xn+(IΔ2A)1ΔBun+1 \begin{aligned} x_{n+1} & = x_n + \frac{\Delta}{2} (\mathbf{A}x_n + \mathbf{B} u_n + \mathbf{A}x_{n+1} + \mathbf{B} u_{n+1}) \\ \Longleftrightarrow x_{n+1} - \frac{\Delta}{2}\mathbf{A}x_{n+1} & = x_n + \frac{\Delta}{2}\mathbf{A}x_{n} + \frac{\Delta}{2}\mathbf{B}(u_{n+1} + u_n) \\ (*) \Longleftrightarrow (\mathbf{I} - \frac{\Delta}{2} \mathbf{A}) x_{n+1} & = (\mathbf{I} + \frac{\Delta}{2} \mathbf{A}) x_{n} + \Delta \mathbf{B} u_{n+1}\\ \Longleftrightarrow x_{n+1} & = (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1} (\mathbf{I} + \frac{\Delta}{2} \mathbf{A}) x_n + (\mathbf{I} - \frac{\Delta}{2} \mathbf{A})^{-1} \Delta \mathbf{B} u_{n+1} \end{aligned}

(*) un+1unu_{n+1} \overset{\Delta}{\simeq} u_n (the control vector is assumed to be constant over a small Δ\Delta).

We've just obtained our discretized SSM!
To make this completely explicit, let's pose :

Aˉ=(IΔ2A)1(I+Δ2A)Bˉ=(IΔ2A)1ΔBCˉ=C \begin{aligned} \mathbf{\bar{A}} &= (\mathbf {I} - \frac{\Delta}{2} \mathbf{A})^{-1}(\mathbf {I} + \frac{\Delta}{2} \mathbf{A}) \\ \mathbf {\bar{B}} &= (\mathbf{I} - \frac{\Delta}{2} \mathbf {A})^{-1} \Delta \mathbf{B} \\ \mathbf {\bar{C}} &= \mathbf{C}\\ \end{aligned}

We then have

xk=Aˉxk1+Bˉukyk=Cˉxk \begin{aligned} x_k &= \mathbf{\bar{A}}x_{k-1} + \mathbf{\bar{B}}u_k \\ y_k &= \mathbf{\bar{C}}x_k \end{aligned}

The notation of matrices with a bar was introduced in S4 to designate matrices in the discrete case and has since become a convention in the field of SSM applied to deep learning.


Convolutive view of an SSM

This recurrence can be written as a convolution. To do this, simply iterate the equations of the system

xk=Aˉxk1+Bˉukyk=Cˉxk \begin{aligned} x_k &= \mathbf{\bar{A}}x_{k-1} + \mathbf{\bar{B}}u_k \\ y_k &= \mathbf{\bar{C}}x_k \end{aligned}

Let's start with the first line of the system:
Step 0: x0=Bˉu0x_0 = \mathbf{\bar{B}} u_0
Step 1: x1=Aˉx0+Bˉu1=AˉBˉu0+Bˉu1x_1 = \mathbf{\bar{A}}x_{0} + \mathbf{\bar{B}}u_1 = \mathbf{\bar{A}} \mathbf{\bar{B}} u_0 + \mathbf{\bar{B}}u_1
Step 2: x2=Aˉx1+Bˉu2=Aˉ(AˉBˉu0+Bˉu1)+Bˉu2=Aˉ2Bˉu0+AˉBˉu1+Bˉu2x_2 = \mathbf{\bar{A}}x_{1} + \mathbf{\bar{B}}u_2 = \mathbf{\bar{A}} (\mathbf{\bar{A}} \mathbf{\bar{B}} u_0 + \mathbf{\bar{B}}u_1) + \mathbf{\bar{B}}u_2 = \mathbf{\bar{A}}^{2} \mathbf{\bar{B}} u_0 + \mathbf{\bar{A}} \mathbf{\bar{B}} u_1 + \mathbf{\bar{B}}u_2
We have xkx_k which can be written as a function ff parametrized by (u0,u1,...uk)(u_0, u_1, ... u_k).

Let's move on to the second line of the system, where we can now inject the xkx_k values calculated just now:
Step 0: y0=Cˉx0=CˉBˉu0y_0 = \mathbf{\bar{C}} x_0 = \mathbf{\bar{C}} \mathbf{\bar{B}} u_0
Step 1: y1=Cˉx1=Cˉ(AˉBˉu0+Bˉu1)=CˉAˉBˉu0+CˉBˉu1y_1 = \mathbf{\bar{C}} x_1 = \mathbf{\bar{C}} ( \mathbf{\bar{A}} \mathbf{\bar{B}} u_0 + \mathbf{\bar{B}}u_1) = \mathbf{\bar{C}} \mathbf{\bar{A}} \mathbf{\bar{B}} u_0 + \mathbf{\bar{C}} \mathbf{\bar{B}}u_1
Step 2: y2=Cˉx2=Cˉ(Aˉ2Bˉu0+AˉBˉu1+Bˉu2)=CˉAˉ2Bˉu0+CˉAˉBˉu1+CˉBˉu2y_2 = \mathbf{\bar{C}} x_2 = \mathbf{\bar{C}}(\mathbf{\bar{A}}^{2} \mathbf{\bar{B}} u_0 + \mathbf{\bar{A}} \mathbf{\bar{B}} u_1 + \mathbf{\bar{B}}u_2 ) = \mathbf{\bar{C}}\mathbf{\bar{A}}^{2} \mathbf{\bar{B}} u_0 + \mathbf{\bar{C}}\mathbf{\bar{A}} \mathbf{\bar{B}} u_1 + \mathbf{\bar{C}}\mathbf{\bar{B}}u_2 We can observe the convolution kernel Kˉk=(CˉBˉ,CˉAˉBˉ,...,CˉAˉkBˉ)\mathbf{\bar{K}} _k = (\mathbf{\bar{C}} \mathbf{\bar{B}}, \mathbf{\bar{C}} \mathbf{\bar{A}} \mathbf{\bar{B}}, ..., \mathbf{\bar{C}} \mathbf{\bar{A}}^{k} \mathbf{\bar{B}}) applicable to uku_k, hence KuK \ast u.

As with matrices, we apply a bar to the Kˉ\mathbf{\bar{K}} to specify that it is the convolution kernel obtained after discretization. It is generally referred to as the SSM convolution kernel in the literature, and its size is equivalent to the entire input sequence.
This convolution kernel is calculated by Fast Fourier Transform (FFT) and will be explained in future articles (do you like the Flash Attention of transformers? You'll love Flash FFT Convolution, which we'll look at in the third blog post).


Advantages and limitations of each of the three views

image
Figure 3: Image from the paper Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers by Albert GU et al, released a week before S4

The different views of SSM each have their advantages and disadvantages - let's take a closer look.

For the continuous view, the advantages and disadvantages are as follows:
✓ Automatically handles continuous data (audio signals, time series, for example). This represents a huge practical advantage when processing data with irregular or time-shifted sampling.
✓ Mathematically feasible analysis, e.g. by calculating exact trajectories or building memory systems (HiPPO).
✗ Extremely slow for both training and inference.

For the recursive view these are the well-known advantages and disadvantages of recursive neural networks, namely:
✓ Natural inductive bias for sequential data, and in principle unbounded context.
✓ Efficient inference (constant-time state updates).
✗ Slow learning (lack of parallelism).
✗ Gradient disappearance or explosion when training too-long sequences.

For the convolutional view, we're talking here about the well-known advantages and disadvantages of convolutional neural networks (we're here in the context of their one-dimensional version), namely:
✓ Local, interpretable features.
✓ Efficient (parallelizable) training.
✗ Slowness in online or autoregressive contexts (must recalculate entire input for each new data point).
✗ Fixed context size.

So, depending on the stage of the process (training or inference) or the type of data at our disposal, it is possible to switch from one view to another in order to fall back on a favorable framework for getting the most out of the model.
We prefer the convolutional training view for fast training via parallelization, the recursive view for efficient inference, and the continuous view for handling continuous data.



Learning matrices

In the convolution kernel developed above, Cˉ\mathbf{\bar{C}} and Bˉ\mathbf{\bar{B}}, are learnable scalars.
Concerning Aˉ\mathbf{\bar{A}}, we've seen that in our convolution kernel, it's expressed as a power of kk at time kk. This can be very time-consuming to calculate, so we're looking for a fixed Aˉ\mathbf{\bar{A}}. For this, the best option is to have it diagonal:

A=[λ1000λ2000λn]Ak=[λ1k000λ2k000λnk] \mathbf{A} = \begin{bmatrix} \lambda_{1} & 0 & \cdots & 0 \\ 0 & \lambda_{2} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & \lambda_{n} \end{bmatrix} \Rightarrow \mathbf{A^k} = \begin{bmatrix} \lambda_{1}^k & 0 & \cdots & 0 \\ 0 & \lambda_{2}^k & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & \lambda_{n}^k \end{bmatrix}

By the spectral theorem of linear algebra, this is exactly the class of normal matrices.
In addition to the choice of discretization mentioned above, the way in which Aˉ\mathbf{\bar{A}} is defined and initiated is one of the points that differentiates the various SSM architectures developed in the literature, which we'll develop in the next blog post. Indeed, empirically, it appears that an SSM initialized with a random A\mathbf{A} matrix leads to poor results, whereas an initialization based on the HiPPO matrix (for High-Order Polynomial Projection Operator) gives very good results (from 60% to 98% on the MNIST sequential benchmark).

The HiPPO matrix was introduced by the S4 authors in a previous paper (2020). It is included in the LSSL paper (2021), also by the S4 authors, as well as in the S4 appendix. Its formula is as follows:

A=[112133135413575135796135791171357911138]Ank={(1)nk(2k+1)n>kk+1n=k0n<k \mathbf{A} = \begin{bmatrix} 1 \\ -1 & 2 \\ 1 & -3 & 3 \\ -1 & 3 & -5 & 4 \\ 1 & -3 & 5 & -7 & 5 \\ -1 & 3 & -5 & 7 & -9 & 6 \\ 1 & -3 & 5 & -7 & 9 & -11 & 7 \\ -1 & 3 & -5 & 7 & -9 & 11 & -13 & 8 \\ \vdots & & & & & & & & \ddots \\ \end{bmatrix} \\ \Rightarrow \mathbf{A}_{nk} = \begin{cases}% (-1)^{n-k} (2k+1) & n > k \\ k+1 & n=k \\ 0 & n<k \end{cases}

This matrix is not normal, but it can be decomposed as a normal matrix plus a matrix of lower rank (summarized in the paper as NPLR for Normal Plus Low Rank). The authors prove in their paper that this type of matrix can be computed efficiently via three techniques (see Algorithm 1 in the paper): truncated generating series, Cauchy kernels and Woodbury identity.

Details of the demonstration showing that an NPLR matrix can be computed efficiently as a diagonal matrix can be found in the appendix (see part B and C) of the paper.
The authors of S4 subsequently made modifications to the HiPPO matrix (on how to initiate it) in their paper How to Train Your HiPPO (2022). The model resulting from this paper is generally referred to as "S4 V2" or "S4 updated" in the literature as opposed to the "original S4" or "S4 V1".
In the next article, we'll see that other authors (notably Ankit GUPTA) have proposed using a diagonal matrix instead of an NPRL matrix, an approach that is now preferred as it is simpler to implement.



Experimental results

Let's end this blog post by analyzing a selection of the S4's results on various tasks and benchmarks to get a feel for the potential of SSMs.

Let's start with an audio task and the benchmark Speech Commands by WARDEN (2018).

image
Figure 4: Image from the paper On the Parameterization and Initialization of Diagonal State Space Models by Albert GU et al. (2022), also known as S4D, published after S4 but which reproduces in a more structured form the results of S4 for this benchmark (the results of S4D having been removed from the image so as not to spoil the next article ;)

Several things can be observed in this table.
Firstly, for a more or less equivalent number of parameters, the S4 performs much better (at least +13%) than the other models, here of the ConvNet type.
Secondly, to achieve equivalent performance, a ConvNet requires 85 times more parameters.
Thirdly, a ConvNet trained on 16K Hz gives very poor results when then applied to 8K Hz data. In contrast, the S4 retains 95% of its performance on this resampling. This can be explained by the continuous view of the SSM, where it was sufficient to halve the Δ\Delta value at the time of the test phase.


Let's continue with a time series task (introduced in a revision of S4).

image
Figure 5: Image from the S4 appendix

The authors of the paper take up the methodology of the Informer model by ZHOU et al. (2020) and show that their model outperforms this transformer on 40 of the 50 configurations. The results in the table are shown in a univariate framework, but the same is observable for a multivariate framework (table 14 in the appendix).


Let's continue with a vision task and the benchmark sCIFAR-10 by KRIZHESKY (2009).

image
Figure 6: Image from the S4 appendix

S4 establishes SoTA on sCIFAR-10 with just 100,000 parameters (the authors don't specify the number for the other methods).


Let's conclude with a textual task and the benchmark Long Range Arena (LRA) by TAY et al. (2020).

image
Figure 7: Image from the S4 appendix

The LRA consisted of 6 tasks, including Path-X with a length of 16K tokens, for which the S4 was the first model to succeed, demonstrating its performance on very long-sequence tasks.
It would be more than 2 years before AMOS et al. showed in their paper Never Train from Scratch: Fair Comparison of Long-Sequence Models Requires Data-Driven Priors (2023) that transformers (not hybridized with an SSM) could also solve this task. However, unlike SSMs, they are unable to pass the 65K token PathX-256.

Note, however, a negative point concerning the text for S4: it obtains a higher perplexity compared to that of a transformer (standard, with more optimized versions having an even lower perplexity) on WikiText-103 by MERITY et al. (2016).

image
Figure 8: Image from the S4 appendix

This is probably due to the non-continuous nature of text (it has not been sampled from an underlying physical process such as speech or time series). We'll see in the article devoted to developments in SSM in 2023 that this point has been the subject of a great deal of work, and that SSM has now succeeded in bridging this gap.


Conclusion

SSMs are models with three views. A continuous view, and when discretized, a recurrent as well as a convolutive view.
The challenge with this type of architecture is to know when to favor one view over another, depending on the stage of the process (training or inference) and the type of data being processed.
This type of model is highly versatile, since it can be applied to text, vision, audio and time-series tasks (or even graphs).
One of its strengths is its ability to handle very long sequences, generally with a lower number of parameters than other models (ConvNet or transformers), while still being very fast.
As we'll see in later articles, the main differences between the various existing SSM architectures lie in the way the basic SSM equation is discretized, or in the definition of the A\mathbf A matrix.


To dig deeper

For S4, please consult the following resources:

For more information on the HiPPO matrix, please consult the following resources:

To find out more about SSM, take a look at :

References