Cascade: Token-Sharded Private LLM Inference
Rahul Thomas , Louai Zahran , Erica Choi , Micah Goldblum , Arka Pal 1 day ago
Further details on this blog post can be found in our arXiv paper, as well as in our ICML 2025 submission.
In our previous blog post, we introduced a novel reconstruction attack that nearly perfectly reversed input tokens from hidden states of autoregressive LLMs. The success of our attack relied on the non-collision of these LLM hidden states – when performing a forward pass on two different prompts, the resulting hidden states are never the same.
In practice, we saw few collisions and concluded that hidden states leak significant information about the input tokens. Building on this insight, we recognized that reordering hidden state rows and elements within them did not mitigate non-collision. In other words, we could modify our attack to successfully decode certain permutations of hidden states. This ultimately broke the security of three permutation-based MPC protocols which exposed permuted hidden states to the party doing inference: PermLLM, STIP, and Centaur.
Since even a rearranged full hidden state matrix allows recovery of the full input prompt through vocab-matching, it might seem futile to search for obfuscation protocols that do not alter hidden state element values. However, by relaxing the assumption that the full hidden state matrix is revealed, we can formulate such a scheme that protects against vocab-matching: Cascade.
Cascade is based on the idea of token-sharding. The hidden state matrix at each layer has shape N×d, where N is the token count and d is the hidden dimension; revealing all of these N rows to an adversary, even when the N rows are individually permuted and then rearranged, falls prey to our attack. But what if we only reveal n out of these N token rows – do we remain secure to a similar attack? The answer is not clear-cut, and depends on the particular n rows rather than just the value of n – the sharding scheme. For some sparse sharding schemes, we will see that an adversary with reasonable computational power can no longer carry out our attack on the n rows. Therefore, Cascade aims to only reveal some of the N rows to different nodes throughout inference, to prevent vocab-matching.
To motivate Cascade, we will first show how to split transformer blocks into stages compatible with token-sharding.
Transformer Inference
For transformer architectures, the layer l transformer block can be considered a nonlinear mapping RN×d→RN×d from layer l to layer l+1 hidden states, where N is the token count and d is the hidden dimension.
There is a nice way to break down single block inference into three stages. In the first stage, we perform query, key, and value projection on normalized N×d hidden states. In the second stage, we compute the output of multi-headed attention on the previously derived query, key, and value matrices, and flatten it to shape N×dattn. Finally, in the third stage, we project the attention output to shape N×d, add it back to the layer l hiddens, and apply the MLP, which often has a residual connection with normalization.
# Stage 1: Normalization and query, key, value projections.
hidden_norm = norm(hidden) # RMSNorm, LayerNorm, or others.
query = q_proj(hidden_norm) # Apply rotary or positional embedding here if needed.
key = k_proj(hidden_norm)
value = v_proj(hidden_norm)
# Stage 2: Attention logit computation and flattening.
attn_logits = query @ key.T + attn_mask # Attention mask initialized previously.
attn_output = softmax(attn_logits) @ value
attn_output = flatten(attn_output)
# Stage 3: Post-projection, residual connection, and MLP.
attn_output = hidden + o_proj(attn_output)
hidden = mlp(hidden) # MLP can have normalization and residual.
Our key observation is that the majority of operations in these stages are per-token. For example, in Stage 1, the query, key, and value projection maps, as well as normalization, operate independently on the N rows of the N×d hidden states. Also, in Stage 3, the O-projection, residual connection, and MLP operate independently on the N rows of the flattened attention output (and corresponding rows in the hidden states). In fact, Stage 2 is really the only place where operations aren’t per-token.
This allows us to integrate token-sharding quite easily into Stages 1 and 3, as well as Stage 2 with a bit of additional work.
Cascade: Token-Sharded Transformer Inference
Pre-Pass
We first integrate token-sharding into Stage 1 to form the pre-pass. We would first like to split the N rows of the N×d hidden states h=[h1,…,hN] at a given layer among α different nodes, so that no node accesses too many rows. We call these nodes CompNodes, and allocate rows through R-sharding, which is a partition {Ri}i=1α of rows {1,…,N}. This means CompNodei starts the transformer block with rows of h whose indices are in Ri.
Because all operations in Stage 1 are performed independently on the N rows, then CompNodei can carry out Stage 1 on its own shard of h, and obtain the rows of query, key, and value matrices whose indices are in Ri. This completes the pre-pass.
# Pre-pass performed by CompNode_i.
hidden_norm_Ri = norm(hidden_Ri)
query_Ri = q_proj(hidden_norm_Ri)
key_Ri = k_proj(hidden_norm_Ri)
value_Ri = v_proj(hidden_norm_Ri)
Attention-Pass
At the end of the pre-pass, each CompNodei has query, key, and value rows indexed by Ri. Moving onto Stage 2, we aim to compute attention logits through query-key multiplication. However, CompNodes cannot compute all these logits in isolation! For i=j, there is no way for a single node to compute the dot product of a query row in Ri and a key row in Rj, since CompNodei is missing the latter, CompNodej is missing the former, and all other CompNodes are missing both.
Not all hope is lost, though – when we arrive at a point where nodes cannot proceed further in a protocol without sharing information, a general rule of thumb is to introduce new nodes. To this end, we define α2 AttnNodes, indexed as AttnNodeij for 1≤i,j≤α. These will execute the next stage of the protocol, called the attention-pass.
Now, AttnNodeij acts as a sort of mediator between CompNodei and CompNodej: it receives the Ri query rows from the former, and the Rj key and value rows from the latter, so it can compute cross-attention scores on these. In particular, it exactly carries out Stage 2, but with the query matrix replaced by the Ri-sharded query matrix, and key and value matrices replaced by the Rj-sharded key and value matrices. Finally, it stores the row-wise maximum and expsum (sums of exponentials of elements) of its computed attention logits in this process – the importance of this step will become clear in the next stage.
# Attention-pass performed by AttnNode_ij.
attn_logits_Ri_Rj = query_Ri @ key_Rj.T + attn_mask_Ri_Rj # Sliced attention mask.
attn_output_Ri_Rj = softmax(attn_logits_Ri_Rj) @ value_Rj
attn_output_Ri_Rj = flatten(attn_output_Ri_Rj)
# Additional shards to store.
max_logits_Ri_Rj = rowwise_max(attn_logits_Ri_Rj)
expsum_logits_Ri_Rj = rowwise_expsum(attn_logits_Ri_Rj)
Post-Pass
Arriving finally at Stage 3, we notice a subtle discrepancy: we have not actually completed Stage 2 in the attention-pass! Even though the AttnNodes computed results softmax(QKT+mask)V with Q,K,V,mask replaced by the AttnNode’s particular shards, these cannot simply be concatenated to obtain the true value of softmax(QKT+mask)V. The root of this obstacle is that Stage 2 is not per-token.
Thankfully, it turns out that a simple extension of concatenation can derive the true result from AttnNodes’ partial results: linear weighting. Formally, CompNodei will receive partial results from all AttnNodesij, which includes their max and expsum shards and partial attention output. Using these max and expsum shards to form a weight for each AttnNodeij, CompNodei can then perform a weighted average of all partial attention outputs, to get the true attention output softmax(QKT+mask)V. This step is called attention compilation, and motivates our naming convention – CompNodes perform compilation, while AttnNodes focus on attention.
# Finish Stage 2 through linear weighting.
max_logits_Ri = elementwise_max(max_logits_Ri_R1, max_logits_Ri_R2, …)
for j in range(alpha):
weight_Ri_Rj = exp(max_logits_Ri_Rj - max_logits_Ri) * expsum_logits_Ri_Rj
attn_output_Ri = weighted_average(
vectors = [attn_output_Ri_R1, attn_output_Ri_R2, …],
weights = [weight_Ri_R1, weight_Ri_R2, …]
)
# Rest of post-pass performed by CompNode_i.
hidden_Ri = hidden_Ri + o_proj(attn_output_Ri) # Had hidden_Ri from pre-pass.
hidden_Ri = mlp(hidden_Ri)
Our final workflow for Cascade is shown below. In each layer, the α CompNodes begin with R-sharded hidden states, and perform the pre-pass. They send relevant query, key, and value rows to the α2 AttnNodes, who perform the attention-pass to get partial attention outputs and related shards. Finally, AttnNodes send this information back to the CompNodes, who perform a linear weighting of partial results and then various per-token operations to get to the R-sharded hidden states of the next layer. There are other details in the diagram, like β×β AttnNodes, that we will explain in later security considerations.
High-level representation of Cascade.
Do We Defend Against Our Attack?
In the beginning of the series, we mentioned our goal was an obfuscation scheme that defends against the vocab-matching attack, whilst retaining the exact results of normal inference. To make this notion well-defined, we need to explicitly adapt vocab-matching to the token-sharded setting. The attack we outlined in the previous blog post assumes that the adversary has access to the full hidden state matrix, so it does not work out-of-the-box for token-sharded hidden states.
Generalizing Our Attack to Token-Sharding
Suppose an adversary does not have access to all N rows of hidden states h=[h1,…,hN], but only k of them, say hi1,…,hik, with i1,…,ik in strictly increasing order.
Our modification to our original attack again utilizes the autoregressive property to reduce the search space. This proceeds iteratively, where we initialize j=0 and i0=0.
- Assume at this point that we have deciphered tokens 1,…,ij. We iterate through all possible combinations of tokens ij+1,…,ij+1 and perform forward passes, until we find a hidden state matrix whose ij+1-th row matches the observed hij+1. This takes at most Vij+1−ij forward passes.
- We set tokens ij+1,…,ij+1 to the ones that gave a match in Step 1. Then, we increment j, repeat Step 1 if j<k, and terminate if j=k.
Note that this attack only allows recovery of tokens up to and including the last hidden state index ik, but not for tokens ik+1,….
Below is an example of this generalized attack when the adversary has access to h1,h3,h6,h8, meaning i1=1,i2=3,i3=6,i4=8. First, a search over candidate first tokens e1 that matches against h1 allows recovery of e1 in at most V1 forward passes. Once e1 is known, a search over candidate second and third tokens e2,e3 with matching to h3 can recover e2,e3 in at most V2 forward passes. Next, with e1,e2,e3 known, matching against h6 lets us recover e4,e5,e6. Finally, matching to h8 recovers e7,e8. Recovery of e9,e10 is not possible with this attack — although these tokens could be inferred from the previous tokens through other means, e.g. with punctuation.
Demonstration of generalized attack for token-sharded hidden states. Colors represent different iterations of the procedure, with the maximum forward pass cost highlighted underneath.
Since we placed no restrictions on the choice of i1,…,ik, it seems at first glance that this attack works on any token-sharded hidden state matrix, and can decipher the full input sequence if ik=N. This would appear to immediately render Cascade insecure, as the CompNode with the last hidden state row could obtain the full input sequence.
What’s the catch here? The attack cost.
In our attack, we see that at the jth iteration, we could require up to Vij+1−ij passes in the worst case. For a typical LLM, the vocabulary size V is on the order of hundreds of thousands. In Gemma-2-2B-IT, with V=256000, a gap of ij+1−ij=5 entails ≈1027 forward passes. Here, even if each forward pass took one nanosecond, the worst-case runtime would be nearly 30 billion years, more than double the current age of the universe!
In other words, this attack is infeasible when gaps are large enough.
What does “large enough” mean here? This is not clear-cut, as it can depend on the use case, adversarial computational power, and many other security parameters. So, to formalize our security analysis, we introduce the vocab-matching threshold ρ, which is the maximum value of t where Vt−1 forward passes can be performed by the adversary. In context of the attack, this means that once a gap ij+1−ij is at least ρ, the attack times out at the jth iteration. This is the key condition that ensures Cascade is secure to vocab-matching.
(c,δ)-Sharding
Although we have generalized our attack to arbitrary choices of token (row) indices i1,…,ik, it will be useful to focus on a particular sharding setup called (c,δ)-sharding. Formally, this is a form of R-sharding where each Ri consists of clusters of c consecutive elements separated by δ. For instance, below, we highlight (3,9), (2,6), and (1,4) sharding schemes when N=10. The (2,6) case has R1={1,2,7,8} in green, R2={3,4,9,10} in blue, and R3={5,6} in orange.
Visual representation of (c,δ)-sharding, where boxes of the same color represent a single shard Ri.
The motivation for this kind of sharding is that we uniformly spread gaps. This will simplify our analysis on gap sizes, although many other similar schemes could be considered.
Large Gaps Ensure CompNode Security
We can now explain why, when only considering leakage from R-shards of hidden states to CompNodes, (c,δ)-sharding with a large enough gap size is secure to vocab-matching.
For simplicity, consider the first CompNode, which has hidden state rows at some layer with indices 1,2,…,c,δ+1,δ+2,…,δ+c,…. They already know tokens 1,2,…,c from Layer 0 token embeddings, so the next gap in their search is δ−c+1, and this requires on the order of Vδ−c+1 forward passes. Recalling the vocab-matching threshold, this means δ−c+1≥ρ is likely to make the attack infeasible.
A concrete example is shown below, for the (2,6)-sharding setup and the first (green) CompNode. While the CompNode can recover e1,e2 in V1 work without issue, the next gap of size 6−2+1=5 to h7 requires up to V5 forward passes. If the vocab-matching threshold was set to ρ≤5, we would consider this infeasible.
Visual representation of attack failure for large token gaps in (c,δ)-sharding.
We finally note that this is not a comprehensive security analysis, since there are other important shards leaked to CompNodes (e.g. max and expsum shards, and partial attention outputs), which could potentially reveal different information than R-sharded hidden states. Still, it turns out that these additional shards do not allow vocab-matching if the gap size is again sufficiently large. For further details, see our paper.
m-Splitting Improves AttnNode Security
Now that we have explained when vocab-matching is infeasible for CompNodes, we turn to AttnNodes. In our scheme, AttnNodeij receives all query rows with indices in Ri and all key rows with indices in Rj. At first glance, assuming the shards Ri were chosen to ensure CompNode security, one might expect AttnNodes to be secure. However, this is not the case: AttnNodeij gets access to query and key rows spanning indices Ri∪Rj, which has double the number of indices that any CompNodei accesses if i=j!
Concretely, we consider potential security risks at Layer 0, where the query, key, and value rows are linear projections of (normalized) token embeddings, and thus could reveal the same amount of information as their corresponding embeddings. Essentially, CompNodei has access to tokens in Ri and AttnNodeij has access to tokens in Ri∪Rj.
How do we prevent AttnNodes from getting this many tokens in the prompt? We consider an approach which makes sharding at the AttnNode level more granular than R-sharding. Formally, we alternatingly split each Ri into m subsets, to form a new partition {Sj}j=1β of tokens {1,…,N} with β=mα. Now, AttnNodeij receives query rows in Si and key rows in Sj, so their token (row) access has been decreased by a factor of m. This approach is called m-splitting.
While our motivation for m-splitting came from reducing direct token access, the question still remains: does this prevent vocab-matching? The answer is affirmative, in the same way as our analysis of CompNodes. If m is large enough, there are still large enough gaps between consecutive clusters in Ri∪Rj, so as long as a gap is ≥ρ, the attack times out.
What Do Learning-Based Attacks Reveal?
While we have enumerated security considerations from our attack, a comprehensive analysis must consider other forms of reversal on hidden states. Most existing attacks in the literature are learning-based, which we break into the cases of Layer 0 (textual) and later layers.
Layer 0
At Layer 0, CompNodes individually receive some tokens in the prompt. While we would expect little security risk from revelation of a few scattered tokens, there are serious issues when gaps between the tokens are not too large, due to the risk of token infilling. To test this, we use ModernBERT-large to estimate the prior distribution on input tokens and thus infill tokens. Below, we see ROUGE scores for reconstructed tokens are quite low for larger c (cluster size) and α (CompNode count), with scores below 0.1 for c,α≥8 indicating nearly random reconstruction and good security.
ROUGE-L scores for Layer 0 token infilling in Gemma-2-2B-IT with ModernBERT-Large tend to decrease as c, the cluster size, and α, the CompNode count, increase.
The analysis for AttnNodes is similar. Indeed, we remarked earlier that at Layer 0, AttnNodeij essentially has access to tokens with indices in Si∪Sj, so they essentially have the same information as a CompNode with shard Si∪Sj. Thus, for sufficiently large c,α, we ensure AttnNodes cannot perform vocab-matching, by referencing CompNode security.
Later Layers
To assess the capability of learning-based attacks that aim to infill at later layers, we follow the approaches of Wan et al and Morris et. al. We fine-tune Gemma-2-2B-IT on (c,δ)-masked inputs to CompNodes at Layer 1, with a bidirectional mask to match the infilling task, and then evaluate on Layer 1 representations. As we see below, the resulting ROUGE evaluation scores decrease as c and α increase; and c=8,α≥8 achieves a score near 0.2 or below, indicating little reconstruction success by the learning-based attack.
α=4 | α=8 | α=12 | |
---|---|---|---|
c=1 | 0.701 | 0.467 | 0.349 |
c=4 | 0.427 | 0.290 | 0.230 |
c=8 | 0.355 | 0.222 | 0.191 |
ROUGE-L scores for Layer 1 infilling in Gemma-2-2B-IT tend to decrease as c, the cluster size, and α, the CompNode count, increase.
For AttnNodes, the use of m-splitting gets ROUGE scores near those of CompNodes for c=8,α≥8. Following the same approach as above for c=8,α=8, but now using m-split (c,δ)-masked inputs to AttnNodes at Layer 1 for various values of m, we see below that m=4 achieves comparable reconstructibility to CompNodes.
m | ROUGE-L |
---|---|
2 | 0.3057 |
3 | 0.2643 |
4 | 0.2376 |
ROUGE-L scores for Layer 1 infilling in Gemma-2-2B-IT, using m-splitting, tend to decrease as m increases, with m=4 achieving comparable performance to the CompNode score.
Thus, measures like (c,δ)-sharding and m-splitting protect Cascade against existing learning-based attacks, for particular choices of those security parameters.
Cascade vs. SMPC: Scalability vs. Security
We have shown that Cascade with a large-gap (c,δ)-sharding setup defends well against learning-based attacks, as well as our generalized vocab-matching attack. However, we emphasize that Cascade does not make any claims about provable security, as cryptographic schemes like SMPC do. We can only offer statistical security, like in our demonstrated low ROUGE-L scores of existing attacks on hidden state shards.
Given the gap in security strength between Cascade and SMPC, why should we choose Cascade? The main benefits are practical deployment and scalability, as SMPC schemes are often infeasible for larger models. For instance, one recent state-of-the-art SMPC scheme, Puma, takes around 5 minutes for a full forward pass from Llama-2-7B on a 128-token prompt. This is far too slow to be used in any real-time inference service.
We compare Cascade runtimes for various values of α to existing SMPC schemes MPCFormer and Puma on Bert-Base and Bert-Large:
Scheme | Bert-Base (s) | Bert-Large (s) |
---|---|---|
MPCFormer | 55.32 | 141.22 |
Puma | 33.91 | 73.72 |
Cascadeα=1 | 0.32 [0.31, 0.36] | 1.01 [0.97, 1.09] |
Cascadeα=4 | 0.59 [0.51, 0.69] | 1.57 [1.44, 1.73] |
Cascadeα=8 | 0.74 [0.62, 0.96] | 1.58 [1.27, 1.97] |
Vanilla | 0.09 [0.08, 0.12] | 0.27 [0.20, 0.99] |
Total runtime means and 95% confidence intervals in seconds over 100 trials, for a single 128-token prompt forward pass on Bert-Base and Bert-Large for MPCFormer, Puma, and various settings of newmethod. Higher α corresponds to increased node counts and security.
We run the above across machines on a WAN; for Cascade α=4, we use 6 machines, and for α=8, we use 18 machines. We see that Cascade is up to 100x faster than MPCFormer and Puma. We also note the sublinear growth in runtime as α increases, so that scaling to large numbers of nodes (which offers the most security) does not significantly compromise runtime.
We also measure communicated bytes for each of the methods; these are shown in the table below. Cascade requires several orders of magnitude fewer communicated bytes than existing SMPC methods, supporting operation even in poor bandwidth network conditions.
Scheme | Bert-Base (GB) | Bert-Large (GB) |
---|---|---|
MPCFormer | 12.089 | 32.577 |
Puma | 10.773 | 27.246 |
Cascadeα=1 | 0.009 | 0.025 |
Cascadeα=4 | 0.038 | 0.101 |
Cascadeα=8 | 0.076 | 0.203 |
Total gigabytes (GB) communicated for a single forward pass on Bert-Base and Bert-Large for MPCFormer, Puma, and newmethod with various settings. A prompt length of 128 is used.
Finally, we show the scalability of Cascade to larger model sizes below:
Model Name | Model Size (Parameters) | Mean Runtime (s) |
---|---|---|
Bert-Base | 110M | 0.70 |
Bert-Large | 335M | 1.33 |
Llama-3.2-1B-Instruct | 1B | 2.67 |
Llama-2-7B | 7B | 12.71 |
Llama-2-13B | 13B | 22.72 |
Cascade with α=2 and no m-splits scales well to larger models, Runtimes here are averages over 100 trials for a 128-token prompt.
Cascade further scales well to larger models like Llama-2-13B, and generally seems to sublinearly increase in runtime with model size. This is because for all per-token operations in transformer blocks, there is no computational overhead in Cascade’s sharding setup relative to vanilla inference. All overhead comes from attention, but it is ultimately negligible compared to the heavy matrix multiplication costs present in vanilla inference.
Thus, even as Cascade is not cryptographically secure, due to its efficiency and scalability, it holds immediate promise as a viable protocol for private large-scale LLM inference.
Summary
Our work on Cascade stems from the following fundamental question: How much information about input tokens do partial LLM hidden states leak? We motivated this inquiry from our work on the vocab-matching attack in the previous blog post, which showed full hidden states, even when shuffled, leak the full input sequence. Our results show that by choosing a sharding scheme which places sufficient gaps between tokens, the resulting sharded hidden states reveal little to no information about the input, from the perspective of both vocab-matching and existing attacks in literature. As Cascade reveals such sharded states to nodes in isolation, this provides strong statistical evidence for its security. While Cascade cannot offer the same strict privacy guarantees as SMPC schemes, it is a highly efficient protocol that offers potential for faster private LLM inference on larger hosted models.
To cite this blog post, please use:
@misc{cascade-token-sharded-inference,
title={Cascade: Token-Sharded Private LLM Inference},
author={Rahul Thomas, Louai Zahran, Erica Choi, Micah Goldblum, Arka Pal},
year={2025},
howpublished=\\url{ritual.net/blog/cascade}
}
Disclaimer: This post is for general information purposes only. It does not constitute investment advice or a recommendation, offer or solicitation to buy or sell any investment and should not be used in the evaluation of the merits of making any investment decision. It should not be relied upon for accounting, legal or tax advice or investment recommendations. The information in this post should not be construed as a promise or guarantee in connection with the release or development of any future products, services or digital assets. This post reflects the current opinions of the authors and is not made on behalf of Ritual or its affiliates and does not necessarily reflect the opinions of Ritual, its affiliates or individuals associated with Ritual. All information in this post is provided without any representation or warranty of any kind. The opinions reflected herein are subject to change without being updated.