Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
LlamaTokenizer,
LlamaForCausalLM,
T5Tokenizer,
Gemma3ForCausalLM,
)

from fastchat.constants import CPU_ISA
Expand All @@ -36,6 +37,7 @@
from fastchat.model.model_exllama import generate_stream_exllama
from fastchat.model.model_xfastertransformer import generate_stream_xft
from fastchat.model.model_cllm import generate_stream_cllm
from fastchat.model.model_gemma3 import generate_stream_gemma3

from fastchat.model.monkey_patch_non_inplace import (
replace_llama_attn_with_non_inplace_operations,
Expand Down Expand Up @@ -253,7 +255,12 @@ def load_model(
kwargs = {"torch_dtype": torch.float16}
import transformers

version = tuple(int(v) for v in transformers.__version__.split("."))
try:
version = tuple(int(v) for v in transformers.__version__.split("."))
except ValueError:
# some versions of transformers have a different version format (
# e.g. 4.50.0.dev0) and these break this parser so we set a default
version = (4, 36, 0)
if version < (4, 35, 0):
# NOTE: Recent transformers library seems to fix the mps issue, also
# it has made some changes causing compatibility issues with our
Expand Down Expand Up @@ -414,6 +421,7 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
is_xft = "xft" in model_type
is_yuan = "yuan" in model_type
is_cllm = "consistency-llm" in model_path.lower()
is_gemma3 = "gemma-3" in model_path.lower()

if is_chatglm:
return generate_stream_chatglm
Expand All @@ -429,6 +437,8 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
return generate_stream_yuan2
elif is_cllm:
return generate_stream_cllm
elif is_gemma3:
return generate_stream_gemma3

elif peft_share_base_weights and is_peft:
# Return a curried stream function that loads the right adapter
Expand All @@ -453,6 +463,7 @@ def generate_stream_peft(
is_xft = "xft" in base_model_type
is_yuan = "yuan" in base_model_type
is_cllm = "consistency-llm" in model_path.lower()
is_gemma3 = "gemma-3" in model_path.lower()

generate_stream_function = generate_stream
if is_chatglm:
Expand All @@ -469,6 +480,8 @@ def generate_stream_peft(
generate_stream_function = generate_stream_yuan2
elif is_cllm:
generate_stream_function = generate_stream_cllm
elif is_gemma3:
generate_stream_function = generate_stream_gemma3
for x in generate_stream_function(
model,
tokenizer,
Expand Down Expand Up @@ -817,6 +830,31 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
)
return model, tokenizer

class Gemma3Adapter(BaseModelAdapter):
"""The model adapter for google/gemma-3"""

def match(self, model_path: str):
return "gemma-3" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
device_map = from_pretrained_kwargs.get("device_map", None)
if device_map == "sequential":
device_map = "auto"
# print("From pretrained kwargs", from_pretrained_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
model = Gemma3ForCausalLM.from_pretrained(
model_path,
revision=revision,
torch_dtype=torch.bfloat16,
device_map=device_map,
)
return model, tokenizer


def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("gemma")


class KoalaAdapter(BaseModelAdapter):
"""The model adapter for Koala"""
Expand Down Expand Up @@ -2500,8 +2538,12 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("api_based_default")





# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(Gemma3Adapter)
register_model_adapter(PeftModelAdapter)
register_model_adapter(StableVicunaAdapter)
register_model_adapter(VicunaAdapter)
Expand Down
132 changes: 132 additions & 0 deletions fastchat/model/model_gemma3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from threading import Thread
import gc
import torch
from transformers import TextIteratorStreamer

def generate_stream_gemma3(
model,
tokenizer,
params,
device,
context_len,
stream_interval=2,
judge_sent_end=False
):
"""Custom generate stream function for Gemma-3 models"""
# Get parameters from the request
prompt = params.get("prompt", "")
messages = params.get("messages", None)
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 256))
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
model_name = params.get("model", None)

if tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(tokenizer.eos_token_id)

is_base_model = "pt" in model_name.lower() or "base" in model_name.lower()

if not is_base_model:
# Format input based on whether we have messages or a plain prompt
if messages:
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(model.device)
else:
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
).to(model.device)
else:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

input_ids = inputs["input_ids"]
input_echo_len = input_ids.shape[1]

# Configure generation parameters
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0.0,
"temperature": temperature if temperature > 0.0 else 1.0,
}

if top_p < 1.0:
generate_kwargs["top_p"] = top_p
if top_k > 0:
generate_kwargs["top_k"] = top_k
if repetition_penalty > 1.0:
generate_kwargs["repetition_penalty"] = repetition_penalty

streamer = TextIteratorStreamer(tokenizer, skip_prompt=not echo, skip_special_tokens=True)
generate_kwargs["streamer"] = streamer

# Start generation in a separate thread
thread = Thread(target=lambda: model.generate(input_ids=input_ids, **generate_kwargs))
thread.start()

# Track generation progress
generated_tokens = 0
output_text = ""

# Stream tokens
for new_text in streamer:
output_text += new_text
generated_tokens += 1

# Check for stop strings
should_stop = False
if stop_str:
if isinstance(stop_str, str):
if stop_str in output_text:
output_text = output_text[: output_text.find(stop_str)]
should_stop = True
elif isinstance(stop_str, list):
for stop in stop_str:
if stop in output_text:
output_text = output_text[: output_text.find(stop)]
should_stop = True
break

