Skip to content

Commit 8fe8742

Browse files
authored
add generator offline throughput benchmark (#675) (#680)
1 parent 1926549 commit 8fe8742

File tree

5 files changed

+649
-0
lines changed

5 files changed

+649
-0
lines changed

benchmarks/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""TorchForge benchmarking utilities."""

benchmarks/generator/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Generator throughput benchmarking tools."""

benchmarks/generator/datasets.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Lightweight dataset utilities for generator throughput benchmarking.
9+
"""
10+
11+
import random
12+
import uuid
13+
from dataclasses import dataclass
14+
15+
from vllm import __version__ as vllm_version
16+
17+
18+
if vllm_version >= "0.13.0":
19+
from vllm.tokenizers import TokenizerLike as Tokenizer
20+
else:
21+
from vllm.transformers_utils.tokenizer import AnyTokenizer as Tokenizer
22+
23+
24+
@dataclass
25+
class BenchmarkRequest:
26+
"""
27+
Attributes:
28+
prompt: The text prompt to generate from
29+
prompt_len: Length of the prompt in tokens
30+
expected_output_len: Expected length of generated output in tokens
31+
request_id: Unique identifier for this request.
32+
"""
33+
34+
prompt: str
35+
prompt_len: int
36+
expected_output_len: int
37+
request_id: str
38+
39+
40+
class RandomDataset:
41+
"""Generates prompts with random token sequences of specified lengths.
42+
43+
Args:
44+
tokenizer: Tokenizer to use for encoding/decoding
45+
num_requests: Number of benchmark requests to generate
46+
input_len: Target input prompt length in tokens
47+
output_len: Target output generation length in tokens
48+
range_ratio: Variance ratio for input/output lengths (0.0-1.0).
49+
0.0 means fixed lengths, 0.2 means ±20% variance.
50+
"""
51+
52+
def __init__(
53+
self,
54+
tokenizer: Tokenizer,
55+
num_requests: int,
56+
input_len: int,
57+
output_len: int,
58+
range_ratio: float = 0.0,
59+
):
60+
self.tokenizer = tokenizer
61+
self.num_requests = num_requests
62+
self.input_len = input_len
63+
self.output_len = output_len
64+
self.range_ratio = range_ratio
65+
self.vocab_size = tokenizer.vocab_size
66+
67+
def _sample_length(self, target_len: int) -> int:
68+
"""Sample a length with variance based on range_ratio."""
69+
if self.range_ratio == 0.0:
70+
return target_len
71+
72+
min_len = int(target_len * (1 - self.range_ratio))
73+
max_len = int(target_len * (1 + self.range_ratio))
74+
return random.randint(min_len, max_len)
75+
76+
def generate(self) -> list[BenchmarkRequest]:
77+
"""Generate benchmark requests with random token sequences.
78+
79+
Returns:
80+
List of BenchmarkRequest objects with random prompts
81+
"""
82+
requests = []
83+
84+
for i in range(self.num_requests):
85+
# Sample lengths with variance
86+
prompt_len = self._sample_length(self.input_len)
87+
output_len = self._sample_length(self.output_len)
88+
89+
token_ids = [
90+
random.randint(0, self.vocab_size - 1) for _ in range(prompt_len)
91+
]
92+
prompt = self.tokenizer.decode(token_ids)
93+
94+
requests.append(
95+
BenchmarkRequest(
96+
prompt=prompt,
97+
prompt_len=prompt_len,
98+
expected_output_len=output_len,
99+
request_id=f"random-{i}-{uuid.uuid4().hex[:8]}",
100+
)
101+
)
102+
103+
return requests
104+
105+
106+
class FixedDataset:
107+
"""Repeat a fixed prompt for baseline testing.
108+
109+
Args:
110+
tokenizer: Tokenizer to use for encoding the prompt
111+
prompt: The fixed text prompt to repeat
112+
num_requests: Number of times to repeat the prompt
113+
output_len: Target output generation length in tokens
114+
"""
115+
116+
def __init__(
117+
self,
118+
tokenizer: Tokenizer,
119+
prompt: str,
120+
num_requests: int,
121+
output_len: int,
122+
):
123+
self.tokenizer = tokenizer
124+
self.prompt = prompt
125+
self.num_requests = num_requests
126+
self.output_len = output_len
127+
self.prompt_len = len(tokenizer.encode(prompt))
128+
129+
def generate(self) -> list[BenchmarkRequest]:
130+
"""Generate benchmark requests with the same fixed prompt.
131+
132+
Returns:
133+
List of BenchmarkRequest objects with the fixed prompt
134+
"""
135+
requests = []
136+
137+
for i in range(self.num_requests):
138+
requests.append(
139+
BenchmarkRequest(
140+
prompt=self.prompt,
141+
prompt_len=self.prompt_len,
142+
expected_output_len=self.output_len,
143+
request_id=f"fixed-{i}-{uuid.uuid4().hex[:8]}",
144+
)
145+
)
146+
147+
return requests

benchmarks/generator/metrics.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Metrics collection and reporting for generator throughput benchmarks.
8+
9+
Based on vLLM's throughput benchmark metrics patterns.
10+
Reference: vllm/benchmarks/throughput.py (lines 762-809)
11+
"""
12+
13+
import json
14+
from dataclasses import asdict, dataclass
15+
16+
from forge.data_models.completion import Completion
17+
18+
19+
@dataclass
20+
class ThroughputMetrics:
21+
"""Throughput benchmark metrics for offline inference.
22+
Reference: https://github.com/vllm-project/vllm/blob/main/vllm/benchmarks/throughput.py
23+
24+
Attributes:
25+
elapsed_time: Total wall-clock time in seconds
26+
num_requests: Total number of requests processed
27+
num_completions: Total number of completions (requests * n samples)
28+
total_prompt_tokens: Sum of all prompt tokens
29+
total_output_tokens: Sum of all generated output tokens
30+
total_tokens: Sum of prompt and output tokens
31+
requests_per_second: Request throughput (requests/sec)
32+
completions_per_second: Completion throughput (completions/sec)
33+
tokens_per_second: Total token throughput (tokens/sec)
34+
output_tokens_per_second: Output token throughput (output tokens/sec)
35+
model: Optional model name for reporting
36+
config: Optional benchmark configuration dict
37+
"""
38+
39+
elapsed_time: float
40+
num_requests: int
41+
num_completions: int
42+
total_prompt_tokens: int
43+
total_output_tokens: int
44+
total_tokens: int
45+
requests_per_second: float
46+
completions_per_second: float
47+
tokens_per_second: float
48+
output_tokens_per_second: float
49+
model: str | None = None
50+
config: dict | None = None
51+
52+
53+
def extract_token_counts(completions: list[list[Completion]]) -> tuple[int, int]:
54+
"""Extract token counts from generator completions.
55+
56+
Args:
57+
completions: List of completion lists from Generator.generate() calls.
58+
Each Generator.generate() call returns a list of Completion objects.
59+
60+
Returns:
61+
Tuple of (total_prompt_tokens, total_output_tokens)
62+
"""
63+
total_prompt_tokens = 0
64+
total_output_tokens = 0
65+
66+
for completion_list in completions:
67+
for completion in completion_list:
68+
# Completion has prompt_ids and token_ids as torch.Tensor
69+
# Shape: (seq_len,)
70+
total_prompt_tokens += completion.prompt_ids.shape[0]
71+
total_output_tokens += completion.token_ids.shape[0]
72+
73+
return total_prompt_tokens, total_output_tokens
74+
75+
76+
def calculate_metrics(
77+
completions: list[list[Completion]],
78+
elapsed_time: float,
79+
model: str | None = None,
80+
config: dict | None = None,
81+
) -> ThroughputMetrics:
82+
"""Calculate throughput metrics from completions and timing.
83+
84+
Args:
85+
completions: List of completion lists from Generator.generate() calls
86+
elapsed_time: Total time elapsed in seconds
87+
model: Optional model name
88+
config: Optional benchmark configuration
89+
90+
Returns:
91+
ThroughputMetrics object with calculated metrics
92+
"""
93+
num_requests = len(completions)
94+
num_completions = sum(len(completion_list) for completion_list in completions)
95+
total_prompt_tokens, total_output_tokens = extract_token_counts(completions)
96+
total_tokens = total_prompt_tokens + total_output_tokens
97+
98+
return ThroughputMetrics(
99+
elapsed_time=elapsed_time,
100+
num_requests=num_requests,
101+
num_completions=num_completions,
102+
total_prompt_tokens=total_prompt_tokens,
103+
total_output_tokens=total_output_tokens,
104+
total_tokens=total_tokens,
105+
requests_per_second=num_requests / elapsed_time if elapsed_time > 0 else 0.0,
106+
completions_per_second=(
107+
num_completions / elapsed_time if elapsed_time > 0 else 0.0
108+
),
109+
tokens_per_second=total_tokens / elapsed_time if elapsed_time > 0 else 0.0,
110+
output_tokens_per_second=(
111+
total_output_tokens / elapsed_time if elapsed_time > 0 else 0.0
112+
),
113+
model=model,
114+
config=config,
115+
)
116+
117+
118+
def print_metrics(metrics: ThroughputMetrics) -> None:
119+
"""Print metrics to console in a formatted table.
120+
121+
Args:
122+
metrics: ThroughputMetrics to print
123+
"""
124+
print("=" * 55)
125+
print("Throughput Benchmark Results".center(55))
126+
print("=" * 55)
127+
128+
if metrics.model:
129+
print(f"Model: {metrics.model}")
130+
131+
# Calculate samples per request
132+
samples_per_request = (
133+
metrics.num_completions / metrics.num_requests
134+
if metrics.num_requests > 0
135+
else 0
136+
)
137+
138+
print(f"Requests: {metrics.num_requests}")
139+
print(
140+
f"Completions: {metrics.num_completions} ({samples_per_request:.1f} per request)"
141+
)
142+
print(f"Elapsed Time: {metrics.elapsed_time:.2f} seconds")
143+
print("-" * 55)
144+
print(f"Total Prompt Tokens: {metrics.total_prompt_tokens}")
145+
print(f"Total Output Tokens: {metrics.total_output_tokens}")
146+
print(f"Total Tokens: {metrics.total_tokens}")
147+
print("-" * 55)
148+
print("Throughput:")
149+
print(f" Requests/sec: {metrics.requests_per_second:.2f}")
150+
print(f" Completions/sec: {metrics.completions_per_second:.2f}")
151+
print(f" Total Tokens/sec: {metrics.tokens_per_second:.2f}")
152+
print(f" Output Tokens/sec: {metrics.output_tokens_per_second:.2f}")
153+
print("=" * 55)
154+
155+
156+
def save_metrics_json(metrics: ThroughputMetrics, output_path: str) -> None:
157+
"""Save metrics to JSON file.
158+
159+
Args:
160+
metrics: ThroughputMetrics to save
161+
output_path: Path to output JSON file
162+
"""
163+
metrics_dict = asdict(metrics)
164+
165+
with open(output_path, "w") as f:
166+
json.dump(metrics_dict, f, indent=2)
167+
168+
print(f"\nMetrics saved to: {output_path}")

0 commit comments

Comments
 (0)