Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 103 additions & 11 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
AxisNames,
BATCH,
BATCH_NO_EXP,
CACHE_BATCH,
CACHE_BATCH_PREFILL,
CACHE_SEQUENCE,
CACHE_HEADS_NONE,
CACHE_KV,
Config,
DECODE_BATCH,
DECODE_LENGTH,
Expand Down Expand Up @@ -75,6 +80,9 @@
from maxtext.utils.sharding import create_sharding


PLACEHOLDER_SEQ_LEN = 1


class Indexer(nnx.Module):
"""Indexer for DeepSeek Sparse Attention (DSA).

Expand Down Expand Up @@ -108,6 +116,7 @@ def __init__(
self.rngs = rngs
self.dtype = config.dtype
self.weight_dtype = config.weight_dtype
self.max_target_length = config.max_target_length

self.n_heads = config.index_n_heads
self.head_dim = config.index_head_dim
Expand Down Expand Up @@ -167,6 +176,31 @@ def __init__(
rngs=self.rngs,
)

def update_indexer_cache(self, kv_cache, k, decoder_segment_ids, model_mode, previous_chunk):
"""Updates Indexer buffers by processing KV cache results."""
k_expanded = k[:, :, jnp.newaxis, :]
p_res, a_res = kv_cache(
key=k_expanded,
value=k_expanded,
decoder_segment_ids=decoder_segment_ids,
model_mode=model_mode,
Comment on lines +181 to +186
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, a property of the indexer is that it only requires storage for the keys.

However, we are initializing the indexer cache via KVCache, which stores keys and values. Passing k_expanded as the value doubles the HBM footprint. Is it possible to pass in a dummy value/tensor. Maybe of shape [B, S, 1, 1] for the value.

Moreover, to accomodate the usage of KVCache for indexer class, we artificially unsqueeze k to 4D to satisfy the generic API, only to squeeze it back to 3D upon retrieval. This operation is redundant but would still work functionally.

Maybe a comment we can put here as a TODO is if a dedicated SparseKVCache class were written explicitly for the Indexer, we can drop the head dimension entirely and simplify the computation graph.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could focus on functionality now, and optimize as a followup, either via maxtext or vLLM.

Similar KV cache optimization: In the other PR which set value = key, we could similarly reduce KV cache to K cache only. Their solution is to defer the decoding optimization to vLLM: #3283 (comment)

use_ragged_attention=self.config.use_ragged_attention,
previous_chunk=previous_chunk,
)

# Filter out None values to handle PREFILL vs AR modes uniformly
active_results = [res for res in [p_res, a_res] if res is not None]

if not active_results:
return None, None

# Extract keys (index 0) and segment IDs (index 2)
keys = jnp.concatenate([res[0] for res in active_results], axis=1)
segs = jnp.concatenate([res[2] for res in active_results], axis=1)

# squeeze(2) removes the jnp.newaxis added above
return keys.squeeze(2), segs

def apply_partial_rope(
self,
inputs: Array,
Expand Down Expand Up @@ -220,6 +254,10 @@ def __call__(
inputs_kv: Array,
inputs_positions: Optional[Array | None] = None,
attention_mask: Optional[Array | None] = None,
decoder_segment_ids: Optional[Array | None] = None,
previous_chunk: Any = None,
kv_cache: Any = None,
model_mode: str = MODEL_MODE_TRAIN,
):
"""Computes the index score to determine the top-k relevant tokens.

