Home Blog S4 の数式を全部理解したい
S4 の数式を全部理解したい
Cancel

S4 の数式を全部理解したい

今回読んだ論文

まず、簡単に論文の立ち位置を説明する。

  • HiPPO: ある時点までの入力を圧縮する操作のこと。LSTM の記憶セルとして使用することができる。
  • LSSL: 前論文では HiPPO は RNN の補助パーツのような扱いだった。この論文では状態空間モデルを採用して HiPPO を中心としたアーキテクチャを提案している。
  • S4: 前論文で LSSL を並列かつ高速に計算するアルゴリズムが提案されたが数値的に不安定だった。この論文ではそれをより安定なアルゴリズムに改善した。

S4 の関連論文の数式を理解するためには arXiv の付録が必須。

参考(神)資料

State Space Model (1)

\[\begin{aligned} x'(t) &= Ax(t) + Bu(t) \\ y(t) &= Cx(t) + Du(t) \end{aligned} \tag{1}\]

これはまさに状態空間モデル (wiki)。

HiPPO 論文では、$x(t), u(t)$ がそれぞれ $c(t), f(t)$ と表記されていることに注意が必要。

係数の時間発展

まず、State Space の 1 つ目の式がなぜ登場したのかを理解したい。 実は、HiPPO 論文では上側の式だけしか出てこなかった:

\[\begin{aligned} c'(t) = Ac(t) + Bf(t). \end{aligned}\]

これだけだと状態空間モデルとは言えない。実際、HiPPO 論文には「State Space」というワードは出てこない。この式は、人間が現象をモデル化したのではなく、手法から自然に導かれた方程式である、という機械学習系の論文だとあまり見ない経緯がある。

そもそも HiPPO の正式名称は high-order polynomial projection operators で、直訳すると時刻 $t$ までの関数 $f$ を高次多項式 $\lbrace g_n \rbrace_{n<N}$ で表される空間に射影する操作のことである (cf. HiPPO Appendix C.2, D.3):

\[\begin{aligned} f(x) &\approx \sum_{n=0}^{N-1} c_n^{(t)} g_n^{(t)}(x) \\ where ~ c_n^{(t)} &= \int f(x) g_n^{(t)}(x) \omega^{(t)}(x) dx. \end{aligned}\]

テイラー展開を用いて、任意の関数 $f$ を $n$ 次多項式 $g_n$ で近似するのと同じノリ。

(備考)深層学習モデルの隠れ次元 $N$ はエイヤで決めがちだが、HiPPO の場合は誤差項から逆算で決められる。ただし、真の隠れ状態と近似した隠れ状態の誤差に対する良し悪しの規準を設けられないので結局エイヤになりそう。

このときの係数 $c_n(t)$ を微分を、$g$ と $\omega$ の定義を利用して整理すると:

\[\begin{aligned} \frac{d}{dt} c_n(t) &= \int f(x) \left( \frac{\partial}{\partial t} g_n^{(t)}(x) \right) \omega^{(t)}(x) dx \\ &\quad+ \int f(x) g_n^{(t)}(x) \left( \frac{\partial}{\partial t} \omega^{(t)}(x) \right) dx \\ &= \cdots = Ac(t) + Bf(t) \end{aligned}\]

先程の微分方程式が得られる。これにより、時刻 $t$ に応じて係数をどう変化させればいいかが分かる。すなわち、新しい時系列データ $f(t+1)$ が入力されたときに、最良の近似を与える係数 $c(t+1)$ が得られる(かなりすごいこと)。

State Space Model (SSM) の提案

HiPPO 論文では、HiPPO を RNN に組み込んで実験することで有効性の検証をしていた。個人的には、せっかくの精密な理論に基づく手法を理論的背景が乏しい RNN に組み込んでしまうのはもったいないのではないかと感じた。

LSSL 論文では、SSM と Linear+非線形変換の積み重ねから成るアーキテクチャが提案された。SSM の第一式が HiPPO 論文の係数更新式に完全に一致していることから、HiPPO は補助パーツからコアパーツに格上げされたように見える(Attention みたい)。

既に HiPPO は RNN の補助パーツとして完成されていたのに、なぜ SSM に進化したのか。ただの勘だが、著者は HiPPO の式が SSM の第一式と一致することに気づいたことで、RNN を捨てて HiPPO を中心とした SMM を提案することを思い付いたのではないか(流石に SMM から HiPPO を思い付くという逆の経緯はありえない気がする)。

HiPPO Matrix (2)

\[A_{nk} = - \begin{cases} (2n+1)^{1/2} (2k+1)^{1/2} & \text{if} ~ n>k \\ n+1 & \text{if} ~ n=k \\ 0 & \text{if} ~ n<k \end{cases} \tag{2}\]

前セクションで「$g$ と $\omega$ の定義を利用して整理する」と雑に書いたが、HiPPO 論文の Appendix D.3 に厳密な導出過程が載っている。 行列がこういう形をしていると、LegS という直交多項系で入力系列を近似しているんだなと思えば良い。

微分方程式の離散化 (3)

微分方程式は $t$ が連続値なので、数値計算時は離散化した方程式にしなければならない。

\[\begin{aligned} x_k &= \overline{A}x_{k-1} + \overline{B}u_k \\ y_k &= Cx_{k} \\ where ~ \overline{A} &= (I-\Delta/2 \cdot A)^{-1} (I+\Delta/2 \cdot A) \\ \overline{B} &= (I-\Delta/2 \cdot A)^{-1} \Delta B \end{aligned} \tag{3}\]

この式は $x$ の漸化式になっているので、RNN のようなモデルとして機能する。

離散化は、関数の積分を台形を敷き詰めて近似するという直感的な手順に従っている(HiPPO Appendix B.3):

1° $x’(t)$ を $[t,t+\Delta]$ で積分する:

\[\int_t^{t+\Delta} x'(s) ds = x(t+\Delta) - x(t)\]

2° $x’(t)$ の微分方程式を代入すると:

\[x(t+\Delta) - x(t) = A \int_t^{t+\Delta} x(s) ds + B \int_t^{t+\Delta} u(s) ds\]

3° 1 項目は、上底 $x(t)$ - 下底 $x(t+\Delta)$ - 高さ $\Delta$ の台形で近似する:

\[\int_t^{t+\Delta} x(s) ds = \Delta/2 \cdot (x(t) + x(t+\Delta))\]

4° $u$ は曲線ではなく階段関数のような形をしていることにするらしいので、2 項目は長方形で近似する (wiki):

\[\int_t^{t+\Delta} u(s) ds = \Delta u(t)\]

5° 合わせると:

\[x(t+\Delta) - x(t) = \Delta/2 \cdot A (x(t) + x(t+\Delta)) + \Delta B u(t)\]

6° 離散風に書き換える:

\[x_{k+1} - x_k = \Delta/2 \cdot A (x_k + x_{k+1}) + \Delta B u_k\]

7° 各変数について整理すると:

\[(I - \Delta/2 \cdot A) x_{k+1} = (I + \Delta/2 \cdot A) x_k + \Delta B u_k\]

これはこのセクションの一番上の式に等しい。

数値解析の観点で $\overline{A}$ の定義を見ると「深層学習モデルの各層で入力ごとに逆行列の計算をするって不安定すぎないか?」と感じる。この疑問については、S4 の手法を取り入れると解決するので後ほど説明する。

畳み込み表現 (4) (5)

式 (3) の $y_k=Cx_k$ に $x_k = \dots$ を代入し続けてみる:

Markdown でも LaTeX みたいに式番号を \ref する方法はないものか…

\[\begin{aligned} y_k &= Cx_k = C(\overline{A} x_{k-1} + \overline{B}u_k) \\ &= C\overline{A} x_{k-1} + C\overline{B}u_k = C\overline{A} (\overline{A} x_{k-2} + \overline{B}u_{k-1}) + C\overline{B}u_k \\ &= C\overline{A}^2 x_{k-2} + C\overline{A}\overline{B}u_{k-1} + C\overline{B}u_k = C\overline{A}^2 (\overline{A} x_{k-3} + \overline{B}u_{k-2}) + C\overline{A}\overline{B}u_{k-1} + C\overline{B}u_k \\ &= C\overline{A}^3 x_{k-3} + C\overline{A}^2\overline{B}u_{k-2} + C\overline{A}\overline{B}u_{k-1} + C\overline{B}u_k \end{aligned}\]

これを $k=0$ まで続けると、$y_k$ は各 $u$ とカーネル $\overline{K}$ の畳み込みとして書ける:

\[\begin{align} &y_k = \overline{K} * u \tag{4} \\ &where ~ \overline{K} := (\overline{C}\overline{A}^i\overline{B})_{i=0,1,\dots,L-1} \tag{5} \end{align}\]

ただし、$x_{-1}=0$ とした。

つまり、RNN は $k$ 番目の出力 $y_k$ を計算するためには 1 番目から漸化式で順々に更新しなければならないのに対して、S4 は $k$ までの入力 $u$ が与えられれば畳み込み演算一発で計算できる。

WIP

でも $A^{L-1}$ の計算ってやばいよね。でもこれが Z 変換を使うと効率的に計算できちゃう。めっちゃ暇だったら続き書く。(でも、S4 の後に行列 $A$ にそもそも対角行列を使うという話が出てきて、じゃあ難解な議論考えなくてもいいじゃんってなってる)