This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 812
Expand file tree
/
Copy pathbenchmark_generation_utils.py
More file actions
53 lines (39 loc) · 1.48 KB
/
benchmark_generation_utils.py
File metadata and controls
53 lines (39 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import time
from functools import partial
from torch.utils.data import DataLoader
from torcheval.metrics.functional import word_error_rate
from torchtext.datasets import Multi30k
from torchtext.models import T5_BASE_GENERATION, T5_3B_GENERATION
from torchtext.prototype.generate import GenerationUtils
multi_batch_size = 16
language_pair = ("en", "de")
multi_datapipe = Multi30k(split="test", language_pair=language_pair)
task = "translate English to German"
def apply_prefix(task, x):
return f"{task}: " + x[0], x[1]
multi_datapipe = multi_datapipe.map(partial(apply_prefix, task))
multi_datapipe = multi_datapipe.batch(multi_batch_size)
multi_datapipe = multi_datapipe.rows2columnar(["english", "german"])
multi_dataloader = DataLoader(multi_datapipe, batch_size=None)
def benchmark_beam_search_wer():
model = T5_BASE_GENERATION.get_model()
transform = T5_BASE_GENERATION.transform()
seq_generator = GenerationUtils(model)
batch = next(iter(multi_dataloader))
input_text = batch["english"]
target = batch["german"]
beam_size = 8
model_input = transform(input_text)
model_output = seq_generator.generate(
model_input,
num_beams=beam_size,
beam_threshold=1000,
vocab_size=model.config.vocab_size,
eos_score=-1.0,
eos_idx=1,
pad_idx=0,
)
output_text = transform.decode(model_output.tolist())
print(word_error_rate(output_text, target))
if __name__ == "__main__":
benchmark_beam_search_wer()