This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 812
Expand file tree
/
Copy pathdataset.py
More file actions
120 lines (94 loc) · 4.23 KB
/
dataset.py
File metadata and controls
120 lines (94 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import itertools
import re
import torch
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.experimental.datasets.translation import DATASETS, TranslationDataset
from torchtext.experimental.functional import sequential_transforms, vocab_func
from torchtext.vocab import build_vocab_from_iterator
def build_word_vocab(data, transforms, index, init_token="<w>", eos_token="</w>"):
tok_list = [[init_token], [eos_token]]
for line in data:
tok_list.append(transforms(line[index]))
return build_vocab_from_iterator(tok_list)
def build_char_vocab(
data, transforms, index, init_word_token="<w>", eos_word_token="</w>", init_sent_token="<s>", eos_sent_token="</s>",
):
tok_list = [
[init_word_token],
[eos_word_token],
[init_sent_token],
[eos_sent_token],
]
for line in data:
tokens = list(itertools.chain.from_iterable(transforms(line[index])))
tok_list.append(tokens)
return build_vocab_from_iterator(tok_list)
def char_vocab_func(vocab):
def func(tok_iter):
return [[vocab[char] for char in word] for word in tok_iter]
return func
def special_char_tokens_func(
init_word_token="<w>", eos_word_token="</w>", init_sent_token="<s>", eos_sent_token="</s>",
):
def func(tok_iter):
result = [[init_word_token, init_sent_token, eos_word_token]]
result += [[init_word_token] + word + [eos_word_token] for word in tok_iter]
result += [[init_word_token, eos_sent_token, eos_word_token]]
return result
return func
def special_word_token_func(init_word_token="<w>", eos_word_token="</w>"):
def func(tok_iter):
return [init_word_token] + tok_iter + [eos_word_token]
return func
def parallel_transforms(*transforms):
def func(txt_input):
result = []
for transform in transforms:
result.append(transform(txt_input))
return tuple(result)
return func
def get_dataset():
# Get the raw dataset first. This will give us the text
# version of the dataset
train, test, val = DATASETS["Multi30k"]()
# Cache training data for vocabulary construction
train_data = [line for line in train]
val_data = [line for line in val]
test_data = [line for line in test]
# Setup word tokenizer
src_tokenizer = get_tokenizer("spacy", language="de_core_news_sm")
tgt_tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
# Setup char tokenizer
def char_tokenizer(words):
return [list(word) for word in words]
def remove_extra_whitespace(line):
return re.sub(" {2,}", " ", line)
src_char_transform = sequential_transforms(remove_extra_whitespace, src_tokenizer, char_tokenizer)
tgt_char_transform = sequential_transforms(remove_extra_whitespace, tgt_tokenizer, char_tokenizer)
tgt_word_transform = sequential_transforms(remove_extra_whitespace, tgt_tokenizer)
# Setup vocabularies (both words and chars)
src_char_vocab = build_char_vocab(train_data, src_char_transform, index=0)
tgt_char_vocab = build_char_vocab(train_data, tgt_char_transform, index=1)
tgt_word_vocab = build_word_vocab(train_data, tgt_word_transform, index=1)
# Building the dataset with character level tokenization
src_char_transform = sequential_transforms(
src_char_transform, special_char_tokens_func(), char_vocab_func(src_char_vocab)
)
tgt_char_transform = sequential_transforms(
tgt_char_transform, special_char_tokens_func(), char_vocab_func(tgt_char_vocab)
)
tgt_word_transform = sequential_transforms(
tgt_word_transform, special_word_token_func(), vocab_func(tgt_word_vocab)
)
tgt_transform = parallel_transforms(tgt_char_transform, tgt_word_transform)
train_dataset = TranslationDataset(
train_data, (src_char_vocab, tgt_char_vocab, tgt_word_vocab), (src_char_transform, tgt_transform)
)
val_dataset = TranslationDataset(
val_data, (src_char_vocab, tgt_char_vocab, tgt_word_vocab), (src_char_transform, tgt_transform)
)
test_dataset = TranslationDataset(
test_data, (src_char_vocab, tgt_char_vocab, tgt_word_vocab), (src_char_transform, tgt_transform)
)
return train_dataset, val_dataset, test_dataset