diff --git a/benchmark/profile_restful_api.py b/benchmark/profile_restful_api.py index 1579188910..3f81dc3e49 100644 --- a/benchmark/profile_restful_api.py +++ b/benchmark/profile_restful_api.py @@ -12,6 +12,7 @@ import argparse import asyncio import csv +import io import json import os import random @@ -27,9 +28,12 @@ import aiohttp import numpy as np +import pybase64 import requests +from PIL import Image from tqdm.asyncio import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast +from transformers import (AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, + PreTrainedTokenizerFast) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=None) @@ -54,6 +58,7 @@ class RequestFuncInput: prompt_len: int output_len: int model: str + image_data: Optional[List[str]] extra_request_body: Dict[str, Any] @@ -65,8 +70,8 @@ class RequestFuncOutput: ttft: float = 0.0 # Time to first token itl: List[float] = field(default_factory=list) # List of inter-token latencies prompt_len: int = 0 - error: str = '' output_len: int = 0 + error: str = '' def remove_prefix(text: str, prefix: str) -> str: @@ -223,6 +228,114 @@ async def async_request_openai_completions( return output +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith('chat/completions'), "OpenAI Chat Completions API URL must end with 'chat/completions'." + + if request_func_input.image_data: + # Build multi-image content: a list of image_url entries followed by the text + content_items = [{ + 'type': 'image_url', + 'image_url': { + 'url': img_url + }, + } for img_url in request_func_input.image_data] + content_items.append({'type': 'text', 'text': request_func_input.prompt}) + messages = [ + { + 'role': 'user', + 'content': content_items, + }, + ] + else: + messages = [{'role': 'user', 'content': request_func_input.prompt}] + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + 'model': request_func_input.model, + 'messages': messages, + 'temperature': 0.0, + 'max_completion_tokens': request_func_input.output_len, + 'stream': not args.disable_stream, + 'ignore_eos': not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = {'Authorization': f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = '' + output_len = request_func_input.output_len + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, headers=headers) as response: + if response.status == 200: + if args.disable_stream: + # Non-streaming response + response_json = await response.json() + output.generated_text = response_json['choices'][0]['message']['content'] + output.success = True + output.latency = time.perf_counter() - st + output.ttft = (output.latency) # For non-streaming, TTFT = total latency + output.output_len = response_json.get('usage', {}).get('completion_tokens', output_len) + else: + # Streaming response + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode('utf-8'), 'data: ') + latency = time.perf_counter() - st + if chunk == '[DONE]': + pass + else: + data = json.loads(chunk) + + # Check if this chunk contains content + delta = data.get('choices', [{}])[0].get('delta', {}) + content = delta.get('content', '') + + if content: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += content + + # Check for usage info in final chunk + output_len = (data.get('usage') or {}).get('completion_tokens', output_len) + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = output_len + else: + output.error = ((response.reason or '') + ': ' + (await response.text())) + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = ''.join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + async def async_request_sglang_generate( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, @@ -333,12 +446,27 @@ def get_tokenizer(pretrained_model_name_or_path: str, ) -> Union[PreTrainedToken return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) +def get_processor(pretrained_model_name_or_path: str, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + assert (pretrained_model_name_or_path is not None and pretrained_model_name_or_path != '') + if pretrained_model_name_or_path.endswith('.json') or pretrained_model_name_or_path.endswith('.model'): + from sglang.srt.utils.hf_transformers_utils import get_processor + + return get_processor(pretrained_model_name_or_path) + + if pretrained_model_name_or_path is not None and not os.path.exists(pretrained_model_name_or_path): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) + return AutoProcessor.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) + + ASYNC_REQUEST_FUNCS = { 'sglang': async_request_sglang_generate, 'sglang-native': async_request_sglang_generate, 'sglang-oai': async_request_openai_completions, + 'sglang-oai-chat': async_request_openai_chat_completions, 'vllm': async_request_openai_completions, + 'vllm-chat': async_request_openai_chat_completions, 'lmdeploy': async_request_openai_completions, + 'lmdeploy-chat': async_request_openai_chat_completions, 'trt': async_request_trt_llm, 'gserver': async_request_gserver, } @@ -348,6 +476,8 @@ def get_tokenizer(pretrained_model_name_or_path: str, ) -> Union[PreTrainedToken class BenchmarkMetrics: completed: int total_input: int + total_input_text: int + total_input_vision: int total_output: int total_output_retokenized: int request_throughput: float @@ -407,10 +537,26 @@ def download_and_cache_file(url: str, filename: Optional[str] = None): return filename +@dataclass +class DatasetRow: + prompt: str + prompt_len: int + output_len: int + text_prompt_len: Optional[int] = None + vision_prompt_len: Optional[int] = None + image_data: Optional[List[str]] = None + + def __post_init__(self): + if self.text_prompt_len is None: + self.text_prompt_len = self.prompt_len + if self.vision_prompt_len is None: + self.vision_prompt_len = 0 + + def sample_sharegpt_requests(dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int] = None) -> List[Tuple[str, int, int]]: + fixed_output_len: Optional[int] = None) -> List[DatasetRow]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError('output_len too small') @@ -430,7 +576,7 @@ def sample_sharegpt_requests(dataset_path: str, random.shuffle(dataset) # Filter out sequences that are too long or too short - filtered_dataset: List[Tuple[str, int, int]] = [] + filtered_dataset: List[DatasetRow] = [] for i in range(len(dataset)): if len(filtered_dataset) == num_requests: break @@ -448,13 +594,26 @@ def sample_sharegpt_requests(dataset_path: str, if prompt_len > 1024 or (prompt_len + output_len > 2048 and fixed_output_len is None): # Prune too long sequences. continue - filtered_dataset.append((prompt, prompt_len, output_len)) - print(f'#Input tokens: {np.sum([x[1] for x in filtered_dataset])}') - print(f'#Output tokens: {np.sum([x[2] for x in filtered_dataset])}') + filtered_dataset.append(DatasetRow( + prompt=prompt, + prompt_len=prompt_len, + output_len=output_len, + )) + + print(f'#Input tokens: {sum(x.prompt_len for x in filtered_dataset)}') + print(f'#Output tokens: {sum(x.output_len for x in filtered_dataset)}') return filtered_dataset +def compute_random_lens(full_len: int, range_ratio: float, num: int): + return np.random.randint( + max(int(full_len * range_ratio), 1), + full_len + 1, + size=num, + ) + + def sample_random_requests( input_len: int, output_len: int, @@ -462,17 +621,17 @@ def sample_random_requests( range_ratio: float, tokenizer: PreTrainedTokenizerBase, dataset_path: str, -) -> List[Tuple[str, int, int]]: +) -> List[DatasetRow]: - input_lens = np.random.randint( - max(int(input_len * range_ratio), 1), - input_len + 1, - size=num_prompts, + input_lens = compute_random_lens( + full_len=input_len, + range_ratio=range_ratio, + num=num_prompts, ) - output_lens = np.random.randint( - int(output_len * range_ratio), - output_len + 1, - size=num_prompts, + output_lens = compute_random_lens( + full_len=output_len, + range_ratio=range_ratio, + num=num_prompts, ) # Sample token ids from ShareGPT and repeat/truncate them to @@ -496,7 +655,7 @@ def sample_random_requests( random.shuffle(dataset) # Filter out sequences that are too long or too short - input_requests: List[Tuple[str, int, int]] = [] + input_requests: List[DatasetRow] = [] origin_output_lens: List[int] = [] for i in range(num_prompts): # Tokenize the prompts and completions. @@ -513,17 +672,222 @@ def sample_random_requests( ratio = (input_lens[i] + prompt_len - 1) // prompt_len input_ids = (prompt_token_ids * ratio)[:input_lens[i]] prompt = tokenizer.decode(input_ids) - input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + input_requests.append(DatasetRow( + prompt=prompt, + prompt_len=int(input_lens[i]), + output_len=int(output_lens[i]), + )) - print(f'#Input tokens: {np.sum([x[1] for x in input_requests])}') - print(f'#Output tokens: {np.sum([x[2] for x in input_requests])}') + print(f'#Input tokens: {sum(x.prompt_len for x in input_requests)}') + print(f'#Output tokens: {sum(x.output_len for x in input_requests)}') return input_requests +def parse_image_resolution(image_resolution: str) -> Tuple[int, int]: + """Parse image resolution into (width, height). + + Supports presets '1080p', '720p', '360p'. And custom 'heightxwidth' format (e.g., '1080x1920' means height=1080, + width=1920) will be parsed into (width, height). + """ + resolution_to_size = { + '4k': (3840, 2160), + '1080p': (1920, 1080), + '720p': (1280, 720), + '360p': (640, 360), + } + if image_resolution in resolution_to_size: + return resolution_to_size[image_resolution] + + res = image_resolution.strip().lower() + if 'x' in res: + parts = res.split('x') + if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit(): + height = int(parts[0]) + width = int(parts[1]) + if height > 0 and width > 0: + return (width, height) + + raise ValueError(f'Unsupported image resolution: {image_resolution}. ' + "Choose from 4k, 1080p, 720p, 360p, or provide custom 'heightxwidth' (e.g., 1080x1920).") + + +def gen_mm_prompt(tokenizer, image_pad_id, token_num): + """Generate a random prompt of specified token length using tokenizer + vocabulary.""" + all_available_tokens = list(tokenizer.get_vocab().values()) + if image_pad_id: + all_available_tokens.remove(image_pad_id) + selected_tokens = random.choices(all_available_tokens, k=token_num) + return tokenizer.decode(selected_tokens) + + +def create_mm_data_row(text_prompt, images: list, images_base64, output_len, processor, backend): + try: + content_items = [{'type': 'image', 'image': {'url': image_base64}} for image_base64 in images_base64] + content_items.append({'type': 'text', 'text': text_prompt}) + prompt_str = processor.apply_chat_template( + [{ + 'role': 'user', + 'content': content_items + }], + add_generation_prompt=True, + tokenize=False, + ) + except Exception as e: + # Note (Xinyuan): This is a workaround for an issue where some tokenizers + # do not support content as a list. (e.g. InternVL) + print(f'Error applying chat template: {e}, fallback to tag') + # Some tokenizers do not support list content; fall back to a placeholder in the text + prompt_str = f'{text_prompt}' + + # Calculate total tokens (text + vision) + prompt_len = processor( + text=[prompt_str], + images=images, + padding=False, + return_tensors='pt', + )['input_ids'].numel() + + # Calculate text-only tokens + try: + # Create text-only version of the prompt + text_only_prompt = processor.apply_chat_template( + [{ + 'role': 'user', + 'content': text_prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + text_prompt_len = processor( + text=[text_only_prompt], + padding=False, + return_tensors='pt', + )['input_ids'].numel() + except Exception: + # Fallback: just tokenize the text prompt directly + tokenizer_to_use = (processor.tokenizer if hasattr(processor, 'tokenizer') else processor) + text_prompt_len = len(tokenizer_to_use.encode(text_prompt)) + + # Vision tokens = total tokens - text tokens + vision_prompt_len = prompt_len - text_prompt_len + + use_raw_prompt = backend in [ + 'sglang', + 'sglang-oai', + 'sglang-oai-chat', + 'vllm', + 'vllm-chat', + 'lmdeploy', + 'lmdeploy-chat', + ] + return DatasetRow( + prompt=text_prompt if use_raw_prompt else prompt_str, + prompt_len=prompt_len, + output_len=output_len, + text_prompt_len=text_prompt_len, + vision_prompt_len=vision_prompt_len, + image_data=images_base64, + ) + + +def sample_image_requests( + num_requests: int, + image_count: int, + input_len: int, + output_len: int, + range_ratio: float, + processor: AutoProcessor, + image_content: str, + image_format: str, + image_resolution: str, + backend: str, +) -> List[DatasetRow]: + """Generate requests with images. + + - Each request includes ``image_count`` images. + - Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360), + or custom 'heightxwidth' (e.g., 1080x1920). + - Text lengths follow the 'random' dataset sampling rule. ``prompt_len`` + only counts text tokens and excludes image data. + """ + + # Parse resolution (supports presets and 'heightxwidth') + width, height = parse_image_resolution(image_resolution) + + # Check for potentially problematic combinations and warn user + if width * height >= 1920 * 1080 and image_count * num_requests >= 100: + warnings.warn( + f'High resolution ({width}x{height}) with {image_count * num_requests} total images ' + f'may take a long time. Consider reducing resolution or image count.', + UserWarning, + stacklevel=2, + ) + + # Sample text lengths + input_lens = compute_random_lens( + full_len=input_len, + range_ratio=range_ratio, + num=num_requests, + ) + output_lens = compute_random_lens( + full_len=output_len, + range_ratio=range_ratio, + num=num_requests, + ) + + def _gen_random_image_data_uri(width: int = width, height: int = height) -> Tuple[Image.Image, str, int]: + if image_content == 'blank': + # Generate blank white image + arr = np.full((height, width, 3), 255, dtype=np.uint8) + else: + # Generate random colored image + arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8) + img = Image.fromarray(arr) + buf = io.BytesIO() + img.save(buf, format=image_format, quality=85) + encoded = pybase64.b64encode(buf.getvalue()).decode('utf-8') + image_data = f'data:image/{image_format};base64,{encoded}' + image_bytes = len(image_data.encode('utf-8')) + return img, image_data, image_bytes + + dataset: List[DatasetRow] = [] + total_image_bytes = 0 + for i in range(num_requests): + # Generate text prompt + text_prompt = gen_mm_prompt( + processor.tokenizer, + processor.image_token_id if hasattr(processor, 'image_token_id') else None, + int(input_lens[i]), + ) + + # Generate image list + images, images_base64, images_bytes = zip(*[_gen_random_image_data_uri() for _ in range(image_count)]) + total_image_bytes += sum(list(images_bytes)) + + data_row = create_mm_data_row( + text_prompt, + list(images), + list(images_base64), + int(output_lens[i]), + processor, + backend, + ) + + dataset.append(data_row) + avg_image_bytes = total_image_bytes // num_requests if num_requests > 0 else 0 + + print(f'#Input tokens: {np.sum([x.prompt_len for x in dataset])}') + print(f'#Output tokens: {np.sum([x.output_len for x in dataset])}') + print(f'\nCreated {len(dataset)} {image_content} {image_format} images \ + with average {avg_image_bytes} bytes per request') + return dataset + + async def get_request( - input_requests: List[Tuple[str, int, int]], + input_requests: List[DatasetRow], request_rate: float, -) -> AsyncGenerator[Tuple[str, int, int], None]: +) -> AsyncGenerator[DatasetRow, None]: input_requests = iter(input_requests) for request in input_requests: yield request @@ -539,7 +903,7 @@ async def get_request( def calculate_metrics( - input_requests: List[Tuple[str, int, int]], + input_requests: List[DatasetRow], outputs: List[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, @@ -548,18 +912,23 @@ def calculate_metrics( output_lens: List[int] = [] retokenized_output_lens: List[int] = [] total_input = 0 + total_input_text = 0 + total_input_vision = 0 completed = 0 itls: List[float] = [] tpots: List[float] = [] ttfts: List[float] = [] e2e_latencies: List[float] = [] + for i in range(len(outputs)): if outputs[i].success: output_len = outputs[i].output_len output_lens.append(output_len) retokenized_output_len = len(tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)) retokenized_output_lens.append(retokenized_output_len) - total_input += input_requests[i][1] + total_input += input_requests[i].prompt_len + total_input_text += input_requests[i].text_prompt_len + total_input_vision += input_requests[i].vision_prompt_len if output_len > 1: tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) itls += outputs[i].itl @@ -581,6 +950,8 @@ def calculate_metrics( metrics = BenchmarkMetrics( completed=completed, total_input=total_input, + total_input_text=total_input_text, + total_input_vision=total_input_vision, total_output=sum(output_lens), total_output_retokenized=sum(retokenized_output_lens), request_throughput=completed / dur_s, @@ -611,7 +982,7 @@ async def benchmark( api_url: str, model_id: str, tokenizer: PreTrainedTokenizerBase, - input_requests: List[Tuple[str, int, int]], + input_requests: List[DatasetRow], request_rate: float, disable_tqdm: bool, extra_request_body: Dict[str, Any], @@ -624,14 +995,15 @@ async def benchmark( if not args.disable_warmup: print('Starting initial single prompt test run...') start_warmup = time.perf_counter() - test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_request = input_requests[0] test_input = RequestFuncInput( model=model_id, - prompt=test_prompt, + prompt=test_request.prompt, api_url=api_url, - prompt_len=test_prompt_len, - output_len=test_output_len, + prompt_len=test_request.prompt_len, + output_len=test_request.output_len, extra_request_body=extra_request_body, + image_data=test_request.image_data, ) test_output = await request_func(request_func_input=test_input) if not test_output.success: @@ -648,13 +1020,13 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request request_func_input = RequestFuncInput( model=model_id, - prompt=prompt, + prompt=request.prompt, api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, + prompt_len=request.prompt_len, + output_len=request.output_len, + image_data=request.image_data, extra_request_body=extra_request_body, ) tasks.append(asyncio.create_task(request_func(request_func_input=request_func_input, pbar=pbar))) @@ -679,6 +1051,8 @@ async def benchmark( print('{:<40} {:<10}'.format('Successful requests:', metrics.completed)) print('{:<40} {:<10.2f}'.format('Benchmark duration (s):', benchmark_duration)) print('{:<40} {:<10}'.format('Total input tokens:', metrics.total_input)) + print('{:<40} {:<10}'.format('Total input text tokens:', metrics.total_input_text)) + print('{:<40} {:<10}'.format('Total input vision tokens:', metrics.total_input_vision)) print('{:<40} {:<10}'.format('Total generated tokens:', metrics.total_output)) print('{:<40} {:<10}'.format('Total generated tokens (retokenized):', metrics.total_output_retokenized)) print('{:<40} {:<10.2f}'.format('Request throughput (req/s):', metrics.request_throughput)) @@ -820,8 +1194,11 @@ def run_benchmark(args_: argparse.Namespace): 'sglang': 30000, 'sglang-native': 30000, 'sglang-oai': 30000, + 'sglang-oai-chat': 30000, 'lmdeploy': 23333, + 'lmdeploy-chat': 23333, 'vllm': 8000, + 'vllm-chat': 8000, 'trt': 8000, 'gserver': 9988, }.get(args.backend, 30000) @@ -833,6 +1210,9 @@ def run_benchmark(args_: argparse.Namespace): elif args.backend in ['sglang-oai', 'vllm', 'lmdeploy']: api_url = (f'{args.base_url}/v1/completions' if args.base_url else f'http://{args.host}:{args.port}/v1/completions') + elif args.backend in ['lmdeploy-chat', 'vllm-chat', 'sglang-oai-chat']: + api_url = (f'{args.base_url}/v1/chat/completions' + if args.base_url else f'http://{args.host}:{args.port}/v1/chat/completions') elif args.backend == 'trt': api_url = ( f'{args.base_url}/v2/models/ensemble/generate_stream' @@ -898,6 +1278,20 @@ def run_benchmark(args_: argparse.Namespace): tokenizer=tokenizer, dataset_path=args.dataset_path, ) + elif args.dataset_name == 'image': + processor = get_processor(model_id) + input_requests = sample_image_requests( + num_requests=args.num_prompts, + image_count=args.image_count, + input_len=args.random_input_len, + output_len=args.random_output_len, + range_ratio=args.random_range_ratio, + processor=processor, + image_content=args.image_content, + image_format=args.image_format, + image_resolution=args.image_resolution, + backend=args.backend, + ) else: raise ValueError(f'Unknown dataset: {args.dataset_name}') @@ -969,7 +1363,7 @@ def set_ulimit(target_soft_limit=65535): '--dataset-name', type=str, default='sharegpt', - choices=['sharegpt', 'random'], + choices=['sharegpt', 'random', 'image'], help='Name of the dataset to benchmark on.', ) parser.add_argument('--dataset-path', type=str, default='', help='Path to the dataset.') @@ -1017,6 +1411,34 @@ def set_ulimit(target_soft_limit=65535): help='Range of sampled ratio of input/output length, ' 'used only for random dataset.', ) + # image dataset args + parser.add_argument( + '--image-count', + type=int, + default=1, + help='Number of images per request (only available with the image dataset)', + ) + parser.add_argument( + '--image-resolution', + type=str, + default='1080p', + help=('Resolution of images for image dataset. ' + "Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)."), + ) + parser.add_argument( + '--image-format', + type=str, + default='jpeg', + help=('Format of images for image dataset. ' + 'Supports jpeg and png.'), + ) + parser.add_argument( + '--image-content', + type=str, + default='random', + help=('Content for images for image dataset. ' + 'Supports random and blank.'), + ) parser.add_argument( '--request-rate', type=float, diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 2db185b6ad..6c73d92068 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -97,13 +97,12 @@ def _get_extra_outputs(self, resp: Response): routed_experts = resp.data.get('routed_experts', None) if resp.data else None if routed_experts is not None and resp.type in [ResponseType.FINISH, ResponseType.CANCEL]: if self._enable_transfer_obj_ref: - import base64 - + import pybase64 import ray ref = ray.put(routed_experts) data = ray.cloudpickle.dumps(ref) - outputs['routed_experts'] = base64.b64encode(data).decode('utf-8') + outputs['routed_experts'] = pybase64.b64encode(data).decode('utf-8') else: outputs['routed_experts'] = routed_experts return outputs diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index d4e4fae1bd..bea404ff6d 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -import base64 import functools import time from contextlib import contextmanager @@ -10,6 +9,7 @@ from typing import Any, Dict, List, Optional import numpy as np +import pybase64 import torch import torch.distributed as dist from torch.profiler import ProfilerActivity, profile, record_function @@ -1128,7 +1128,7 @@ def _construct(item): if isinstance(serialized_data, list): serialized_data = serialized_data[self.dist_ctx.tp_group.rank] model = self.patched_model.get_model() - weights = ForkingPickler.loads(base64.b64decode(serialized_data)) + weights = ForkingPickler.loads(pybase64.b64decode(serialized_data)) if request.load_format == 'flattened_bucket': metadata: List[FlattenedTensorMetadata] = weights['metadata'] if metadata: diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 6844c6b8d0..60c3617ffe 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +from functools import lru_cache from typing import Any, Dict, Iterable, List, Optional, Tuple +import numpy as np import torch from torch import nn from transformers.configuration_utils import PretrainedConfig @@ -326,102 +328,104 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: for _ in range(len(config.deepstack_visual_indexes)) ]) + @staticmethod + @lru_cache(maxsize=1024) + def rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + h_div = h // spatial_merge_size + w_div = w // spatial_merge_size + + hpos_ids = np.broadcast_to(np.arange(h).reshape(h, 1), (h, w)) + hpos_ids = hpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + hpos_ids = hpos_ids.transpose(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = np.broadcast_to(np.arange(w).reshape(1, w), (h, w)) + wpos_ids = wpos_ids.reshape( + h_div, + spatial_merge_size, + w_div, + spatial_merge_size, + ) + wpos_ids = wpos_ids.transpose(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + + return torch.from_numpy(np.stack([hpos_ids, wpos_ids], axis=-1)) + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: - merge_size = self.spatial_merge_size - - max_hw = int(grid_thw[:, 1:].max().item()) - freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) - device = freq_table.device - - total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) - pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) - - offset = 0 - for num_frames, height, width in grid_thw: - merged_h, merged_w = height // merge_size, width // merge_size - - block_rows = torch.arange(merged_h, device=device) # block row indices - block_cols = torch.arange(merged_w, device=device) # block col indices - intra_row = torch.arange(merge_size, device=device) # intra-block row offsets - intra_col = torch.arange(merge_size, device=device) # intra-block col offsets - - # Compute full-resolution positions - row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] - col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] - - row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) - - coords = torch.stack((row_idx, col_idx), dim=-1) - - if num_frames > 1: - coords = coords.repeat(num_frames, 1) - - num_tokens = coords.shape[0] - pos_ids[offset:offset + num_tokens] = coords - offset += num_tokens - - embeddings = freq_table[pos_ids] # lookup rotary embeddings - embeddings = embeddings.flatten(1) - return embeddings - - def fast_pos_embed_interpolate(self, grid_thw): - grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] - - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] - - for t, h, w in zip(grid_ts, grid_hs, grid_ws): - h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) - - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side - - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] - - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), - ] - - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) - - idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) - weight_tensor = torch.tensor(weight_list, - dtype=self.pos_embed.weight.dtype, - device=self.pos_embed.weight.device) - pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] - patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] - - patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) - - patch_pos_embeds_permute = [] - merge_size = self.config.spatial_merge_size - for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): - pos_embed = pos_embed.repeat(t, 1) - pos_embed = (pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, - -1).permute(0, 1, 3, 2, 4, 5).flatten(0, 4)) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + """Rotary position embedding.""" + pos_ids = [] + + for t, h, w in grid_thw: + base = self.rot_pos_ids(int(h), int(w), self.spatial_merge_size) + pos_ids.append(base if t == 1 else base.repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + + return rotary_pos_emb + + # copy from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen3_vl.py#L474 + def fast_pos_embed_interpolate(self, grid_thw: List[List[int]]) -> torch.Tensor: + num_grid_per_side = self.num_grid_per_side + m_size = self.spatial_merge_size + hidden_dim = self.pos_embed.embedding_dim + device = self.pos_embed.weight.device + + outputs = [] + for t, h, w in grid_thw: + h_idxs = torch.linspace(0, num_grid_per_side - 1, h, dtype=torch.float32, device=device) + w_idxs = torch.linspace(0, num_grid_per_side - 1, w, dtype=torch.float32, device=device) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij') + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing='ij') + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing='ij') + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - w01 + + h_grid = torch.stack([h_floor_grid, h_floor_grid, h_ceil_grid, h_ceil_grid]) + w_grid = torch.stack([w_floor_grid, w_ceil_grid, w_floor_grid, w_ceil_grid]) + h_grid_idx = h_grid * num_grid_per_side + + indices = (h_grid_idx + w_grid).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.pos_embed.weight.dtype, device=device) + + embeds = self.pos_embed(indices) + embeds *= weights + combined = embeds.sum(dim=0) + + combined = combined.reshape(h // m_size, m_size, w // m_size, m_size, hidden_dim) + combined = combined.permute(0, 2, 1, 3, 4).reshape(1, -1, hidden_dim) + repeated = combined.expand(t, -1, -1).reshape(-1, hidden_dim) + outputs.append(repeated) + + return torch.cat(outputs, dim=0) def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, pos_embeds: torch.Tensor) -> torch.Tensor: diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 8ce9373258..9b819c5271 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -import base64 import copy import json import math @@ -16,6 +15,7 @@ from typing import Any, Dict, List, Optional import numpy as np +import pybase64 import torch import yaml from torch.nn.utils.rnn import pad_sequence @@ -328,7 +328,7 @@ def _construct(item): with torch.cuda.device(self.devices[0]): if isinstance(request.serialized_named_tensors, str): - weights = ForkingPickler.loads(base64.b64decode(request.serialized_named_tensors)) + weights = ForkingPickler.loads(pybase64.b64decode(request.serialized_named_tensors)) weights = {k: _construct(v) for k, v in weights} else: weights = request.serialized_named_tensors diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py index 321e1c700d..779289627d 100644 --- a/lmdeploy/utils.py +++ b/lmdeploy/utils.py @@ -455,10 +455,10 @@ def serialize_state_dict(state_dict: dict) -> str: Returns: str: serialized state dict. """ - import base64 from io import BytesIO from multiprocessing.reduction import ForkingPickler + import pybase64 from torch.multiprocessing.reductions import reduce_tensor # flattened_tensor @@ -472,7 +472,7 @@ def serialize_state_dict(state_dict: dict) -> str: buf = BytesIO() ForkingPickler(buf).dump(data) buf.seek(0) - return base64.b64encode(buf.read()).decode('utf-8') + return pybase64.b64encode(buf.read()).decode('utf-8') def is_dlblas_installed(): diff --git a/lmdeploy/vl/utils.py b/lmdeploy/vl/utils.py index d933a2611a..d66a75d5d5 100644 --- a/lmdeploy/vl/utils.py +++ b/lmdeploy/vl/utils.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -import base64 import os from io import BytesIO from typing import Union +import pybase64 import requests from PIL import Image, ImageFile @@ -40,13 +40,13 @@ def encode_image_base64(image: Union[str, Image.Image]) -> str: # use dummy image image = Image.new('RGB', (32, 32)) image.save(buffered, format='PNG') - res = base64.b64encode(buffered.getvalue()).decode('utf-8') + res = pybase64.b64encode(buffered.getvalue()).decode('utf-8') return res def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: """Load image from base64 format.""" - return Image.open(BytesIO(base64.b64decode(image))) + return Image.open(BytesIO(pybase64.b64decode(image))) def load_image(image_url: Union[str, Image.Image]) -> Image.Image: diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index 927e6f7c3d..a54f6ce03b 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -12,6 +12,7 @@ partial_json_parser peft<=0.11.1 pillow protobuf +pybase64 pydantic>2.0.0 pyzmq ray diff --git a/requirements/runtime_camb.txt b/requirements/runtime_camb.txt index 5b37b003c0..f150e57a0e 100644 --- a/requirements/runtime_camb.txt +++ b/requirements/runtime_camb.txt @@ -10,6 +10,7 @@ partial_json_parser peft<=0.11.1 pillow protobuf +pybase64 pydantic>2.0.0 pyzmq safetensors diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index bba7fa4068..03376695c0 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -13,6 +13,7 @@ peft<=0.14.0 pillow prometheus_client protobuf +pybase64 pydantic>2.0.0 pyzmq ray diff --git a/requirements/runtime_maca.txt b/requirements/runtime_maca.txt index 70202d5ce5..263e268f42 100644 --- a/requirements/runtime_maca.txt +++ b/requirements/runtime_maca.txt @@ -10,6 +10,7 @@ partial_json_parser peft<=0.11.1 pillow protobuf +pybase64 pydantic>2.0.0 pyzmq safetensors diff --git a/requirements/runtime_rocm.txt b/requirements/runtime_rocm.txt index cf8091d251..b9b3f7b0dd 100644 --- a/requirements/runtime_rocm.txt +++ b/requirements/runtime_rocm.txt @@ -11,6 +11,7 @@ partial_json_parser peft<=0.14.0 pillow protobuf +pybase64 pydantic>2.0.0 pyzmq ray