Skip to content

Commit 3fd765e

Browse files
committed
Add benchmarking script for generation utils
1 parent 79af4ae commit 3fd765e

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import time
2+
from functools import partial
3+
4+
from torch.utils.data import DataLoader
5+
from torcheval.metrics.functional import word_error_rate
6+
from torchtext.data.metrics import bleu_score
7+
from torchtext.datasets import CNNDM
8+
from torchtext.datasets import Multi30k
9+
from torchtext.models import T5_BASE_GENERATION
10+
from torchtext.prototype.generate import GenerationUtils
11+
12+
multi_batch_size = 5
13+
language_pair = ("en", "de")
14+
multi_datapipe = Multi30k(split="test", language_pair=language_pair)
15+
task = "translate English to German"
16+
17+
18+
def apply_prefix(task, x):
19+
return f"{task}: " + x[0], x[1]
20+
21+
22+
multi_datapipe = multi_datapipe.map(partial(apply_prefix, task))
23+
multi_datapipe = multi_datapipe.batch(multi_batch_size)
24+
multi_datapipe = multi_datapipe.rows2columnar(["english", "german"])
25+
multi_dataloader = DataLoader(multi_datapipe, batch_size=None)
26+
27+
28+
def benchmark_beam_search_wer():
29+
model = T5_BASE_GENERATION.get_model()
30+
transform = T5_BASE_GENERATION.transform()
31+
32+
seq_generator = GenerationUtils(model)
33+
34+
batch = next(iter(multi_dataloader))
35+
input_text = batch["english"]
36+
target = batch["german"]
37+
beam_size = 4
38+
39+
model_input = transform(input_text)
40+
model_output = seq_generator.generate(model_input, num_beams=beam_size, vocab_size=model.config.vocab_size)
41+
output_text = transform.decode(model_output.tolist())
42+
43+
print(word_error_rate(output_text, target))
44+
45+
46+
if __name__ == "__main__":
47+
benchmark_beam_search_wer()

0 commit comments

Comments
 (0)