Skip to content

Transformers v5.0.0rc0

Pre-release
Pre-release

Choose a tag to compare

@LysandreJik LysandreJik released this 01 Dec 18:14
· 117 commits to main since this release

Transformers v5 release notes

image
  • Highlights
  • Significant API changes: dynamic weight loading, tokenization
  • Backwards Incompatible Changes
  • Bugfixes and improvements

Highlights

We are excited to announce the initial release of Transformers v5. This is the first major release in five years, and the release is significant: 800 commits have been pushed to main since the latest minor release. This release removes a lot of long-due deprecations, introduces several refactors that significantly simplify our APIs and internals, and comes with a large number of bug fixes.

We give an overview of our focus for this release in the following blogpost. In these release notes, we'll focus directly on the refactors and new APIs coming with v5.

This release is a release candidate (RC). It is not the final v5 release, and we will push on pypi as a pre-release. This means that the current release is purely opt-in, as installing transformers without specifying this exact release will install the latest version instead (v4.57.3 as of writing).

In order to install this release, please do so with the following:

pip install transformers --pre

For us to deliver the best package possible, it is imperative that we have feedback on how the toolkit is currently working for you. Please try it out, and open an issue in case you're facing something inconsistent/a bug.

Transformers version 5 is a community endeavor, and this is the last mile. Let's ship this together!

Significant API changes

Note

👀 Nothing is final and things are still actively in movement. We have a section dedicated to what is planned for future release candidates, yet is known not to work in the RC0. Look for "Disclaimers for the RC0".

We'll be eagerly awaiting your feedback in our GitHub issues!

Tokenization

Just as we moved towards a single backend library for model definition, we want our tokenizers, and the Tokenizer object to be a lot more intuitive. With v5, tokenizer definition is much simpler; one can now initialize an empty LlamaTokenizer and train it directly on your corpus.

Defining a new tokenizer object should be as simple as this:

from transformers import TokenizersBackend, generate_merges
from tokenizers import pre_tokenizers, Tokenizer
from tokenizers.model import BPE

class Llama5Tokenizer(TokenizersBackend):
    def __init__(self, unk_token="<unk>",bos_token="<s>", eos_token="</s>", vocab=None, merges=None ):
        if vocab is None:
            self._vocab = {
                str(unk_token): 0,
                str(bos_token): 1,
                str(eos_token): 2,
            }

        else:
            self._vocab = vocab

        if merges is not None:
            self._merges = merges
        else:
            self._merges = generate_merges(filtered_vocab)

        self._tokenizer = Tokenizer(
            BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True)
        )
        self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
            replacement="▁", prepend_scheme=_get_prepend_scheme(self.add_prefix_space, self), split=False
        )
        super().__init__(
            tokenizer_object=self._tokenizer,
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
        )

Once the tokenizer is defined as above, you can load it with the following: Llama5Tokenizer(). Doing this returns you an empty, trainable tokenizer that follows the definition of the authors of Llama5 (it does not exist yet 😉).

The above is the main motivation towards refactoring tokenization: we want tokenizers to behave similarly to models: trained or empty, and with exactly what is defined in their class definition.

Backend Architecture Changes: moving away from the slow/fast tokenizer separation

Up to now, transformers maintained two parallel implementations for many tokenizers:

  • "Slow" tokenizers (tokenization_<model>.py) - Python-based implementations, often using SentencePiece as the backend.
  • "Fast" tokenizers (tokenization_<model>_fast.py) - Rust-based implementations using the 🤗 tokenizers library.

In v5, we consolidate to a single tokenizer file per model: tokenization_<model>.py. This file will use the most appropriate backend available:

  1. TokenizersBackend (preferred): Rust-based tokenizers from the 🤗 tokenizers library. In general it provides optimal performance, but it also offers a lot more features that are commonly adopted across the ecosystem:
  • handling additional tokens
  • a full python API for setting and updating
  • automatic parallelization,
  • automatic offsets
  • customization
  • training
  1. SentencePieceBackend: for tokenizers requiring the sentencepiece library. It inherits from PythonBackend.
  2. PythonBackend: a Python implementations of the features provided by tokenizers. Basically allows adding tokens.
  3. MistralCommonBackend: relies on MistralCommon's tokenization library. (Previously known as the MistralCommonTokenizer)

The AutoTokenizer automatically selects the appropriate backend based on available files and dependencies. This is transparent, you continue to use AutoTokenizer.from_pretrained() as before. This allows transformers to be future-proof and modular to easily support future backends.

Defining a tokenizers outside of the existing backends

We enable users and tokenizer builders to define their own tokenizers from top to bottom. Tokenizers are usually defined using a backend such as tokenizers, sentencepiece or mistral-common, but we offer the possibility to design the tokenizer at a higher-level, without relying on those backends.

To do so, you can import the PythonBackend (which was previously known as PreTrainedTokenizer). This class encapsulates all the logic related to added tokens, encoding, and decoding.

If you want something even higher up the stack, then PreTrainedTokenizerBase is what PythonBackend inherits from. It contains the very basic tokenizer API features:

  • encode
  • decode
  • vocab_size
  • get_vocab
  • convert_tokens_to_ids
  • convert_ids_to_tokens
  • from_pretrained
  • save_pretrained
  • among a few others

API Changes

1. Direct tokenizer initialization with vocab and merges

