Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit bcb9104

Browse files
authored
Adding dunder methods to cpp Vectors (#852)
* Adding dunder methods to cpp Vectors * Fix length * Fixed pytest errors * Fixed stylecheck
1 parent b887daa commit bcb9104

3 files changed

Lines changed: 23 additions & 6 deletions

File tree

test/experimental/test_vectors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616

1717

1818
class TestVectors(TorchtextTestCase):
19+
def tearDown(self):
20+
super().tearDown()
21+
torch._C._jit_clear_class_registry()
22+
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
23+
1924
def test_empty_vectors(self):
2025
tokens = []
2126
vectors = []

torchtext/csrc/vectors.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,28 @@ struct Vectors : torch::CustomClassHolder {
4141
}
4242
}
4343

44-
torch::Tensor GetItem(const std::string &token) const {
44+
torch::Tensor __getitem__(const std::string &token) const {
4545
if (stovectors_.find(token) != stovectors_.end()) {
4646
return stovectors_.at(token);
4747
}
4848
return unk_tensor_;
4949
}
5050

51-
void AddItem(const std::string &token, const torch::Tensor &vector) {
51+
void __setitem__(const std::string &token, const torch::Tensor &vector) {
5252
stovectors_[token] = vector;
5353
}
54+
55+
int64_t __len__() { return stovectors_.size(); }
5456
};
5557

5658
// Registers our custom class with torch.
5759
static auto vectors =
5860
torch::class_<Vectors>("torchtext", "Vectors")
5961
.def(torch::init<std::vector<std::string>, std::vector<torch::Tensor>,
6062
torch::Tensor>())
61-
.def("GetItem", &Vectors::GetItem)
62-
.def("AddItem", &Vectors::AddItem)
63+
.def("__getitem__", &Vectors::__getitem__)
64+
.def("__setitem__", &Vectors::__setitem__)
65+
.def("__len__", &Vectors::__len__)
6366
.def_pickle(
6467
// __getstate__
6568
[](const c10::intrusive_ptr<Vectors> &self)

torchtext/experimental/vectors.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def __getitem__(self, token: str) -> Tensor:
256256
Returns:
257257
vector (Tensor): a tensor (the vector) corresponding to the associated token.
258258
"""
259-
return self.vectors.GetItem(token)
259+
return self.vectors[token]
260260

261261
@torch.jit.export
262262
def __setitem__(self, token: str, vector: Tensor):
@@ -271,7 +271,16 @@ def __setitem__(self, token: str, vector: Tensor):
271271
if vector.dtype != torch.float:
272272
raise TypeError("`vector` should be of data type `torch.float` but it's of type " + vector.dtype)
273273

274-
self.vectors.AddItem(token, vector.float())
274+
self.vectors[token] = vector.float()
275+
276+
@torch.jit.export
277+
def __len__(self):
278+
r"""Get length of vectors object.
279+
280+
Returns:
281+
length (int): the length of the vectors.
282+
"""
283+
return len(self.vectors)
275284

276285

277286
CHECKSUMS_GLOVE = {

0 commit comments

Comments
 (0)