Expand All @@ -244,6 +282,10 @@ def __call__(
`DEFAULT_MASK_VALUE` (a large negative number) prevent it.
Returns `None` if no masking is determined to be necessary based on
the inputs and configuration.
decoder_segment_ids: Segment IDs for decoder masking.
previous_chunk: Previous chunk info for prefill.
kv_cache: Key-value cache used when serving models.
model_mode: "train", "prefill", or "autoregressive".

Returns:
index_mask: A sparse mask [b, t, s] with 0.0 for top-k selected tokens
Expand All @@ -258,10 +300,6 @@ def __call__(
h: Number of Indexer Heads (index_n_heads)
d: Indexer Head Dimension (index_head_dim)
"""
# NOTE: If sequence length <= topk, indexer always selects all tokens.
if self.config.max_target_length <= self.index_topk:
return None, None, None

bsz, seqlen, _ = inputs_q.shape # s = t = seqlen

# Query Processing: Project from Latent low_rank_q
Expand All @@ -276,6 +314,16 @@ def __call__(
k = self.apply_partial_rope(k, inputs_positions=inputs_positions)
k = k.squeeze(2) # [b, s, 1, d] -> [b, s, d]

# Update and retrieve from cache if not training
cached_s = None
if model_mode != MODEL_MODE_TRAIN:
k_cached, cached_s = self.update_indexer_cache(kv_cache, k, decoder_segment_ids, model_mode, previous_chunk)
k = k_cached if k_cached is not None else k

# NOTE: If the total available sequence length <= topk, indexer always selects all tokens.
if k.shape[1] <= self.index_topk:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this check, doesn't k.shape[1] always equal to max_prefill_predict_length. Since JAX uses static shapes unlike Pytorch, then this condition would evaluate to False anytime max_prefill_predict_length > top_k value.

However, the ideal behavior would be as decoding runs, the condition should check against the current token number that increases by 1 every decode step.

I asked gemini to generate jax code for this check and gave this:

# 1. Dynamically count active tokens
# cached_s has > 0 for valid tokens, 0 for padding
active_token_count = jnp.sum(cached_s > 0, axis=-1) 

# 2. Define the heavy compute path
def compute_sparse_mask():
    logits = jnp.einsum("bthd, bsd -> btsh", q, k)
    # ... (rest of the heavy einsum and top_k math) ...
    return index_mask, topk_indices, index_score

# 3. Define the bypass path (Dense Attention)
def bypass_sparse_mask():
    # If we have fewer tokens than top-k, we just want standard dense attention.
    # Dense attention is achieved by simply returning the padding mask!
    index_mask = jnp.where(cached_s > 0, 0.0, DEFAULT_MASK_VALUE)
    
    # Return dummy tensors for the unused outputs to keep JAX happy
    dummy_topk = jnp.zeros((bsz, seqlen, self.index_topk), dtype=jnp.int32)
    dummy_score = jnp.zeros((bsz, seqlen, k.shape[1]), dtype=self.dtype)
    return index_mask, dummy_topk, dummy_score

# 4. Use jax.lax.cond to execute only ONE of the paths
index_mask, topk_indices, index_score = jax.lax.cond(
    (active_token_count > self.index_topk).any(), # Condition
    compute_sparse_mask,                          # Runs if True
    bypass_sparse_mask                            # Runs if False
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could focus on functionality now, and optimize as a followup. Currently, the logic seems correct to me

train, k.shape[1] = max_target_length
prefill, k.shape[1] = max_prefill_length
autoregressive, k.shape[1] = max_target_length

e.g., I tried

python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \
model_name=deepseek3.2-671b \
override_model_config=true base_num_decoder_layers=4 first_num_dense_layers=2 \
attention=dot_product scan_layers=false sparse_matmul=false \
dtype=bfloat16 weight_dtype=bfloat16 \
per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 mla_naive_kvcache=false 'prompt=An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is '

log shows train: 1024, prefill: 512, autoregressive: 1024

return None, None, None

# Compute Index Scores
# QK product: relu(q @ k.T), [b, t, s, h]
# Similar to MQA, each key is shared by h query head
Expand All @@ -289,6 +337,12 @@ def __call__(
# Aggregate head-wise logits: logits @ weights
index_score = jnp.einsum("btsh, bth -> bts", logits, weights, precision=self.config.matmul_precision) # [b, t, s]

internal_padding_mask = None
if cached_s is not None:
# cached_s marks valid tokens from the original prefill step and all subsequent AR steps
internal_padding_mask = jnp.where(cached_s > 0, 0.0, DEFAULT_MASK_VALUE)
index_score += internal_padding_mask[:, None, :]

# Apply attention mask before TopK
if attention_mask is not None:
index_score += attention_mask
Expand All @@ -297,12 +351,15 @@ def __call__(
_, topk_indices = jax.lax.top_k(index_score, k=self.index_topk) # topk_indices [b, t, k]

# Create Sparse Index Mask: 0 and large negatives
index_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s]
index_mask = self.generate_mask(topk_indices, k.shape[1]) # [b, t, s]

# Re-apply attention mask after TopK: in case number of unmasked tokens < TopK
if attention_mask is not None:
index_mask += attention_mask

if internal_padding_mask is not None:
index_mask += internal_padding_mask[:, None, :]

return index_mask, topk_indices, index_score


Expand Down Expand Up @@ -615,16 +672,47 @@ def __init__(
indexer_rope.interleave = False
self.indexer = Indexer(
config,
rngs=rngs,
rotary_embedding=indexer_rope,
kernel_init=kernel_init,
quant=quant,
model_mode=model_mode,
rngs=rngs,
)
self.IndexerKVCache_0 = self.init_indexer_cache(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None
else:
self.indexer = None
self.IndexerKVCache_0 = None

# Module attribute names must match names previously passed to Linen for checkpointing
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None

def init_indexer_cache(self, inputs_kv_shape: Tuple):
"""Initializes Indexer Cache."""
batch_size, _, _ = inputs_kv_shape
# Use standard KVCache to store keys. Values are unused but required by KVCache API.
# KVCache expects key_heads and value_heads. Since k is shared (MQA-like for Indexer),
# we use key_heads=1, value_heads=1.
return kvcache.KVCache(
Comment on lines +692 to +695
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at MLA.__init__, it sets up a PagedAttentionOp for the main query, key, and value computations if paged attention is enabled. However, the Indexer cache explicitly relies on the standard, dense kvcache.KVCache which would be incompatible for paged attention.

We should add a comment that paged attention is not supported with Sparse Attention.

max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=batch_size,
key_seq_len=PLACEHOLDER_SEQ_LEN,
value_seq_len=PLACEHOLDER_SEQ_LEN,
key_heads=1,
value_heads=1,
key_head_size=self.config.index_head_dim,
value_head_size=self.config.index_head_dim,
dtype=self.dtype,
kv_quant=None, # Quantization is not yet supported by the indexer.
prefill_cache_logical_axis_names=(CACHE_BATCH_PREFILL, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV),
cache_logical_axis_names=(CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS_NONE, CACHE_KV),
prefill_cache_axis_order=(1, 2, 0, 3),
ar_cache_axis_order=(1, 2, 0, 3),
use_chunked_prefill=self.config.use_chunked_prefill,
model_mode=self.model_mode,
rngs=self.rngs,
)

def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
"""Initializes the MLA-specific projections."""
# Assert required configuration parameters for MLA attention.
Expand Down Expand Up @@ -856,14 +944,13 @@ def init_mla_kv_caches(self, inputs_kv_shape: Tuple):
# and max_target_length, not the passed seq_len.
# We can use a placeholder value. The correct fix might involve refactoring
# MlaKVCache.
placeholder_seq_len = 1

return kvcache.MlaKVCache(
max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=batch_size,
key_seq_len=placeholder_seq_len,
value_seq_len=placeholder_seq_len,
key_seq_len=PLACEHOLDER_SEQ_LEN,
value_seq_len=PLACEHOLDER_SEQ_LEN,
key_head_size=self.kv_lora_rank,
value_head_size=self.qk_rope_head_dim,
dtype=self.dtype,
Expand Down Expand Up @@ -1002,6 +1089,9 @@ def __call__(
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)

if model_mode != MODEL_MODE_TRAIN and decoder_segment_ids is None:
decoder_segment_ids = jnp.ones(inputs_q.shape[:2], dtype=jnp.int32)

query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
if self.config.force_q_layout:
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
Expand All @@ -1015,8 +1105,6 @@ def __call__(
# Indexer Logic
index_mask = None
if self.use_sparse_indexer:
if model_mode != MODEL_MODE_TRAIN:
raise NotImplementedError("Sparse indexer has not implemented for inference yet.")
# generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
attention_mask = self.attention_op.generate_attention_mask(
query, key, decoder_segment_ids, model_mode, previous_chunk, bidirectional_mask
Expand All @@ -1028,6 +1116,10 @@ def __call__(
inputs_kv=inputs_kv,
inputs_positions=inputs_positions,
attention_mask=attention_mask,
decoder_segment_ids=decoder_segment_ids,
previous_chunk=previous_chunk,
kv_cache=self.IndexerKVCache_0,
model_mode=model_mode,
)

# Check if we need QK Clip stats
Expand Down
Loading
Loading