Skip to content

Commit 0b3cfaf

Browse files
Torchscript C++ Interface (#110)
* Began writing C++ inerface test. * Switched to a CMake workflow. * Making progress getting the test to link. * We are having trouble linking the loaded library. * Renamed files, linked in Python 3 appropriately. * Successfully linked program and loaded in scripted module. * Working on a nontrivial JITScript test. * Call needs to be updated to work on a CUDA backend. * Forward call works! * Ready for linting, testing, and merge. * Linted. * Deleted README tutorial file, already covered by import_test.py. * Added rpath linking, now need to fix the Python linkage issue. * Removed print statements. * Linted.
1 parent c8fd4f4 commit 0b3cfaf

12 files changed

Lines changed: 227 additions & 127 deletions

File tree

examples/readme_tutorial.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

openequivariance/__init__.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,6 @@
1010
)
1111
from openequivariance.implementations.utils import torch_to_oeq_dtype
1212

13-
__all__ = [
14-
"TPProblem",
15-
"Irreps",
16-
"TensorProduct",
17-
"TensorProductConv",
18-
"torch_to_oeq_dtype",
19-
]
20-
2113
__version__ = version("openequivariance")
2214

2315

@@ -30,3 +22,18 @@ def _check_package_editable():
3022

3123

3224
_editable_install_output_path = Path(__file__).parent.parent / "outputs"
25+
26+
27+
def torch_ext_so_path():
28+
return openequivariance.extlib.torch_module.__file__
29+
30+
31+
__all__ = [
32+
"TPProblem",
33+
"Irreps",
34+
"TensorProduct",
35+
"TensorProductConv",
36+
"torch_to_oeq_dtype",
37+
"_check_package_editable",
38+
"torch_ext_so_path",
39+
]

openequivariance/extension/torch_tp_jit.cpp renamed to openequivariance/extension/libtorch_tp_jit.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_doubl
416416

417417
// ===========================================================
418418

419-
TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
419+
TORCH_LIBRARY_FRAGMENT(libtorch_tp_jit, m) {
420420
m.class_<TorchJITProduct>("TorchJITProduct")
421421
.def(torch::init<string, Map_t, Map_t, Map_t, Map_t>())
422422
.def("__obj_flatten__", &TorchJITProduct::__obj_flatten__)
@@ -437,9 +437,9 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
437437
return c10::make_intrusive<TorchJITProduct>(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state));
438438
});
439439

440-
m.def("jit_tp_forward(__torch__.torch.classes.torch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W) -> Tensor");
441-
m.def("jit_tp_backward(__torch__.torch.classes.torch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)");
442-
m.def("jit_tp_double_backward(__torch__.torch.classes.torch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)");
440+
m.def("jit_tp_forward(__torch__.torch.classes.libtorch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W) -> Tensor");
441+
m.def("jit_tp_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)");
442+
m.def("jit_tp_double_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)");
443443

444444

445445
m.class_<TorchJITConv>("TorchJITConv")
@@ -462,12 +462,12 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
462462
return c10::make_intrusive<TorchJITConv>(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state));
463463
});
464464

465-
m.def("jit_conv_forward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor");
466-
m.def("jit_conv_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)");
467-
m.def("jit_conv_double_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)");
465+
m.def("jit_conv_forward(__torch__.torch.classes.libtorch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor");
466+
m.def("jit_conv_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)");
467+
m.def("jit_conv_double_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)");
468468
};
469469

