Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1057,15 +1057,27 @@ sa_block_kv_dkv: 512
sa_block_kv_dkv_compute: 512
sa_block_q_dq: 512
sa_block_kv_dq: 512
local_sa_block_q: 512
local_sa_block_kv: 512
local_sa_block_kv_compute: 512
local_sa_block_q_dkv: 512
local_sa_block_kv_dkv: 512
local_sa_block_kv_dkv_compute: 512
local_sa_block_q_dq: 512
local_sa_block_kv_dq: 512
sa_use_fused_bwd_kernel: false
sa_q_layout: "HEAD_DIM_MINOR"
sa_k_layout: "HEAD_DIM_MINOR"
sa_v_layout: "HEAD_DIM_MINOR"
local_sa_q_layout: "HEAD_DIM_MINOR"
local_sa_k_layout: "HEAD_DIM_MINOR"
local_sa_v_layout: "HEAD_DIM_MINOR"
use_max_logit_estimate: -1 # -1 means no estimate, any > 0 value will be used as max logit estimate
cost_estimate_flops_fwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (forward)
cost_estimate_flops_bwd: -1 # -1 means using splash default cost estmiation, any >= 0 value will be used as cost estmiation for splash to overlap for communication (backward)
dq_reduction_steps: 0 #the number of reduction steps. For now, only 3 or all the kv steps are supported.
use_splash_scheduler: false # to use tokamax splash attention scheduler.
local_use_splash_scheduler: false # to use experimental local splash attention scheduler.
### Determine if we want to use load balance for context parallelism
context_parallel_load_balance: true
context_parallel_strategy: "all_gather" # "all_gather" or "ring"
Expand Down
14 changes: 13 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,21 @@ class SplashAttention(BaseModel):
sa_block_kv_dkv_compute: int = Field(512, description="Block size for KV_dkv compute in splash attention.")
sa_block_q_dq: int = Field(512, description="Block size for Q_dq in splash attention.")
sa_block_kv_dq: int = Field(512, description="Block size for KV_dq in splash attention.")
local_sa_block_q: int = Field(512, description="Block size for Q in local splash attention.")
local_sa_block_kv: int = Field(512, description="Block size for KV in local splash attention.")
local_sa_block_kv_compute: int = Field(512, description="Block size for KV compute in local splash attention.")
local_sa_block_q_dkv: int = Field(512, description="Block size for Q_dkv in local splash attention.")
local_sa_block_kv_dkv: int = Field(512, description="Block size for KV_dkv in local splash attention.")
local_sa_block_kv_dkv_compute: int = Field(512, description="Block size for KV_dkv compute in local splash attention.")
local_sa_block_q_dq: int = Field(512, description="Block size for Q_dq in local splash attention.")
local_sa_block_kv_dq: int = Field(512, description="Block size for KV_dq in local splash attention.")
sa_use_fused_bwd_kernel: bool = Field(False, description="Use fused backward kernel in splash attention.")
sa_q_layout: str = Field("HEAD_DIM_MINOR", description="Layout for Q in splash attention.")
sa_k_layout: str = Field("HEAD_DIM_MINOR", description="Layout for K in splash attention.")
sa_v_layout: str = Field("HEAD_DIM_MINOR", description="Layout for V in splash attention.")
local_sa_q_layout: str = Field("HEAD_DIM_MINOR", description="Layout for Q in local splash attention.")
local_sa_k_layout: str = Field("HEAD_DIM_MINOR", description="Layout for K in local splash attention.")
local_sa_v_layout: str = Field("HEAD_DIM_MINOR", description="Layout for V in local splash attention.")
use_max_logit_estimate: int = Field(
-1,
description="-1 means no estimate, any > 0 value will be used as max logit estimate",
Expand All @@ -697,6 +708,7 @@ class SplashAttention(BaseModel):
description="the number of reduction steps. For now, only 3 or all the kv steps are supported.",
)
use_splash_scheduler: bool = Field(False, description="Use experimental splash attention scheduler.")
local_use_splash_scheduler: bool = Field(False, description="Use experimental local splash attention scheduler.")