# Stream at intervals or when stopping
if generated_tokens % stream_interval == 0 or should_stop:
yield {
"text": output_text,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": generated_tokens,
"total_tokens": input_echo_len + generated_tokens,
},
"finish_reason": "stop" if should_stop else None,
}

if should_stop:
break

# Final output with finish reason
if thread.is_alive():
thread.join(
timeout=3600
) # Arbitrary value, but if it doesn't complete in this much time then something is wrong

yield {
"text": output_text,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": generated_tokens,
"total_tokens": input_echo_len + generated_tokens,
},
"finish_reason": "length",
}

# Clean up
gc.collect()
torch.cuda.empty_cache()
if device == "xpu":
torch.xpu.empty_cache()
if device == "npu":
torch.npu.empty_cache()
10 changes: 9 additions & 1 deletion fastchat/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from fastchat.utils import build_logger


logger = build_logger("controller", "controller.log")
logger = None


class DispatchMethod(Enum):
Expand Down Expand Up @@ -351,6 +351,7 @@ async def worker_api_get_status(request: Request):


def create_controller():
global logger
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21001)
Expand All @@ -367,7 +368,14 @@ def create_controller():
default=False,
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
)
parser.add_argument(
"--log-file",
type=str,
default="controller.log",
help="Path to the controller log file",
)
args = parser.parse_args()
logger = build_logger("controller", args.log_file)
logger.info(f"args: {args}")

controller = Controller(args.dispatch_method)
Expand Down
58 changes: 46 additions & 12 deletions fastchat/serve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def generate_stream(
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
logprobs_requested = params.get("logprobs") is not None
top_logprobs_n = int(params.get("top_logprobs_n", 5) if logprobs_requested else 0)
if tokenizer.eos_token_id not in stop_token_ids:
stop_token_ids.append(tokenizer.eos_token_id)

Expand Down Expand Up @@ -116,9 +118,12 @@ def generate_stream(

past_key_values = out = None
token_logprobs = [None] # The first token has no logprobs.
top_logprobs_list = [{}] # The first token has no top logprobs.
sent_interrupt = False
finish_reason = None
stopped = False
last_sent_token_pos = 0

for i in range(max_new_tokens):
if i == 0: # prefill
if model.config.is_encoder_decoder:
Expand All @@ -142,6 +147,8 @@ def generate_stream(
shift_input_ids[0].tolist(), shift_logits[0]
):
token_logprobs.append(logit[label_id])
# Add empty top_logprobs during prefill (would need to reconstruct full logits tensor to get these)
top_logprobs_list.append({})
else: # decoding
if model.config.is_encoder_decoder:
out = model.decoder(
Expand Down Expand Up @@ -197,6 +204,28 @@ def generate_stream(
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
)

# Calculate top logprobs for the current token if needed
if logprobs_requested and top_logprobs_n > 0:
# Get raw logits for current position
current_logits = torch.log_softmax(logits[0, -1, :], dim=-1)

# Get top tokens and their logprobs
topk_logits, topk_indices = torch.topk(current_logits, min(top_logprobs_n, len(current_logits)))

# Create dictionary of token → logprob
top_dict = {}
for logit, token_id in zip(topk_logits.tolist(), topk_indices.tolist()):
token_text = tokenizer.decode([token_id]) # Use list to ensure proper decoding
if token_text and token_text.strip(): # Check if token is non-empty after stripping
# If the same token appears with different logprobs, keep the highest one
if token_text not in top_dict or logit > top_dict[token_text]:
top_dict[token_text] = logit

top_logprobs_list.append(top_dict)
else:
top_logprobs_list.append({})


if token in stop_token_ids:
stopped = True
else:
Expand All @@ -219,21 +248,26 @@ def generate_stream(
)
ret_logprobs = None
if logprobs is not None:
# Calculate the start position for this streaming chunk
if echo:
start_pos = last_sent_token_pos
tokens_to_send = output_ids[start_pos:]
else:
start_pos = max(last_sent_token_pos, input_echo_len)
tokens_to_send = output_ids[start_pos:]

# Update last sent position for next stream chunk
last_sent_token_pos = len(output_ids)

# Format response with only new tokens
ret_logprobs = {
"text_offset": [],
"tokens": [
tokenizer.decode(token)
for token in (
output_ids if echo else output_ids[input_echo_len:]
)
],
"token_logprobs": token_logprobs
if echo
else token_logprobs[input_echo_len:],
"top_logprobs": [{}]
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
"tokens": [tokenizer.decode(token) for token in tokens_to_send],
"token_logprobs": token_logprobs[start_pos:],
"top_logprobs": top_logprobs_list[start_pos:],
}
# Compute text_offset

# Compute text_offset for just this chunk
curr_pos = 0
for text in ret_logprobs["tokens"]:
ret_logprobs["text_offset"].append(curr_pos)
Expand Down
3 changes: 3 additions & 0 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ async def api_model_details(request: Request):
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument("--model_dtype", type=str, default="auto")
parser.add_argument(
"--trust_remote_code",
action="store_false",
Expand All @@ -281,6 +282,8 @@ async def api_model_details(request: Request):

parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.model_dtype and args.model_dtype != "auto" and args.model_dtype.strip() != "":
args.dtype = args.model_dtype
if args.model_path:
args.model = args.model_path
if args.num_gpus > 1:
Expand Down