470-
TORCH_LIBRARY_IMPL(torch_tp_jit, CUDA, m) {
470+
TORCH_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) {
471471
m.impl("jit_tp_forward", &jit_tp_forward);
472472
m.impl("jit_tp_backward", &jit_tp_backward);
473473
m.impl("jit_tp_double_backward", &jit_tp_double_backward);
@@ -477,4 +477,4 @@ TORCH_LIBRARY_IMPL(torch_tp_jit, CUDA, m) {
477477
m.impl("jit_conv_double_backward", &jit_conv_double_backward);
478478
};
479479

480-
PYBIND11_MODULE(torch_tp_jit, m) {}
480+
PYBIND11_MODULE(libtorch_tp_jit, m) {}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
project(test_oeq_jitscript_load)
3+
4+
find_package(Torch REQUIRED)
5+
6+
add_executable(load_jitscript load_jitscript.cpp)
7+
target_link_libraries(load_jitscript "${TORCH_LIBRARIES}")
8+
target_link_libraries(load_jitscript -Wl,--no-as-needed "${OEQ_EXTLIB}")
9+
set_property(TARGET load_jitscript PROPERTY CXX_STANDARD 17)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include <torch/script.h>
2+
3+
#include <iostream>
4+
#include <memory>
5+
6+
/*
7+
* This program takes in two JITScript modules that execute
8+
* a tensor product in FP32 precision.
9+
* The first module is compiled from e3nn, the second is
10+
* OEQ's compiled module. The program checks that the
11+
* two outputs are comparable.
12+
*/
13+
14+
int main(int argc, const char* argv[]) {
15+
if (argc != 7) {
16+
std::cerr << "usage: load_jitscript "
17+
<< "<path-to-e3nn-module> "
18+
<< "<path-to-oeq-module> "
19+
<< "<L1_dim> "
20+
<< "<L2_dim> "
21+
<< "<weight_numel> "
22+
<< "<batch_size> "
23+
<< std::endl;
24+
25+
return 1;
26+
}
27+
28+
int64_t L1_dim = std::stoi(argv[3]);
29+
int64_t L2_dim = std::stoi(argv[4]);
30+
int64_t weight_numel = std::stoi(argv[5]);
31+
int64_t batch_size = std::stoi(argv[6]);
32+
33+
torch::Device device(torch::kCUDA);
34+
std::vector<torch::jit::IValue> inputs;
35+
inputs.push_back(torch::randn({batch_size, L1_dim}, device));
36+
inputs.push_back(torch::randn({batch_size, L2_dim}, device));
37+
inputs.push_back(torch::randn({batch_size, weight_numel}, device));
38+
39+
torch::jit::script::Module module_e3nn, module_oeq;
40+
try {
41+
module_e3nn = torch::jit::load(argv[1]);
42+
module_oeq = torch::jit::load(argv[2]);
43+
}
44+
catch (const c10::Error& e) {
45+
std::cerr << "error loading script module" << std::endl;
46+
return 1;
47+
}
48+
49+
module_e3nn.to(device);
50+
module_oeq.to(device);
51+
52+
at::Tensor output_e3nn = module_e3nn.forward(inputs).toTensor();
53+
at::Tensor output_oeq = module_oeq.forward(inputs).toTensor();
54+
55+
if(at::allclose(output_e3nn, output_oeq, 1e-5, 1e-5)) {
56+
return 0;
57+
}
58+
else {
59+
std::cerr << "torch.allclose returned FALSE comparing model outputs." << std::endl;
60+
return 1;
61+
}
62+
}

openequivariance/extlib/__init__.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# ruff: noqa : F401, E402
2+
import sys
23
import os
34
import warnings
45
from pathlib import Path
56

67
from openequivariance.benchmark.logging_utils import getLogger
8+
from distutils import sysconfig
79

810
oeq_root = str(Path(__file__).parent.parent)
911

@@ -12,6 +14,18 @@
1214
torch_module, generic_module = None, None
1315
postprocess_kernel = lambda kernel: kernel # noqa : E731
1416

17+
try:
18+
python_lib_dir = sysconfig.get_config_var("LIBDIR")
19+
major, minor = sys.version_info.major, sys.version_info.minor
20+
python_lib_name = f"python{major}.{minor}"
21+
22+
except Exception as e:
23+
print("Error while retrieving Python library information:", file=sys.stderr)
24+
print(e, file=sys.stderr)
25+
print("Syconfig variable list:", file=sys.stderr)
26+
print(sysconfig.get_config_vars(), file=sys.stderr)
27+
exit(1)
28+
1529
if not build_ext:
1630
from openequivariance.extlib.generic_module import (
1731
GenericTensorProductImpl,
@@ -32,14 +46,23 @@
3246

3347
extra_cflags = ["-O3"]
3448
generic_sources = ["generic_module.cpp"]
35-
torch_sources = ["torch_tp_jit.cpp"]
49+
torch_sources = ["libtorch_tp_jit.cpp"]
50+
51+
include_dirs, extra_link_args = (
52+
["util"],
53+
[
54+
f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
55+
f"-L{python_lib_dir}",
56+
f"-l{python_lib_name}",
57+
],
58+
)
3659

37-
include_dirs, extra_link_args = ["util"], None
3860
if torch.version.cuda:
39-
extra_link_args = ["-Wl,--no-as-needed", "-lcuda", "-lcudart", "-lnvrtc"]
61+
extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])
4062

