Skip to content

Commit 7be6c7f

Browse files
[test] Fix pooler unit tests on gpu (#1165)
Summary: Pull Request resolved: #1165 Fix no attribute "cuda" runtime error when testing poolers with cuda. Unit tests pass on github actions as those machines run without gpu. Test Plan: Test unit tests with GPU on learnfair Reviewed By: ebsmothers Differential Revision: D32805045 Pulled By: Ryan-Qiyu-Jiang fbshipit-source-id: 5d2272fc99688d71bbf9cd3b7b71c051deb1de36
1 parent 6f3f40f commit 7be6c7f

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

tests/modules/test_poolers.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import mmf.modules.poolers as poolers
66
import torch
7+
from mmf.utils.general import get_current_device
78

89

910
class TestModulePoolers(unittest.TestCase):
@@ -13,36 +14,29 @@ def setUp(self):
1314
self.num_tokens = 10
1415
self.embedding_size = 768
1516
self.token_len = 10
17+
self.device = get_current_device()
1618
self.encoded_layers = [
17-
torch.randn(self.batch_size, self.token_len, self.embedding_size),
18-
torch.randn(self.batch_size, self.token_len, self.embedding_size),
19-
torch.randn(self.batch_size, self.token_len, self.embedding_size),
19+
torch.randn(self.batch_size, self.token_len, self.embedding_size).to(
20+
self.device
21+
)
22+
for _ in range(3)
2023
]
21-
self.pad_mask = torch.randn(self.batch_size, self.token_len)
24+
self.pad_mask = torch.randn(self.batch_size, self.token_len).to(self.device)
2225

2326
def test_AverageConcat(self):
24-
pool_fn = poolers.AverageConcatLastN(self.k)
27+
pool_fn = poolers.AverageConcatLastN(self.k).to(self.device)
2528
out = pool_fn(self.encoded_layers, self.pad_mask)
26-
if torch.cuda.is_available():
27-
pool_fn.cuda()
28-
out = pool_fn(self.encoded_layers.cuda(), self.pad_mask.cuda())
2929

3030
assert torch.Size([self.batch_size, self.embedding_size * self.k]) == out.shape
3131

3232
def test_AverageKFromLast(self):
33-
pool_fn = poolers.AverageKFromLast(self.k)
33+
pool_fn = poolers.AverageKFromLast(self.k).to(self.device)
3434
out = pool_fn(self.encoded_layers, self.pad_mask)
35-
if torch.cuda.is_available():
36-
pool_fn.cuda()
37-
out = pool_fn([self.encoded_layers.cuda(), self.pad_mask.cuda()])
3835

3936
assert torch.Size([self.batch_size, self.embedding_size]) == out.shape
4037

4138
def test_AverageSumLastK(self):
42-
pool_fn = poolers.AverageSumLastK(self.k)
39+
pool_fn = poolers.AverageSumLastK(self.k).to(self.device)
4340
out = pool_fn(self.encoded_layers, self.pad_mask)
44-
if torch.cuda.is_available():
45-
pool_fn.cuda()
46-
out = pool_fn([self.encoded_layers.cuda(), self.pad_mask.cuda()])
4741

4842
assert torch.Size([self.batch_size, self.embedding_size]) == out.shape

0 commit comments

Comments
 (0)