diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 16cf5d2b6..93f872a39 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -23,6 +23,7 @@ LlamaTokenizer, LlamaForCausalLM, T5Tokenizer, + Gemma3ForCausalLM, ) from fastchat.constants import CPU_ISA @@ -36,6 +37,7 @@ from fastchat.model.model_exllama import generate_stream_exllama from fastchat.model.model_xfastertransformer import generate_stream_xft from fastchat.model.model_cllm import generate_stream_cllm +from fastchat.model.model_gemma3 import generate_stream_gemma3 from fastchat.model.monkey_patch_non_inplace import ( replace_llama_attn_with_non_inplace_operations, @@ -253,7 +255,12 @@ def load_model( kwargs = {"torch_dtype": torch.float16} import transformers - version = tuple(int(v) for v in transformers.__version__.split(".")) + try: + version = tuple(int(v) for v in transformers.__version__.split(".")) + except ValueError: + # some versions of transformers have a different version format ( + # e.g. 4.50.0.dev0) and these break this parser so we set a default + version = (4, 36, 0) if version < (4, 35, 0): # NOTE: Recent transformers library seems to fix the mps issue, also # it has made some changes causing compatibility issues with our @@ -414,6 +421,7 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str): is_xft = "xft" in model_type is_yuan = "yuan" in model_type is_cllm = "consistency-llm" in model_path.lower() + is_gemma3 = "gemma-3" in model_path.lower() if is_chatglm: return generate_stream_chatglm @@ -429,6 +437,8 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str): return generate_stream_yuan2 elif is_cllm: return generate_stream_cllm + elif is_gemma3: + return generate_stream_gemma3 elif peft_share_base_weights and is_peft: # Return a curried stream function that loads the right adapter @@ -453,6 +463,7 @@ def generate_stream_peft( is_xft = "xft" in base_model_type is_yuan = "yuan" in base_model_type is_cllm = "consistency-llm" in model_path.lower() + is_gemma3 = "gemma-3" in model_path.lower() generate_stream_function = generate_stream if is_chatglm: @@ -469,6 +480,8 @@ def generate_stream_peft( generate_stream_function = generate_stream_yuan2 elif is_cllm: generate_stream_function = generate_stream_cllm + elif is_gemma3: + generate_stream_function = generate_stream_gemma3 for x in generate_stream_function( model, tokenizer, @@ -817,6 +830,31 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict): ) return model, tokenizer +class Gemma3Adapter(BaseModelAdapter): + """The model adapter for google/gemma-3""" + + def match(self, model_path: str): + return "gemma-3" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + device_map = from_pretrained_kwargs.get("device_map", None) + if device_map == "sequential": + device_map = "auto" + # print("From pretrained kwargs", from_pretrained_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision) + model = Gemma3ForCausalLM.from_pretrained( + model_path, + revision=revision, + torch_dtype=torch.bfloat16, + device_map=device_map, + ) + return model, tokenizer + + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("gemma") + class KoalaAdapter(BaseModelAdapter): """The model adapter for Koala""" @@ -2500,8 +2538,12 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("api_based_default") + + + # Note: the registration order matters. # The one registered earlier has a higher matching priority. +register_model_adapter(Gemma3Adapter) register_model_adapter(PeftModelAdapter) register_model_adapter(StableVicunaAdapter) register_model_adapter(VicunaAdapter) diff --git a/fastchat/model/model_gemma3.py b/fastchat/model/model_gemma3.py new file mode 100644 index 000000000..61d41379c --- /dev/null +++ b/fastchat/model/model_gemma3.py @@ -0,0 +1,132 @@ +from threading import Thread +import gc +import torch +from transformers import TextIteratorStreamer + +def generate_stream_gemma3( + model, + tokenizer, + params, + device, + context_len, + stream_interval=2, + judge_sent_end=False +): + """Custom generate stream function for Gemma-3 models""" + # Get parameters from the request + prompt = params.get("prompt", "") + messages = params.get("messages", None) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 256)) + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + model_name = params.get("model", None) + + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + + is_base_model = "pt" in model_name.lower() or "base" in model_name.lower() + + if not is_base_model: + # Format input based on whether we have messages or a plain prompt + if messages: + inputs = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(model.device) + else: + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + inputs = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(model.device) + else: + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + input_ids = inputs["input_ids"] + input_echo_len = input_ids.shape[1] + + # Configure generation parameters + generate_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0.0, + "temperature": temperature if temperature > 0.0 else 1.0, + } + + if top_p < 1.0: + generate_kwargs["top_p"] = top_p + if top_k > 0: + generate_kwargs["top_k"] = top_k + if repetition_penalty > 1.0: + generate_kwargs["repetition_penalty"] = repetition_penalty + + streamer = TextIteratorStreamer(tokenizer, skip_prompt=not echo, skip_special_tokens=True) + generate_kwargs["streamer"] = streamer + + # Start generation in a separate thread + thread = Thread(target=lambda: model.generate(input_ids=input_ids, **generate_kwargs)) + thread.start() + + # Track generation progress + generated_tokens = 0 + output_text = "" + + # Stream tokens + for new_text in streamer: + output_text += new_text + generated_tokens += 1 + + # Check for stop strings + should_stop = False + if stop_str: + if isinstance(stop_str, str): + if stop_str in output_text: + output_text = output_text[: output_text.find(stop_str)] + should_stop = True + elif isinstance(stop_str, list): + for stop in stop_str: + if stop in output_text: + output_text = output_text[: output_text.find(stop)] + should_stop = True + break + + # Stream at intervals or when stopping + if generated_tokens % stream_interval == 0 or should_stop: + yield { + "text": output_text, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": generated_tokens, + "total_tokens": input_echo_len + generated_tokens, + }, + "finish_reason": "stop" if should_stop else None, + } + + if should_stop: + break + + # Final output with finish reason + if thread.is_alive(): + thread.join( + timeout=3600 + ) # Arbitrary value, but if it doesn't complete in this much time then something is wrong + + yield { + "text": output_text, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": generated_tokens, + "total_tokens": input_echo_len + generated_tokens, + }, + "finish_reason": "length", + } + + # Clean up + gc.collect() + torch.cuda.empty_cache() + if device == "xpu": + torch.xpu.empty_cache() + if device == "npu": + torch.npu.empty_cache() \ No newline at end of file diff --git a/fastchat/serve/controller.py b/fastchat/serve/controller.py index 42d928403..edbba39cb 100644 --- a/fastchat/serve/controller.py +++ b/fastchat/serve/controller.py @@ -28,7 +28,7 @@ from fastchat.utils import build_logger -logger = build_logger("controller", "controller.log") +logger = None class DispatchMethod(Enum): @@ -351,6 +351,7 @@ async def worker_api_get_status(request: Request): def create_controller(): + global logger parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=21001) @@ -367,7 +368,14 @@ def create_controller(): default=False, help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", ) + parser.add_argument( + "--log-file", + type=str, + default="controller.log", + help="Path to the controller log file", + ) args = parser.parse_args() + logger = build_logger("controller", args.log_file) logger.info(f"args: {args}") controller = Controller(args.dispatch_method) diff --git a/fastchat/serve/inference.py b/fastchat/serve/inference.py index 6d155aab7..ad676ce46 100644 --- a/fastchat/serve/inference.py +++ b/fastchat/serve/inference.py @@ -83,6 +83,8 @@ def generate_stream( echo = bool(params.get("echo", True)) stop_str = params.get("stop", None) stop_token_ids = params.get("stop_token_ids", None) or [] + logprobs_requested = params.get("logprobs") is not None + top_logprobs_n = int(params.get("top_logprobs_n", 5) if logprobs_requested else 0) if tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(tokenizer.eos_token_id) @@ -116,9 +118,12 @@ def generate_stream( past_key_values = out = None token_logprobs = [None] # The first token has no logprobs. + top_logprobs_list = [{}] # The first token has no top logprobs. sent_interrupt = False finish_reason = None stopped = False + last_sent_token_pos = 0 + for i in range(max_new_tokens): if i == 0: # prefill if model.config.is_encoder_decoder: @@ -142,6 +147,8 @@ def generate_stream( shift_input_ids[0].tolist(), shift_logits[0] ): token_logprobs.append(logit[label_id]) + # Add empty top_logprobs during prefill (would need to reconstruct full logits tensor to get these) + top_logprobs_list.append({}) else: # decoding if model.config.is_encoder_decoder: out = model.decoder( @@ -197,6 +204,28 @@ def generate_stream( torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() ) + # Calculate top logprobs for the current token if needed + if logprobs_requested and top_logprobs_n > 0: + # Get raw logits for current position + current_logits = torch.log_softmax(logits[0, -1, :], dim=-1) + + # Get top tokens and their logprobs + topk_logits, topk_indices = torch.topk(current_logits, min(top_logprobs_n, len(current_logits))) + + # Create dictionary of token → logprob + top_dict = {} + for logit, token_id in zip(topk_logits.tolist(), topk_indices.tolist()): + token_text = tokenizer.decode([token_id]) # Use list to ensure proper decoding + if token_text and token_text.strip(): # Check if token is non-empty after stripping + # If the same token appears with different logprobs, keep the highest one + if token_text not in top_dict or logit > top_dict[token_text]: + top_dict[token_text] = logit + + top_logprobs_list.append(top_dict) + else: + top_logprobs_list.append({}) + + if token in stop_token_ids: stopped = True else: @@ -219,21 +248,26 @@ def generate_stream( ) ret_logprobs = None if logprobs is not None: + # Calculate the start position for this streaming chunk + if echo: + start_pos = last_sent_token_pos + tokens_to_send = output_ids[start_pos:] + else: + start_pos = max(last_sent_token_pos, input_echo_len) + tokens_to_send = output_ids[start_pos:] + + # Update last sent position for next stream chunk + last_sent_token_pos = len(output_ids) + + # Format response with only new tokens ret_logprobs = { "text_offset": [], - "tokens": [ - tokenizer.decode(token) - for token in ( - output_ids if echo else output_ids[input_echo_len:] - ) - ], - "token_logprobs": token_logprobs - if echo - else token_logprobs[input_echo_len:], - "top_logprobs": [{}] - * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + "tokens": [tokenizer.decode(token) for token in tokens_to_send], + "token_logprobs": token_logprobs[start_pos:], + "top_logprobs": top_logprobs_list[start_pos:], } - # Compute text_offset + + # Compute text_offset for just this chunk curr_pos = 0 for text in ret_logprobs["tokens"]: ret_logprobs["text_offset"].append(curr_pos) diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index 0af680bb5..a7cb72bb7 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -261,6 +261,7 @@ async def api_model_details(request: Request): parser.add_argument( "--conv-template", type=str, default=None, help="Conversation prompt template." ) + parser.add_argument("--model_dtype", type=str, default="auto") parser.add_argument( "--trust_remote_code", action="store_false", @@ -281,6 +282,8 @@ async def api_model_details(request: Request): parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() + if args.model_dtype and args.model_dtype != "auto" and args.model_dtype.strip() != "": + args.dtype = args.model_dtype if args.model_path: args.model = args.model_path if args.num_gpus > 1: