1+ #include " torch/torch.h"
2+ #include " core/util/prelude.h"
3+ #include " core/conversion/converters/converters.h"
4+ #include " NvInfer.h"
5+ #include " torch/csrc/autograd/generated/variable_factories.h"
6+
7+ #include < ATen/ATen.h>
8+ #include < vector>
9+
10+ #include < csignal>
11+
12+ namespace trtorch {
13+ namespace core {
14+ namespace conversion {
15+ namespace converters {
16+ namespace impl {
17+ namespace {
18+
19+ auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
20+ .pattern({
21+ " aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))" ,
22+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
23+ std::cout << " select.int converter recognized" << std::endl;
24+
25+ auto in = args[0 ].ITensor ();
26+ auto axis = args[1 ].unwrapToInt ();
27+ auto ind = (int32_t ) args[2 ].unwrapToInt ();
28+
29+ // tried: vector for input
30+ // std::vector<int32_t> indices_input = {ind};
31+
32+ auto options = torch::TensorOptions ().device (torch::kCUDA , 1 ).dtype (torch::kInt32 );
33+ at::Tensor indices = torch::tensor (torch::detail::TensorDataContainer (ind), options);
34+
35+ auto weights = Weights (ctx, indices);
36+ // manually setting weights
37+ // weights.data.type = nvinfer1::DataType::kINT32;
38+
39+ auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
40+ const_layer->setName (util::node_info (n).c_str ());
41+ // manually setting output type
42+ // const_layer->setOutputType(0, nvinfer1::DataType::kINT32);
43+
44+ auto const_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], const_layer->getOutput (0 ));
45+
46+ auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
47+ gather_layer->setName (util::node_info (n).c_str ());
48+ // manually setting output type
49+ // gather_layer->setOutputType(0, nvinfer1::DataType::kINT32);
50+
51+ auto gather_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gather_layer->getOutput (0 ));
52+
53+ LOG_DEBUG (" Output tensor shape: " << gather_output->getDimensions ());
54+
55+ // for debugging
56+ // std::raise(SIGTRAP);
57+
58+ return true ;
59+ }
60+ });
61+
62+ } // namespace
63+ } // namespace impl
64+ } // namespace converters
65+ } // namespace conversion
66+ } // namespace core
67+ } // namespace trtorch
0 commit comments