@@ -17,6 +17,16 @@ namespace pyapi {
1717 return field_name; \
1818 }
1919
20+ // TODO: Make this error message more informative
21+ #define ADD_ENUM_GET_SET (field_name, type, max_val ) \
22+ void set_##field_name(int64_t val) { \
23+ TRTORCH_CHECK (val < max_val, " Invalid enum value for field" ); \
24+ field_name = static_cast <type>(val); \
25+ } \
26+ int64_t get_##field_name() { \
27+ return static_cast <int64_t >(field_name); \
28+ }
29+
2030struct InputRange : torch::CustomClassHolder {
2131 std::vector<int64_t > min;
2232 std::vector<int64_t > opt;
@@ -59,7 +69,7 @@ struct Device : torch::CustomClassHolder {
5969 allow_gpu_fallback(false ) // allow_gpu_fallback
6070 {}
6171
62- ADD_FIELD_GET_SET (device_type, DeviceType);
72+ ADD_ENUM_GET_SET (device_type, DeviceType, 1 );
6373 ADD_FIELD_GET_SET (gpu_id, int64_t );
6474 ADD_FIELD_GET_SET (dla_core, int64_t );
6575 ADD_FIELD_GET_SET (allow_gpu_fallback, bool );
@@ -77,28 +87,22 @@ enum class EngineCapability : int8_t {
7787std::string to_str (EngineCapability value);
7888nvinfer1::EngineCapability toTRTEngineCapability (EngineCapability value);
7989
80- // TODO: Make this error message more informative
81- #define ADD_ENUM_GET_SET (field_name, type, max_val ) \
82- void set_##field_name(int64_t val) { \
83- TRTORCH_CHECK (val < max_val, " Invalid enum value for field" ); \
84- field_name = static_cast <type>(val); \
85- } \
86- int64_t get_##field_name() { \
87- return static_cast <int64_t >(field_name); \
88- }
89-
9090struct CompileSpec : torch::CustomClassHolder {
9191 core::CompileSpec toInternalCompileSpec ();
9292 std::string stringify ();
9393 void appendInputRange (const c10::intrusive_ptr<InputRange>& ir) {
9494 input_ranges.push_back (*ir);
9595 }
9696
97- ADD_ENUM_GET_SET (op_precision, DataType, 3 );
97+ void setDeviceIntrusive (const c10::intrusive_ptr<Device>& d) {
98+ device = *d;
99+ }
100+
101+ ADD_ENUM_GET_SET (op_precision, DataType, 2 );
98102 ADD_FIELD_GET_SET (refit, bool );
99103 ADD_FIELD_GET_SET (debug, bool );
100104 ADD_FIELD_GET_SET (strict_types, bool );
101- ADD_ENUM_GET_SET (capability, EngineCapability, 3 );
105+ ADD_ENUM_GET_SET (capability, EngineCapability, 2 );
102106 ADD_FIELD_GET_SET (num_min_timing_iters, int64_t );
103107 ADD_FIELD_GET_SET (num_avg_timing_iters, int64_t );
104108 ADD_FIELD_GET_SET (workspace_size, int64_t );
0 commit comments