-
Notifications
You must be signed in to change notification settings - Fork 481
Enable Indexer cache for DS v3.2 decoding #3195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,11 @@ | |
| AxisNames, | ||
| BATCH, | ||
| BATCH_NO_EXP, | ||
| CACHE_BATCH, | ||
| CACHE_BATCH_PREFILL, | ||
| CACHE_SEQUENCE, | ||
| CACHE_HEADS_NONE, | ||
| CACHE_KV, | ||
| Config, | ||
| DECODE_BATCH, | ||
| DECODE_LENGTH, | ||
|
|
@@ -75,6 +80,9 @@ | |
| from maxtext.utils.sharding import create_sharding | ||
|
|
||
|
|
||
| PLACEHOLDER_SEQ_LEN = 1 | ||
|
|
||
|
|
||
| class Indexer(nnx.Module): | ||
| """Indexer for DeepSeek Sparse Attention (DSA). | ||
|
|
||
|
|
@@ -108,6 +116,7 @@ def __init__( | |
| self.rngs = rngs | ||
| self.dtype = config.dtype | ||
| self.weight_dtype = config.weight_dtype | ||
| self.max_target_length = config.max_target_length | ||
|
|
||
| self.n_heads = config.index_n_heads | ||
| self.head_dim = config.index_head_dim | ||
|
|
@@ -167,6 +176,31 @@ def __init__( | |
| rngs=self.rngs, | ||
| ) | ||
|
|
||
| def update_indexer_cache(self, kv_cache, k, decoder_segment_ids, model_mode, previous_chunk): | ||
| """Updates Indexer buffers by processing KV cache results.""" | ||
| k_expanded = k[:, :, jnp.newaxis, :] | ||
| p_res, a_res = kv_cache( | ||
| key=k_expanded, | ||
| value=k_expanded, | ||
| decoder_segment_ids=decoder_segment_ids, | ||
| model_mode=model_mode, | ||
| use_ragged_attention=self.config.use_ragged_attention, | ||
| previous_chunk=previous_chunk, | ||
| ) | ||
|
|
||
| # Filter out None values to handle PREFILL vs AR modes uniformly | ||
| active_results = [res for res in [p_res, a_res] if res is not None] | ||
|
|
||
| if not active_results: | ||
| return None, None | ||
|
|
||
| # Extract keys (index 0) and segment IDs (index 2) | ||
| keys = jnp.concatenate([res[0] for res in active_results], axis=1) | ||
| segs = jnp.concatenate([res[2] for res in active_results], axis=1) | ||
|
|
||
| # squeeze(2) removes the jnp.newaxis added above | ||
| return keys.squeeze(2), segs | ||
|
|
||
| def apply_partial_rope( | ||
| self, | ||
| inputs: Array, | ||
|
|
@@ -220,6 +254,10 @@ def __call__( | |
| inputs_kv: Array, | ||
| inputs_positions: Optional[Array | None] = None, | ||
| attention_mask: Optional[Array | None] = None, | ||
| decoder_segment_ids: Optional[Array | None] = None, | ||
| previous_chunk: Any = None, | ||
| kv_cache: Any = None, | ||
| model_mode: str = MODEL_MODE_TRAIN, | ||
| ): | ||
| """Computes the index score to determine the top-k relevant tokens. | ||
|
|
||
|
|
@@ -244,6 +282,10 @@ def __call__( | |
| `DEFAULT_MASK_VALUE` (a large negative number) prevent it. | ||
| Returns `None` if no masking is determined to be necessary based on | ||
| the inputs and configuration. | ||
| decoder_segment_ids: Segment IDs for decoder masking. | ||
| previous_chunk: Previous chunk info for prefill. | ||
| kv_cache: Key-value cache used when serving models. | ||
| model_mode: "train", "prefill", or "autoregressive". | ||
|
|
||
| Returns: | ||
| index_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens | ||
|
|
@@ -258,10 +300,6 @@ def __call__( | |
| h: Number of Indexer Heads (index_n_heads) | ||
| d: Indexer Head Dimension (index_head_dim) | ||
| """ | ||
| # NOTE: If sequence length <= topk, indexer always selects all tokens. | ||
| if self.config.max_target_length <= self.index_topk: | ||
| return None, None, None | ||
|
|
||
| bsz, seqlen, _ = inputs_q.shape # s = t = seqlen | ||
|
|
||
| # Query Processing: Project from Latent low_rank_q | ||
|
|
@@ -276,6 +314,16 @@ def __call__( | |
| k = self.apply_partial_rope(k, inputs_positions=inputs_positions) | ||
| k = k.squeeze(2) # [b, s, 1, d] -> [b, s, d] | ||
|
|
||
| # Update and retrieve from cache if not training | ||
| cached_s = None | ||
| if model_mode != MODEL_MODE_TRAIN: | ||
| k_cached, cached_s = self.update_indexer_cache(kv_cache, k, decoder_segment_ids, model_mode, previous_chunk) | ||
| k = k_cached if k_cached is not None else k | ||
|
|
||
| # NOTE: If the total available sequence length <= topk, indexer always selects all tokens. | ||
| if k.shape[1] <= self.index_topk: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this check, doesn't k.shape[1] always equal to max_prefill_predict_length. Since JAX uses static shapes unlike Pytorch, then this condition would evaluate to False anytime max_prefill_predict_length > top_k value. However, the ideal behavior would be as decoding runs, the condition should check against the current token number that increases by 1 every decode step. I asked gemini to generate jax code for this check and gave this:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could focus on functionality now, and optimize as a followup. Currently, the logic seems correct to me e.g., I tried log shows |
||
| return None, None, None | ||
|
|
||
| # Compute Index Scores | ||
| # QK product: relu(q @ k.T), [b, t, s, h] | ||
| # Similar to MQA, each key is shared by h query head | ||
|
|
@@ -289,6 +337,12 @@ def __call__( | |
| # Aggregate head-wise logits: logits @ weights | ||
| index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s] | ||
|
|
||
| internal_padding_mask = None | ||
| if cached_s is not None: | ||
| # cached_s marks valid tokens from the original prefill step and all subsequent AR steps | ||
| internal_padding_mask = jnp.where(cached_s > 0, 0.0, DEFAULT_MASK_VALUE) | ||
| index_score += internal_padding_mask[:, None, :] | ||
|
|
||
| # Apply attention mask before TopK | ||
| if attention_mask is not None: | ||
| index_score += attention_mask | ||
|
|
@@ -297,12 +351,15 @@ def __call__( | |
| _, topk_indices = jax.lax.top_k(index_score, k=self.index_topk) # topk_indices [b, t, k] | ||
|
|
||
| # Create Sparse Index Mask: 0 and large negatives | ||
| index_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s] | ||
| index_mask = self.generate_mask(topk_indices, k.shape[1]) # [b, t, s] | ||
|
|
||
| # Re-apply attention mask after TopK: in case number of unmasked tokens < TopK | ||
| if attention_mask is not None: | ||
| index_mask += attention_mask | ||
|
|
||
| if internal_padding_mask is not None: | ||
| index_mask += internal_padding_mask[:, None, :] | ||
|
|
||
| return index_mask, topk_indices, index_score | ||
|
|
||
|
|
||
|
|
@@ -615,16 +672,47 @@ def __init__( | |
| indexer_rope.interleave = False | ||
| self.indexer = Indexer( | ||
| config, | ||
| rngs=rngs, | ||
| rotary_embedding=indexer_rope, | ||
| kernel_init=kernel_init, | ||
| quant=quant, | ||
| model_mode=model_mode, | ||
| rngs=rngs, | ||
| ) | ||
| self.IndexerKVCache_0 = self.init_indexer_cache(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None | ||
| else: | ||
| self.indexer = None | ||
| self.IndexerKVCache_0 = None | ||
|
|
||
| # Module attribute names must match names previously passed to Linen for checkpointing | ||
| self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None | ||
|
|
||
| def init_indexer_cache(self, inputs_kv_shape: Tuple): | ||
| """Initializes Indexer Cache.""" | ||
| batch_size, _, _ = inputs_kv_shape | ||
| # Use standard KVCache to store keys. Values are unused but required by KVCache API. | ||
| # KVCache expects key_heads and value_heads. Since k is shared (MQA-like for Indexer), | ||
| # we use key_heads=1, value_heads=1. | ||
| return kvcache.KVCache( | ||
|
Comment on lines
+692
to
+695
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking at We should add a comment that paged attention is not supported with Sparse Attention. |
||
| max_prefill_length=self.max_prefill_predict_length, | ||
| max_target_length=self.max_target_length, | ||
| batch=batch_size, | ||
| key_seq_len=PLACEHOLDER_SEQ_LEN, | ||
| value_seq_len=PLACEHOLDER_SEQ_LEN, | ||
| key_heads=1, | ||
| value_heads=1, | ||
| key_head_size=self.config.index_head_dim, | ||
| value_head_size=self.config.index_head_dim, | ||
| dtype=self.dtype, | ||
| kv_quant=None, # Quantization is not yet supported by the indexer. | ||
| prefill_cache_logical_axis_names=(CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV), | ||
| cache_logical_axis_names=(CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV), | ||
| prefill_cache_axis_order=(1, 2, 0, 3), | ||
| ar_cache_axis_order=(1, 2, 0, 3), | ||
| use_chunked_prefill=self.config.use_chunked_prefill, | ||
| model_mode=self.model_mode, | ||
| rngs=self.rngs, | ||
| ) | ||
|
|
||
| def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: | ||
| """Initializes the MLA-specific projections.""" | ||
| # Assert required configuration parameters for MLA attention. | ||
|
|
@@ -856,14 +944,13 @@ def init_mla_kv_caches(self, inputs_kv_shape: Tuple): | |
| # and max_target_length, not the passed seq_len. | ||
| # We can use a placeholder value. The correct fix might involve refactoring | ||
| # MlaKVCache. | ||
| placeholder_seq_len = 1 | ||
|
|
||
| return kvcache.MlaKVCache( | ||
| max_prefill_length=self.max_prefill_predict_length, | ||
| max_target_length=self.max_target_length, | ||
| batch=batch_size, | ||
| key_seq_len=placeholder_seq_len, | ||
| value_seq_len=placeholder_seq_len, | ||
| key_seq_len=PLACEHOLDER_SEQ_LEN, | ||
| value_seq_len=PLACEHOLDER_SEQ_LEN, | ||
| key_head_size=self.kv_lora_rank, | ||
| value_head_size=self.qk_rope_head_dim, | ||
| dtype=self.dtype, | ||
|
|
@@ -1002,6 +1089,9 @@ def __call__( | |
| inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names) | ||
| out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV) | ||
|
|
||
| if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None: | ||
| decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32) | ||
|
|
||
| query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) | ||
| if self.config.force_q_layout: | ||
| query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) | ||
|
|
@@ -1015,8 +1105,6 @@ def __call__( | |
| # Indexer Logic | ||
| index_mask = None | ||
| if self.use_sparse_indexer: | ||
| if model_mode != MODEL_MODE_TRAIN: | ||
| raise NotImplementedError("Sparse indexer has not implemented for inference yet.") | ||
| # generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len] | ||
| attention_mask = self.attention_op.generate_attention_mask( | ||
| query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask | ||
|
|
@@ -1028,6 +1116,10 @@ def __call__( | |
| inputs_kv=inputs_kv, | ||
| inputs_positions=inputs_positions, | ||
| attention_mask=attention_mask, | ||
| decoder_segment_ids=decoder_segment_ids, | ||
| previous_chunk=previous_chunk, | ||
| kv_cache=self.IndexerKVCache_0, | ||
| model_mode=model_mode, | ||
| ) | ||
|
|
||
| # Check if we need QK Clip stats | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, a property of the indexer is that it only requires storage for the keys.
However, we are initializing the indexer cache via KVCache, which stores keys and values. Passing
k_expandedas the value doubles the HBM footprint. Is it possible to pass in a dummy value/tensor. Maybe of shape [B, S, 1, 1] for the value.Moreover, to accomodate the usage of KVCache for indexer class, we artificially unsqueeze
kto 4D to satisfy the generic API, only to squeeze it back to 3D upon retrieval. This operation is redundant but would still work functionally.Maybe a comment we can put here as a TODO is if a dedicated SparseKVCache class were written explicitly for the Indexer, we can drop the head dimension entirely and simplify the computation graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could focus on functionality now, and optimize as a followup, either via maxtext or vLLM.
Similar KV cache optimization: In the other PR which set value = key, we could similarly reduce KV cache to K cache only. Their solution is to defer the decoding optimization to vLLM: #3283 (comment)