Starting with v5, we now enable initializing blank, untrained tokenizers-backed tokenizers:

from transformers import LlamaTokenizer

tokenizer = LlamaTokenizer()

This tokenizer will therefore follow the definition of the LlamaTokenizer as defined in its class definition. It can then be trained on a corpus as can be seen in the tokenizers documentation.

These tokenizers can also be initialized from vocab and merges (if necessary), like the previous "slow" tokenizers:

from transformers import LlamaTokenizer

vocab = {"<unk>": 0, "<s>": 1, "</s>": 2, "hello": 3, "world": 4}
merges = [("h", "e"), ("l", "l"), ("o", " ")]

tokenizer = LlamaTokenizer(vocab=vocab, merges=merges)

This tokenizer will behave as a Llama-like tokenizer, with an updated vocabulary. This allows comparing different tokenizer classes with the same vocab; therefore enabling the comparison of different pre-tokenizers, normalizers, etc.

⚠️ The vocab_file (as in, a path towards a file containing the vocabulary) cannot be used to initialize the LlamaTokenizer as loading from files is reserved to the from_pretrained method.

2. Simplified decoding API

The batch_decode and decode methods have been unified to reflect behavior of the encode method. Both single and batch decoding now use the same decode method. See an example of the new behavior below:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small") 
inputs = ["hey how are you?", "fine"]
tokenizer.decode(tokenizer.encode(inputs))

Gives:

- 'hey how are you?</s> fine</s>'
+ ['hey how are you?</s>', 'fine</s>']

We expect encode and decode to behave, as two sides of the same coin: encode, process, decode, should work.

Note

A common use-case would be: encode, model.generate, decode. However, using generate would return list[list[int]], which would then be incompatible with decode.

3. Unified encoding API

The encode_plus method is deprecated in favor of the single __call__ method.

4. apply_chat_template returns BatchEncoding

Previously, apply_chat_template returned input_ids for backward compatibility. Starting with v5, it now consistently returns a BatchEncoding dict like other tokenizer methods.

# v5
messages = [
    {"role": "user", "content": "Hello!"},
    {"role": "assistant", "content": "Hi there!"}
]

# Now returns BatchEncoding with input_ids, attention_mask, etc.
outputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
print(outputs.keys())  # dict_keys(['input_ids', 'attention_mask'])

5. Removed legacy configuration file saving:

We simplify the serialization of tokenization attributes:

  • special_tokens_map.json - special tokens are now stored in tokenizer_config.json.
  • added_tokens.json - added tokens are now stored in tokenizer.json.
  • added_tokens_decoder is only stored when there is no tokenizer.json.

When loading older tokenizers, these files are still read for backward compatibility, but new saves use the consolidated format. We're gradually moving towards consolidating attributes to fewer files so that other libraries and implementations may depend on them more reliably.

6. Model-Specific Changes

Several models that had identical tokenizers now import from their base implementation:

  • LayoutLM → uses BertTokenizer
  • LED → uses BartTokenizer
  • Longformer → uses RobertaTokenizer
  • LXMert → uses BertTokenizer
  • MT5 → uses T5Tokenizer
  • MVP → uses BartTokenizer

These modules will eventually be removed altogether.

Removed T5-specific workarounds

The internal _eventually_correct_t5_max_length method has been removed. T5 tokenizers now handle max length consistently with other models.

Testing Changes

A few testing changes specific to tokenizers have been applied:

  • Model-specific tokenization test files now focus on integration tests.
  • Common tokenization API tests (e.g., add_tokens, encode, decode) are now centralized and automatically applied across all tokenizers. This reduces test duplication and ensures consistent behavior

For legacy implementations, the original BERT Python tokenizer code (including WhitespaceTokenizer, BasicTokenizer, etc.) is preserved in bert_legacy.py for reference purposes.

7. Deprecated / Modified Features

Special Tokens Structure:

  • SpecialTokensMixin: Merged into PreTrainedTokenizerBase to simplify the tokenizer architecture.
  • special_tokens_map: Now only stores named special token attributes (e.g., bos_token, eos_token). Use extra_special_tokens for additional special tokens (formerly additional_special_tokens). all_special_tokens includes both named and extra tokens.
# v4
tokenizer.special_tokens_map  # Included 'additional_special_tokens'

# v5
tokenizer.special_tokens_map  # Only named tokens
tokenizer.extra_special_tokens  # Additional tokens
  • special_tokens_map_extended and all_special_tokens_extended: Removed. Access AddedToken objects directly from _special_tokens_map or _extra_special_tokens if needed.
  • additional_special_tokens: Still accepted for backward compatibility but is automatically converted to extra_special_tokens.

Deprecated Methods:

  • sanitize_special_tokens(): Already deprecated in v4, removed in v5.
  • prepare_seq2seq_batch(): Deprecated; use __call__() with text_target parameter instead.
# v4
model_inputs = tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts, max_length=128)

# v5
model_inputs = tokenizer(src_texts, text_target=tgt_texts, max_length=128, return_tensors="pt")
model_inputs["labels"] = model_inputs.pop("input_ids_target")
  • BatchEncoding.words(): Deprecated; use word_ids() instead.

