Skip to content

Commit f862576

Browse files
day9011day9011
andauthored
[fix] fix conv2d subsampling4 mask bug (#1602)
* fix utils/executor.py L76. fix transducer config use bug. * fix flake8 error * fix Conv2dSubsampling4 mask bug * fix subsampling mask bug Co-authored-by: day9011 <day9011@gmail.com>
1 parent 05b1411 commit f862576

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

wenet/transformer/subsampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def forward(
131131
b, c, t, f = x.size()
132132
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
133133
x, pos_emb = self.pos_enc(x, offset)
134-
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]
134+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
135135

136136

137137
class Conv2dSubsampling6(BaseSubsampling):
@@ -182,7 +182,7 @@ def forward(
182182
b, c, t, f = x.size()
183183
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
184184
x, pos_emb = self.pos_enc(x, offset)
185-
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3]
185+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
186186

187187

188188
class Conv2dSubsampling8(BaseSubsampling):
@@ -237,4 +237,4 @@ def forward(
237237
b, c, t, f = x.size()
238238
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
239239
x, pos_emb = self.pos_enc(x, offset)
240-
return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
240+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]

0 commit comments

Comments
 (0)