@@ -39,11 +39,15 @@ std::string AITModel::serialize() const {
3939 pick_output_names.push_back (picojson::value (entry));
4040 }
4141 var[OUTPUT_NAMES_STR] = picojson::value (pick_output_names);
42- var[FLOATING_POINT_INPUT_DTYPE_STR] = picojson::value (std::to_string (
43- static_cast <int16_t >(aitModelImpl_.floatingPointInputDtype ().value ())));
42+ var[FLOATING_POINT_INPUT_DTYPE_STR] = picojson::value (
43+ std::to_string (
44+ static_cast <int16_t >(
45+ aitModelImpl_.floatingPointInputDtype ().value ())));
4446
45- var[FLOATING_POINT_OUTPUT_DTYPE_STR] = picojson::value (std::to_string (
46- static_cast <int16_t >(aitModelImpl_.floatingPointOutputDtype ().value ())));
47+ var[FLOATING_POINT_OUTPUT_DTYPE_STR] = picojson::value (
48+ std::to_string (
49+ static_cast <int16_t >(
50+ aitModelImpl_.floatingPointOutputDtype ().value ())));
4751
4852 result = picojson::value (var).serialize ();
4953 return result;
@@ -58,14 +62,15 @@ void AITModel::loadAsTorchClass() {
5862
5963static auto registerAITModel =
6064 torch::class_<AITModel>(" ait" , " AITModel" )
61- .def(torch::init<
62- std::string,
63- std::vector<std::string>,
64- std::vector<std::string>,
65- std::optional<at::ScalarType>,
66- std::optional<at::ScalarType>,
67- int64_t ,
68- bool >())
65+ .def(
66+ torch::init<
67+ std::string,
68+ std::vector<std::string>,
69+ std::vector<std::string>,
70+ std::optional<at::ScalarType>,
71+ std::optional<at::ScalarType>,
72+ int64_t ,
73+ bool >())
6974 .def(" forward" , &AITModel::forward)
7075 .def(" profile" , &AITModel::profile)
7176 .def(" get_library_path" , &AITModel::libraryPath)
0 commit comments