Skip to content

Commit 20acf99

Browse files
[mypyc] Move API table definitions to .c files (python#21183)
lib-rt API tables are defined as static variables in lib-rt headers. This means that each translation unit gets its own independent instance of these variables. That becomes a problem in multi-file compilation mode when using `BytesWriter.write`, as this function is translated to `CPyBytesWriter_Write` by mypyc which is defined in `bytes_writer_extra_ops.c`. With multi-file compilation, `bytes_writer_extra_ops.c` is compiled as its own translation unit and gets linked with the C extension compiled from python files. The C extension TU copies the `librt.strings` capsule contents into the global table but because it's static this is not visible in the table in the `bytes_writer_extra_ops.c` TU, which stays zero-initialized. This results in a seg fault with the following call chain: [`CPyBytesWriter_Write`](https://github.com/python/mypy/blob/master/mypyc/lib-rt/byteswriter_extra_ops.c#L8) -> [`CPyBytesWriter_EnsureSize`](https://github.com/python/mypy/blob/master/mypyc/lib-rt/byteswriter_extra_ops.c#L20) -> [`LibRTStrings_ByteWriter_grow_buffer_internal`](https://github.com/python/mypy/blob/master/mypyc/lib-rt/byteswriter_extra_ops.h#L26) -> [`LibRTStrings_API[5]`](https://github.com/python/mypy/blob/master/mypyc/lib-rt/strings/librt_strings.h#L49)-> oops all zeros To fix this, declare the tables as `extern` and define them in .c files so there's only one global version of each variable. There's one new .c file for each lib-rt module that is compiled or included when mypyc detects a dependency on that module. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ea8efbd commit 20acf99

25 files changed

Lines changed: 621 additions & 300 deletions

mypyc/build.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def build_using_shared_lib(
357357
deps: list[str],
358358
build_dir: str,
359359
extra_compile_args: list[str],
360+
extra_include_dirs: list[str],
360361
) -> list[Extension]:
361362
"""Produce the list of extension modules when a shared library is needed.
362363
@@ -373,7 +374,7 @@ def build_using_shared_lib(
373374
get_extension()(
374375
shared_lib_name(group_name),
375376
sources=cfiles,
376-
include_dirs=[include_dir(), build_dir],
377+
include_dirs=[include_dir(), build_dir] + extra_include_dirs,
377378
depends=deps,
378379
extra_compile_args=extra_compile_args,
379380
)
@@ -399,7 +400,10 @@ def build_using_shared_lib(
399400

400401

401402
def build_single_module(
402-
sources: list[BuildSource], cfiles: list[str], extra_compile_args: list[str]
403+
sources: list[BuildSource],
404+
cfiles: list[str],
405+
extra_compile_args: list[str],
406+
extra_include_dirs: list[str],
403407
) -> list[Extension]:
404408
"""Produce the list of extension modules for a standalone extension.
405409
@@ -409,7 +413,7 @@ def build_single_module(
409413
get_extension()(
410414
sources[0].module,
411415
sources=cfiles,
412-
include_dirs=[include_dir()],
416+
include_dirs=[include_dir()] + extra_include_dirs,
413417
extra_compile_args=extra_compile_args,
414418
)
415419
]
@@ -513,7 +517,9 @@ def mypyc_build(
513517
*,
514518
separate: bool | list[tuple[list[str], str | None]] = False,
515519
only_compile_paths: Iterable[str] | None = None,
516-
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
520+
skip_cgen_input: (
521+
tuple[list[list[tuple[str, str]]], list[tuple[str, list[str], bool]]] | None
522+
) = None,
517523
always_use_shared_lib: bool = False,
518524
) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]], list[SourceDep]]:
519525
"""Do the front and middle end of mypyc building, producing and writing out C source."""
@@ -547,7 +553,10 @@ def mypyc_build(
547553
write_file(os.path.join(compiler_options.target_dir, "ops.txt"), ops_text)
548554
else:
549555
group_cfiles = skip_cgen_input[0]
550-
source_deps = [SourceDep(d) for d in skip_cgen_input[1]]
556+
source_deps = [
557+
SourceDep(path, include_dirs=dirs, internal=internal)
558+
for (path, dirs, internal) in skip_cgen_input[1]
559+
]
551560

552561
# Write out the generated C and collect the files for each group
553562
# Should this be here??
@@ -664,7 +673,9 @@ def mypycify(
664673
strip_asserts: bool = False,
665674
multi_file: bool = False,
666675
separate: bool | list[tuple[list[str], str | None]] = False,
667-
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
676+
skip_cgen_input: (
677+
tuple[list[list[tuple[str, str]]], list[tuple[str, list[str], bool]]] | None
678+
) = None,
668679
target_dir: str | None = None,
669680
include_runtime_files: bool | None = None,
670681
strict_dunder_typing: bool = False,
@@ -781,12 +792,19 @@ def mypycify(
781792
# runtime library in. Otherwise it just gets #included to save on
782793
# compiler invocations.
783794
shared_cfilenames = []
795+
include_dirs = set()
784796
if not compiler_options.include_runtime_files:
785797
# Collect all files to copy: runtime files + conditional source files
786798
files_to_copy = list(RUNTIME_C_FILES)
787799
for source_dep in source_deps:
788800
files_to_copy.append(source_dep.path)
789801
files_to_copy.append(source_dep.get_header())
802+
include_dirs.update(source_dep.include_dirs)
803+
804+
if compiler_options.depends_on_librt_internal:
805+
files_to_copy.append("internal/librt_internal_api.h")
806+
files_to_copy.append("internal/librt_internal_api.c")
807+
include_dirs.add("internal")
790808

791809
# Copy all files
792810
for name in files_to_copy:
@@ -797,6 +815,7 @@ def mypycify(
797815
shared_cfilenames.append(rt_file)
798816

799817
extensions = []
818+
extra_include_dirs = [os.path.join(include_dir(), dir) for dir in include_dirs]
800819
for (group_sources, lib_name), (cfilenames, deps) in zip(groups, group_cfilenames):
801820
if lib_name:
802821
extensions.extend(
@@ -807,11 +826,14 @@ def mypycify(
807826
deps,
808827
build_dir,
809828
cflags,
829+
extra_include_dirs,
810830
)
811831
)
812832
else:
813833
extensions.extend(
814-
build_single_module(group_sources, cfilenames + shared_cfilenames, cflags)
834+
build_single_module(
835+
group_sources, cfilenames + shared_cfilenames, cflags, extra_include_dirs
836+
)
815837
)
816838

817839
if install_librt:

mypyc/codegen/emitmodule.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -436,13 +436,23 @@ def load_scc_from_cache(
436436
return modules
437437

438438

439-
def collect_source_dependencies(modules: dict[str, ModuleIR]) -> set[SourceDep]:
440-
"""Collect all SourceDep dependencies from all modules."""
439+
def collect_source_dependencies(
440+
modules: dict[str, ModuleIR], *, internal: bool = True
441+
) -> set[SourceDep]:
442+
"""Collect all SourceDep dependencies from all modules.
443+
444+
If internal is set to False, returns only the dependencies that can be exported to C extensions
445+
dependent on the one currently being compiled.
446+
"""
441447
source_deps: set[SourceDep] = set()
442448
for module in modules.values():
443449
for dep in module.dependencies:
444450
if isinstance(dep, SourceDep):
445-
source_deps.add(dep)
451+
if internal == dep.internal:
452+
source_deps.add(dep)
453+
else:
454+
capsule_dep = dep.internal_dep() if internal else dep.external_dep()
455+
source_deps.add(capsule_dep)
446456
return source_deps
447457

448458

@@ -585,6 +595,8 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
585595
source_deps = collect_source_dependencies(self.modules)
586596
for source_dep in sorted(source_deps, key=lambda d: d.path):
587597
base_emitter.emit_line(f'#include "{source_dep.path}"')
598+
if self.compiler_options.depends_on_librt_internal:
599+
base_emitter.emit_line('#include "internal/librt_internal_api.c"')
588600
base_emitter.emit_line(f'#include "__native{self.short_group_suffix}.h"')
589601
base_emitter.emit_line(f'#include "__native_internal{self.short_group_suffix}.h"')
590602
emitter = base_emitter
@@ -634,26 +646,27 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
634646
ext_declarations.emit_line(f"#define MYPYC_NATIVE{self.group_suffix}_H")
635647
ext_declarations.emit_line("#include <Python.h>")
636648
ext_declarations.emit_line("#include <CPy.h>")
637-
if self.compiler_options.depends_on_librt_internal:
638-
ext_declarations.emit_line("#include <internal/librt_internal.h>")
639-
if any(LIBRT_BASE64 in mod.dependencies for mod in self.modules.values()):
640-
ext_declarations.emit_line("#include <base64/librt_base64.h>")
641-
if any(LIBRT_STRINGS in mod.dependencies for mod in self.modules.values()):
642-
ext_declarations.emit_line("#include <strings/librt_strings.h>")
643-
if any(LIBRT_TIME in mod.dependencies for mod in self.modules.values()):
644-
ext_declarations.emit_line("#include <time/librt_time.h>")
645-
if any(LIBRT_VECS in mod.dependencies for mod in self.modules.values()):
646-
ext_declarations.emit_line("#include <vecs/librt_vecs.h>")
647-
# Include headers for conditional source files
648-
source_deps = collect_source_dependencies(self.modules)
649-
for source_dep in sorted(source_deps, key=lambda d: d.path):
650-
ext_declarations.emit_line(f'#include "{source_dep.get_header()}"')
649+
650+
def emit_dep_headers(decls: Emitter, internal: bool) -> None:
651+
suffix = "_api" if internal else ""
652+
if self.compiler_options.depends_on_librt_internal:
653+
decls.emit_line(f'#include "internal/librt_internal{suffix}.h"')
654+
# Include headers for conditional source files
655+
source_deps = collect_source_dependencies(self.modules, internal=internal)
656+
for source_dep in sorted(source_deps, key=lambda d: d.path):
657+
decls.emit_line(f'#include "{source_dep.get_header()}"')
658+
659+
emit_dep_headers(ext_declarations, False)
651660

652661
declarations = Emitter(self.context)
653662
declarations.emit_line(f"#ifndef MYPYC_LIBRT_INTERNAL{self.group_suffix}_H")
654663
declarations.emit_line(f"#define MYPYC_LIBRT_INTERNAL{self.group_suffix}_H")
655664
declarations.emit_line("#include <Python.h>")
656665
declarations.emit_line("#include <CPy.h>")
666+
667+
if not self.compiler_options.include_runtime_files:
668+
emit_dep_headers(declarations, True)
669+
657670
declarations.emit_line(f'#include "__native{self.short_group_suffix}.h"')
658671
declarations.emit_line()
659672
declarations.emit_line("int CPyGlobalsInit(void);")

mypyc/ir/deps.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Final
24

35

@@ -17,17 +19,48 @@ def __eq__(self, other: object) -> bool:
1719
def __hash__(self) -> int:
1820
return hash(("Capsule", self.name))
1921

22+
def internal_dep(self) -> SourceDep:
23+
"""Internal source dependency of the capsule that should only be included in the C extensions
24+
that depend on the capsule, eg. by importing a type or function from the capsule.
25+
"""
26+
module = self.name.split(".")[-1]
27+
return SourceDep(f"{module}/librt_{module}_api.c", include_dirs=[module])
28+
29+
# TODO: This SourceDep is really only used for its associated header so it would make more sense
30+
# to add a separate type. Alternatively, see if this can be removed altogether if we move the
31+
# definitions that depend on this header from the external header of the C extension.
32+
def external_dep(self) -> SourceDep:
33+
"""External source dependency of the capsule that may be included in external headers of C
34+
extensions that depend on the capsule.
35+
36+
The external headers of the C extensions are included by other C extensions that don't
37+
necessarily import the capsule. However, they may need type definitions from the capsule
38+
for types that are used in the exports table of the included C extensions.
39+
40+
Only the external header should be included in this case because if the other C extension
41+
doesn't import the capsule, it also doesn't include the definition for its API table and
42+
including the internal header would result in undefined symbols.
43+
"""
44+
module = self.name.split(".")[-1]
45+
return SourceDep(f"{module}/librt_{module}.c", include_dirs=[module], internal=False)
46+
2047

2148
class SourceDep:
2249
"""Defines a C source file that a primitive may require.
2350
2451
Each source file must also have a corresponding .h file (replace .c with .h)
2552
that gets implicitly #included if the source is used.
53+
include_dirs are passed to the C compiler when the file is compiled as a
54+
shared library separate from the C extension.
2655
"""
2756

28-
def __init__(self, path: str) -> None:
57+
def __init__(
58+
self, path: str, *, include_dirs: list[str] | None = None, internal: bool = True
59+
) -> None:
2960
# Relative path from mypyc/lib-rt, e.g. 'bytes_extra_ops.c'
3061
self.path: Final = path
62+
self.include_dirs: Final = include_dirs or []
63+
self.internal: Final = internal
3164

3265
def __repr__(self) -> str:
3366
return f"SourceDep(path={self.path!r})"

mypyc/ir/module_ir.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@ def serialize(self) -> JsonDict:
4141
if isinstance(dep, Capsule):
4242
serialized_deps.append({"type": "Capsule", "name": dep.name})
4343
elif isinstance(dep, SourceDep):
44-
serialized_deps.append({"type": "SourceDep", "path": dep.path})
44+
source_dep: JsonDict = {
45+
"type": "SourceDep",
46+
"path": dep.path,
47+
"include_dirs": dep.include_dirs,
48+
"internal": dep.internal,
49+
}
50+
serialized_deps.append(source_dep)
4551

4652
return {
4753
"fullname": self.fullname,
@@ -69,7 +75,13 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ModuleIR:
6975
if dep_dict["type"] == "Capsule":
7076
deps.add(Capsule(dep_dict["name"]))
7177
elif dep_dict["type"] == "SourceDep":
72-
deps.add(SourceDep(dep_dict["path"]))
78+
deps.add(
79+
SourceDep(
80+
dep_dict["path"],
81+
include_dirs=dep_dict["include_dirs"],
82+
internal=dep_dict["internal"],
83+
)
84+
)
7385
module.dependencies = deps
7486

7587
return module

mypyc/lib-rt/base64/librt_base64.h

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,44 +7,4 @@
77
#define LIBRT_BASE64_API_VERSION 2
88
#define LIBRT_BASE64_API_LEN 4
99

10-
static void *LibRTBase64_API[LIBRT_BASE64_API_LEN];
11-
12-
#define LibRTBase64_ABIVersion (*(int (*)(void)) LibRTBase64_API[0])
13-
#define LibRTBase64_APIVersion (*(int (*)(void)) LibRTBase64_API[1])
14-
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[2])
15-
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[3])
16-
17-
static int
18-
import_librt_base64(void)
19-
{
20-
PyObject *mod = PyImport_ImportModule("librt.base64");
21-
if (mod == NULL)
22-
return -1;
23-
Py_DECREF(mod); // we import just for the side effect of making the below work.
24-
void *capsule = PyCapsule_Import("librt.base64._C_API", 0);
25-
if (capsule == NULL)
26-
return -1;
27-
memcpy(LibRTBase64_API, capsule, sizeof(LibRTBase64_API));
28-
if (LibRTBase64_ABIVersion() != LIBRT_BASE64_ABI_VERSION) {
29-
char err[128];
30-
snprintf(err, sizeof(err), "ABI version conflict for librt.base64, expected %d, found %d",
31-
LIBRT_BASE64_ABI_VERSION,
32-
LibRTBase64_ABIVersion()
33-
);
34-
PyErr_SetString(PyExc_ValueError, err);
35-
return -1;
36-
}
37-
if (LibRTBase64_APIVersion() < LIBRT_BASE64_API_VERSION) {
38-
char err[128];
39-
snprintf(err, sizeof(err),
40-
"API version conflict for librt.base64, expected %d or newer, found %d (hint: upgrade librt)",
41-
LIBRT_BASE64_API_VERSION,
42-
LibRTBase64_APIVersion()
43-
);
44-
PyErr_SetString(PyExc_ValueError, err);
45-
return -1;
46-
}
47-
return 0;
48-
}
49-
5010
#endif // LIBRT_BASE64_H
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include "librt_base64_api.h"
2+
3+
void *LibRTBase64_API[LIBRT_BASE64_API_LEN] = {0};
4+
5+
int
6+
import_librt_base64(void)
7+
{
8+
PyObject *mod = PyImport_ImportModule("librt.base64");
9+
if (mod == NULL)
10+
return -1;
11+
Py_DECREF(mod); // we import just for the side effect of making the below work.
12+
void *capsule = PyCapsule_Import("librt.base64._C_API", 0);
13+
if (capsule == NULL)
14+
return -1;
15+
memcpy(LibRTBase64_API, capsule, sizeof(LibRTBase64_API));
16+
if (LibRTBase64_ABIVersion() != LIBRT_BASE64_ABI_VERSION) {
17+
char err[128];
18+
snprintf(err, sizeof(err), "ABI version conflict for librt.base64, expected %d, found %d",
19+
LIBRT_BASE64_ABI_VERSION,
20+
LibRTBase64_ABIVersion()
21+
);
22+
PyErr_SetString(PyExc_ValueError, err);
23+
return -1;
24+
}
25+
if (LibRTBase64_APIVersion() < LIBRT_BASE64_API_VERSION) {
26+
char err[128];
27+
snprintf(err, sizeof(err),
28+
"API version conflict for librt.base64, expected %d or newer, found %d (hint: upgrade librt)",
29+
LIBRT_BASE64_API_VERSION,
30+
LibRTBase64_APIVersion()
31+
);
32+
PyErr_SetString(PyExc_ValueError, err);
33+
return -1;
34+
}
35+
return 0;
36+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef LIBRT_BASE64_API_H
2+
#define LIBRT_BASE64_API_H
3+
4+
#include "librt_base64.h"
5+
6+
extern void *LibRTBase64_API[LIBRT_BASE64_API_LEN];
7+
8+
#define LibRTBase64_ABIVersion (*(int (*)(void)) LibRTBase64_API[0])
9+
#define LibRTBase64_APIVersion (*(int (*)(void)) LibRTBase64_API[1])
10+
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[2])
11+
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[3])
12+
13+
int import_librt_base64(void);
14+
15+
#endif // LIBRT_BASE64_API_H

mypyc/lib-rt/byteswriter_extra_ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <Python.h>
99

1010
#include "mypyc_util.h"
11-
#include "strings/librt_strings.h"
11+
#include "strings/librt_strings_api.h"
1212
#include "strings/librt_strings_common.h"
1313

1414
// BytesWriter: Length and capacity

0 commit comments

Comments
 (0)