@@ -416,7 +416,7 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_doubl
416416
417417// ===========================================================
418418
419- TORCH_LIBRARY_FRAGMENT (torch_tp_jit , m) {
419+ TORCH_LIBRARY_FRAGMENT (libtorch_tp_jit , m) {
420420 m.class_ <TorchJITProduct>(" TorchJITProduct" )
421421 .def (torch::init<string, Map_t, Map_t, Map_t, Map_t>())
422422 .def (" __obj_flatten__" , &TorchJITProduct::__obj_flatten__)
@@ -437,9 +437,9 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
437437 return c10::make_intrusive<TorchJITProduct>(get<0 >(state), get<1 >(state), get<2 >(state), get<3 >(state), get<4 >(state));
438438 });
439439
440- m.def (" jit_tp_forward(__torch__.torch.classes.torch_tp_jit .TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W) -> Tensor" );
441- m.def (" jit_tp_backward(__torch__.torch.classes.torch_tp_jit .TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)" );
442- m.def (" jit_tp_double_backward(__torch__.torch.classes.torch_tp_jit .TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)" );
440+ m.def (" jit_tp_forward(__torch__.torch.classes.libtorch_tp_jit .TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W) -> Tensor" );
441+ m.def (" jit_tp_backward(__torch__.torch.classes.libtorch_tp_jit .TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)" );
442+ m.def (" jit_tp_double_backward(__torch__.torch.classes.libtorch_tp_jit .TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)" );
443443
444444
445445 m.class_ <TorchJITConv>(" TorchJITConv" )
@@ -462,12 +462,12 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
462462 return c10::make_intrusive<TorchJITConv>(get<0 >(state), get<1 >(state), get<2 >(state), get<3 >(state), get<4 >(state));
463463 });
464464
465- m.def (" jit_conv_forward(__torch__.torch.classes.torch_tp_jit .TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor" );
466- m.def (" jit_conv_backward(__torch__.torch.classes.torch_tp_jit .TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)" );
467- m.def (" jit_conv_double_backward(__torch__.torch.classes.torch_tp_jit .TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)" );
465+ m.def (" jit_conv_forward(__torch__.torch.classes.libtorch_tp_jit .TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor" );
466+ m.def (" jit_conv_backward(__torch__.torch.classes.libtorch_tp_jit .TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)" );
467+ m.def (" jit_conv_double_backward(__torch__.torch.classes.libtorch_tp_jit .TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)" );
468468};
469469
470- TORCH_LIBRARY_IMPL (torch_tp_jit , CUDA, m) {
470+ TORCH_LIBRARY_IMPL (libtorch_tp_jit , CUDA, m) {
471471 m.impl (" jit_tp_forward" , &jit_tp_forward);
472472 m.impl (" jit_tp_backward" , &jit_tp_backward);
473473 m.impl (" jit_tp_double_backward" , &jit_tp_double_backward);
@@ -477,4 +477,4 @@ TORCH_LIBRARY_IMPL(torch_tp_jit, CUDA, m) {
477477 m.impl (" jit_conv_double_backward" , &jit_conv_double_backward);
478478};
479479
480- PYBIND11_MODULE (torch_tp_jit , m) {}
480+ PYBIND11_MODULE (libtorch_tp_jit , m) {}
0 commit comments