Skip to content

Commit 1b7ff0e

Browse files
authored
[mypyc] Set dunder attrs when adding module to sys.modules (python#21275)
Recently there was a change to add native modules to `sys.modules` before they are executed to be able to detect circular imports. This introduced a regression when the module is a package that imports objects from other files within the package, eg. `from pkg.file import something` inside `pkg/__init__.py`. Such imports result in an exception `ModuleNotFoundError: No module named 'pkg.file'; 'pkg' is not a package.`, for example when trying to upgrade mypy in [black](https://github.com/psf/black/actions/runs/23933086642/job/69803937853?pr=5071). This error is raised because Python expects the parent module of `file` to have the `__path__` attribute set when [resolving the import](https://github.com/python/cpython/blob/main/Lib/importlib/_bootstrap.py#L1226) but we don't set this attribute before adding the `pkg` module to `sys.modules`. So use existing functions to set relevant dunder attributes (`__path__` for packages and `__file__`, `__spec__`, and `__package__` for all) before registering the module in `sys.modules`. Don't skip compilation for `__init__.py` files in separate compilation mode to make this possible to test. Use `Py_CLEAR` instead of `Py_DECREF` on the import object on failure in `CPyImport_ImportNative` as the import object might be freed when deleting it from `sys.modules`. This triggered an assertion when running tests with a debug build of cpython.
1 parent 2dc89dc commit 1b7ff0e

6 files changed

Lines changed: 156 additions & 48 deletions

File tree

mypyc/build.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,7 @@ def get_mypy_config(
215215
mypyc_sources = all_sources
216216

217217
if compiler_options.separate:
218-
mypyc_sources = [
219-
src for src in mypyc_sources if src.path and not src.path.endswith("__init__.py")
220-
]
218+
mypyc_sources = [src for src in mypyc_sources if src.path]
221219

222220
if not mypyc_sources:
223221
return mypyc_sources, all_sources, options
@@ -243,6 +241,10 @@ def get_mypy_config(
243241
return mypyc_sources, all_sources, options
244242

245243

244+
def is_package_source(source: BuildSource) -> bool:
245+
return source.path is not None and os.path.split(source.path)[1] == "__init__.py"
246+
247+
246248
def generate_c_extension_shim(
247249
full_module_name: str, module_name: str, dir_name: str, group_name: str
248250
) -> str:
@@ -388,7 +390,7 @@ def build_using_shared_lib(
388390
# since this seems to be needed for it to end up in the right place.
389391
full_module_name = source.module
390392
assert source.path
391-
if os.path.split(source.path)[1] == "__init__.py":
393+
if is_package_source(source):
392394
full_module_name += ".__init__"
393395
extensions.append(
394396
get_extension()(
@@ -534,6 +536,7 @@ def mypyc_build(
534536
use_shared_lib = (
535537
len(mypyc_sources) > 1
536538
or any("." in x.module for x in mypyc_sources)
539+
or any(is_package_source(x) for x in mypyc_sources)
537540
or always_use_shared_lib
538541
)
539542

mypyc/codegen/emitmodule.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from mypyc.codegen.literals import Literals
4848
from mypyc.common import (
49+
EXT_SUFFIX,
4950
IS_FREE_THREADED,
5051
MODULE_PREFIX,
5152
PREFIX,
@@ -1286,11 +1287,42 @@ def emit_module_init_func(
12861287
f"if (unlikely({module_static} == NULL))",
12871288
" goto fail;",
12881289
)
1290+
1291+
emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");')
1292+
emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();")
1293+
emitter.emit_line("int rv = 0;")
1294+
if self.group_name:
1295+
shared_lib_mod_name = shared_lib_name(self.group_name)
1296+
emitter.emit_line("PyObject *mod_dict = PyImport_GetModuleDict();")
1297+
emitter.emit_line("PyObject *shared_lib = NULL;")
1298+
emitter.emit_line(
1299+
f'rv = PyDict_GetItemStringRef(mod_dict, "{shared_lib_mod_name}", &shared_lib);'
1300+
)
1301+
emitter.emit_line("if (rv < 0) goto fail;")
1302+
emitter.emit_line(
1303+
'PyObject *shared_lib_file = PyObject_GetAttrString(shared_lib, "__file__");'
1304+
)
1305+
emitter.emit_line("if (shared_lib_file == NULL) goto fail;")
1306+
else:
1307+
emitter.emit_line(
1308+
f'PyObject *shared_lib_file = PyUnicode_FromString("{module_name + EXT_SUFFIX}");'
1309+
)
1310+
emitter.emit_line("if (shared_lib_file == NULL) CPyError_OutOfMemory();")
1311+
emitter.emit_line(f'PyObject *ext_suffix = PyUnicode_FromString("{EXT_SUFFIX}");')
1312+
emitter.emit_line("if (ext_suffix == NULL) CPyError_OutOfMemory();")
1313+
is_pkg = int(self.source_paths[module_name].endswith("__init__.py"))
1314+
emitter.emit_line(f"Py_ssize_t is_pkg = {is_pkg};")
1315+
1316+
emitter.emit_line(
1317+
f"rv = CPyImport_SetDunderAttrs({module_static}, modname, shared_lib_file, ext_suffix, is_pkg);"
1318+
)
1319+
emitter.emit_line("Py_DECREF(ext_suffix);")
1320+
emitter.emit_line("Py_DECREF(shared_lib_file);")
1321+
emitter.emit_line("if (rv < 0) goto fail;")
1322+
12891323
# Register in sys.modules early so that circular imports via
12901324
# CPyImport_ImportNative can detect that this module is already
12911325
# being initialized and avoid re-executing the module body.
1292-
emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");')
1293-
emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();")
12941326
emitter.emit_line(
12951327
f"if (PyObject_SetItem(PyImport_GetModuleDict(), modname, {module_static}) < 0)"
12961328
)

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,8 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
967967
CPyModule **module_static,
968968
PyObject *shared_lib_file, PyObject *ext_suffix,
969969
Py_ssize_t is_package);
970+
int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file,
971+
PyObject *ext_suffix, Py_ssize_t is_package);
970972

971973
PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls,
972974
PyObject *func);

mypyc/lib-rt/misc_ops.c

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,47 @@ static int CPyImport_InitSpecClasses(void) {
12251225
return 0;
12261226
}
12271227

1228+
// Set __package__ before executing the module body so it is available
1229+
// during module initialization. For a package, __package__ is the module
1230+
// name itself. For a non-package submodule "a.b.c", it is "a.b". For a
1231+
// top-level non-package module, it is "".
1232+
static int CPyImport_SetModulePackage(PyObject *modobj, PyObject *module_name,
1233+
Py_ssize_t is_package) {
1234+
PyObject *pkg = NULL;
1235+
int rc = PyObject_GetOptionalAttrString(modobj, "__package__", &pkg);
1236+
if (rc < 0) {
1237+
return -1;
1238+
}
1239+
if (pkg != NULL && pkg != Py_None) {
1240+
Py_DECREF(pkg);
1241+
return 0;
1242+
}
1243+
Py_XDECREF(pkg);
1244+
1245+
PyObject *package_name = NULL;
1246+
if (is_package) {
1247+
package_name = module_name;
1248+
Py_INCREF(package_name);
1249+
} else {
1250+
Py_ssize_t name_len = PyUnicode_GetLength(module_name);
1251+
if (name_len < 0) {
1252+
return -1;
1253+
}
1254+
Py_ssize_t dot = PyUnicode_FindChar(module_name, '.', 0, name_len, -1);
1255+
if (dot >= 0) {
1256+
package_name = PyUnicode_Substring(module_name, 0, dot);
1257+
} else {
1258+
package_name = PyUnicode_FromString("");
1259+
}
1260+
}
1261+
if (package_name == NULL) {
1262+
return -1;
1263+
}
1264+
rc = PyObject_SetAttrString(modobj, "__package__", package_name);
1265+
Py_DECREF(package_name);
1266+
return rc;
1267+
}
1268+
12281269
// Derive and set __file__ on modobj from the shared library path, module name,
12291270
// and extension suffix. Returns 0 on success, -1 on error.
12301271
static int CPyImport_SetModuleFile(PyObject *modobj, PyObject *module_name,
@@ -1509,47 +1550,7 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
15091550
goto fail;
15101551
}
15111552

1512-
// Set __package__ before executing the module body so it is available
1513-
// during module initialization. For a package, __package__ is the module
1514-
// name itself. For a non-package submodule "a.b.c", it is "a.b". For a
1515-
// top-level non-package module, it is "".
1516-
{
1517-
PyObject *pkg = NULL;
1518-
if (PyObject_GetOptionalAttrString(modobj, "__package__", &pkg) < 0) {
1519-
goto fail;
1520-
}
1521-
if (pkg == NULL || pkg == Py_None) {
1522-
Py_XDECREF(pkg);
1523-
PyObject *package_name;
1524-
if (is_package) {
1525-
package_name = module_name;
1526-
Py_INCREF(package_name);
1527-
} else if (dot >= 0) {
1528-
package_name = PyUnicode_Substring(module_name, 0, dot);
1529-
} else {
1530-
package_name = PyUnicode_FromString("");
1531-
if (package_name == NULL) {
1532-
CPyError_OutOfMemory();
1533-
}
1534-
}
1535-
if (PyObject_SetAttrString(modobj, "__package__", package_name) < 0) {
1536-
Py_DECREF(package_name);
1537-
goto fail;
1538-
}
1539-
Py_DECREF(package_name);
1540-
} else {
1541-
Py_DECREF(pkg);
1542-
}
1543-
}
1544-
1545-
if (CPyImport_SetModuleFile(modobj, module_name, shared_lib_file, ext_suffix,
1546-
is_package) < 0) {
1547-
goto fail;
1548-
}
1549-
if (is_package && CPyImport_SetModulePath(modobj) < 0) {
1550-
goto fail;
1551-
}
1552-
if (CPyImport_SetModuleSpec(modobj, module_name, is_package) < 0) {
1553+
if (CPyImport_SetDunderAttrs(modobj, module_name, shared_lib_file, ext_suffix, is_package) < 0) {
15531554
goto fail;
15541555
}
15551556

@@ -1577,10 +1578,34 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
15771578
PyErr_Restore(exc_type, exc_val, exc_tb);
15781579
Py_XDECREF(parent_module);
15791580
Py_XDECREF(child_name);
1580-
Py_DECREF(modobj);
1581+
Py_CLEAR(*module_static);
15811582
return NULL;
15821583
}
15831584

1585+
int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file,
1586+
PyObject *ext_suffix, Py_ssize_t is_package)
1587+
{
1588+
int res = CPyImport_SetModulePackage(module, module_name, is_package);
1589+
if (res < 0) {
1590+
return res;
1591+
}
1592+
1593+
res = CPyImport_SetModuleFile(module, module_name, shared_lib_file, ext_suffix,
1594+
is_package);
1595+
if (res < 0) {
1596+
return res;
1597+
}
1598+
1599+
if (is_package) {
1600+
res = CPyImport_SetModulePath(module);
1601+
if (res < 0) {
1602+
return res;
1603+
}
1604+
}
1605+
1606+
return CPyImport_SetModuleSpec(module, module_name, is_package);
1607+
}
1608+
15841609
#if CPY_3_14_FEATURES
15851610

15861611
#include "internal/pycore_object.h"

mypyc/test-data/commandline.test

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,20 @@ print(type(Eggs(obj1=pkg1.A.B())["obj1"]).__module__)
313313
B
314314
pkg2.mod2
315315

316+
[case testCompilePackageOnlyInitPy]
317+
# cmd: pkg/__init__.py
318+
import os.path
319+
import pkg
320+
321+
print(pkg.x)
322+
assert os.path.splitext(pkg.__file__)[1] != ".py"
323+
324+
[file pkg/__init__.py]
325+
x: int = 1
326+
327+
[out]
328+
1
329+
316330
[case testStrictBytesRequired]
317331
# cmd: --no-strict-bytes a.py
318332

mypyc/test-data/run-multimodule.test

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,38 @@ globals()['A'] = None
473473
[file driver.py]
474474
import other_main
475475

476+
[case testNonNativeImportInPackageFile]
477+
# The import is really non-native only in separate compilation mode where __init__.py and
478+
# other_cache.py are in different libraries and the import uses the standard Python procedure.
479+
# Python imports are resolved using __path__ and __spec__ from the package file so this checks
480+
# that they are set up correctly.
481+
[file other/__init__.py]
482+
from other.other_cache import Cache
483+
484+
x = 1
485+
[file other/other_cache.py]
486+
class Cache:
487+
pass
488+
489+
[file driver.py]
490+
import other
491+
492+
[case testRelativeImportInPackageFile]
493+
# Relative imports from a compiled package __init__ depend on package metadata being
494+
# available while the package module body is executing.
495+
[file other/__init__.py]
496+
assert __package__ == "other"
497+
from .other_cache import Cache
498+
499+
x = 1
500+
[file other/other_cache.py]
501+
class Cache:
502+
pass
503+
504+
[file driver.py]
505+
import other
506+
assert other.Cache.__name__ == "Cache"
507+
476508
[case testMultiModuleSameNames]
477509
# Use same names in both modules
478510
import other

0 commit comments

Comments
 (0)