class MoEGeneral(BaseModel):
Expand Down Expand Up @@ -2441,7 +2453,7 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
# If the tokenizer path is a relative name without a directory, resolve it against the assets root.
# This maintains backward compatibility for configs that just specify e.g., "tokenizer.llama2".
tokenizer_path = getattr(self, "tokenizer_path", "")
if tokenizer_path and not os.path.exists(tokenizer_path) and not tokenizer_path.startswith("gs://"):
if tokenizer_path and not os.path.exists(tokenizer_path) and not tokenizer_path.startswith("gs://") and not self.colocated_python_data_input:
tokenizer_path = next(
filter(
os.path.exists,
Expand Down
102 changes: 50 additions & 52 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,6 @@
# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
# pytype: disable=attribute-error

# Used to pass in splash attention block sizes from config.
global_block_q = 0
global_block_kv = 0
global_block_kv_compute = 0
global_block_q_dkv = 0
global_block_kv_dkv = 0
global_block_kv_dkv_compute = 0
global_block_q_dq = 0
global_block_kv_dq = 0
global_use_fused_bwd_kernel = False
global_q_layout = ""
global_k_layout = ""
global_v_layout = ""

dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))

Expand Down Expand Up @@ -492,6 +479,32 @@ def __init__(
self.quant = quant
self.kv_quant = kv_quant
self.attention_type = attention_type
if self.attention_type == AttentionType.LOCAL_SLIDING:
self.block_q = self.config.local_sa_block_q
self.block_kv = self.config.local_sa_block_kv
self.block_kv_compute = self.config.local_sa_block_kv_compute
self.block_q_dkv = self.config.local_sa_block_q_dkv
self.block_kv_dkv = self.config.local_sa_block_kv_dkv
self.block_kv_dkv_compute = self.config.local_sa_block_kv_dkv_compute
self.block_q_dq = self.config.local_sa_block_q_dq
self.block_kv_dq = self.config.local_sa_block_kv_dq
self.q_layout = self.config.local_sa_q_layout
self.k_layout = self.config.local_sa_k_layout
self.v_layout = self.config.local_sa_v_layout
self.use_splash_scheduler = self.config.local_use_splash_scheduler
else:
self.block_q = self.config.sa_block_q
self.block_kv = self.config.sa_block_kv
self.block_kv_compute = self.config.sa_block_kv_compute
self.block_q_dkv = self.config.sa_block_q_dkv
self.block_kv_dkv = self.config.sa_block_kv_dkv
self.block_kv_dkv_compute = self.config.sa_block_kv_dkv_compute
self.block_q_dq = self.config.sa_block_q_dq
self.block_kv_dq = self.config.sa_block_kv_dq
self.q_layout = self.config.sa_q_layout
self.k_layout = self.config.sa_k_layout
self.v_layout = self.config.sa_v_layout
self.use_splash_scheduler = self.config.use_splash_scheduler
self.attn_logits_soft_cap = attn_logits_soft_cap
self.sliding_window_size = sliding_window_size
self.chunk_attn_window_size = chunk_attn_window_size
Expand Down Expand Up @@ -1151,21 +1164,6 @@ def tpu_flash_attention(
axis_names_kv = self._logical_to_mesh_axes(self.flash_axis_names_kv)
indexer_mask_axis_names = self._logical_to_mesh_axes((BATCH_ATTN, Q_LENGTH, KV_LENGTH))

global global_block_q, global_block_kv, global_block_kv_compute, global_block_q_dkv, global_block_kv_dkv
global global_block_kv_dkv_compute, global_block_q_dq, global_block_kv_dq, global_use_fused_bwd_kernel
global global_q_layout, global_k_layout, global_v_layout
global_block_q = self.config.sa_block_q
global_block_kv = self.config.sa_block_kv
global_block_kv_compute = self.config.sa_block_kv_compute
global_block_q_dkv = self.config.sa_block_q_dkv
global_block_kv_dkv = self.config.sa_block_kv_dkv
global_block_kv_dkv_compute = self.config.sa_block_kv_dkv_compute
global_block_q_dq = self.config.sa_block_q_dq
global_block_kv_dq = self.config.sa_block_kv_dq
global_use_fused_bwd_kernel = self.config.sa_use_fused_bwd_kernel
global_q_layout = self.config.sa_q_layout
global_k_layout = self.config.sa_k_layout
global_v_layout = self.config.sa_v_layout

devices_in_data_fsdp = self.mesh.shape.get("data", 1) * self.mesh.shape.get("fsdp", 1)
assert (query.shape[0] / devices_in_data_fsdp).is_integer(), (
Expand All @@ -1178,16 +1176,16 @@ def tpu_flash_attention(
def create_sa_config(config, query, key, attn_logits_soft_cap):
if config.use_tokamax_splash:
sa_config = tokamax_splash_kernel.SplashConfig(
block_q=min(global_block_q, query.shape[2]),
block_kv=min(global_block_kv, key.shape[2]),
block_kv_compute=min(global_block_kv_compute, key.shape[2]),
block_q_dkv=min(global_block_q_dkv, query.shape[2]),
block_kv_dkv=min(global_block_kv_dkv, key.shape[2]),
block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]),
block_q=min(self.block_q, query.shape[2]),
block_kv=min(self.block_kv, key.shape[2]),
block_kv_compute=min(self.block_kv_compute, key.shape[2]),
block_q_dkv=min(self.block_q_dkv, query.shape[2]),
block_kv_dkv=min(self.block_kv_dkv, key.shape[2]),
block_kv_dkv_compute=min(self.block_kv_dkv_compute, query.shape[2]),
use_fused_bwd_kernel=True, # tokamax only supports fused bwd kernel
q_layout=tokamax_splash_kernel.QKVLayout[global_q_layout],
k_layout=tokamax_splash_kernel.QKVLayout[global_k_layout],
v_layout=tokamax_splash_kernel.QKVLayout[global_v_layout],
q_layout=tokamax_splash_kernel.QKVLayout[self.q_layout],
k_layout=tokamax_splash_kernel.QKVLayout[self.k_layout],
v_layout=tokamax_splash_kernel.QKVLayout[self.v_layout],
attn_logits_soft_cap=attn_logits_soft_cap,
residual_checkpoint_name="context",
fwd_cost_estimate=pl.CostEstimate(
Expand All @@ -1205,22 +1203,22 @@ def create_sa_config(config, query, key, attn_logits_soft_cap):
if config.cost_estimate_flops_bwd >= 0
else None,
dq_reduction_steps=config.dq_reduction_steps if config.dq_reduction_steps > 0 else None,
use_experimental_scheduler=config.use_splash_scheduler,
use_experimental_scheduler=self.use_splash_scheduler,
)
else:
sa_config = splash_attention_kernel.BlockSizes(
block_q=min(global_block_q, query.shape[2]),
block_kv=min(global_block_kv, key.shape[2]),
block_kv_compute=min(global_block_kv_compute, key.shape[2]),
block_q_dkv=min(global_block_q_dkv, query.shape[2]),
block_kv_dkv=min(global_block_kv_dkv, key.shape[2]),
block_kv_dkv_compute=min(global_block_kv_dkv_compute, query.shape[2]),
block_q_dq=None if global_use_fused_bwd_kernel else min(global_block_q_dq, query.shape[2]),
block_kv_dq=None if global_use_fused_bwd_kernel else min(global_block_kv_dq, query.shape[2]),
use_fused_bwd_kernel=global_use_fused_bwd_kernel,
q_layout=splash_attention_kernel.QKVLayout[global_q_layout],
k_layout=splash_attention_kernel.QKVLayout[global_k_layout],
v_layout=splash_attention_kernel.QKVLayout[global_v_layout],
block_q=min(self.block_q, query.shape[2]),
block_kv=min(self.block_kv, key.shape[2]),
block_kv_compute=min(self.block_kv_compute, key.shape[2]),
block_q_dkv=min(self.block_q_dkv, query.shape[2]),
block_kv_dkv=min(self.block_kv_dkv, key.shape[2]),
block_kv_dkv_compute=min(self.block_kv_dkv_compute, query.shape[2]),
block_q_dq=None if config.sa_use_fused_bwd_kernel else min(self.block_q_dq, query.shape[2]),
block_kv_dq=None if config.sa_use_fused_bwd_kernel else min(self.block_kv_dq, query.shape[2]),
use_fused_bwd_kernel=config.sa_use_fused_bwd_kernel,
q_layout=splash_attention_kernel.QKVLayout[self.q_layout],
k_layout=splash_attention_kernel.QKVLayout[self.k_layout],
v_layout=splash_attention_kernel.QKVLayout[self.v_layout],
)
return sa_config

Expand Down Expand Up @@ -1453,8 +1451,8 @@ def kernel_fn(q, k, v, d, s):
key,
value,
decoder_segment_ids_tuple,
block_kv=self.config.sa_block_kv,
block_q=self.config.sa_block_q,
block_kv=self.block_kv,
block_q=self.block_q,
mask=materialized_mask,
mask_value=DEFAULT_MASK_VALUE,
)
Expand Down
Loading