@@ -95,43 +95,81 @@ ShapedWeights::operator nvinfer1::Weights() const
9595}
9696
9797template <typename DType>
98- void transpose2DWeights (ShapedWeights const & weights, nvinfer1::Dims const & new_shape , ShapedWeights* result)
98+ void transpose4DWeights (ShapedWeights const & weights, nvinfer1::Permutation const perm , ShapedWeights* result)
9999{
100+ nvinfer1::Dims original_shape = weights.shape ;
101+ nvinfer1::Dims new_shape = result->shape ;
102+ int nbDims = new_shape.nbDims ;
100103 DType const * src = reinterpret_cast <DType*>(weights.values );
101104 DType* dst = reinterpret_cast <DType*>(result->values );
102- int src_stride = weights.shape .d [1 ];
103- int dst_stride = result->shape .d [1 ];
104- for (int i = 0 ; i < new_shape.d [0 ]; ++i)
105+
106+ nvinfer1::Dims expanded_original_shape{4 , {1 , 1 , 1 , 1 }};
107+ nvinfer1::Dims expanded_new_shape{4 , {1 , 1 , 1 , 1 }};
108+ nvinfer1::Permutation expanded_perm{0 , 1 , 2 , 3 };
109+
110+ int pad = 4 - nbDims;
111+ for (int i = 0 ; i < nbDims; ++i)
112+ {
113+ expanded_original_shape.d [pad + i] = original_shape.d [i];
114+ expanded_new_shape.d [pad + i] = new_shape.d [i];
115+ expanded_perm.order [pad + i] = perm.order [i] + pad;
116+ }
117+
118+ int src_strides[4 ] = {1 , 1 , 1 , 1 };
119+ int dst_strides[4 ] = {1 , 1 , 1 , 1 };
120+
121+ for (int i = 2 ; i >= 0 ; --i)
122+ {
123+ src_strides[i] = expanded_original_shape.d [i + 1 ] * src_strides[i + 1 ];
124+ dst_strides[i] = expanded_new_shape.d [i + 1 ] * dst_strides[i + 1 ];
125+ }
126+
127+ for (int n = 0 ; n < expanded_original_shape.d [0 ]; ++n)
105128 {
106- for (int j = 0 ; j < new_shape .d [1 ]; ++j )
129+ for (int c = 0 ; c < expanded_original_shape .d [1 ]; ++c )
107130 {
108- dst[i * dst_stride + j] = src[j * src_stride + i];
131+ for (int h = 0 ; h < expanded_original_shape.d [2 ]; ++h)
132+ {
133+ for (int w = 0 ; w < expanded_original_shape.d [3 ]; ++w)
134+ {
135+ int src_index = 0 ;
136+ int dst_index = 0 ;
137+ int src_coord[4 ] = {n, c, h, w};
138+ int dst_coord[4 ];
139+ for (int i = 0 ; i < 4 ; ++i)
140+ {
141+ dst_coord[i] = src_coord[expanded_perm.order [i]];
142+ src_index += src_coord[i] * src_strides[i];
143+ dst_index += dst_coord[i] * dst_strides[i];
144+ }
145+ dst[dst_index] = src[src_index];
146+ }
147+ }
109148 }
110149 }
111150}
112151
113152bool transposeWeights (ShapedWeights const & weights, nvinfer1::Permutation const & perm, ShapedWeights* result)
114153{
115154 nvinfer1::Dims shape = weights.shape ;
155+ int nbDims = shape.nbDims ;
116156 nvinfer1::Dims new_shape;
117- new_shape.nbDims = shape. nbDims ;
118- for (int d = 0 ; d < shape. nbDims ; ++d)
157+ new_shape.nbDims = nbDims;
158+ for (int d = 0 ; d < nbDims; ++d)
119159 {
120160 new_shape.d [d] = shape.d [perm.order [d]];
121161 result->shape .d [d] = new_shape.d [d];
122162 }
123- // TODO: Need to generalize this transpose implementation
124- assert (perm.order [0 ] == 1 && perm.order [1 ] == 0 );
125163
126- if (shape.nbDims == 2 )
164+ if (shape.nbDims <= 4 )
127165 {
128166 if (weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT)
129167 {
130- transpose2DWeights <float >(weights, new_shape , result);
168+ transpose4DWeights <float >(weights, perm , result);
131169 }
132170 else if (weights.type == ::ONNX_NAMESPACE::TensorProto::FLOAT16)
133171 {
134- transpose2DWeights <uint16_t >(weights, new_shape , result);
172+ transpose4DWeights <uint16_t >(weights, perm , result);
135173 }
136174 else
137175 {
0 commit comments