Removed Methods:

  • create_token_type_ids_from_sequences(): Removed from base class. Subclasses that need custom token type ID creation should implement this method directly.
  • clean_up_tokenization(): Removed from base class. Now defined at model class level for models that need it (e.g., PLBart, CLVP, Wav2Vec2).
  • prepare_for_model(), build_inputs_with_special_tokens(), truncate_sequences(): Moved from tokenization_utils_base.py to tokenization_python.py for PythonBackend tokenizers. TokenizersBackend provides model-ready input via tokenize() and encode(), so these methods are no longer needed in the base class.
  • _switch_to_input_mode(), _switch_to_target_mode(), as_target_tokenizer(): Removed from base class. Use __call__() with text_target parameter instead.
# v4
with tokenizer.as_target_tokenizer():
    labels = tokenizer(tgt_texts, ...)

# v5
labels = tokenizer(text_target=tgt_texts, ...)
  • parse_response(): Removed from base class.

Disclaimers for the RC0

PEFT + MoE:

Because we are switching from the naive MOE (nn.ModuleList for experts) we currently have an issue with MoEs that have adapters. For more details see #42491 (comment).

We aim for this to be fixed and released in a following release candidate in the week that follows RC0.

Tensor parallel and Expert parallel + MoE

We are streamlining the MoE support with vLLM; while this is being implemented, tensor parallelism and expert parallelism aren't working as expected.
This is known and actively being worked on.

We aim for this to be fixed and released in a following release candidate in the week that follows RC0.

Custom pretrained models:

For anyone inheriting from a transformers PreTrainedModel, the weights are automatically initialized with the common scheme:

    @torch.no_grad()
    def _init_weights(self, module):
        """
        Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex
        initialization scheme, it should be overridden by the derived `PreTrainedModel` class. In case a model adds an explicit
        `nn.Parameter`, this method should also be overridden in order to initialize it correctly.
        """
        if hasattr(self.config, "initializer_range"):
            std = self.config.initializer_range or 0.02
        elif hasattr(self.config, "init_std"):
            std = self.config.init_std
        elif hasattr(self.config, "initializer_factor"):
            std = self.config.initializer_factor
        else:
            # 0.02 is the standard default value across the library
            std = getattr(self.config.get_text_config(), "initializer_range", 0.02)

        if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
            if getattr(module, "weight", None) is not None:
                init.normal_(module.weight, mean=0.0, std=std)
            if getattr(module, "bias", None) is not None:
                init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            if getattr(module, "weight", None) is not None:
                init.normal_(module.weight, mean=0.0, std=std)
                # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
                if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
                    init.zeros_(module.weight[module.padding_idx])
        elif isinstance(module, nn.MultiheadAttention):
            # This uses torch's original init
            module._reset_parameters()
        # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
        # between modelings (because they are prefixed with the model name)
        elif (
            isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
            or "LayerNorm" in module.__class__.__name__
            or "RMSNorm" in module.__class__.__name__
        ):
            # Norms can exist without weights (in which case they are None from torch primitives)
            if hasattr(module, "weight") and module.weight is not None:
                init.ones_(module.weight)
            if hasattr(module, "bias") and module.bias is not None:
                init.zeros_(module.bias)

If you want to avoid that, for now you should just do:

class CustomModel(Qwen3VLForConditionalGeneration):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.action_head = nn.Linear(1024, 7)
        self.positional_embedding = nn.Parameter(torch.randn(16, 1152))
        self.post_init()
    
    def _init_weights(self, module):
        pass 

There is a tracker for that here: #42418.

Library-wide changes with lesser impact

use_auth_token

The use_auth_token argument/parameter is deprecated in favor of token everywhere.
You should be able to search and replace use_auth_token with token and get the same logic.

Linked PR: #41666

Attention-related features

We decided to remove some features for the upcoming v5 as they are currently only supported in a few old models and no longer integrated in current model additions. It's recommended to stick to v4.x in case you need them. Following features are affected:

  • No more head masking, see #41076. This feature allowed to turn off certain heads during the attention calculation and only worked for eager.
  • No more relative positional biases in Bert-like models, see #41170. This feature was introduced to allow relative position scores within attention calculations (similar to T5). However, this feature is barely used in official models and a lot of complexity instead. It also only worked with eager.
  • No more head pruning, see #41417 by @gante. As the name suggests, it allowed to prune heads within your attention layers.

Updates to supported torch APIs

We dropped support for two torch APIs:

Those APIs were deprecated by the PyTorch team, and we're instead focusing on the supported APIs dynamo and export.

Quantization changes

We clean up the quantization API in transformers, and significantly refactor the weight loading as highlighted
above.

We drop support for two quantization arguments that have been deprecated for some time:

  • load_in_4bit
  • load_in_8bit

We remove them in favor of the quantization_config argument which is much more complete. As an example, here is how
you would load a 4-bit bitsandbytes model using this argument:

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B",
    device_map="auto",
    quantization_config=quantization_config
)