4163
try:
42-
cuda_libs = library_paths("cuda")[1]
64+
torch_libs, cuda_libs = library_paths("cuda")
65+
extra_link_args.append("-Wl,-rpath," + torch_libs)
4366
extra_link_args.append("-L" + cuda_libs)
4467
if os.path.exists(cuda_libs + "/stubs"):
4568
extra_link_args.append("-L" + cuda_libs + "/stubs")
@@ -48,7 +71,9 @@
4871

4972
extra_cflags.append("-DCUDA_BACKEND")
5073
elif torch.version.hip:
51-
extra_link_args = ["-Wl,--no-as-needed", "-lhiprtc"]
74+
extra_link_args.extend(["-lhiprtc"])
75+
torch_libs = library_paths("cuda")[0]
76+
extra_link_args.append("-Wl,-rpath," + torch_libs)
5277

5378
def postprocess(kernel):
5479
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
@@ -72,7 +97,7 @@ def postprocess(kernel):
7297

7398
try:
7499
torch_module = torch.utils.cpp_extension.load(
75-
"torch_tp_jit",
100+
"libtorch_tp_jit",
76101
torch_sources,
77102
extra_cflags=extra_cflags,
78103
extra_include_paths=include_dirs,

openequivariance/implementations/LoopUnrollTP.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def generate_double_backward_schedule(warps_per_block):
103103
global torch
104104
import torch
105105

106-
internal_cls = torch.classes.torch_tp_jit.TorchJITProduct
106+
internal_cls = torch.classes.libtorch_tp_jit.TorchJITProduct
107107
else:
108108
internal_cls = extlib.JITTPImpl
109109

@@ -142,7 +142,7 @@ def register_torch_fakes(cls):
142142
global torch
143143
import torch
144144

145-
@torch._library.register_fake_class("torch_tp_jit::TorchJITProduct")
145+
@torch._library.register_fake_class("libtorch_tp_jit::TorchJITProduct")
146146
class TorchJITProduct:
147147
def __init__(
148148
self,
@@ -198,19 +198,19 @@ def backward_rawptr(
198198
):
199199
pass
200200

201-
@torch.library.register_fake("torch_tp_jit::jit_tp_forward")
201+
@torch.library.register_fake("libtorch_tp_jit::jit_tp_forward")
202202
def fake_forward(jit, L1_in, L2_in, W):
203203
return L1_in.new_empty(
204204
L1_in.shape[0], jit.wrapped_obj.kernel_dims["L3_dim"]
205205
)
206206

207-
@torch.library.register_fake("torch_tp_jit::jit_tp_backward")
207+
@torch.library.register_fake("libtorch_tp_jit::jit_tp_backward")
208208
def fake_backward(jit, L1_in, L2_in, W, L3_grad):
209209
return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W)
210210

211211
@classmethod
212212
def register_autograd(cls):
213-
backward_op = torch.ops.torch_tp_jit.jit_tp_backward
213+
backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward
214214

215215
def setup_context(ctx, inputs, output):
216216
ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs
@@ -222,20 +222,20 @@ def backward(ctx, grad_output):
222222
return None, L1_grad, L2_grad, W_grad
223223

224224
torch.library.register_autograd(
225-
"torch_tp_jit::jit_tp_forward", backward, setup_context=setup_context
225+
"libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context
226226
)
227227

228228
def setup_context_double_backward(ctx, inputs, output):
229229
ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs
230230

231231
def double_backward(ctx, E, F, G):
232-
result = torch.ops.torch_tp_jit.jit_tp_double_backward(
232+
result = torch.ops.libtorch_tp_jit.jit_tp_double_backward(
233233
ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G
234234
)
235235
return None, result[0], result[1], result[2], result[3]
236236

237237
torch.library.register_autograd(
238-
"torch_tp_jit::jit_tp_backward",
238+
"libtorch_tp_jit::jit_tp_backward",
239239
double_backward,
240240
setup_context=setup_context_double_backward,
241241
)

openequivariance/implementations/TensorProduct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def name():
2020
def forward(
2121
self, L1: torch.Tensor, L2: torch.Tensor, W: torch.Tensor
2222
) -> torch.Tensor:
23-
return torch.ops.torch_tp_jit.jit_tp_forward(self.internal, L1, L2, W)
23+
return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, L1, L2, W)

0 commit comments

Comments
 (0)