diff --git a/diffusion_lm/README.md b/diffusion_lm/README.md new file mode 100644 index 000000000..9e4f9ff57 --- /dev/null +++ b/diffusion_lm/README.md @@ -0,0 +1,167 @@ +# Diffusion Language Model (LLaDA-style) for MLX + +An MLX implementation of **masked diffusion language models**, following the approach from: + +> **LLaDA: Large Language Diffusion with mAsking** +> Liao et al., 2025 — [arXiv:2502.09992](https://arxiv.org/abs/2502.09992) +> GitHub: [ML-GSAI/LLaDA](https://github.com/ML-GSAI/LLaDA) + +## What is a masked diffusion LM? + +Instead of predicting the next token autoregressively, a masked diffusion LM +works bidirectionally: + +- **Training (forward process)**: randomly mask a fraction of tokens in the input + sequence with a special `[MASK]` token. The fraction is sampled per-sequence + from U(ε, 1). The model (a bidirectional transformer) learns to predict the + original token at every masked position simultaneously. +- **Inference (reverse process)**: start with a fully-masked response, then + iteratively unmask the most-confident predictions until the sequence is + complete. + +The training objective is a weighted cross-entropy loss that is provably an +upper bound on the negative log-likelihood, making the model a proper generative +model. + +## Files + +| File | Description | +|---|---| +| `model.py` | Bidirectional transformer (mask predictor). LLaMA-style blocks without the causal mask. | +| `train.py` | Pre-training and supervised fine-tuning (SFT) training script. | +| `generate.py` | Iterative-unmasking generation with Gumbel-noise sampling, semi-autoregressive blocks, and classifier-free guidance. | +| `convert.py` | Convert a HuggingFace LLaDA checkpoint to MLX safetensors format. | + +## Setup + +```bash +pip install mlx transformers huggingface_hub +``` + +## Quick demo (no weights required) + +```bash +python generate.py --demo +``` + +This runs a tiny 2-layer model with random weights to verify the generation +loop works end-to-end. + +## Using a pre-trained LLaDA model + +### Step 1 — Convert from HuggingFace + +```bash +python convert.py \ + --hf-path GSAI-ML/LLaDA-8B-Instruct \ + --mlx-path mlx_llada_instruct +``` + +To save memory with 4-bit quantization: + +```bash +python convert.py \ + --hf-path GSAI-ML/LLaDA-8B-Instruct \ + --mlx-path mlx_llada_instruct_4bit \ + -q +``` + +### Step 2 — Generate text + +```bash +python generate.py \ + --model-path mlx_llada_instruct \ + --prompt "What is the capital of France?" \ + --gen-length 128 \ + --steps 128 \ + --chat +``` + +Useful flags: + +| Flag | Default | Description | +|---|---|---| +| `--gen-length` | 64 | Tokens to generate | +| `--steps` | gen_length | Denoising steps (more = better quality, slower) | +| `--block-length` | gen_length | Block size for semi-autoregressive generation | +| `--temperature` | 0.0 | Gumbel noise temperature (0 = greedy) | +| `--cfg-scale` | 0.0 | Classifier-free guidance strength | +| `--remasking` | low_confidence | Remasking strategy: `low_confidence` or `random` | +| `--chat` | False | Apply the instruct chat template | + +## Training from scratch + +### Pre-training + +Prepare a JSONL file where each line is `{"text": "..."}`: + +```bash +python train.py \ + --data data/train.jsonl \ + --tokenizer meta-llama/Meta-Llama-3-8B \ + --d_model 512 --n_layers 8 --n_heads 8 --mlp_hidden_size 2048 \ + --batch_size 16 --iters 50000 --lr 3e-4 +``` + +### Supervised fine-tuning (SFT) + +Prepare a JSONL file where each line is `{"prompt": "...", "response": "..."}`: + +```bash +python train.py \ + --sft \ + --data data/sft.jsonl \ + --tokenizer meta-llama/Meta-Llama-3-8B \ + --d_model 512 --n_layers 8 --n_heads 8 --mlp_hidden_size 2048 \ + --batch_size 8 --iters 10000 --lr 1e-4 +``` + +For a quick smoke-test with a tiny model on random data: + +```bash +python train.py --d_model 128 --n_layers 2 --n_heads 4 --mlp_hidden_size 256 \ + --batch_size 4 --iters 100 --log_every 10 +``` + +## Architecture + +The model is a standard transformer with one key change: **no causal mask** in +self-attention. Every token can attend to every other token, making the model +bidirectional. + +| Component | Detail | +|---|---| +| Attention | Multi-head, bidirectional (no causal mask), RoPE position encoding | +| FFN | SwiGLU: `silu(gate_proj(x)) * up_proj(x)` → `down_proj` | +| Normalisation | RMSNorm (pre-norm, before attention and FFN) | +| Special token | `[MASK]` token replaces corrupted positions during both training and inference | + +LLaDA-8B uses the same hyper-parameters as LLaMA-3-8B (d_model=4096, 32 layers, +32 heads, mlp_hidden_size=14336, rope_theta=500000) with the LLaMA-3 tokenizer +(vocab_size=126464, mask_token_id=126336). + +## Generation algorithm + +``` +Input: prompt tokens [p₁ p₂ … pₙ] +Init: x = [p₁ … pₙ | M M … M] (M = [MASK], gen_length masks) + +For each block b in 1..num_blocks: + Compute token schedule: how many tokens to unmask at each step + For step t in 1..steps_per_block: + logits ← model(x) # bidirectional forward pass + x₀ ← argmax(logits + Gumbel) # sample predictions + conf ← softmax_prob(x₀) # confidence scores + k ← num_transfer[b, t] # tokens to commit this step + Commit the k highest-confidence masked positions in x +Return x[n+1:] +``` + +The **linear noise schedule** distributes unmaskings evenly: if there are `N` +masked tokens and `T` steps, approximately `N/T` tokens are committed per step. + +## References + +- Liao et al. (2025). *LLaDA: Large Language Diffusion with mAsking*. arXiv:2502.09992. +- Austin et al. (2021). *Structured Denoising Diffusion Models in Discrete State-Spaces*. +- Shi et al. (2024). *Simplified and Generalized Masked Diffusion for Discrete Data*. arXiv:2406.04329. diff --git a/diffusion_lm/convert.py b/diffusion_lm/convert.py new file mode 100644 index 000000000..b5f56e8e7 --- /dev/null +++ b/diffusion_lm/convert.py @@ -0,0 +1,339 @@ +"""Convert a HuggingFace LLaDA model to MLX format. + +Downloads (or reads from a local path) the LLaDA-8B-Base or LLaDA-8B-Instruct +model, remaps the weight names to match this MLX implementation, and saves the +result as a directory of safetensors files together with a config.json. + +Usage: + # Convert from HuggingFace Hub (requires huggingface_hub): + python convert.py --hf-path GSAI-ML/LLaDA-8B-Instruct --mlx-path mlx_llada_instruct + + # Convert from a local HuggingFace checkpoint: + python convert.py --hf-path /path/to/llada --mlx-path mlx_llada + + # Quantize weights to 4-bit while converting: + python convert.py --hf-path GSAI-ML/LLaDA-8B-Instruct --mlx-path mlx_llada_4bit -q + +Weight name mapping +------------------- +The LLaDA HuggingFace model uses the OLMo-derived naming convention under a +``transformer.*`` top-level namespace, with LLaMA-style block internals: + + transformer.wte.weight → embed_tokens.weight + transformer.ln_f.weight → norm.weight + transformer.ff_out.weight → lm_head.weight + transformer.blocks.{i}.attn_norm.weight → layers.{i}.input_layernorm.weight + transformer.blocks.{i}.ff_norm.weight → layers.{i}.post_attention_layernorm.weight + transformer.blocks.{i}.q_proj.weight → layers.{i}.self_attn.q_proj.weight + transformer.blocks.{i}.k_proj.weight → layers.{i}.self_attn.k_proj.weight + transformer.blocks.{i}.v_proj.weight → layers.{i}.self_attn.v_proj.weight + transformer.blocks.{i}.attn_out.weight → layers.{i}.self_attn.o_proj.weight + transformer.blocks.{i}.ff_proj.weight → layers.{i}.mlp.gate_proj.weight + transformer.blocks.{i}.up_proj.weight → layers.{i}.mlp.up_proj.weight + transformer.blocks.{i}.ff_out.weight → layers.{i}.mlp.down_proj.weight +""" + +import argparse +import json +import re +import shutil +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_flatten + + +# --------------------------------------------------------------------------- +# Weight name mapping +# --------------------------------------------------------------------------- + +# Simple prefix renames +_PREFIX_MAP = { + "transformer.wte.weight": "embed_tokens.weight", + "transformer.ln_f.weight": "norm.weight", + "transformer.ff_out.weight": "lm_head.weight", +} + +# Per-layer renames expressed as (source_suffix, dest_suffix) +_LAYER_MAP = [ + ("attn_norm.weight", "input_layernorm.weight"), + ("ff_norm.weight", "post_attention_layernorm.weight"), + ("q_proj.weight", "self_attn.q_proj.weight"), + ("k_proj.weight", "self_attn.k_proj.weight"), + ("v_proj.weight", "self_attn.v_proj.weight"), + ("attn_out.weight", "self_attn.o_proj.weight"), + ("ff_proj.weight", "mlp.gate_proj.weight"), + ("up_proj.weight", "mlp.up_proj.weight"), + ("ff_out.weight", "mlp.down_proj.weight"), +] + +_LAYER_RE = re.compile(r"^transformer\.blocks\.(\d+)\.(.+)$") + + +def _remap_key(key: str) -> str | None: + """Return the MLX weight name for a HuggingFace weight name, or None to skip.""" + # Top-level renames + if key in _PREFIX_MAP: + return _PREFIX_MAP[key] + + # Per-layer renames + m = _LAYER_RE.match(key) + if m: + layer_idx, suffix = m.group(1), m.group(2) + for src_suffix, dst_suffix in _LAYER_MAP: + if suffix == src_suffix: + return f"layers.{layer_idx}.{dst_suffix}" + + # Drop everything else (rotary buffers, etc.) + return None + + +# --------------------------------------------------------------------------- +# Config mapping +# --------------------------------------------------------------------------- + + +def _build_mlx_config(hf_config: dict) -> dict: + """Build MLX ModelArgs-compatible config from a HuggingFace config dict.""" + return { + "d_model": hf_config.get("d_model", 4096), + "n_layers": hf_config.get("n_layers", 32), + "n_heads": hf_config.get("n_heads", 32), + "n_kv_heads": hf_config.get("n_kv_heads", hf_config.get("n_heads", 32)), + "mlp_hidden_size": hf_config.get("mlp_hidden_size", 14336), + "vocab_size": hf_config.get("vocab_size", hf_config.get("embedding_size", 126464)), + "mask_token_id": hf_config.get("mask_token_id", 126336), + "rms_norm_eps": hf_config.get("rms_norm_eps", 1e-5), + "rope_theta": hf_config.get("rope_theta", 500000.0), + } + + +# --------------------------------------------------------------------------- +# Conversion +# --------------------------------------------------------------------------- + + +def convert( + hf_path: str, + mlx_path: str, + quantize: bool = False, + q_group_size: int = 64, + q_bits: int = 4, + dtype: str = "bfloat16", + upload_repo: str | None = None, +): + """Convert a HuggingFace LLaDA model to MLX format. + + Args: + hf_path: HuggingFace Hub repo id (e.g. ``"GSAI-ML/LLaDA-8B-Instruct"``) + or local directory containing the HuggingFace model files. + mlx_path: Output directory for the converted model. + quantize: If True, quantize weights with ``mlx.nn.quantize``. + q_group_size: Quantization group size. + q_bits: Quantization bit-width (4 or 8). + dtype: Target float dtype (``"float16"`` or ``"bfloat16"``). + upload_repo: If set, upload the converted model to this HuggingFace repo. + """ + from transformers import AutoTokenizer + + hf_path = Path(hf_path) + mlx_path = Path(mlx_path) + mlx_path.mkdir(parents=True, exist_ok=True) + + dtype_map = {"float16": mx.float16, "bfloat16": mx.bfloat16, "float32": mx.float32} + target_dtype = dtype_map.get(dtype, mx.bfloat16) + + # ------------------------------------------------------------------ config + config_path = hf_path / "config.json" + if not config_path.exists(): + # Try downloading from Hub + try: + from huggingface_hub import snapshot_download + hf_path = Path(snapshot_download(str(hf_path))) + config_path = hf_path / "config.json" + except ImportError: + raise RuntimeError( + "huggingface_hub not installed. " + "Install it with: pip install huggingface_hub" + ) + + with open(config_path) as f: + hf_config = json.load(f) + + mlx_config = _build_mlx_config(hf_config) + with open(mlx_path / "config.json", "w") as f: + json.dump(mlx_config, f, indent=2) + print(f"Saved config → {mlx_path / 'config.json'}") + + # ----------------------------------------------------------------- weights + import glob + + weight_files = sorted(glob.glob(str(hf_path / "*.safetensors"))) + if not weight_files: + # Fall back to PyTorch bin files + weight_files = sorted(glob.glob(str(hf_path / "*.bin"))) + + if not weight_files: + raise FileNotFoundError(f"No weight files found in {hf_path}") + + print(f"Loading weights from {len(weight_files)} file(s) …") + raw_weights: dict[str, mx.array] = {} + for wf in weight_files: + if wf.endswith(".safetensors"): + raw_weights.update(mx.load(wf).items()) + else: + # PyTorch bin: load via numpy + import torch + state_dict = torch.load(wf, map_location="cpu") + for k, v in state_dict.items(): + raw_weights[k] = mx.array(v.numpy()) + + print(f"Loaded {len(raw_weights)} tensors. Remapping …") + + mlx_weights: dict[str, mx.array] = {} + skipped: list[str] = [] + for src_key, tensor in raw_weights.items(): + dst_key = _remap_key(src_key) + if dst_key is None: + skipped.append(src_key) + continue + mlx_weights[dst_key] = tensor.astype(target_dtype) + + if skipped: + print(f"Skipped {len(skipped)} tensors (rotary buffers, unused, …):") + for k in skipped[:10]: + print(f" {k}") + if len(skipped) > 10: + print(f" … and {len(skipped) - 10} more") + + print(f"Remapped {len(mlx_weights)} tensors.") + + # ---------------------------------------------------------------- quantize + if quantize: + from model import Model, ModelArgs + + model_args = ModelArgs(**mlx_config) + model = Model(model_args) + model.load_weights(list(mlx_weights.items())) + mx.eval(model.parameters()) + + nn.quantize(model, group_size=q_group_size, bits=q_bits) + mx.eval(model.parameters()) + + mlx_weights = dict(tree_flatten(model.parameters())) + + # Record quantization in config + mlx_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + with open(mlx_path / "config.json", "w") as f: + json.dump(mlx_config, f, indent=2) + print(f"Quantized to {q_bits}-bit (group_size={q_group_size}).") + + # ------------------------------------------------------------------- save + # Split into ≤5 GB shards + max_shard_bytes = 5 * 1024**3 + shard_weights: list[dict[str, mx.array]] = [{}] + shard_bytes = 0 + + for key, tensor in mlx_weights.items(): + nbytes = tensor.size * tensor.itemsize + if shard_bytes + nbytes > max_shard_bytes and shard_weights[-1]: + shard_weights.append({}) + shard_bytes = 0 + shard_weights[-1][key] = tensor + shard_bytes += nbytes + + n_shards = len(shard_weights) + for i, shard in enumerate(shard_weights): + if n_shards == 1: + out_file = mlx_path / "weights.safetensors" + else: + out_file = mlx_path / f"weights-{i + 1:05d}-of-{n_shards:05d}.safetensors" + mx.save_safetensors(str(out_file), shard) + print(f"Saved {out_file.name} ({len(shard)} tensors)") + + # ---------------------------------------------------------------- tokenizer + print("Copying tokenizer files …") + tok_files = [ + "tokenizer.json", + "tokenizer_config.json", + "tokenizer.model", + "special_tokens_map.json", + "added_tokens.json", + ] + for fname in tok_files: + src = hf_path / fname + if src.exists(): + shutil.copy2(src, mlx_path / fname) + print(f" {fname}") + + print(f"\nConversion complete → {mlx_path}") + + # ----------------------------------------------------------------- upload + if upload_repo: + from huggingface_hub import HfApi + + api = HfApi() + api.create_repo(upload_repo, exist_ok=True) + api.upload_folder( + folder_path=str(mlx_path), + repo_id=upload_repo, + repo_type="model", + ) + print(f"Uploaded to https://huggingface.co/{upload_repo}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Convert a HuggingFace LLaDA model to MLX format" + ) + parser.add_argument( + "--hf-path", type=str, required=True, + help="HuggingFace repo id or local path (e.g. GSAI-ML/LLaDA-8B-Instruct)", + ) + parser.add_argument( + "--mlx-path", type=str, required=True, + help="Output directory for the converted MLX model", + ) + parser.add_argument( + "-q", "--quantize", action="store_true", + help="Quantize model weights after conversion", + ) + parser.add_argument( + "--q-group-size", type=int, default=64, + help="Quantization group size (default: 64)", + ) + parser.add_argument( + "--q-bits", type=int, default=4, choices=[4, 8], + help="Quantization bits (default: 4)", + ) + parser.add_argument( + "--dtype", type=str, default="bfloat16", + choices=["float16", "bfloat16", "float32"], + help="Weight dtype (default: bfloat16)", + ) + parser.add_argument( + "--upload-repo", type=str, default=None, + help="Upload the converted model to this HuggingFace repo id", + ) + + args = parser.parse_args() + convert( + hf_path=args.hf_path, + mlx_path=args.mlx_path, + quantize=args.quantize, + q_group_size=args.q_group_size, + q_bits=args.q_bits, + dtype=args.dtype, + upload_repo=args.upload_repo, + ) + + +if __name__ == "__main__": + main() diff --git a/diffusion_lm/generate.py b/diffusion_lm/generate.py new file mode 100644 index 000000000..4421e214d --- /dev/null +++ b/diffusion_lm/generate.py @@ -0,0 +1,429 @@ +"""Masked-diffusion language model generation (inference). + +Implements the iterative unmasking algorithm from LLaDA: + "LLaDA: Large Language Diffusion with mAsking" + https://arxiv.org/abs/2502.09992 + +Starting from a fully masked response region, the model predicts token +probabilities at every masked position simultaneously. At each step, the most +confident predictions are "unmasked" (committed), while the rest are either kept +masked or re-masked (depending on the remasking strategy). This is repeated +until all positions are unmasked. + +Usage: + # Load a pre-trained LLaDA-8B-Instruct model converted to MLX format: + python generate.py \\ + --model-path path/to/mlx_llada \\ + --prompt "What is the capital of France?" \\ + --gen-length 128 \\ + --steps 128 + + # Quick test with a tiny random model (no real weights): + python generate.py --demo +""" + +import argparse +import json +import time +from pathlib import Path + +import mlx.core as mx +import numpy as np + +from model import Model, ModelArgs + + +# --------------------------------------------------------------------------- +# Gumbel-noise categorical sampling +# --------------------------------------------------------------------------- + + +def add_gumbel_noise(logits: mx.array, temperature: float) -> mx.array: + """Apply Gumbel noise for stochastic categorical sampling. + + At temperature=0 this reduces to argmax (greedy decoding). + Using higher temperature introduces diversity at the cost of coherence. + + Note: The LLaDA paper recommends float32 precision here; low-precision + Gumbel noise slightly improves perplexity but hurts generation quality. + + Args: + logits: Unnormalised log-probabilities of shape (..., vocab_size). + temperature: Gumbel noise scale; 0 = greedy. + + Returns: + Noisy logits in the same shape (for use with argmax). + """ + if temperature == 0.0: + return logits + # Sample Gumbel noise: -log(-log(U)) ≡ log(exp(logit)) / (-log(U))^T + noise = mx.random.uniform(shape=logits.shape) + # Clamp to avoid log(0) + noise = mx.clip(noise, 1e-10, 1.0) + gumbel = (-mx.log(noise)) ** temperature + return mx.exp(logits) / gumbel + + +# --------------------------------------------------------------------------- +# Step-size schedule +# --------------------------------------------------------------------------- + + +def get_num_transfer_tokens( + mask_index: mx.array, + steps: int, +) -> np.ndarray: + """Compute how many masked tokens to unmask at each denoising step. + + Uses a linear (uniform) schedule: distributes the total number of masked + tokens as evenly as possible across the given number of steps. + + Args: + mask_index: Boolean array of shape (B, L); True where tokens are masked. + steps: Number of denoising steps. + + Returns: + Integer numpy array of shape (B, steps) giving per-step token counts. + """ + mask_num = np.array(mask_index.sum(axis=1)) # (B,) + base = mask_num // steps + remainder = mask_num % steps + + num_transfer = np.zeros((mask_num.shape[0], steps), dtype=np.int64) + base[:, None] + for i, rem in enumerate(remainder): + num_transfer[i, :rem] += 1 + + return num_transfer + + +# --------------------------------------------------------------------------- +# Core generation function +# --------------------------------------------------------------------------- + + +def generate( + model: Model, + prompt: mx.array, + attention_mask: mx.array | None = None, + steps: int = 128, + gen_length: int = 128, + block_length: int = 128, + temperature: float = 0.0, + cfg_scale: float = 0.0, + remasking: str = "low_confidence", + verbose: bool = False, +) -> mx.array: + """Generate tokens using iterative masked diffusion. + + The algorithm: + 1. Append ``gen_length`` [MASK] tokens to the prompt. + 2. For each block (to support semi-autoregressive generation): + a. Pre-compute how many tokens to unmask at each step. + b. At each step: + - Run the model to predict all token positions. + - (Optional) apply classifier-free guidance. + - Score each masked position by predicted confidence. + - Commit the top-k most confident predictions; keep the rest masked. + 3. Return the completed sequence. + + Args: + model: A trained ``Model`` (mask predictor). + prompt: Token ids of shape (1, L_prompt) or (B, L_prompt). + attention_mask: Optional padding mask for the prompt of shape (B, L_prompt). + steps: Total denoising steps (≤ gen_length). More steps → better quality. + gen_length: Number of new tokens to generate. + block_length: Generate this many tokens per semi-autoregressive block. + Must divide gen_length exactly. Set equal to gen_length for + standard (non-semi-autoregressive) generation. + temperature: Gumbel noise temperature; 0 = greedy. + cfg_scale: Classifier-free guidance scale. 0 = disabled. + When >0, runs an extra unconditional forward pass at each step. + remasking: Strategy for deciding which tokens to re-mask between steps: + - ``"low_confidence"`` (default): re-mask predictions with the + lowest softmax probability (works best). + - ``"random"``: re-mask a random subset. + verbose: Print progress information. + + Returns: + Completed token ids of shape (B, L_prompt + gen_length). + """ + mask_token_id = model.args.mask_token_id + B, L_prompt = prompt.shape + + assert gen_length % block_length == 0, "gen_length must be divisible by block_length" + assert steps % (gen_length // block_length) == 0, ( + "steps must be divisible by the number of blocks" + ) + + # Build the working sequence: [prompt | MASK … MASK] + mask_fill = mx.full((B, gen_length), mask_token_id, dtype=mx.int32) + x = mx.concatenate([prompt, mask_fill], axis=1) # (B, L_prompt + gen_length) + + # Extend attention mask to cover the generation region + if attention_mask is not None: + gen_attn = mx.ones((B, gen_length), dtype=attention_mask.dtype) + attention_mask = mx.concatenate([attention_mask, gen_attn], axis=1) + + # Remember which positions are part of the prompt (never unmask these) + prompt_mask = x != mask_token_id # (B, L_total) + + num_blocks = gen_length // block_length + steps_per_block = steps // num_blocks + + if verbose: + print(f"Generating {gen_length} tokens in {num_blocks} block(s), " + f"{steps_per_block} step(s) each …") + + for block_idx in range(num_blocks): + block_start = L_prompt + block_idx * block_length + block_end = L_prompt + (block_idx + 1) * block_length + + # Compute the per-step token-transfer schedule for this block + block_mask_index = x[:, block_start:block_end] == mask_token_id + mx.eval(block_mask_index) + num_transfer = get_num_transfer_tokens( + np.array(block_mask_index), steps_per_block + ) # numpy (B, steps_per_block) + + for step in range(steps_per_block): + # ---- forward pass ------------------------------------------------ + if cfg_scale > 0.0: + # Classifier-free guidance: run conditional and unconditional + # (unconditional = prompt tokens also replaced by [MASK]) + un_x = mx.where(prompt_mask, mask_token_id, x) + x_cat = mx.concatenate([x, un_x], axis=0) # (2B, L) + + if attention_mask is not None: + attn_cat = mx.concatenate([attention_mask, attention_mask], axis=0) + logits_cat = model(x_cat, attn_cat) + else: + logits_cat = model(x_cat) + + logits_cond = logits_cat[:B] + logits_uncond = logits_cat[B:] + logits = logits_uncond + (cfg_scale + 1.0) * ( + logits_cond - logits_uncond + ) + else: + logits = model(x, attention_mask) # (B, L, V) + + # ---- sample predictions ----------------------------------------- + logits_noisy = add_gumbel_noise(logits, temperature) + x0 = mx.argmax(logits_noisy, axis=-1) # (B, L) + + # ---- confidence scoring ----------------------------------------- + mask_index = x == mask_token_id # (B, L) + + if remasking == "low_confidence": + p = mx.softmax(logits.astype(mx.float32), axis=-1) # (B, L, V) + # Gather probability of the predicted token at each position + x0_p = p[ + mx.arange(B)[:, None], + mx.arange(x.shape[1])[None, :], + x0, + ] # (B, L) + elif remasking == "random": + x0_p = mx.random.uniform(shape=(B, x.shape[1])) + else: + raise ValueError(f"Unknown remasking strategy: {remasking!r}") + + # Mask confidence scores for positions beyond the current block + # and for already-unmasked (prompt / previously committed) positions + beyond_block = mx.arange(x.shape[1])[None, :] >= block_end + x0_p = mx.where(beyond_block | ~mask_index, float("-inf"), x0_p) + + # Best prediction for each position (use committed token if not masked) + x0 = mx.where(mask_index, x0, x) + + # ---- commit top-k tokens ---------------------------------------- + # For each batch element, select the num_transfer[b, step] positions + # with the highest confidence and commit them. + # + # Implementation: double-argsort gives rank (0 = most confident). + # A position is transferred if its rank < num_to_unmask AND + # it was masked. + k_vals = mx.array(num_transfer[:, step]) # (B,) + + # argsort twice: first gives sorted indices, second gives rank + sorted_idx = mx.argsort(-x0_p, axis=-1) # (B, L) descending + ranks = mx.argsort(sorted_idx, axis=-1) # (B, L) rank of each pos + + transfer_mask = ranks < k_vals[:, None] # (B, L) + + x = mx.where(transfer_mask, x0, x) + mx.eval(x) + + if verbose and (step + 1) % max(1, steps_per_block // 4) == 0: + n_remaining = int((x == mask_token_id).sum()) + print(f" block {block_idx + 1}/{num_blocks}, " + f"step {step + 1}/{steps_per_block}: " + f"{n_remaining} mask tokens remaining") + + return x + + +# --------------------------------------------------------------------------- +# Model loading +# --------------------------------------------------------------------------- + + +def load_model(model_path: str) -> tuple[Model, object]: + """Load an MLX diffusion LM model and its tokenizer. + + The model directory should contain: + - ``config.json``: ModelArgs fields (as saved by ``convert.py``). + - ``weights.safetensors`` (or sharded ``weights-00001-of-XXXX.safetensors``). + - A HuggingFace tokenizer (``tokenizer.json``, ``tokenizer_config.json``, …). + + Args: + model_path: Path to the model directory. + + Returns: + (model, tokenizer) tuple ready for inference. + """ + import glob + + from transformers import AutoTokenizer + + path = Path(model_path) + + # Load config + with open(path / "config.json") as f: + cfg = json.load(f) + args = ModelArgs.from_dict(cfg) + model = Model(args) + + # Load weights (support sharding) + weight_files = sorted(glob.glob(str(path / "*.safetensors"))) + if not weight_files: + raise FileNotFoundError(f"No .safetensors files found in {path}") + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + model.load_weights(list(weights.items())) + mx.eval(model.parameters()) + model.eval() + + tokenizer = AutoTokenizer.from_pretrained(model_path) + if tokenizer.padding_side != "left": + tokenizer.padding_side = "left" + + return model, tokenizer + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser(description="Generate text with a diffusion LM") + + parser.add_argument("--model-path", type=str, default=None, + help="Path to MLX model directory") + parser.add_argument("--prompt", type=str, + default="The quick brown fox") + parser.add_argument("--gen-length", type=int, default=64, + help="Number of tokens to generate") + parser.add_argument("--steps", type=int, default=None, + help="Denoising steps (default: gen_length)") + parser.add_argument("--block-length", type=int, default=None, + help="Semi-autoregressive block size (default: gen_length)") + parser.add_argument("--temperature", type=float, default=0.0, + help="Gumbel noise temperature (0 = greedy)") + parser.add_argument("--cfg-scale", type=float, default=0.0, + help="Classifier-free guidance scale (0 = disabled)") + parser.add_argument("--remasking", type=str, default="low_confidence", + choices=["low_confidence", "random"]) + parser.add_argument("--chat", action="store_true", + help="Apply instruct chat template to the prompt") + parser.add_argument("--demo", action="store_true", + help="Run a quick demo with a tiny random model (no weights needed)") + parser.add_argument("--verbose", action="store_true") + + args = parser.parse_args() + + gen_length = args.gen_length + steps = args.steps or gen_length + block_length = args.block_length or gen_length + + if args.demo: + # ---- tiny synthetic demo (no real weights) -------------------------- + print("Running demo with a tiny random model …\n") + demo_args = ModelArgs( + d_model=128, n_layers=2, n_heads=4, n_kv_heads=4, + mlp_hidden_size=256, vocab_size=1001, mask_token_id=1000, + ) + model = Model(demo_args) + mx.eval(model.parameters()) + + prompt_ids = mx.array([[1, 2, 3, 4, 5]]) # fake prompt + t0 = time.perf_counter() + out = generate( + model, prompt_ids, + steps=8, gen_length=16, block_length=16, + temperature=1.0, verbose=args.verbose, + ) + mx.eval(out) + elapsed = time.perf_counter() - t0 + print(f"Output token ids: {out[0].tolist()}") + print(f"Generated in {elapsed:.2f}s") + return + + if args.model_path is None: + parser.error("--model-path is required (or use --demo)") + + model, tokenizer = load_model(args.model_path) + + prompt_text = args.prompt + if args.chat: + messages = [{"role": "user", "content": prompt_text}] + prompt_text = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + + enc = tokenizer( + [prompt_text], + return_tensors="np", + add_special_tokens=not args.chat, + padding=True, + ) + input_ids = mx.array(enc["input_ids"]) + attention_mask = mx.array(enc["attention_mask"]) + + print(f"Prompt ({input_ids.shape[1]} tokens): {prompt_text!r}") + print(f"Generating {gen_length} tokens with {steps} steps …\n") + + t0 = time.perf_counter() + out = generate( + model, + input_ids, + attention_mask=attention_mask, + steps=steps, + gen_length=gen_length, + block_length=block_length, + temperature=args.temperature, + cfg_scale=args.cfg_scale, + remasking=args.remasking, + verbose=args.verbose, + ) + mx.eval(out) + elapsed = time.perf_counter() - t0 + + response_ids = out[:, input_ids.shape[1]:] + response = tokenizer.batch_decode( + np.array(response_ids), skip_special_tokens=True + ) + + for i, text in enumerate(response): + print(f"[{i}] {text}") + print("-" * 60) + + print(f"\nGenerated {gen_length} tokens in {elapsed:.2f}s " + f"({gen_length / elapsed:.1f} tok/s)") + + +if __name__ == "__main__": + main() diff --git a/diffusion_lm/model.py b/diffusion_lm/model.py new file mode 100644 index 000000000..8d3b20013 --- /dev/null +++ b/diffusion_lm/model.py @@ -0,0 +1,170 @@ +# Copyright © 2024 Apple Inc. +# +# Diffusion Language Model (LLaDA-style) for MLX. +# Based on "LLaDA: Large Language Diffusion with mAsking" +# https://arxiv.org/abs/2502.09992 + +import inspect +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + + +@dataclass +class ModelArgs: + # Model dimensions + d_model: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int = 32 + mlp_hidden_size: int = 14336 + + # Vocabulary + vocab_size: int = 126464 + mask_token_id: int = 126336 + + # Normalization + rms_norm_eps: float = 1e-5 + + # RoPE + rope_theta: float = 500000.0 + + @classmethod + def from_dict(cls, params: dict) -> "ModelArgs": + valid = inspect.signature(cls).parameters + return cls(**{k: v for k, v in params.items() if k in valid}) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.n_kv_heads = args.n_kv_heads + self.head_dim = args.d_model // args.n_heads + self.scale = self.head_dim**-0.5 + + self.q_proj = nn.Linear(args.d_model, args.n_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear( + args.d_model, args.n_kv_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + args.d_model, args.n_kv_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear(args.n_heads * self.head_dim, args.d_model, bias=False) + + self.rope = nn.RoPE(self.head_dim, traditional=False, base=args.rope_theta) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + B, L, _ = x.shape + + queries = self.q_proj(x) + keys = self.k_proj(x) + values = self.v_proj(x) + + # Reshape to (B, n_heads, L, head_dim) + queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose( + 0, 2, 1, 3 + ) + + # Apply rotary embeddings + queries = self.rope(queries) + keys = self.rope(keys) + + # Expand KV heads for grouped query attention + if self.n_kv_heads != self.n_heads: + n_repeat = self.n_heads // self.n_kv_heads + keys = mx.repeat(keys, n_repeat, axis=1) + values = mx.repeat(values, n_repeat, axis=1) + + # Bidirectional (non-causal) scaled dot-product attention + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(queries.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + # SwiGLU: silu(gate_proj(x)) * up_proj(x) -> down_proj + self.gate_proj = nn.Linear(args.d_model, args.mlp_hidden_size, bias=False) + self.up_proj = nn.Linear(args.d_model, args.mlp_hidden_size, bias=False) + self.down_proj = nn.Linear(args.mlp_hidden_size, args.d_model, bias=False) + + def __call__(self, x: mx.array) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.input_layernorm = nn.RMSNorm(args.d_model, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.d_model, eps=args.rms_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + # Pre-norm attention (no causal mask — bidirectional) + r = self.self_attn(self.input_layernorm(x), mask=mask) + h = x + r + # Pre-norm FFN + r = self.mlp(self.post_attention_layernorm(h)) + return h + r + + +class Model(nn.Module): + """Bidirectional transformer mask predictor for masked diffusion LM. + + Identical to a decoder-only LLM (LLaMA-style) except that attention is + *non-causal* (bidirectional), allowing each position to attend to all + other positions. No KV-cache is used because every forward pass sees the + full (partially masked) sequence. + """ + + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embed_tokens = nn.Embedding(args.vocab_size, args.d_model) + self.layers = [TransformerBlock(args) for _ in range(args.n_layers)] + self.norm = nn.RMSNorm(args.d_model, eps=args.rms_norm_eps) + self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False) + + def __call__( + self, + input_ids: mx.array, + attention_mask: Optional[mx.array] = None, + ) -> mx.array: + """Forward pass. + + Args: + input_ids: Token ids of shape (B, L). Masked positions should + contain ``args.mask_token_id``. + attention_mask: Optional boolean/float padding mask of shape (B, L). + 1 = attend, 0 = ignore (padding). + + Returns: + Logits of shape (B, L, vocab_size). + """ + x = self.embed_tokens(input_ids) + + # Build additive padding mask: (B, 1, 1, L) → broadcast over heads/query + mask = None + if attention_mask is not None: + # attention_mask: 1 where real token, 0 where padding + # Convert to large negative for softmax + pad = (1.0 - attention_mask.astype(mx.float32)) * -1e9 + mask = pad[:, None, None, :] # (B, 1, 1, L) + + for layer in self.layers: + x = layer(x, mask=mask) + + x = self.norm(x) + return self.lm_head(x) diff --git a/diffusion_lm/train.py b/diffusion_lm/train.py new file mode 100644 index 000000000..545f7555e --- /dev/null +++ b/diffusion_lm/train.py @@ -0,0 +1,420 @@ +"""Training script for a masked diffusion language model (LLaDA-style). + +Usage (pre-training from scratch on WikiText-103): + python train.py + +Usage (fine-tuning / SFT on instruction data): + python train.py --sft --data path/to/sft_data.jsonl + +The training data file should be a JSONL where each line is: + {"text": "..."} for pre-training + {"prompt": "...", "response": "..."} for SFT + +For a quick smoke-test with a tiny model: + python train.py --d_model 256 --n_layers 4 --n_heads 4 --mlp_hidden_size 512 \ + --batch_size 2 --iters 100 +""" + +import argparse +import json +import math +import time +from functools import partial +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np +from mlx.utils import tree_flatten, tree_map + +from model import Model, ModelArgs + + +# --------------------------------------------------------------------------- +# Masking / forward process +# --------------------------------------------------------------------------- + + +def forward_process( + input_ids: mx.array, + mask_token_id: int, + eps: float = 1e-3, +) -> tuple[mx.array, mx.array, mx.array]: + """Apply random token masking for masked diffusion training. + + For each sequence in the batch, samples a masking probability t ~ U(eps, 1) + and independently masks each token with probability t. + + Args: + input_ids: Integer tensor of shape (B, L). + mask_token_id: Token ID used as the [MASK] placeholder. + eps: Minimum masking probability (avoids t=0). + + Returns: + noisy_ids: Masked input ids of shape (B, L). + masked: Boolean mask of shape (B, L); True where tokens were replaced. + p_mask: Masking probability per token of shape (B, L). + """ + B, L = input_ids.shape + # Sample one masking probability per sequence + t = mx.random.uniform(shape=(B,), low=eps, high=1.0) + p_mask = mx.broadcast_to(t[:, None], (B, L)) + + # Create binary mask by comparing uniform samples to p_mask + u = mx.random.uniform(shape=(B, L)) + masked = u < p_mask + + noisy_ids = mx.where(masked, mask_token_id, input_ids) + return noisy_ids, masked, p_mask + + +# --------------------------------------------------------------------------- +# Loss functions +# --------------------------------------------------------------------------- + + +def pretrain_loss(model: Model, input_ids: mx.array) -> mx.array: + """Masked diffusion pre-training loss. + + Cross-entropy at masked positions, divided by the masking probability, + then averaged over all token positions. This is an upper bound on the + negative log-likelihood of the model distribution. + + Args: + model: The mask-predictor network. + input_ids: Clean token ids of shape (B, L). + + Returns: + Scalar loss. + """ + mask_token_id = model.args.mask_token_id + noisy_ids, masked, p_mask = forward_process(input_ids, mask_token_id) + + logits = model(noisy_ids) # (B, L, V) + + # Cross-entropy loss for every token + token_ce = nn.losses.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + input_ids.reshape(-1), + reduction="none", + ) # (B*L,) + token_ce = token_ce.reshape(input_ids.shape) # (B, L) + + # Weight each token's loss by 1/p_mask (inverse masking probability) + # Only sum at masked positions; un-masked positions have zero gradient + weighted = token_ce * masked / p_mask + + # Normalise by total tokens (masked + unmasked), as in the paper + B, L = input_ids.shape + return weighted.sum() / (B * L) + + +def sft_loss( + model: Model, + input_ids: mx.array, + prompt_lengths: mx.array, +) -> mx.array: + """Supervised fine-tuning loss for conditional masked diffusion. + + Masking is applied only to the *response* portion of each sequence; + the prompt tokens remain intact. Loss is normalised by the length of + the response (not the full sequence) as described in the LLaDA paper. + + Args: + model: The mask-predictor network. + input_ids: Token ids of shape (B, L) containing both prompt and response. + prompt_lengths: Integer tensor of shape (B,) with each prompt's length. + + Returns: + Scalar loss. + """ + mask_token_id = model.args.mask_token_id + B, L = input_ids.shape + + noisy_ids, masked, p_mask = forward_process(input_ids, mask_token_id) + + # Restore prompt tokens — do not mask the conditioning input + positions = mx.broadcast_to(mx.arange(L)[None, :], (B, L)) + prompt_region = positions < prompt_lengths[:, None] + noisy_ids = mx.where(prompt_region, input_ids, noisy_ids) + + logits = model(noisy_ids) # (B, L, V) + + token_ce = nn.losses.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + input_ids.reshape(-1), + reduction="none", + ).reshape(input_ids.shape) + + # Only apply loss where tokens were masked (response region) + masked = (noisy_ids == mask_token_id) + weighted = token_ce * masked / p_mask + + # Normalise by answer (response) length per sequence + response_len = (L - prompt_lengths).astype(mx.float32) # (B,) + per_seq = weighted.sum(axis=1) / mx.maximum(response_len, 1.0) + return per_seq.mean() + + +# --------------------------------------------------------------------------- +# Data loading +# --------------------------------------------------------------------------- + + +def tokenize_file(path: str, tokenizer, max_length: int) -> list[list[int]]: + """Tokenise a JSONL file into fixed-length token id lists.""" + samples = [] + with open(path) as f: + for line in f: + obj = json.loads(line) + text = obj.get("text", obj.get("prompt", "") + obj.get("response", "")) + ids = tokenizer.encode(text) + # Chunk into max_length windows + for start in range(0, len(ids), max_length): + chunk = ids[start : start + max_length] + if len(chunk) == max_length: + samples.append(chunk) + return samples + + +def iterate_batches( + data: list[list[int]], + batch_size: int, + shuffle: bool = True, +): + """Yield batches of token id arrays.""" + indices = np.arange(len(data)) + while True: + if shuffle: + np.random.shuffle(indices) + for start in range(0, len(indices) - batch_size + 1, batch_size): + batch_idx = indices[start : start + batch_size] + batch = mx.array([data[i] for i in batch_idx]) + yield batch + + +def iterate_sft_batches( + data: list[dict], + tokenizer, + batch_size: int, + max_length: int, + shuffle: bool = True, +): + """Yield batches for SFT, returning (input_ids, prompt_lengths).""" + records = [] + for obj in data: + prompt_ids = tokenizer.encode(obj["prompt"]) + response_ids = tokenizer.encode(obj["response"]) + ids = prompt_ids + response_ids + if len(ids) > max_length: + ids = ids[:max_length] + padded = ids + [tokenizer.eos_token_id] * (max_length - len(ids)) + records.append({"ids": padded, "prompt_len": len(prompt_ids)}) + + indices = np.arange(len(records)) + while True: + if shuffle: + np.random.shuffle(indices) + for start in range(0, len(indices) - batch_size + 1, batch_size): + batch_idx = indices[start : start + batch_size] + batch = [records[i] for i in batch_idx] + input_ids = mx.array([r["ids"] for r in batch]) + prompt_lengths = mx.array([r["prompt_len"] for r in batch]) + yield input_ids, prompt_lengths + + +# --------------------------------------------------------------------------- +# Training +# --------------------------------------------------------------------------- + + +def cosine_schedule( + step: int, + total_steps: int, + warmup_steps: int, + max_lr: float, + min_lr: float, +) -> float: + if step < warmup_steps: + return max_lr * step / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress)) + + +def main(): + parser = argparse.ArgumentParser(description="Train a masked diffusion LM") + + # Model + parser.add_argument("--d_model", type=int, default=512) + parser.add_argument("--n_layers", type=int, default=8) + parser.add_argument("--n_heads", type=int, default=8) + parser.add_argument("--n_kv_heads", type=int, default=8) + parser.add_argument("--mlp_hidden_size", type=int, default=2048) + parser.add_argument("--vocab_size", type=int, default=32000) + parser.add_argument("--mask_token_id", type=int, default=32000) + + # Training + parser.add_argument("--sft", action="store_true", help="SFT mode") + parser.add_argument("--data", type=str, default=None, help="Path to JSONL data") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument("--max_length", type=int, default=512) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--iters", type=int, default=10000) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--min_lr", type=float, default=3e-5) + parser.add_argument("--warmup", type=int, default=500) + parser.add_argument("--grad_clip", type=float, default=1.0) + + # Checkpointing + parser.add_argument("--save_every", type=int, default=1000) + parser.add_argument("--save_dir", type=str, default="checkpoints") + parser.add_argument("--resume", type=str, default=None) + + # Logging + parser.add_argument("--log_every", type=int, default=10) + parser.add_argument("--seed", type=int, default=42) + + args = parser.parse_args() + mx.random.seed(args.seed) + + save_dir = Path(args.save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + # ------------------------------------------------------------------ model + model_args = ModelArgs( + d_model=args.d_model, + n_layers=args.n_layers, + n_heads=args.n_heads, + n_kv_heads=args.n_kv_heads, + mlp_hidden_size=args.mlp_hidden_size, + vocab_size=args.vocab_size + 1, # +1 for mask token + mask_token_id=args.mask_token_id, + ) + model = Model(model_args) + mx.eval(model.parameters()) + + n_params = sum(p.size for _, p in tree_flatten(model.parameters())) + print(f"Model parameters: {n_params / 1e6:.1f}M") + + # Save config alongside checkpoints + config_path = save_dir / "config.json" + with open(config_path, "w") as f: + json.dump(vars(model_args), f, indent=2) + + # --------------------------------------------------------------- resume + start_step = 0 + if args.resume: + weights = mx.load(args.resume) + model.load_weights(list(weights.items())) + mx.eval(model.parameters()) + # Infer step from filename if possible + try: + start_step = int(Path(args.resume).stem.split("_")[-1]) + except ValueError: + pass + print(f"Resumed from {args.resume} at step {start_step}") + + # --------------------------------------------------------------- data + if args.data: + # Load tokenizer if data path provided + from transformers import AutoTokenizer + + tok_name = args.tokenizer or "meta-llama/Meta-Llama-3-8B" + tokenizer = AutoTokenizer.from_pretrained(tok_name) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + raw = [json.loads(l) for l in open(args.data)] + if args.sft: + data_iter = iterate_sft_batches( + raw, tokenizer, args.batch_size, args.max_length + ) + else: + samples = tokenize_file(args.data, tokenizer, args.max_length) + data_iter = iterate_batches(samples, args.batch_size) + else: + # Synthetic random data for quick testing + print("No data file specified — using random synthetic token ids.") + rng = np.random.default_rng(args.seed) + + def _random_iter(): + while True: + ids = rng.integers( + 0, + args.vocab_size, + size=(args.batch_size, args.max_length), + ) + yield mx.array(ids) + + data_iter = _random_iter() + args.sft = False # Can't do SFT without real data + + # ------------------------------------------------------------ optimizer + optimizer = optim.AdamW(learning_rate=args.lr, weight_decay=0.1) + + # --------------------------------------------------------- training step + def loss_and_grad(model, batch, sft=False): + if sft: + input_ids, prompt_lengths = batch + return sft_loss(model, input_ids, prompt_lengths) + else: + return pretrain_loss(model, batch) + + loss_and_grad_fn = nn.value_and_grad(model, loss_and_grad) + + max_norm = args.grad_clip + + def clip_gradients(grads): + leaves = [g for _, g in tree_flatten(grads) if isinstance(g, mx.array)] + total_norm = mx.sqrt(sum(mx.sum(g**2) for g in leaves)) + scale = mx.minimum(max_norm / (total_norm + 1e-6), 1.0) + return tree_map(lambda g: g * scale if isinstance(g, mx.array) else g, grads) + + state = [model.state, optimizer.state] + + @partial(mx.compile, inputs=state, outputs=state) + def train_step(batch, sft=False): + loss, grads = loss_and_grad_fn(model, batch, sft) + grads = clip_gradients(grads) + optimizer.update(model, grads) + return loss + + # ----------------------------------------------------------------- loop + losses = [] + t_start = time.perf_counter() + + for step in range(start_step, args.iters): + # Update learning rate + lr = cosine_schedule(step, args.iters, args.warmup, args.lr, args.min_lr) + optimizer.learning_rate = lr + + batch = next(data_iter) + loss = train_step(batch, args.sft) + mx.eval(state) + losses.append(loss.item()) + + if (step + 1) % args.log_every == 0: + elapsed = time.perf_counter() - t_start + avg_loss = sum(losses[-args.log_every :]) / args.log_every + tps = args.log_every / elapsed + print( + f"step {step + 1:6d} | loss {avg_loss:.4f} | lr {lr:.2e} | {tps:.1f} it/s" + ) + t_start = time.perf_counter() + + if (step + 1) % args.save_every == 0: + ckpt = save_dir / f"weights_{step + 1:07d}.safetensors" + flat = dict(tree_flatten(model.parameters())) + mx.save_safetensors(str(ckpt), flat) + print(f"Saved checkpoint → {ckpt}") + + # Final checkpoint + ckpt = save_dir / f"weights_{args.iters:07d}.safetensors" + flat = dict(tree_flatten(model.parameters())) + mx.save_safetensors(str(ckpt), flat) + print(f"Training complete. Final checkpoint → {ckpt}") + + +if __name__ == "__main__": + main()