Configuration

  • Methods to init a nested config such as from_xxx_config are deleted. Configs can be init from the __init__ method in the same way. See #41314.
  • It is no longer possible to load a config class from a URL file. Configs must be loaded from either a local path or a repo on the Hub. See #42383.
  • All parameters for configuring model's rotary embedding are now stored under mode.rope_parameters, including the rope_theta and rope_type. Model's config.rope_parameters is a simple dictionaty in most cases, and can also be a nested dict in special cases (i.e. Gemma3 and ModernBert) with different rope parameterization for each layer type. Trying to get config.rope_theta will throw an attribute error from now on. See #39847 and #42255
  • Qwen-VL family configuration is in a nested format and trying to access keys directly will throw an error (e.g. config.vocab_size). Users are expected to access keys from their respective sub-configs (config.text_config.vocab_size).
  • Configurations of non-generative models (any model that doesn't call model.generate()) will no longer have a generation_config and model.config.generation_config will throw an attribute error.

Processing

Tokenization

  • Slow tokenizer files (aka: tokenization_<model>.py ) will be removed in favor of using fast tokenizer files tokenization_<model>_fast.py --> will be renamed to tokenization_<model>.py. As fast tokenizers are 🤗tokenizers - backend, they include a wider range of features that are maintainable and reliable.
  • Other backends (sentence piece, tokenizers, etc.) will be supported with a light layer if loading a fast tokenizer fails
  • Remove legacy files like special_tokens_map.json and added_tokens.json
  • Remove _eventually_correct_t5_max_length
  • encode_plus --> __call__
  • batch_decode --> decode

apply_chat_template by default returns naked input_ids rather than a BatchEncoding dict.
This was inconvenient - it should return a BatchEncoding dict like tokenizer.__call__(), but we were stuck with
it for backward compatibility. The method now returns a BatchEncoding.

Linked PRs:

Processing classes

  • In processing classes each attribute will be serialized under processor_config.json as a nested dict, instead of serializing attributes in their own config files. Loading will be supported for all old format processors (#41474)
  • XXXFeatureExtractors classes are completely removed in favor of XXXImageProcessor class for all vision models (#41174)
  • Minor change: XXXFastImageProcessorKwargs is removed in favor of XXXImageProcessorKwargs which will be shared between fast and slow processors (#40931)

Modeling

  • Some RotaryEmbeddings layers will start returning a dict of tuples, in case the model uses several RoPE configurations (Gemma2, ModernBert). Each value will be a tuple of "cos, sin" per RoPE type.
  • Config attribute for RotaryEmbeddings layer will be unified and accessed via config.rope_parameters. Config attr for rope_theta might not be accessible anymore for some models, and instead will be in config.rope_parameters['rope_theta']. BC will be supported for a while as much as possible, and in the near future we'll gradually move to the new RoPE format (#39847)
  • Vision Language models will not have a shortcut access to its language and vision component from the generative model via model.language_model. It is recommended to either access the module with model.model.language_model or model.get_decoder(). See #42156

Generate

  • Old, deprecated output type aliases were removed (e.g. GreedySearchEncoderDecoderOutput). We now only have 4 output classes built from the following matrix: decoder-only vs encoder-decoder, uses beams vs doesn't use beams (#40998)
  • Removed deprecated classes regarding decoding methods that were moved to the Hub due to low usage (constraints and beam scores) (#41223)
  • If generate doesn't receive any KV Cache argument, the default cache class used is now defined by the model (as opposed to always being DynamicCache) (#41505)
  • Generation parameters are no longer accessible via model's config. If generation paramaters are serialized in config.json for any old model, it will be loaded back into model's generation config. Users are expected to access or modify generation parameters only with model.generation_config.do_sample = True.

Trainer

New Features

  • ALST/Ulysses Sequence Parallelism Integration
    • Added sequence parallelism support via HF Accelerate for training with longer sequences. Enables splitting sequences across devices using ALST (All-to-All Long Sequence Training) and Ulysses algorithms with DeepSpeed.
  • Improved compute_loss_func Handling
    • compute_loss_func now always takes priority over the model's built-in loss computation, giving users consistent control over custom loss functions.
  • num_items_in_batch in Prediction Step
    • The num_items_in_batch argument is now passed to compute_loss during prediction_step, enabling proper loss scaling during evaluation.

Breaking Changes

  • report_to now defaults to "none"
    • Logging integrations are no longer auto-detected by default; users must explicitly specify which reporting backends to use.

Removing arguments without deprecation cycle in TrainingArguments due to low usage

  • mp_parameters -> legacy param that was later on added to the Sagemaker trainer
  • _n_gpu -> not intended for users to set, we will initialize it correctly instead of putting it in the TrainingArguments
  • overwrite_output_dir - > replaced by resume_from_checkpoint, and it was only used in the examples script, no impact on Trainer.
  • logging_dir -> only used for tensorboard, set TENSORBOARD_LOGGING_DIR env var instead
  • jit_mode_eval -> use use_torch_compile instead, as torchscript is not recommended anymore
  • tpu_num_cores-> It is actually better to remove it, as it is not recommended to set the number of cores. By default, all TPU cores are used . Set TPU_NUM_CORES env var instead
  • past_index -> it was only used for a very small number of models that have special architecture like transformersxl + it was not documented at all how to train those models
  • ray_scope -> only for a minor arg for ray integration. Set RAY_SCOPE var env instead
  • warmup_ratio -> use warmup_step instead. We combined both args together by allowing passing float values in warmup_step.

Removing deprecated arguments in TrainingArguments

  • fsdp_min_num_params and fsdp_transformer_layer_cls_to_wrap -> use fsdp_config
  • tpu_metrics_debug -> debug
  • push_to_hub_token -> hub_token
  • push_to_hub_model_id and push_to_hub_organization -> hub_model_id
  • include_inputs_for_metrics -> include_for_metrics
  • per_gpu_train_batch_size -> per_device_train_batch_size
  • per_gpu_eval_batch_size -> per_device_eval_batch_size
  • use_mps_device -> mps will be used by default if detected
  • fp16_backend and half_precision_backend -> we will only rely on torch.amp as everything has been upstreamed to torch
  • no_cuda -> use_cpu
  • include_tokens_per_second -> include_num_input_tokens_seen
  • use_legacy_prediction_loop -> we only use evaluation_loop function from now on

Removing deprecated arguments in Trainer

  • tokenizer in initialization -> processing_class
  • model_path in train() -> resume_from_checkpoint

Removed features for Trainer

  • sigpot integration for hp search was removed as the library was archived + the api stopped working
  • drop support for sagemaker API <1.10
  • bump accelerate minimum version to 1.1.0
  • bump peft minimum version to 0.18.0
  • bump bitsandbytes minimum version to 0.46.1

New defaults for Trainer

  • use_cache in the model config will be set to False. You can still change the cache value through TrainingArguments usel_cache argument if needed.

Pipeline

  • Image text to text pipelines will no longer accept images as a separate argument along with conversation chats. Image data has to be embedded in the chat's "content" field. See #42359

PushToHubMixin

  • removed deprecated organization and repo_url from PushToHubMixin. You must pass a repo_id instead.
  • removed ignore_metadata_errors from PushToMixin. In practice if we ignore errors while loading the model card, we won't be able to push the card back to the Hub so it's better to fail early and not provide the option to fail later.
  • push_to_hub do not accept **kwargs anymore. All accepted parameters are explicitly documented.
  • arguments of push_to_hub are now keyword-only to avoid confusion. Only repo_id can be positional since it's the main arg.
  • removed use_temp_dir argument from push_to_hub. We now use a tmp dir in all cases.

Linked PR: #42391.

CLI

The deprecated transformers-cli ... command was deprecated, transformers ... is now the only CLI entry point.

transformers CLI has been migrated to Typer, making it easier to maintain + adding some nice features out of
the box (improved --help section, autocompletion).

Biggest breaking change is in transformers chat. This command starts a terminal UI to interact with a chat model.
It used to also be able to start a Chat Completion server powered by transformers and chat with it. In this revamped
version, this feature has been removed in favor of transformers serve. The goal of splitting transformers chat
and transformers serve is to define clear boundaries between client and server code. It helps with maintenance
but also makes the commands less bloated. The new signature of transformers chat is:

Usage: transformers chat [OPTIONS] BASE_URL MODEL_ID [GENERATE_FLAGS]...

Chat with a model from the command line.

It works hand in hand with transformers serve, which means that if transformers serve is running on its default endpoint, transformers chat can be launched as follows:

transformers chat HuggingFaceTB/SmolLM3-3B

It can however use any OpenAI API compatible HTTP endpoint:

transformers chat HuggingFaceTB/SmolLM3-3B https://router.huggingface.co/v1

Linked PRs:

Removal of the run method

The transformers run (previously transformers-cli run) is an artefact of the past, was not documented nor tested,
and isn't part of any public documentation. We're removing it for now and ask you to please let us know in case
this is a method you are using; in which case we should bring it back with better support.

Linked PR: #42447

Environment variables

  • Legacy environment variables like TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, and PYTORCH_PRETRAINED_BERT_CACHE have been removed. Please use HF_HOME instead.
  • Constants HUGGINGFACE_CO_EXAMPLES_TELEMETRY, HUGGINGFACE_CO_EXAMPLES_TELEMETRY, HUGGINGFACE_CO_PREFIX, and HUGGINGFACE_CO_RESOLVE_ENDPOINT have been removed. Please use huggingface_hub.constants.ENDPOINT instead.

Linked PR: #42391.

Requirements update

transformers v5 pins the huggingface_hub version to >=1.0.0. See this migration guide to learn more about this major release. Here are to main aspects to know about:

  • switched the HTTP backend from requests to httpx. This change was made to improve performance and to support both synchronous and asynchronous requests the same way. If you are currently catching requests.HTTPError errors in your codebase, you'll need to switch to httpx.HTTPError.
  • related to 1., it is not possible to set proxies from your script. To handle proxies, you must set the HTTP_PROXY / HTTPS_PROXY environment variables
  • hf_transfer and therefore HF_HUB_ENABLE_HF_TRANSFER have been completed dropped in favor of hf_xet. This should be transparent for most users. Please let us know if you notice any downside!

typer-slim has been added as required dependency, used to implement both hf and transformers CLIs.

New model additions in v5

CWM

image

The Code World Model (CWM) model was proposed in CWM: An Open-Weights LLM for Research on Code Generation with World Models by Meta FAIR CodeGen Team. CWM is an LLM for code generation and reasoning about code that has, in particular, been trained to better represent and reason about how code and commands affect the state of a program or system. Specifically, we mid-trained CWM on a large number of observation-action trajectories from Python execution traces and agentic interactions in containerized environments. We post-trained with extensive multi-task RL in verifiable coding, math, and multi-turn software engineering environments.

SAM3

image

SAM3 (Segment Anything Model 3) was introduced in SAM 3: Segment Anything with Concepts.

The SAM3 addition adds four new architectures:

  • Sam3
  • Sam3Tracker
  • Sam3TrackerVideo
  • Sam3Video

SAM3 performs Promptable Concept Segmentation (PCS) on images. PCS takes text and/or image exemplars as input (e.g., "yellow school bus"), and predicts instance and semantic masks for every single object matching the concept.

Sam3Tracker and Sam3TrackerVideo perform Promptable Visual Segmentation (PVS) on images. PVS takes interactive visual prompts (points, boxes, masks) or text inputs to segment a specific object instance per prompt. This is the task that SAM 1 and SAM 2 focused on, and SAM 3 improves upon it. Sam3Tracker and Sam3TrackerVideo are updated versions of SAM2 Video that maintain the same API while providing improved performance and capabilities.

SAM3 Video performs Promptable Concept Segmentation (PCS) on videos. PCS takes text as input (e.g., "yellow school bus"), and predicts instance and semantic masks for every single object matching the concept, while preserving object identities across video frames. The model combines a detection module (SAM3) with a tracking module (SAM2-style tracker) to enable robust object tracking across video frames using text prompts.

LFM2 MoE

image

LFM2-MoE is a Mixture-of-Experts (MoE) variant of LFM2. The LFM2 family is optimized for on-device inference by combining short‑range, input‑aware gated convolutions with grouped‑query attention (GQA) in a layout tuned to maximize quality under strict speed and memory constraints.

LFM2‑MoE keeps this fast backbone and introduces sparse MoE feed‑forward networks to add representational capacity without significantly increasing the active compute path. The first LFM2-MoE release is LFM2-8B-A1B, with 8.3B total parameters and 1.5B active parameters. The model excels in quality (comparable to 3-4B dense models) and speed (faster than other 1.5B class models).

VideoLlama 3

image

The VideoLLaMA3 model is a major update to VideoLLaMA2 from Alibaba DAMO Academy.

  • [model] Add VideoLLaMA3 implementation by @lkhl in #40499

AudioFlamingo 3

image

Audio Flamingo 3 (AF3) is a fully open large audio–language model designed for robust understanding and reasoning over speech, environmental sounds, and music. AF3 pairs a Whisper-style audio encoder with a causal language model and performs replace-in-place audio–text fusion: the processor aligns post-pool audio frames to a dedicated placeholder token and the model replaces those token slots with projected audio embeddings during the forward pass.

The model checkpoint is available at: nvidia/audio-flamingo-3-hf

Highlights:

  • Unified audio encoder across speech, sound, and music.
  • Long-audio support via windowing and post-pool alignment (up to 10 minutes maximum). The model processes audio in 30-second windows with a hard limit of 20 windows (10 minutes total). Audio longer than 10 minutes will be truncated.
  • Deterministic fusion that preserves sequence length by replacing audio placeholder tokens with audio embeddings.

Nanochat

NanoChat is a compact decoder-only transformer model designed for educational purposes and efficient training. The model features several fundamental architectural innovations which are common in modern transformer models. Therefore, it is a good model to use as a starting point to understand the principles of modern transformer models. NanoChat is a variant of the Llama architecture, with simplified attention mechanism and normalization layers.

Bugfixes and improvements

Significant community contributions

The following contributors have made significant changes to the library over the last release:

  • @ArthurZucker
    • JetMoe Fix jetmoe after #40132 (#41324)
    • [ModularChecker] QOL for the modular checker (#41361)
    • [CB] Refactors the way we access paged (#41370)
    • Update from pretrained error when loading (#33380)
    • 🤦 CB nit! (#41413)
    • [from_pretrained] Small refactor from_pretrained: move around unrelated stuff (#41445)
    • update deps table (#42120)
    • Refactor weight loading (#41580)
    • Update conversion mapping to separate renaming from converting (#42254)
    • Auto convert tekken.json (#42299)
    • fix tekken pattern matching (#42363)
    • Small tp fix (#42366)
    • Fix tp (#42368)
    • misc don't recreate it (#42394)
  • @vasqu
    • 🚨 [v5] Remove relative position embeddings (for bert like models) (#41170)
    • [v5] Sync Bert and Bart eager attention (#41248)
    • [JetMoe] Fix KV head repetition and padding free (#41423)
    • 🚨 [Attention Masks] Bidirectional masks for encoder and encoder-decoder models (#41265)
    • [CI] Fix copies on main (#41486)
    • [Docs] Fix changed references (#41614)
    • [Executorch] Simplify for encoder models (#41627)
    • [Ernie 4.5 Moe] Fix Moe and offloading (#41385)
    • [Masks] Fix mask handling in eager for vision models (#41625)
    • [Attn] Allow dynamic causality in SDPA via Kwargs (#41692)
    • [Onnx docs] Remove some traces (#41791)
    • 🚨 [Clip] Fix masking and enable flash attention on all model types (#41750)
    • [Attn Masks] Non-vmap default for attention masks (#41852)
    • [T5Gemma] Fix cross attention cache (#41890)
    • [Pop2Piano] Fix cache usage (#42170)
    • [PEFT] Fix prefix tuning (#41696)
    • [PEFT] Fix the general test for prefix tuning (#42185)
    • [Pop2Piano] Fix tied weights (#42193)
    • [BLT] Fix cache usage (#42188)
    • [CI] Skip EfficientLoFTR test (#42327)
    • [Attn Masks] Lift bidirectional mask restriction on eager (#42325)
    • [Attn Masks] Add skip option for non-packed sequences (#42367)
    • [Mistral Tokenizers] Fix tokenizer detection (#42389)
    • [FA] Cleanup loading logic (#41427)
    • [CI] Add to run slow (#42459)
  • @ydshieh
    • [testing] update test_longcat_generation_cpu (#41368)
    • [testing] Fix JetMoeIntegrationTest (#41377)
    • Pickle - part 2 (#41476)
    • Try to remove pickle - BloomTokenizerFast (#41466)
    • [testing] reduce runtime of HunYuanMoEV1IntegrationTest:test_model_generation (#41373)
    • delete some tokenizer tests using pickle (#41514)
    • torch 2.9 don't ❤️ torchcodec 💔 (#41610)
    • Update a dataset reop link (#41618)
    • Remove the head masking block in some vision models (#41620)
    • improve utils/check_bad_commit.py (#41658)
    • torch 2.9 still don't ❤️ torchcodec 0.8 💔 (#41686)
    • path validation for security reason (#41256)
    • pin torchcodec on CI docker image (#41703)
    • further improve utils/check_bad_commit.py (#41658) (#41690)
    • Revert "Remove upper version bound of pandas" (#41744)
    • Fix bark after #41445 (#41645)
    • flash attn pytest marker (#41781)
    • unpin torch/torchcodec for CircleCI (#41839)
    • further reducing flakiness in utils/check_bad_commit.py (#41658) (#41815)
    • CI workflow for Flash Attn (#41857)
    • Update some workflow files (#41892)
    • Minor fix in docker image build workflow (#41949)
    • Run slow v2 (#41914)
    • Fix detectron2 installation in docker files (#41975)
    • Fix autoawq[kernels] installation in quantization docker file (#41978)
    • Fix torchcodec version in quantization docker file (#41988)
    • Fix run slow v2: empty report when there is only one model (#42002)
    • Fix torch+deepspeed docker file (#41985)
    • fix deeepspeed in AMD docker file (#42025)
    • Change trigger time for AMD CI (#42034)
    • Remove some custom datasets defined in codebase (#41511)
    • Cleanup workflow - part 1 (#42023)
    • Fix pr_slow_ci_suggestion.yml after #42023 (#42049)
    • Avoid explicit checkout in workflow (#42057)
    • Be careful at explicit checkout actions (#42060)
    • Fix another Argument list too long in pr_slow_ci_suggestion.yml (#42061)
    • Revert back to use GitHub context (#42066)
    • Fix inconsistency of commit sha during the workflow run (#42074)
    • Revert "permissions worflows fix" (#42110)
    • pin pytest<9 for now (#42162)
    • Update test_dynamic_cache_exportability_multiple_run (failing on torch 2.10 nightly) (#42212)
    • Reduce timing on CircleCI - part 1 (Use @slow for IntegrationTests) (#42206)
    • Make tests run in less time by reducing batch_size (#42213)
    • Revert "Make tests run in less time by reducing batch_size" (#42258)
    • delete already deprecated models (#42235)
    • Remove doc files of other langs for deleted models (#42276)
    • [testing] fix cwm (#42261)
  • @cyyever
    • Remove unnecessary list comprehension (#41305)
    • Remove unused function patameters (#41358)
    • Use accelerator API to free device memory (#41195)
    • Remove Python 3.9 classifier (#41410)
    • Remove KERAS_NLP_IMPORT_ERROR (#41468)
    • Import Callable from collections.abc (#41130)
    • Remove infer_device (#41088)
    • Fix Latex typesetting in documentation (#41177)
    • Fix typsetting and content of llm_tutorial_optimization.md (#41172)
    • More markdown file fixes (#41599)
    • Format MarkDown documentation and tiny fixes (#41638)
    • Fix typos in documentation (#41641)
    • Fix confusing cls assignment (#41642)
    • Use | for Optional and Union typing (#41646)
    • Remove require_torch_bf16_gpu (#40979)
    • Fix MarkDown syntax (#41676)
    • Use | for Optional and Union typing (#41675)
    • Enable faiss-cpu on Windows (#41678)
    • Fix Pylint warnings (#41644)
    • Enable FURB rules in ruff (#41395)
    • Remove upper version bound of pandas (#41677)
    • Fix documentation issues (#41726)
    • Apply RUFF PIE rules (#41727)
    • Replace Optional and Union typing with | in some source files (#42294)
    • Replace Optional and Union typing with | in some source files (#42372)
  • @yao-matrix
    • make some ut cases pass on xpu w/ latest torch (#41337)
    • fix asr ut failures (#41332)
    • enable new model uts to xpu and fix some failures on xpu (#41386)
    • enable some falcon-mamba uts on xpu (#41428)
    • enhance patched_tearDown to support python 3.11+ (#41429)
    • fix gemma3n case failure (#41426)
    • upgrade xpu docker file to torch 2.8 (#41551)
    • make apollo test case pass (#41805)
    • extend bitnet cases to xpu, all 8 cases pass (#41831)
    • extend 2 trainer test cases to xpu (#41829)
    • extend 2 blip2 and falcon_h1 test cases to xpu (#41825)
    • make lfm2_moe integration test pass on XPU (#41796)
    • fix some ut failures on XPU w/ torch 2.9 (#41923)
    • fix some ut failures on XPU w/ torch 2.9 (#41941)
    • fix prepare_config_and_inputs_for_common bug in llava test (#41942)
    • make recurrent_gemma and voxtral cases pass on xpu (#41958)
    • extend fp_quant cases to xpu (#41833)
    • fix tensor device placement issue of 2 UT cases (#41921)
    • fix continuous batching issues, extend ut cases to xpu (#41830)
  • @MekkCyber
    • [kernels] Kernel Config (#41232)
    • Fixing comments in init file (#41414)
    • [kernels] Cleanup deta kernel (#41470)
    • Cleaning hub kernels (#41477)
    • Remove DISABLE_KERNEL_MAPPING flag (#41475)
    • [kernels] Remove RWKV kernel finally ! (#41493)
    • [kernels] rm yoso kernel (#41495)
    • [kernels] rm mra kernels (#41507)
    • Revert "add rmsnorm kernels support for Intel XPU" (#41579)
    • [kernels] refactor function kernel calling (#41577)
    • Erroring when KernelConfig is passed without use_kernels = True (#41657)
    • Small Fix for imports (#41411)
    • [kernels] Add version to function mapping (#41685)
    • [quantization] fix compressed_tensors tests (#41780)
    • [quantization] Skip Fp8 tests when hardware capability < 8.9 (#41785)
    • [quantization] fix torchao tests after 0.14.0 release (#41777)
    • revert changes in _is_package_available (#41891)
    • [kernels] Add Tests & CI for kernels (#41765)
    • [kernels] change import time in KernelConfig (#42004)
    • [kernels] Fix XPU layernorm kernel (#41583)
    • [core] Fix torchao (#42289)
    • [core] fix mxfp4 (#42382)
    • [fp8] fix scales param name (#42434)
    • [quantization] make torchao tests slow (#42482)
  • @paulpak58
    • [Cache] lfm2 cache: allocate empty kv layers during init (#41396)
    • [Model] Lfm2Moe (#41401)
  • @gante
    • 🚨 [v5] Prune prune_heads (#41417)
    • [v5] rm utils/tf_ops/ (#41402)
    • [causallm tester] automate pipeline mappings + bloom tests (#41318)
    • 🚨 [v5] generate delegates default cache initialization to the model (#41505)
  • @zRzRzRzRzRzRzR
    • Update GLM-4.1V MMRope implementation (#41182)
    • Update GLM-4.6 doc (#41471)
    • Add aux loss for GLM-4.5V (#41564)
    • 4.1V Model and GLM-4.5V Model Conversion Code Updates (#41784)
    • GLM-V update with new processor (#42122)
  • @jacobkahn
    • Add Code World Model (CWM) (#41199)
  • @molbap
    • Update philosophy (#41438)
    • [QoL] modular conversion shows LoC saved (#41500)
    • Double router compute? (#41653)
    • Add vision contribution guide (#41456)
    • Modernize CLIP modeling code (#41546)
    • handle inputs from Siglip/Siglip2 non-automapped encoder layers (#41930)
    • Fix processor test for glm (#42233)
    • Tiny doc fix (#42296)
    • tiny fix for deepseekocr support [vllm] (#42423)
  • @Wauplin
    • Bump to hfh 1.0.0.rc5 to fix test (#41508)
    • Migrate transformers cli to Typer (#41487)
    • Remove deprecated use_auth_token parameter (#41666)
    • added more breaking changes
    • [cleanup] Don't use Repository in create_dummy_models.py script (#42380)
    • [cleanup] Remove deprecated load config from file (#42383)
    • [cleanup] Offline mode and cache dir from huggingface_hub constants + cleanup in PushToHubMixin (#42391)
  • @remi-or
    • Restore cuda graphs to continuous batching (#41421)
    • Fix an import error with PreTrainModel (#41571)
    • Add iter to DynamicCache (#41569)
    • Gemma3 fixes (#41572)
    • Benchmark overhaul (#41408)
    • Fix fp32_ln for various models (#41605)
    • Fix EncoderDecoder cache (#41612)
    • Switch to CB if cache_implementation == paged (#41655)
    • Small changes to benchmarking script (#41662)
    • Bump AMD docker (#41792)
    • Add a safeguard around a flaky test in gemma2 (#41811)
    • Use indices as position_ids in modernebert (#41789)
    • Move the Mi355 to regular docker (#41989)
    • More data in benchmarking (#41848)
    • Reduce the number of benchmark in the CI (#42008)
    • New docker from AMD (#42208)
    • Add prefix sharing to continuous batching (#42094)
    • Update torchcodec to match torchaudio version (#42288)
    • Gemma3 hybrid fix (#42287)
    • Make benchmarking lighter: clean-up result files and remove non-needed arguments (#42357)
    • Many small fixes for the CI (#42364)
    • Benchmark simplification (#42408)
  • @lkhl
    • [model] Add VideoLLaMA3 implementation (#40499)
  • @philiproeleveld
    • Add logits_to_keep to many older CausalLM models (#41335)
  • @AlphaOrOmega
    • Adding superglue fast image processing (#41394)
  • @echarlaix
    • [v5] Remove deprecated tranformers.onnx (#41700)
  • @Aravind-11
    • Add GLPNImageProcessorFast (#41725)
    • T5 migration to new masking interface (#41804)
    • 🚨 Remove generic output_attentions warning (#42334)
  • @DeXtAr47-oss
    • add fuyu fast image processors (#41817)
  • @lashahub
    • [models] Add AudioFlamingo3 integration (#40290)
  • @lilin-1
  • @burtenshaw
    • [MODEL] Nanochat implementation (#41634)
  • @itazap
    • rm slow tokenizers (#40936)