-
Notifications
You must be signed in to change notification settings - Fork 946
Expand file tree
/
Copy pathtest_sample.py
More file actions
156 lines (122 loc) · 5.4 KB
/
test_sample.py
File metadata and controls
156 lines (122 loc) · 5.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright (c) Facebook, Inc. and its affiliates.
import unittest
import tests.test_utils as test_utils
import torch
from mmf.common.sample import (
Sample,
SampleList,
convert_batch_to_sample_list,
to_device,
)
class TestSample(unittest.TestCase):
def test_sample_working(self):
initial = Sample()
initial.x = 1
initial["y"] = 2
# Assert setter and getter
self.assertEqual(initial.x, 1)
self.assertEqual(initial["x"], 1)
self.assertEqual(initial.y, 2)
self.assertEqual(initial["y"], 2)
update_dict = {"a": 3, "b": {"c": 4}}
initial.update(update_dict)
self.assertEqual(initial.a, 3)
self.assertEqual(initial["a"], 3)
self.assertEqual(initial.b.c, 4)
self.assertEqual(initial["b"].c, 4)
class TestSampleList(unittest.TestCase):
@test_utils.skip_if_no_cuda
def test_pin_memory(self):
sample_list = test_utils.build_random_sample_list()
sample_list.pin_memory()
pin_list = [sample_list.y, sample_list.z.y]
non_pin_list = [sample_list.x, sample_list.z.x]
all_pinned = True
for pin in pin_list:
all_pinned = all_pinned and pin.is_pinned()
self.assertTrue(all_pinned)
any_pinned = False
for pin in non_pin_list:
any_pinned = any_pinned or (hasattr(pin, "is_pinned") and pin.is_pinned())
self.assertFalse(any_pinned)
def test_to_dict(self):
sample_list = test_utils.build_random_sample_list()
sample_dict = sample_list.to_dict()
self.assertTrue(isinstance(sample_dict, dict))
# hasattr won't work anymore
self.assertFalse(hasattr(sample_dict, "x"))
keys_to_assert = ["x", "y", "z", "z.x", "z.y"]
all_keys = True
for key in keys_to_assert:
current = sample_dict
if "." in key:
sub_keys = key.split(".")
for sub_key in sub_keys:
all_keys = all_keys and sub_key in current
current = current[sub_key]
else:
all_keys = all_keys and key in current
self.assertTrue(all_keys)
self.assertTrue(isinstance(sample_dict, dict))
def test_equal(self):
sample_list1 = test_utils.build_random_sample_list()
sample_list2 = sample_list1.copy()
sample_list3 = sample_list1.copy()
sample_list3.add_field('new',list([1,2,3,4,5]))
sample_list4 = sample_list1.copy()
tensor_size = sample_list1.get_batch_size()
sample_list4.add_field('new',torch.zeros(tensor_size))
sample_list5 = SampleList()
sample_list6 = SampleList()
sample_list6.add_field('new',SampleList())
sample_list7 = SampleList()
dict_example = {'a':1, 'b':2}
sample_list7.add_field('new',dict_example)
sample_list8 = sample_list1.copy()
sample_list8.add_field('new',torch.ones(tensor_size))
self.assertTrue(sample_list1 == sample_list2)
self.assertTrue(sample_list1 != sample_list3)
self.assertTrue(sample_list1 != sample_list4)
self.assertTrue(sample_list2 != sample_list4)
self.assertTrue(sample_list1 != sample_list5)
self.assertTrue(sample_list1 != sample_list6)
self.assertTrue(sample_list1 != sample_list7)
self.assertTrue(sample_list5 != sample_list6)
self.assertTrue(sample_list6 != sample_list7)
self.assertTrue(sample_list6 != sample_list5)
self.assertTrue(sample_list6 != sample_list1)
self.assertTrue(sample_list1 != sample_list8)
self.assertFalse(sample_list4 == sample_list8)
class TestFunctions(unittest.TestCase):
def test_to_device(self):
sample_list = test_utils.build_random_sample_list()
modified = to_device(sample_list, "cpu")
self.assertEqual(modified.get_device(), torch.device("cpu"))
modified = to_device(sample_list, torch.device("cpu"))
self.assertEqual(modified.get_device(), torch.device("cpu"))
modified = to_device(sample_list, "cuda")
if torch.cuda.is_available():
self.assertEqual(modified.get_device(), torch.device("cuda:0"))
else:
self.assertEqual(modified.get_device(), torch.device("cpu"))
double_modified = to_device(modified, modified.get_device())
self.assertTrue(double_modified is modified)
custom_batch = [{"a": 1}]
self.assertEqual(to_device(custom_batch), custom_batch)
def test_convert_batch_to_sample_list(self):
# Test list conversion
batch = [{"a": torch.tensor([1.0, 1.0])}, {"a": torch.tensor([2.0, 2.0])}]
sample_list = convert_batch_to_sample_list(batch)
expected_a = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
self.assertTrue(torch.equal(expected_a, sample_list.a))
# Test single element list, samplelist
sample_list = SampleList()
sample_list.add_field("a", expected_a)
parsed_sample_list = convert_batch_to_sample_list([sample_list])
self.assertTrue(isinstance(parsed_sample_list, SampleList))
self.assertTrue("a" in parsed_sample_list)
self.assertTrue(torch.equal(expected_a, parsed_sample_list.a))
# Test no tensor field
batch = [{"a": [1]}, {"a": [2]}]
sample_list = convert_batch_to_sample_list(batch)
self.assertTrue(sample_list.a, [[1], [2]])