Current SSM decoding updates and writes recurrent state back to HBM every step. ReplaySSM instead caches recent inputs, reconstructing the state only when needed and otherwise computing the output directly from the cache.
Decoding is the bottleneck for RL post-training and inference serving due to the long chain-of-thought traces. Agent workloads stretch it further, since every task adds rounds of tool calls and extra reasoning. Transformers remain the industry standard, but their KV cache makes them expensive to run with long context. As sequences grow, storing the full token history linearly increases memory traffic and memory capacity requirements, which hurts both tail latency and throughput.
State space models (SSMs) were designed to address this bottleneck. Rather than storing the full token history, SSMs such as Mamba-2
SSM layers are cheap but suffer from exact recall, while attention layers in Transformers achieve better recall yet are expensive. Hybrid models aim to find the sweet spot in between by interleaving a majority of SSM layers with a few attention layers. Many production-level models (e.g., Nemotron-3
SSM decoding looks like a clean win: memory drops from \(O(N)\) to \(O(1)\). But the same mechanism that gives SSMs constant memory is exactly what introduces new challenges in practice: it summarizes all of history into one fixed-size state, updates it recurrently, and throws the inputs away.
ReplaySSM addresses all three by caching the recent inputs instead of writing the state back every step, without changing the output. That alone speeds up standard decoding by up to 1.48x (1.43x on large MoE models). The larger gain is in speculative decoding, where vLLM’s existing implementation falls below standard decoding at serving batch sizes, while ReplaySSM unlocks 1.87–1.96x speedup. We cover the three challenges next, then the method.
In Mamba-2, the recurrent form for step \(t\) is:
\[S_t = a_t\,S_{t-1} + \Delta_t\,(v_t\,k_t^\top), \qquad y_t = S_t\,q_t.\]Here, \(S\) is the state (i.e., summary of the history) and \(y\) is the output. \(v_t\) and \(k_t\) are the new inputs at step \(t\) that update the summary from \(S_{t-1}\) to \(S_t\), and \(q_t\) is the input that reads out the output from the current state.
State Space Models (SSMs)
In each decode step, an SSM repeatedly does two things: updates the state (the summary) with new inputs and reads the output from it. In Mamba-2, the recurrent form is:
\[S_t = a_t\,S_{t-1} + \Delta_t\,(v_t\,k_t^\top) \quad\text{(update)}, \qquad y_t = S_t\,q_t \quad\text{(readout)} .\]Here, \(S\) is the state and \(y\) is the output. \(v_t\) and \(k_t\) are the new inputs at time \(t\) that update the state from \(S_{t-1}\) to \(S_t\), and \(q_t\) is the input that reads out the output from the current state. \(a_t = e^{A\Delta_t}\) and \(\Delta_t\) are the per-step scalars. (If you are more familiar with Mamba-2’s notation, \(x_t, B_t, C_t\) are our \(v_t, k_t, q_t\).)
An SSM variant: Delta-rule family
The delta-rule family such as Gated DeltaNet (GDN) trades more complex computation for an additional benefit: the ability to erase the state, while Mamba-2 only adds. In GDN, the recurrent form is:
\[S_t = e^{g_t}\,S_{t-1}\,(I - \beta_t\,k_t k_t^\top) + \beta_t\,(v_t\,k_t^\top) \quad\text{(update)}, \qquad y_t = S_t\,q_t \quad\text{(readout)} .\]Here, \(I - \beta_t\,k_t k_t^\top\), called the Householder term, is what erases information from the current state. Similarly, \(g_t\) and \(\beta_t\) are per-step scalars.
Shapes
For each sequence, \(v_t\) is a per-head vector with head dimension \(d\), while \(k_t\) and \(q_t\) are shared across a group of heads (ngroups total) with shape dimension \(n\). The per-head state is a matrix with shape \((d, n)\). Typically, \(n\) and \(d\) are \(64\) or \(128\), with ngroups ranging from \(1\) to \(16\).
The state update is a rank-1 update by the outer product of two vectors, \(v\) and \(k\), followed by an accumulation. Output generation is a vector-matrix multiplication between \(q\) and the fixed-size state. Both operations are lightweight, and neither maps efficiently to modern matrix-multiplication accelerators such as Tensor Cores. Each decoding step is instead dominated by memory I/O
In hybrid models, SSM matters
One might expect the attention layers in a hybrid model to dominate latency because their cost grows with context length. Figure 2 shows otherwise. The SSM state update kernel remains the primary bottleneck up to 100K tokens, which covers a large fraction of real inference workloads.
Three factors lead to this result. First, the \(O(N)\) attention cost remains modest at short to middle context lengths. Second, SSM layers typically outnumber attention layers by a factor of three to six.
An SSM summarizes token history into a fixed-size state at each step. A Transformer stores the entire history explicitly in the KV cache. This fixed-size summary gives SSMs their efficiency, but the summary is lossy and irreversible. Once the state is updated, the model cannot recover the exact tokens that produced it.
This becomes a problem when inference needs to rewind. Speculative decoding, widely used across production-level models such as Nemotron-3, Gemma4
Attention handles rollback naturally. It moves the sequence pointer in the KV cache backward, and the rejected keys and values are no longer used. An SSM does not have an explicit token history to point into. The history is compressed into the recurrent state, and the raw inputs are gone. Figure 3 contrasts the two.
The common workaround, used for example in vLLM, is to store a separate SSM state for every speculative token. On rejection, the system restores the state that corresponds to the last accepted token. This adds \(T\) times more memory traffic per decoding step to an already memory-bound path, where \(T\) is the speculative window. This overhead means speculative decoding, a reliable speedup for Transformers, barely helps SSMs at serving batch sizes, creating a major disadvantage.
The recurrent SSM states are sequentially dependent. Each state depends on the state before it. This dependency makes SSMs harder to parallelize.
Speculative decoding increases parallelism for Transformers because verification over \(T\) draft tokens can be batched. Many small matrix-vector operations become one larger GEMM over the speculative window.
SSMs do not get the same form of batching because verification over \(T\) draft tokens requires the state and output at every speculative position. It cannot replace the whole window with one combined state transition, because verification needs the intermediate outputs, not only the final state. The computation is still a length-\(T\) loop, not a batched GEMM.
Step back from all three problems and ask something that sounds too simple to be useful:
Why do we store the state at all?
An SSM decoding step consists of four stages: loading the state, updating it with the new inputs, generating the output, and storing the updated state back to memory. But the only consumer of that written-back state is the next step, which loads it just to do the same thing again.
We store the state only to recurrently update the state. So, do we need it?
We don’t, and the reason lies in the definition of the SSM recurrent state update.
\[S_t = a_t\, S_{t-1} + \Delta_t\,(v_t\,k_t^\top) = \sum_{i \le t} \Big( \textstyle\prod_{i < j \le t} a_j \Big)\, \Delta_i\,(v_i\,k_i^\top) .\]These two expressions describe the same state, but they suggest two different ways to compute it:
An SSM recurrence has the flexibility to take either route.
An SSM can eagerly summarize each step into a state, or keep the recent inputs and reconstruct it. Current decoding always picks eager summarization. We don’t have to.
We use Mamba-2 as the example. GDN adds a Householder term \(I - \beta_t k_t k_t^{\top}\) to erase content from the state, which makes the history route more complex. However, the same concept still holds since it is still a recurrent update. The full algorithm is placed in the Appendix.
ReplaySSM changes what is stored to memory per step. Instead of storing the recurrent state, ReplaySSM caches the recent inputs in a small buffer. For Mamba-2, the buffer stores the per-step \((v, k)\) pairs and the decay factors needed to replay them.
When the model needs the state, ReplaySSM selects the history route to reconstruct the state from the buffered inputs. The state is no longer something we update and write back to memory at each step. It is something we recompute when needed.
When the buffer has grown large enough that loading it would cost more than just writing back a state, ReplaySSM flushes the buffer. It summarizes the buffered history inputs into the state, clears the buffer, and starts caching inputs again. The state write-back happens only at flush steps; most decoding steps cache small SSM inputs to the buffer. Notably, ReplaySSM is mathematically equivalent to original decoding up to floating-point error. It changes how the state is computed, but not the output.
On most steps, ReplaySSM does not write the recurrent state back to memory. It still loads the recurrent state, but replaces the full state store with a small buffer load and an append of two vectors, \((v, k)\). This roughly halves the dominant state traffic.
In baseline decoding, we load the state and the inputs, and we store the state.
Assuming 4-byte states and 2-byte activations, the memory traffic per head is:
\[8dn + 2(d + 2n + 1)\]The dominant term is the state traffic $8dn$.
ReplaySSM caches recent inputs instead of storing the state. Assume the buffer already caches the most recent \(h\) inputs.
The total memory traffic per head is:
\[4dn + 2h(d + n + 1) + 2(d + 2n + 1) + 2(d + n + 1)\]ReplaySSM halves the dominant state traffic from $8dn$ to $4dn$.
Since SSM decoding is memory-bound, reducing memory traffic directly improves latency. The flush path is more expensive because it summarizes a chunk of recent inputs into the state and writes the full state back once, but this cost is amortized across the whole window.
ReplaySSM caches the recent SSM inputs (e.g., the draft tokens) explicitly. It doesn’t perform the irreversible summarization per step, so rolling back rejected draft tokens only requires removing their buffer entries.
Speculative decoding triggers flushes (state updates when the buffer is full) more frequently because verification appends multiple proposed tokens at once, causing the buffer to fill faster for a fixed buffer size.
However, ReplaySSM still avoids writing the full recurrent state in most steps. Even under speculation, most steps cache the inputs rather than writing the full state back at every speculative position. Rollback becomes cheap, and the amortized state traffic with speculation is even lower than in baseline standard decoding.
In the baseline decoder, every token must produce two things: the updated recurrent state and the output. These two are tied together because the next token needs the state. The baseline decoder follows the recurrence directly: it materializes \(S_t\), reads \(y_t\) from it, and writes \(S_t\) back to memory.
ReplaySSM changes what must be produced
Between flushes, the checkpoint state does not change. The recent history lives in the buffer. That means most decode steps only need the output for the current token. The updated state is needed only when the buffer fills and we summarize the cached inputs into the checkpoint state.
This weaker requirement unlocks the opportunity to use different algorithms in two paths:
Bypass the sequential state dependence
For standard decoding, the second path is an option. For speculative decoding, the second path is needed to bypass the sequential state dependence. It lets ReplaySSM compute multiple draft outputs in parallel and breaks the need to sequentially reconstruct a new state at each draft token.
Intuition
To get the intuition, let’s first assume a zero initial state. Then, the state is \(S = v\,k^\top\), and the output (the only value ReplaySSM needs in most steps) is:
\[y = S q = (v k^\top) q .\]This three-way product can be bracketed in two ways:
\[(v\,k^\top)\,q \qquad\text{or}\qquad v\,(k^\top q).\]The left route builds the full state with an outer product, \(v k^\top\), then reads from it. It gives both the state and the output. This route is useful when ReplaySSM needs to flush the buffer and update the checkpoint state.
The right route never materializes the state. It first computes the inner product \(k^\top q\), then scales \(v\) with that scalar. It gives the same output, but not the state. Figure 5a contrasts the two routes.
ReplaySSM can choose either route.
Most of the decoding steps only need an output, so ReplaySSM uses the output-only route. If the buffer is full and the checkpoint state must be updated, ReplaySSM selects the state-and-output route.
With a nonzero checkpoint state \(S_0\) and a buffer of recent inputs, the same idea applies. Suppose the buffer covers positions \(1,\dots,t-1\) after the last checkpoint. For Mamba-2,
\[S_t = \bar a_t S_0 + \sum_{j=1}^{t} s_{j,t} (v_j k_j^\top),\]with \(\bar a = e^{A\,\mathrm{pre}_t}\), \(\;s_j = \Delta_j\,e^{A(\mathrm{pre}_t - \mathrm{pre}_j)}\), and \(\mathrm{pre}_j = \sum_{i\le j}\Delta_i\)
Reading the output from this state gives
\[y_t = \bar a_t (S_0 q_t) + \sum_{j=1}^{t} s_{j,t} v_j (k_j^\top q_t).\]This is the output-only form. It still reads the checkpoint state and the recent input buffer, but it does not materialize a \(d \times n\) state per head. Notably, since \(k\) and \(q\) are shared across a group of heads, the output-only form also gives us the benefit of precomputing \(k_j^\top q_t\). Figure 5b shows both routes with their matrix shapes.
In the following, we use Mamba-2 as the example. The algorithms for Gated DeltaNet are placed in the Appendix.
What about FLOPs?
Switching from the outer-product to the inner-product form also changes the FLOPs spent per decoded token. FLOPs don’t affect the latency of a memory-bound kernel, but the count is still worth a look to understand what output-only decode actually computes.
Assume the buffer holds the most recent \(h\) inputs (with current token, there are \(h+1\) \((v, k)\) pairs). Per head, the state-and-output (outer product) route costs:
The output-only (inner product) route costs:
The outer-product route pays \(2(h+1)dn\) to materialize the state. The inner-product route replaces that with \(2(h+1)(d+n)\), about 64x smaller for \(d = n = 128\). And since \(k\) and \(q\) are shared within a head group, the \(k_j^\top q_t\) products are computed once per group rather than once per head. Output-only decode needs fewer FLOPs.
The FLOPs count gets messier beyond Mamba-2 standard decode. Speculative decoding adds a quadratic \(T(T+h)\) term to the inner-product route (\(T\) draft queries against \(h+T\) cached keys, where \(T\) is the speculative window). This is also where the \(k^\top q\) products become a real GEMM. Mamba-3
Tensor Cores
Notably, even if we enter the compute-bound regime (e.g., ReplaySSM applied to Mamba-3 speculative decoding), the FLOPs count isn’t everything. Tensor Cores have far higher throughput than CUDA cores (roughly 989 vs. 67 TFLOP/s on an H100). Take Mamba-2 standard decode as an example. In the outer-product route, the state-construction term is a GEMM with inner dimension \(h+1\). Given enough cached tokens, it can map onto Tensor Cores and overlap with the matrix–vector terms running on CUDA cores. The inner-product route has no such GEMM in standard decode; every term is a matrix–vector product or a dot product, all on CUDA cores. So whether the FLOP reduction translates into faster compute requires deeper exploration.
Baseline Mamba-2 decoder
The baseline Mamba-2 decoder eagerly updates and stores the recurrent state back to memory at each decoding step.
Algorithm 1: Baseline (recurrent state update)
State: recurrent state \(S \in \mathbb{R}^{d \times n}\) in HBM
Input: token inputs \((v, \Delta, k, q)\)
- $a \gets e^{A\Delta}$
- Load \(S\) from HBM
- $S \gets a\,S + \Delta\,(v\,k^\top)$
- $y \gets S\,q$
- store \(S\) to HBM
- Return \(y\;[+\,\text{skip, gate}]\)
ReplaySSM
ReplaySSM keeps a checkpoint state \(S_0\) and a buffer of recent inputs. For each token, it appends the current inputs to the buffer, computes the output from the checkpoint plus the cached inputs, and materializes the state only when the buffer must be flushed.
Algorithm 2: ReplaySSM output-only decode
State: checkpoint \(S_0\), buffer \(\mathcal{B} = \{(v_j,\Delta_j,k_j)\}_{j=1}^{h}\) with capacity \(L\) in HBM
Input: token inputs \((v, \Delta, k, q)\)
- Append \((v, \Delta, k)\) to the buffer
- Compute the decay weights \(\bar a\) and \(s_j\)
- \(y \gets \bar a\,(S_0\,q) + \sum_{j=1}^{h+1} s_j\,(k_j^\top q)\,v_j\) // output only, no state materialized
- If buffer full:
- \(\quad S_0 \gets \bar a\,S_0 + \sum_{j=1}^{h+1} s_j\,(v_j\,k_j^\top)\), clear buffer // flush: full-state store
- Return \(y\;[+\,\text{skip, gate}]\)
Baseline in vLLM’s implementation
The baseline must store a full state snapshot for every draft position so it can roll back after rejection. It also computes the draft outputs through a length-\(T\) recurrence loop.
Algorithm 3: Standard speculative decode (serial scan)
State: State for the last accepted token in HBM
Input: draft inputs \({(v_s, \Delta_s, k_s, q_s)}_{s=1}^{T}\)
- Load \(S\) from the last accepted snapshot
- For \(s = 1,\dots,T\):
- $\quad S \gets e^{A\Delta_s}\,S + \Delta_s\,(v_s\,k_s^\top)$
- $\quad y_s \gets S\,q_s$
- \(\quad\)Store \(S\) as the snapshot for draft token \(s\)
- Return \(\{y_s\}\;[+\,\text{skip, gate}]\)
ReplaySSM
ReplaySSM keeps draft inputs in the same buffer used during standard decoding. On rollback, it simply moves the pointer in the buffer to keep accepted entries and discard the rest.
The output-only form also removes the serial state update from verification. Each draft query reads from the same checkpoint state and the same buffer window. The only difference across draft positions is the causal mask: draft output \(s\) can use cached entries up to its own position, but not later drafts.
Algorithm 4: ReplaySSM cached speculative decode
State: checkpoint \(S_0\), buffer \(\mathcal{B} = \{(v_j,\Delta_j,k_j)\}_{j=1}^{h}\) with capacity \(L\) in HBM
Input: draft inputs \(\{(v_s,\Delta_s,k_s,q_s)\}_{s=1}^{T}\)
- Append the draft inputs to the buffer; draft \(s\) sits at position \(p_s = h + s\)
- For each draft \(s\), compute the decay weights \(\bar a_s\) and \(w_{j,s}\) at position \(p_s\)
- \(H_{:,s} \gets \bar a_s\,(S_0\,q_s)\) // checkpoint readout for every draft
- \(M_{j,s} \gets k_j^\top q_s\), masked to \(j \le p_s\) // GEMM
- \(Y_{:,s} \gets H_{:,s} + \sum_{j \le p_s} w_{j,s}\,M_{j,s}\,v_j\) // GEMM
- If this step is a flush step:
- $\quad S_0 \gets \bar a_{\mathcal{B}}\,S_0 + \sum_{j=1}^{h} w_j\,(v_j\,k_j^\top)$
- Return \(Y_{:,s}\;[+\,\text{skip, gate}]\)
Compared with standard speculative decode, which iterates through draft tokens and materializes each intermediate state, ReplaySSM directly computes the outputs through inner products between \(k\) and \(q\). That is a better shape for the hardware. The key-query products become a matrix multiplication over cached keys and draft queries. The weighted sum over values is another matrix multiplication under a causal mask.
This also changes rollback cost. Baseline speculative decode keeps one state snapshot per draft token. ReplaySSM keeps recent inputs. During commit, ReplaySSM advances the pointer by the number of accepted draft tokens and discards the rest. No full-state restore is needed. Notably, in a flush step, only committed cached inputs are summarized into the state. The current step’s speculative tokens are not summarized for rollback.
Let \(h\) be the number of cached tokens currently in the buffer, \(T\) the speculative window, and \(L\) the buffer capacity, as in Algorithm 4. ReplaySSM flushes one window early. It summarizes the cached tokens into the checkpoint when
\[h + 2T > L,\]rather than the natural condition \(h + T > L\).
The natural condition can silently shrink the speculative window. Suppose a step lands at \(h + T = L - 1\). No flush fires, and all \(T\) drafts happen to be accepted. The next step then starts at \(h = L - 1\). The flush fires now, but only one free slot remains for the fresh draft window in that step, so the window is truncated to a single draft and the accepted tokens for that step collapse. Flushing one window early guarantees at least \(T\) free slots on every step.
Here, we highlight two key kernel design choices. Feel free to also check out the Appendix for details on how we integrate our approach into vLLM.
Precomputing shared inner products (Mamba-2)
In the output-only form, all heads in a group need the same inner products \(k_j^\top q\). Computing them inside the main SSM update kernel would repeat the work across the head-dimension grid and add register pressure. ReplaySSM computes them in a small precompute kernel that runs once per group and writes a scratch buffer the main SSM update kernel reads.
A ring buffer avoids data-dependent copies
In speculative decoding, a flush only summarizes the committed cached inputs into the state. The current step’s speculative tokens are not summarized, since rejected ones must still be rolled back. The accepted tokens remain in the buffer as the cached inputs for the next decoding step. To avoid relocating these tokens back to the front of the buffer
We evaluate ReplaySSM on two hybrid families with different SSM layers: Nemotron-3 (Mamba-2) and Qwen3.5 (GDN).
| Model | Params | Precision | Hardware |
|---|---|---|---|
| Nemotron-3-Nano-4B | 4B dense | BF16 | 1×H100 |
| Nemotron-3-Super-120B | A12B MoE | NVFP4 | 1×B300 |
| Nemotron-3-Ultra-550B | A55B MoE | NVFP4 | 2×B300, TP2 |
| Qwen3.5-4B | 4B dense | BF16 | 1×H100 |
| Qwen3.5-122B | A10B MoE | NVFP4 | 1×B300 |
We implemented ReplaySSM on top of vLLM. All results run in vLLM with CUDA Graph enabled. SSM states are in FP32 and the vectors cached in the buffer are in BF16. For speculative decoding, both families use their MTP heads as the drafter.
Across both families and sizes from 4B to 550B, ReplaySSM speeds up vLLM’s standard decoding by up to 1.48x (1.43x on large MoE models) end-to-end and speculative decoding by 1.87–1.96x over vLLM’s standard decoding. It also supports 3.0–3.3x more concurrent requests than vLLM’s speculative path under a fixed memory budget.
Figure 6 reports SSM-kernel and end-to-end per-step speedup at batch size 256 over 1K decoding steps, with buffer size 8 for Nemotron-3 and 16 for Qwen3.5 (the best settings from Figure 7).
ReplaySSM makes SSM decoding faster, and the kernel speedup translates into end-to-end speedup on hybrid models across different SSM families and model sizes (from 4B to 550B). On Nemotron-3, ReplaySSM reaches 1.43x to 1.84x kernel and 1.20x to 1.48x end-to-end speedup. On Qwen3.5, ReplaySSM reaches 1.43x to 1.64x kernel and 1.20x to 1.27x end-to-end speedup. The end-to-end speedup is smaller because ReplaySSM targets only the SSM kernel, while attention, GEMMs, and the rest are unchanged.
Trade-offs of different buffer capacities
The buffer size in ReplaySSM introduces a trade-off. A shorter buffer flushes more often, which pays more cost on writing the updated state back to HBM. A longer buffer reduces flush frequency, but each step reads more from the buffer, and eventually turns the kernel compute-bound. Figure 7 shows the resulting bell shape, where a medium buffer (8 for Nemotron-3, 16 for Qwen3.5) balances the two costs.
End-to-end throughput
We test end-to-end throughput (tokens/s) on vLLM using prompts from the GSM8K dataset
Breakdown 1: faster decode steps
The baseline’s verification cost grows almost linearly with the speculative window, because it stores a full SSM state per draft token on an already memory-bound path. Figure 9 shows this on Qwen3.5-122B. At \(T=6\) the baseline kernel costs 4.85× the standard decoding kernel. ReplaySSM’s state traffic is one checkpoint load plus an occasional flush, so its cost stays near flat, between 1.27× and 1.72× at \(T=6\) depending on how many drafts are accepted (more acceptance advances the buffer faster and flushes more often).
Figure 10 further shows how kernel speedup propagates to full decode step speedup. The 2.28–3.33× kernel speedup translates to 1.23–1.69× on the verify forward pass, then 1.20–1.58× on the full decode step once draft-model and preprocessing overheads are included.
Breakdown 2: higher maximum concurrency
The per-draft snapshots also cost capacity. Under a fixed HBM budget (window = 4), the baseline’s preallocated states cut the maximum decode batch by roughly 4× relative to standard serving. ReplaySSM caches small input vectors instead of full states, recovering 3.0–3.3× of that concurrency (Figure 11). For a throughput-oriented deployment the maximum concurrency matters as much as per-step latency. It determines how many requests the speculative path can serve at all.
Together, these two effects explain the trends in Figure 8. Cheaper verification lifts the entire curve, while the smaller memory footprint allows ReplaySSM to continue scaling with batch size where the baseline flattens out.
ReplaySSM makes a simple change: instead of storing the state, we cache recent inputs. This simple change reduces memory traffic, enables low-cost rollback, and unlocks output-only decoding.
ReplaySSM is not limited to Mamba-2; it also applies to delta-rule models such as GDN. We implemented ReplaySSM in vLLM, where it speeds up standard decoding and removes key obstacles that have long hindered speculative decoding.
Looking ahead, we plan to bring the ideas behind ReplaySSM to more SSM architectures, such as Mamba-3 and GDN2
Baseline GDN decoder
The GDN baseline mirrors Algorithm 1, with one extra correction term that lets the model erase stale content.
Algorithm 5: GDN baseline (recurrent state update)
Input: state \(S \in \mathbb{R}^{d\times n}\) from HBM, step \((q, k, v, g, \beta)\)
- $\alpha \gets e^{g}$
- \(S \gets \alpha\,S\) // gated decay
- \(u \gets \beta\,(v - S\,k)\) // correction: subtract the state’s readout at \(k\)
- \(S \gets S + u\,k^\top\) // rank-1 write at \(k\)
- \(y \gets S\,q\) // readout at \(q\)
- store \(S\) to HBM
- Return \(y\)
The key difference between Mamba-2 and GDN is what we cache. In Mamba-2, we simply append \((v, k)\) to the buffer. This does not work for GDN because of the correction term
\[u = \beta\,(v - S\,k).\]Computing \(u_t\) requires the state \(S_{t-1}\). If we cached the raw \(v_t\), replaying the buffer would still need every intermediate state, reintroducing the serial dependency we are trying to eliminate.
The fix is to cache \(u\) instead of \(v\). Once \(u_t\) is known, the GDN update becomes \(S_t = \alpha_t\,S_{t-1} + u_t\,k_t^\top\). The state then unrolls into a decayed checkpoint plus a weighted sum of \(u_j\,k_j^\top\), same as the Mamba-2 history route.
ReplaySSM
Algorithm 6: ReplaySSM GDN standard decode (state reconstruction)
State: checkpoint \(S_0\), buffer \(\{(u_j, k_j, g_j)\}_{j=1}^{h}\), with \(\alpha_j = e^{g_j}\)
Input: step \((q, k, v, g, \beta)\)
- \(S_h \gets \big(\textstyle\prod_j \alpha_j\big)\,S_0 + \sum_j \big(\textstyle\prod_{i>j}\alpha_i\big)\,u_j\,k_j^\top\) // rebuild from cache
- $\alpha \gets e^{g}$
- \(u \gets \beta\,(v - \alpha\,(S_h\,k))\) // correction at the current key
- \(y \gets \alpha\,(S_h\,q) + u\,(k\!\cdot\! q)\) // output for this token
- Append \((u, k, g)\) to buffer
- If buffer full:
- \(\quad\) \(S_0 \gets \alpha\,S_h + u\,k^\top\), clear buffer // flush: one state store
- Return \(y\)
Unlike Mamba-2, a GDN step needs the state’s readout at two vectors, \(k\) and \(q\). ReplaySSM therefore takes the state-and-output route. It rebuilds the state from the decayed checkpoint and the outer products \(u_j\,k_j^\top\), then reads the output from it.
Baseline in vLLM’s implementation
GDN verification suffers from the same serial loop and per-draft state snapshots as Mamba-2. Moreover, each correction \(u_s\) depends on the state after the previous draft, introducing sequential dependencies between drafts.
Algorithm 7: GDN standard speculative decode (serial delta-rule scan)
Input: drafts \(\{(q_s, k_s, v_s, g_s, \beta_s)\}_{s=1}^{T}\), state snapshots
- \(S \gets\) snapshot at the last accepted token // roll back
- For \(s = 1, \dots, T\): // serial
- $\quad S \gets e^{g_s}\,S$
- $\quad u_s \gets \beta_s\,(v_s - S\,k_s)$
- $\quad S \gets S + u_s\,k_s^\top$
- $\quad y_s \gets S\,q_s$
- \(\quad\) store \(S\) to snapshot \(s\) // full-state store per draft
- Return \(\{y_s\}\)
ReplaySSM removes the serial loop by applying the chunk-wise parallelism approach GDN uses for training. Expanding the recurrence from the reconstructed state \(S_h\) gives \(u_s = R_s - \sum_{s'<s} A_{s,s'}\,u_{s'}\), where \(R_s\) depends only on the drafts and \(S_h\), and \(A\) is strictly lower triangular. The \(T\) corrections can therefore be computed through a single triangular solve rather than \(T\) sequential state updates.
Algorithm 8: ReplaySSM GDN speculative decode (chunked delta-rule)
State: checkpoint \(S_0\), buffer \(\{(u_j, k_j, g_j)\}\)
Input: drafts \(\{(q_s, k_s, v_s, g_s, \beta_s)\}_{s=1}^{T}\), with cumulative gates \(G_s = \sum_{i \le s} g_i\)
- \(S_h \gets\) rebuild from \(S_0\) and the buffer (Algorithm 6, step 1)
- \(hq_s \gets S_h\,q_s\), \(\;hk_s \gets S_h\,k_s\) // history into each draft (GEMMs)
- \(A_{s,s'} \gets \beta_s\,e^{G_s - G_{s'}}\,(k_s\!\cdot\! k_{s'})\) for \(s' < s\) // strictly lower triangular
- $R_s \gets \beta_s\,(v_s - e^{G_s}\,hk_s)$
- \(W \gets (I + A)^{-1}\), // one \(T \times T\) inverse, all corrections at once
- $U_s \gets \sum_{s’ \le s} W_{s,s’}\,R_{s’}$
- $y_s \gets e^{G_s}\,hq_s + \sum_{s’ \le s} e^{G_s - G_{s’}}\,(k_{s’}!\cdot! q_s)\,U_{s’}$
- Append \((U_s, k_s, g_s)\) to the buffer
- If this step is a flush step:
- $\quad$ $S_0 \gets S_h$
- Return \(\{y_s\}\)
ReplaySSM removes both per-draft state snapshots and the serial loop. As a result, the entire verification step becomes parallelizable. Besides one \(T \times T\) triangular solve, all operations are GEMMs.
CUDA Graph is important for inference performance because it removes per-step CPU launch overhead. However, enabling CUDA Graph support for ReplaySSM inside vLLM is not straightforward, due to two issues.
The first issue is batch divergence. With continuous batching, different sequences reach their flush steps at different times. Speculative decoding amplifies this effect because each sequence may accept a different number of draft tokens. As a result, at a given decode step, some sequences need to flush while others do not. Since the same captured graph has to handle the whole batch, ReplaySSM treats the flush decision as per-sequence data that the kernel reads and branches at runtime, not a compile-time constant.
The second issue is commit and rollback for speculative decoding. The number of accepted tokens is only known after sampling, and it can differ across sequences. Sending those counts back to the host would add a host-device synchronization point on every step, which would stall the pipeline. Instead, ReplaySSM uses a small commit kernel to update buffer pointers directly on the device. For each sequence, the kernel advances the relevant pointers by that sequence’s accepted-token count, allowing one captured graph per batch size to cover the speculative path without host synchronization.
Prior SSM speculative decoding methods already avoid keeping a separate recurrent state for each draft token (e.g., Mamba-in-the-Llama
Here are some more articles you might like to read next: