Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ def __call__(
argument_type_gen=argument_type_gen
).convert_arguments(arguments)

# +1 for the return value
num_boxed_args = len(binding_list) + 1
# This safety check does not account for optional args with default values. ET itself doesnt support default args, but when supported is added this check can be relaxed to >= # of non default arg.
safety_check = f"""ET_KERNEL_CHECK_MSG(context, stack.size() == {num_boxed_args}, InvalidProgram, /*void*/, \"Expected %\" ET_PRIsize_t \"args received %\" ET_PRIsize_t, (size_t){num_boxed_args}, stack.size());"""
# for each C++ argument, generate the conversion code
code_connector = "\n\t"
arg_connector = ", "
Expand Down Expand Up @@ -292,12 +296,13 @@ def __call__(
{indent} context.fail(torch::executor::Error::Internal);
{indent}}}"""
newline = "\n "
return "\n".join(
temp = "\n".join(
[
f"""
Kernel(
"{f.namespace}::{f.func.name}",{newline + '"' + (k + '",') if k != "default" else ""}
[]({contextArg.defn()}, Span<EValue*> stack) {{
{safety_check}
{code_connector.join(code_list)}

{exception_boundary_begin}
Expand All @@ -313,6 +318,7 @@ def __call__(
for k in used_kernel_keys
]
)
return temp


def gen_unboxing(
Expand Down Expand Up @@ -534,6 +540,7 @@ def gen_headers(
"headers": [
"#include <executorch/runtime/core/exec_aten/exec_aten.h> // at::Tensor etc.",
"#include <executorch/runtime/kernel/kernel_runtime_context.h>",
"#include <executorch/runtime/core/error.h>",
],
}
if use_aten_lib:
Expand Down
1 change: 1 addition & 0 deletions codegen/templates/RegisterCodegenUnboxedKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/core/span.h>
#include <executorch/runtime/kernel/operator_registry.h>
#include <executorch/runtime/platform/profiler.h>
Expand Down
1 change: 1 addition & 0 deletions codegen/templates/RegisterKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// This implements register_all_kernels() API that is declared in
// RegisterKernels.h
#include "RegisterKernels.h"
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include "${fn_header}" // Generated Function import headers

namespace torch {
Expand Down
5 changes: 4 additions & 1 deletion codegen/test/test_executorch_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def test_codegen_unboxed_specialized(self) -> None:
"custom_1::op_1",
"v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3",
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
ET_KERNEL_CHECK_MSG(context, stack.size() == 1, InvalidProgram, /*void*/, \"Expected %\" ET_PRIsize_t \"args received %\" ET_PRIsize_t, (size_t)1, stack.size());
"""
+ """

Expand Down Expand Up @@ -606,6 +607,7 @@ def test_codegen_unboxed_default(self) -> None:
Kernel(
"custom_1::op_1",
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
ET_KERNEL_CHECK_MSG(context, stack.size() == 1, InvalidProgram, /*void*/, \"Expected %\" ET_PRIsize_t \"args received %\" ET_PRIsize_t, (size_t)1, stack.size());
"""
+ """

Expand All @@ -621,7 +623,6 @@ def test_codegen_unboxed_default(self) -> None:
),
"""
)

self.assertEqual(expected_str, result)

result = ComputeCodegenUnboxedKernels(
Expand All @@ -633,6 +634,7 @@ def test_codegen_unboxed_default(self) -> None:
Kernel(
"custom_1::op_1",
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
ET_KERNEL_CHECK_MSG(context, stack.size() == 1, InvalidProgram, /*void*/, "Expected %" ET_PRIsize_t "args received %" ET_PRIsize_t, (size_t)1, stack.size());
"""
+ """

Expand Down Expand Up @@ -676,6 +678,7 @@ def test_codegen_unboxed_default_kernel_key_selected(self) -> None:
Kernel(
"custom_1::op_1",
[](torch::executor::KernelRuntimeContext & context, Span<EValue*> stack) {
ET_KERNEL_CHECK_MSG(context, stack.size() == 1, InvalidProgram, /*void*/, "Expected %" ET_PRIsize_t "args received %" ET_PRIsize_t, (size_t)1, stack.size());
"""
+ """

Expand Down
2 changes: 2 additions & 0 deletions shim_et/xplat/executorch/codegen/codegen.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,7 @@ def executorch_generated_lib(
exported_deps = [
"//executorch/codegen:macros",
"//executorch/runtime/kernel:kernel_runtime_context" + aten_suffix,
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
],
feature = feature,
)
Expand Down Expand Up @@ -933,6 +934,7 @@ def executorch_generated_lib(
exported_deps = [
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
"//executorch/runtime/kernel:kernel_runtime_context" + aten_suffix,
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
],
xplat_deps = xplat_deps,
fbcode_deps = fbcode_deps,
Expand Down
Loading