From 1714cfc84a54b75d72ffc2a138693e61d5d160a9 Mon Sep 17 00:00:00 2001 From: aireenmei Date: Tue, 9 Jun 2026 23:57:34 +0000 Subject: [PATCH] Add local_sa_* flags for local attention tuning --- src/maxtext/configs/base.yml | 12 ++++ src/maxtext/configs/types.py | 14 +++- src/maxtext/layers/attention_op.py | 102 ++++++++++++++--------------- 3 files changed, 75 insertions(+), 53 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index f62c1d1997..b4c5b25495 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1057,15 +1057,27 @@ sa_block_kv_dkv: 512 sa_block_kv_dkv_compute: 512 sa_block_q_dq: 512 sa_block_kv_dq: 512 +local_sa_block_q: 512 +local_sa_block_kv: 512 +local_sa_block_kv_compute: 512 +local_sa_block_q_dkv: 512 +local_sa_block_kv_dkv: 512 +local_sa_block_kv_dkv_compute: 512 +local_sa_block_q_dq: 512 +local_sa_block_kv_dq: 512 sa_use_fused_bwd_kernel: false sa_q_layout: "HEAD_DIM_MINOR" sa_k_layout: "HEAD_DIM_MINOR" sa_v_layout: "HEAD_DIM_MINOR" +local_sa_q_layout: "HEAD_DIM_MINOR" +local_sa_k_layout: "HEAD_DIM_MINOR" +local_sa_v_layout: "HEAD_DIM_MINOR" use_max_logit_estimate: -1 # -1 means no estimate, any > 0 value will be used as max logit estimate cost_estimate_flops_fwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (forward) cost_estimate_flops_bwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (backward) dq_reduction_steps: 0 #the number of reduction steps. For now, only 3 or all the kv steps are supported. use_splash_scheduler: false # to use tokamax splash attention scheduler. +local_use_splash_scheduler: false # to use experimental local splash attention scheduler. ### Determine if we want to use load balance for context parallelism context_parallel_load_balance: true context_parallel_strategy: "all_gather" # "all_gather" or "ring" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b26e30ac76..3ce7127bb7 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -674,10 +674,21 @@ class SplashAttention(BaseModel): sa_block_kv_dkv_compute: int = Field(512, description="Block size for KV_dkv compute in splash attention.") sa_block_q_dq: int = Field(512, description="Block size for Q_dq in splash attention.") sa_block_kv_dq: int = Field(512, description="Block size for KV_dq in splash attention.") + local_sa_block_q: int = Field(512, description="Block size for Q in local splash attention.") + local_sa_block_kv: int = Field(512, description="Block size for KV in local splash attention.") + local_sa_block_kv_compute: int = Field(512, description="Block size for KV compute in local splash attention.") + local_sa_block_q_dkv: int = Field(512, description="Block size for Q_dkv in local splash attention.") + local_sa_block_kv_dkv: int = Field(512, description="Block size for KV_dkv in local splash attention.") + local_sa_block_kv_dkv_compute: int = Field(512, description="Block size for KV_dkv compute in local splash attention.") + local_sa_block_q_dq: int = Field(512, description="Block size for Q_dq in local splash attention.") + local_sa_block_kv_dq: int = Field(512, description="Block size for KV_dq in local splash attention.") sa_use_fused_bwd_kernel: bool = Field(False, description="Use fused backward kernel in splash attention.") sa_q_layout: str = Field("HEAD_DIM_MINOR", description="Layout for Q in splash attention.") sa_k_layout: str = Field("HEAD_DIM_MINOR", description="Layout for K in splash attention.") sa_v_layout: str = Field("HEAD_DIM_MINOR", description="Layout for V in splash attention.") + local_sa_q_layout: str = Field("HEAD_DIM_MINOR", description="Layout for Q in local splash attention.") + local_sa_k_layout: str = Field("HEAD_DIM_MINOR", description="Layout for K in local splash attention.") + local_sa_v_layout: str = Field("HEAD_DIM_MINOR", description="Layout for V in local splash attention.") use_max_logit_estimate: int = Field( -1, description="-1 means no estimate, any > 0 value will be used as max logit estimate", @@ -697,6 +708,7 @@ class SplashAttention(BaseModel): description="the number of reduction steps. For now, only 3 or all the kv steps are supported.", ) use_splash_scheduler: bool = Field(False, description="Use experimental splash attention scheduler.") + local_use_splash_scheduler: bool = Field(False, description="Use experimental local splash attention scheduler.") class MoEGeneral(BaseModel): @@ -2441,7 +2453,7 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig": # If the tokenizer path is a relative name without a directory, resolve it against the assets root. # This maintains backward compatibility for configs that just specify e.g., "tokenizer.llama2". tokenizer_path = getattr(self, "tokenizer_path", "") - if tokenizer_path and not os.path.exists(tokenizer_path) and not tokenizer_path.startswith("gs://"): + if tokenizer_path and not os.path.exists(tokenizer_path) and not tokenizer_path.startswith("gs://") and not self.colocated_python_data_input: tokenizer_path = next( filter( os.path.exists, diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index 2f46179dd9..59696d2d5d 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -79,19 +79,6 @@ # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes # pytype: disable=attribute-error -# Used to pass in splash attention block sizes from config. -global_block_q = 0 -global_block_kv = 0 -global_block_kv_compute = 0 -global_block_q_dkv = 0 -global_block_kv_dkv = 0 -global_block_kv_dkv_compute = 0 -global_block_q_dq = 0 -global_block_kv_dq = 0 -global_use_fused_bwd_kernel = False -global_q_layout = "" -global_k_layout = "" -global_v_layout = "" dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) @@ -492,6 +479,32 @@ def __init__( self.quant = quant self.kv_quant = kv_quant self.attention_type = attention_type + if self.attention_type == AttentionType.LOCAL_SLIDING: + self.block_q = self.config.local_sa_block_q + self.block_kv = self.config.local_sa_block_kv + self.block_kv_compute = self.config.local_sa_block_kv_compute + self.block_q_dkv = self.config.local_sa_block_q_dkv + self.block_kv_dkv = self.config.local_sa_block_kv_dkv + self.block_kv_dkv_compute = self.config.local_sa_block_kv_dkv_compute + self.block_q_dq = self.config.local_sa_block_q_dq + self.block_kv_dq = self.config.local_sa_block_kv_dq + self.q_layout = self.config.local_sa_q_layout + self.k_layout = self.config.local_sa_k_layout + self.v_layout = self.config.local_sa_v_layout + self.use_splash_scheduler = self.config.local_use_splash_scheduler + else: + self.block_q = self.config.sa_block_q + self.block_kv = self.config.sa_block_kv + self.block_kv_compute = self.config.sa_block_kv_compute + self.block_q_dkv = self.config.sa_block_q_dkv + self.block_kv_dkv = self.config.sa_block_kv_dkv + self.block_kv_dkv_compute = self.config.sa_block_kv_dkv_compute + self.block_q_dq = self.config.sa_block_q_dq + self.block_kv_dq = self.config.sa_block_kv_dq + self.q_layout = self.config.sa_q_layout + self.k_layout = self.config.sa_k_layout + self.v_layout = self.config.sa_v_layout + self.use_splash_scheduler = self.config.use_splash_scheduler self.attn_logits_soft_cap = attn_logits_soft_cap self.sliding_window_size = sliding_window_size self.chunk_attn_window_size = chunk_attn_window_size @@ -1151,21 +1164,6 @@ def tpu_flash_attention( axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv) indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_ATTN, Q_LENGTH, KV_LENGTH)) - global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv - global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel - global global_q_layout, global_k_layout, global_v_layout - global_block_q = self.config.sa_block_q - global_block_kv = self.config.sa_block_kv - global_block_kv_compute = self.config.sa_block_kv_compute - global_block_q_dkv = self.config.sa_block_q_dkv - global_block_kv_dkv = self.config.sa_block_kv_dkv - global_block_kv_dkv_compute = self.config.sa_block_kv_dkv_compute - global_block_q_dq = self.config.sa_block_q_dq - global_block_kv_dq = self.config.sa_block_kv_dq - global_use_fused_bwd_kernel = self.config.sa_use_fused_bwd_kernel - global_q_layout = self.config.sa_q_layout - global_k_layout = self.config.sa_k_layout - global_v_layout = self.config.sa_v_layout devices_in_data_fsdp = self.mesh.shape.get("data", 1) * self.mesh.shape.get("fsdp", 1) assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( @@ -1178,16 +1176,16 @@ def tpu_flash_attention( def create_sa_config(config, query, key, attn_logits_soft_cap): if config.use_tokamax_splash: sa_config = tokamax_splash_kernel.SplashConfig( - block_q=min(global_block_q, query.shape[2]), - block_kv=min(global_block_kv, key.shape[2]), - block_kv_compute=min(global_block_kv_compute, key.shape[2]), - block_q_dkv=min(global_block_q_dkv, query.shape[2]), - block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), - block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), + block_q=min(self.block_q, query.shape[2]), + block_kv=min(self.block_kv, key.shape[2]), + block_kv_compute=min(self.block_kv_compute, key.shape[2]), + block_q_dkv=min(self.block_q_dkv, query.shape[2]), + block_kv_dkv=min(self.block_kv_dkv, key.shape[2]), + block_kv_dkv_compute=min(self.block_kv_dkv_compute, query.shape[2]), use_fused_bwd_kernel=True, # tokamax only supports fused bwd kernel - q_layout=tokamax_splash_kernel.QKVLayout[global_q_layout], - k_layout=tokamax_splash_kernel.QKVLayout[global_k_layout], - v_layout=tokamax_splash_kernel.QKVLayout[global_v_layout], + q_layout=tokamax_splash_kernel.QKVLayout[self.q_layout], + k_layout=tokamax_splash_kernel.QKVLayout[self.k_layout], + v_layout=tokamax_splash_kernel.QKVLayout[self.v_layout], attn_logits_soft_cap=attn_logits_soft_cap, residual_checkpoint_name="context", fwd_cost_estimate=pl.CostEstimate( @@ -1205,22 +1203,22 @@ def create_sa_config(config, query, key, attn_logits_soft_cap): if config.cost_estimate_flops_bwd >= 0 else None, dq_reduction_steps=config.dq_reduction_steps if config.dq_reduction_steps > 0 else None, - use_experimental_scheduler=config.use_splash_scheduler, + use_experimental_scheduler=self.use_splash_scheduler, ) else: sa_config = splash_attention_kernel.BlockSizes( - block_q=min(global_block_q, query.shape[2]), - block_kv=min(global_block_kv, key.shape[2]), - block_kv_compute=min(global_block_kv_compute, key.shape[2]), - block_q_dkv=min(global_block_q_dkv, query.shape[2]), - block_kv_dkv=min(global_block_kv_dkv, key.shape[2]), - block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]), - block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]), - block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]), - use_fused_bwd_kernel=global_use_fused_bwd_kernel, - q_layout=splash_attention_kernel.QKVLayout[global_q_layout], - k_layout=splash_attention_kernel.QKVLayout[global_k_layout], - v_layout=splash_attention_kernel.QKVLayout[global_v_layout], + block_q=min(self.block_q, query.shape[2]), + block_kv=min(self.block_kv, key.shape[2]), + block_kv_compute=min(self.block_kv_compute, key.shape[2]), + block_q_dkv=min(self.block_q_dkv, query.shape[2]), + block_kv_dkv=min(self.block_kv_dkv, key.shape[2]), + block_kv_dkv_compute=min(self.block_kv_dkv_compute, query.shape[2]), + block_q_dq=None if config.sa_use_fused_bwd_kernel else min(self.block_q_dq, query.shape[2]), + block_kv_dq=None if config.sa_use_fused_bwd_kernel else min(self.block_kv_dq, query.shape[2]), + use_fused_bwd_kernel=config.sa_use_fused_bwd_kernel, + q_layout=splash_attention_kernel.QKVLayout[self.q_layout], + k_layout=splash_attention_kernel.QKVLayout[self.k_layout], + v_layout=splash_attention_kernel.QKVLayout[self.v_layout], ) return sa_config @@ -1453,8 +1451,8 @@ def kernel_fn(q, k, v, d, s): key, value, decoder_segment_ids_tuple, - block_kv=self.config.sa_block_kv, - block_q=self.config.sa_block_q, + block_kv=self.block_kv, + block_q=self.block_q, mask=materialized_mask, mask_value=DEFAULT_MASK_VALUE, )