We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a8c09f5 commit 28ea9f8Copy full SHA for 28ea9f8
1 file changed
segmentation_models_pytorch/decoders/dpt/decoder.py
@@ -32,8 +32,8 @@ def forward(
32
features = features.transpose(1, 2).contiguous()
33
34
if prefix_tokens is not None:
35
- # (batch_size, num_tokens, embed_dim) -> (batch_size, embed_dim)
36
- prefix_tokens = prefix_tokens[:, 0].expand_as(features)
+ # (batch_size, num_prefix_tokens, embed_dim) -> (batch_size, 1, embed_dim)
+ prefix_tokens = prefix_tokens[:, :1].expand_as(features)
37
features = torch.cat([features, prefix_tokens], dim=2)
38
39
# Project to embedding dimension
0 commit comments