@@ -8,18 +8,22 @@ namespace converters {
88namespace impl {
99namespace {
1010
11- nvinfer1::ILayer* add_elementwise (ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, float scalar=1 ) {
11+ nvinfer1::ILayer* add_elementwise (ConversionCtx* ctx, nvinfer1::ElementWiseOperation op, nvinfer1::ITensor* self, nvinfer1::ITensor* other, const std::string& name, float scalar=1 ) {
1212 auto self_dims = self->getDimensions ();
13+ auto self_dims_vec = util::toVec (self_dims);
1314 auto other_dims = other->getDimensions ();
15+ auto other_dims_vec = util::toVec (other_dims);
16+ auto other_batch = other_dims_vec[0 ];
1417
15- TRTORCH_CHECK (util::volume (self_dims) == util::volume (other_dims), " Found inputs to elementwise operation do not have the same number of elements:\n Found: self " << self_dims << " other " << other_dims);
18+ // TODO: Proper broadcast check
19+ TRTORCH_CHECK (util::volume (self_dims) == util::volume (other_dims) || util::volume (self_dims) == util::volume (other_dims) / other_batch, " Found inputs to elementwise operation do not have the same number of elements or is not broadcastable:\n Found: self " << self_dims << " other " << other_dims);
1620
1721 if (self_dims != other_dims) {
1822 LOG_DEBUG (" Input shape dont match inserting shuffle layers to reshape to " << self_dims);
19- auto other_shuffle = ctx->net ->addShuffle (*other );
20- other_shuffle ->setReshapeDimensions (self_dims );
21- other_shuffle ->setName (std::string (" [Reshape other to " + util::toStr (self_dims) + ' ] ' ).c_str ());
22- other = other_shuffle ->getOutput (0 );
23+ auto self_shuffle = ctx->net ->addShuffle (*self );
24+ self_shuffle ->setReshapeDimensions (util::toDimsPad (self_dims_vec, other_dims_vec. size ()) );
25+ self_shuffle ->setName (std::string (" [Reshape self to " + util::toStr (self_dims) + " for broadcasting ( " + name + " )] " ).c_str ());
26+ self = self_shuffle ->getOutput (0 );
2327 }
2428
2529
@@ -72,7 +76,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
7276 auto self = args[0 ].ITensor ();
7377 auto other = args[1 ].ITensor ();
7478 auto scalar = args[2 ].unwrapToScalar ().to <float >();
75- auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, scalar);
79+ auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, util::node_info (n), scalar);
7680
7781 TRTORCH_CHECK (add, " Unable to create add layer from node: " << *n);
7882
@@ -89,7 +93,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
8993 auto self = args[0 ].ITensor ();
9094 auto other = args[1 ].ITensor ();
9195 auto scalar = args[2 ].unwrapToScalar ().to <float >();
92- auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, scalar);
96+ auto add = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUM , self, other, util::node_info (n), scalar);
9397
9498 TRTORCH_CHECK (add, " Unable to create add layer from node: " << *n);
9599
@@ -106,7 +110,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
106110 auto self = args[0 ].ITensor ();
107111 auto other = args[1 ].ITensor ();
108112 auto scalar = args[2 ].unwrapToScalar ().to <float >();
109- auto sub = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , self, other, scalar);
113+ auto sub = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , self, other, util::node_info (n), scalar);
110114
111115 TRTORCH_CHECK (sub, " Unable to create sub layer from node: " << *n);
112116
@@ -122,7 +126,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
122126 // Should implement self / other
123127 auto self = args[0 ].ITensor ();
124128 auto other = args[1 ].ITensor ();
125- auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other);
129+ auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n) );
126130
127131 TRTORCH_CHECK (div, " Unable to create div layer from node: " << *n);
128132
@@ -138,7 +142,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
138142 // TODO: Remove with functionalization
139143 auto self = args[0 ].ITensor ();
140144 auto other = args[1 ].ITensor ();
141- auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other);
145+ auto div = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n) );
142146
143147 TRTORCH_CHECK (div, " Unable to create div layer from node: " << *n);
144148
@@ -154,7 +158,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
154158 // Should implement self * other
155159 auto self = args[0 ].ITensor ();
156160 auto other = args[1 ].ITensor ();
157- auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other);
161+ auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other, util::node_info (n) );
158162
159163 TRTORCH_CHECK (mul, " Unable to create mul layer from node: " << *n);
160164
@@ -170,7 +174,7 @@ auto element_wise_registrations = RegisterNodeConversionPatterns()
170174 // TODO: Remove with functionalization
171175 auto self = args[0 ].ITensor ();
172176 auto other = args[1 ].ITensor ();
173- auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other);
177+ auto mul = add_elementwise (ctx, nvinfer1::ElementWiseOperation::kPROD , self, other, util::node_info (n) );
174178
175179 TRTORCH_CHECK (mul, " Unable to create mul layer from node: " << *n);
176180
0 commit comments