1- import csv
21import logging
32import os
43
1615
1716
1817def _infer_shape (f ):
19- num_lines , vector_dim = 0 , None
18+ num_lines = 0
2019 for line in f :
21- if vector_dim is None :
22- row = line .rstrip ().split (b" " )
23- vector = row [1 :]
24- # Assuming word, [vector] format
25- if len (vector ) > 2 :
26- # The header present in some (w2v) formats contains two elements.
27- vector_dim = len (vector )
28- num_lines += 1 # First element read
29- else :
30- num_lines += 1
20+ num_lines += 1
3121 f .seek (0 )
32- return num_lines , vector_dim
22+ return num_lines
3323
3424
35- def _load_token_and_vectors_from_file (file_path ):
25+ def _load_token_and_vectors_from_file (file_path , delimiter = " " ):
3626 stoi , tokens , vectors , dup_tokens = {}, [], [], []
27+ dim = None
3728 with open (file_path , "rb" ) as f :
38- num_lines , _ = _infer_shape (f )
29+ num_lines = _infer_shape (f )
3930 for line in tqdm (f , unit_scale = 0 , unit = "lines" , total = num_lines ):
40- token , entries = line .rstrip ().split (b" " , 1 )
41- vector = torch .tensor ([float (c ) for c in entries .split (b" " )], dtype = torch .float )
31+ # token and entries are seperated by delimeter
32+ token , entries = line .rstrip ().split (bytes (delimiter , "utf-8" ), 1 )
33+ # we assume entries are always seperated by " "
34+ entries = entries .split (b" " )
35+
36+ if dim is None and len (entries ) > 1 :
37+ dim = len (entries )
38+ elif len (entries ) == 1 :
39+ logger .warning ("Skipping token {} with 1-dimensional "
40+ "vector {}; likely a header" .format (token , entries ))
41+ continue
42+ elif dim != len (entries ):
43+ raise RuntimeError (
44+ "Vector for token {} has {} dimensions, but previously "
45+ "read vectors have {} dimensions. All vectors must have "
46+ "the same number of dimensions." .format (token , len (entries ),
47+ dim ))
48+
49+ vector = torch .tensor ([float (c ) for c in entries ], dtype = torch .float )
4250 try :
4351 if isinstance (token , bytes ):
4452 token = token .decode ("utf-8" )
@@ -47,7 +55,7 @@ def _load_token_and_vectors_from_file(file_path):
4755 continue
4856
4957 if token in stoi :
50- dup_tokens .append (token , len (vectors ) + 1 )
58+ dup_tokens .append (( token , len (vectors ) + 1 ) )
5159 continue
5260
5361 stoi [token ] = len (vectors )
@@ -131,11 +139,11 @@ def GloVe(name="840B", dim=300, unk_tensor=None, root=".data", validate_file=Tru
131139 ValueError: if unexpected duplicate tokens are found in GloVe file.
132140
133141 """
134- dup_token_glove_840b = ("����������������������������������������������������������������������"
135- "����������������������������������������������������������������������"
136- "����������������������������������������������������������������������"
137- "����������������������������������������������������������������������"
138- "������������������������������������������������������" , 140649 )
142+ dup_token_glove_840b = [ ("����������������������������������������������������������������������"
143+ "����������������������������������������������������������������������"
144+ "����������������������������������������������������������������������"
145+ "����������������������������������������������������������������������"
146+ "������������������������������������������������������" , 140649 )]
139147 urls = {
140148 "42B" : "https://nlp.stanford.edu/data/glove.42B.300d.zip" ,
141149 "840B" : "https://nlp.stanford.edu/data/glove.840B.300d.zip" ,
@@ -176,43 +184,40 @@ def GloVe(name="840B", dim=300, unk_tensor=None, root=".data", validate_file=Tru
176184 tokens , vectors , dup_tokens = _load_token_and_vectors_from_file (extracted_file_path_with_correct_dim )
177185
178186 # Ensure there is only 1 expected duplicate token present for 840B dataset
179- if dup_tokens :
180- if not (len (dup_tokens ) == 1 and dup_tokens [0 ] == dup_token_glove_840b [0 ] and
181- dup_tokens [1 ] == dup_token_glove_840b [1 ]):
182- raise ValueError ("Found duplicate tokens in file: {}" .format (str (dup_tokens )))
187+ if dup_tokens and dup_tokens != dup_token_glove_840b :
188+ raise ValueError ("Found duplicate tokens in file: {}" .format (str (dup_tokens )))
183189
184190 vectors_obj = Vectors (tokens , vectors , unk_tensor = unk_tensor )
185191 torch .save (vectors_obj , cached_vectors_file_path )
186192 return vectors_obj
187193
188194
189- def vectors_from_file_object (file_like_object , unk_tensor = None ):
195+ def vectors_from_file_object (file_like_object , delimiter = "," , unk_tensor = None ):
190196 r"""Create a Vectors object from a csv file like object.
191197
192198 Note that the tensor corresponding to each vector is of type `torch.float`.
193199
194200 Format for csv file:
195- token1, num1 num2 num3
196- token2, num4 num5 num6
201+ token1<delimiter> num1 num2 num3
202+ token2<delimiter> num4 num5 num6
197203 ...
198- token_n, num_m num_j num_k
204+ token_n<delimiter> num_m num_j num_k
199205
200206 Args:
201207 file_like_object (FileObject): a file like object to read data from.
208+ delimiter (char): a character to delimit between the token and the vector. Default value is ","
202209 unk_tensor (Tensor): a 1d tensor representing the vector associated with an unknown token.
203210
204211 Returns:
205212 Vectors: a Vectors object.
206213
207- """
208- readCSV = csv .reader (file_like_object , delimiter = "," )
209-
210- tokens = []
211- vectors = []
212- for row in readCSV :
213- tokens .append (row [0 ])
214- vectors .append (torch .tensor ([float (c ) for c in row [1 ].split ()], dtype = torch .float ))
214+ Raises:
215+ ValueError: if duplicate tokens are found in FastText file.
215216
217+ """
218+ tokens , vectors , dup_tokens = _load_token_and_vectors_from_file (file_like_object .name , delimiter = delimiter )
219+ if dup_tokens :
220+ raise ValueError ("Found duplicate tokens in file: {}" .format (str (dup_tokens )))
216221 return Vectors (tokens , vectors , unk_tensor = unk_tensor )
217222
218223
0 commit comments