Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,7 @@ def get_mypy_config(
mypyc_sources = all_sources

if compiler_options.separate:
mypyc_sources = [
src for src in mypyc_sources if src.path and not src.path.endswith("__init__.py")
]
mypyc_sources = [src for src in mypyc_sources if src.path]

if not mypyc_sources:
return mypyc_sources, all_sources, options
Expand Down
33 changes: 31 additions & 2 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from mypyc.codegen.literals import Literals
from mypyc.common import (
EXT_SUFFIX,
IS_FREE_THREADED,
MODULE_PREFIX,
PREFIX,
Expand Down Expand Up @@ -1286,11 +1287,39 @@ def emit_module_init_func(
f"if (unlikely({module_static} == NULL))",
" goto fail;",
)

emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");')
emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();")
if self.group_name:
shared_lib_mod_name = shared_lib_name(self.group_name)
emitter.emit_line("PyObject *mod_dict = PyImport_GetModuleDict();")
emitter.emit_line(
f'PyObject *shared_lib = PyDict_GetItemString(mod_dict, "{shared_lib_mod_name}");'
Comment thread
p-sawicki marked this conversation as resolved.
Outdated
)
emitter.emit_line("if (shared_lib == NULL) goto fail;")
emitter.emit_line(
'PyObject *shared_lib_file = PyObject_GetAttrString(shared_lib, "__file__");'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PyDict_GetItemStringRef instead, as this returns a borrowed reference, and we decref it below.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might be thinking of PyDict_GetItemString? this one returns a new reference according to docs https://docs.python.org/3/c-api/object.html#c.PyObject_GetAttrString

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was confusing the two.

)
emitter.emit_line("if (shared_lib_file == NULL) goto fail;")
else:
emitter.emit_line(
f'PyObject *shared_lib_file = PyUnicode_FromString("{module_name + EXT_SUFFIX}");'
Comment thread
p-sawicki marked this conversation as resolved.
)
emitter.emit_line(f'PyObject *ext_suffix = PyUnicode_FromString("{EXT_SUFFIX}");')
emitter.emit_line("if (ext_suffix == NULL) CPyError_OutOfMemory();")
is_pkg = int(self.source_paths[module_name].endswith("__init__.py"))
emitter.emit_line(f"Py_ssize_t is_pkg = {is_pkg};")

emitter.emit_line(
f"int rv = CPyImport_SetDunderAttrs({module_static}, modname, shared_lib_file, ext_suffix, is_pkg);"
)
emitter.emit_line("Py_DECREF(ext_suffix);")
emitter.emit_line("Py_DECREF(shared_lib_file);")
emitter.emit_line("if (rv < 0) goto fail;")

# Register in sys.modules early so that circular imports via
# CPyImport_ImportNative can detect that this module is already
# being initialized and avoid re-executing the module body.
emitter.emit_line(f'modname = PyUnicode_FromString("{module_name}");')
emitter.emit_line("if (modname == NULL) CPyError_OutOfMemory();")
emitter.emit_line(
f"if (PyObject_SetItem(PyImport_GetModuleDict(), modname, {module_static}) < 0)"
)
Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,8 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
CPyModule **module_static,
PyObject *shared_lib_file, PyObject *ext_suffix,
Py_ssize_t is_package);
int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file,
PyObject *ext_suffix, Py_ssize_t is_package);

PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls,
PyObject *func);
Expand Down
109 changes: 67 additions & 42 deletions mypyc/lib-rt/misc_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,47 @@ static int CPyImport_InitSpecClasses(void) {
return 0;
}

// Set __package__ before executing the module body so it is available
// during module initialization. For a package, __package__ is the module
// name itself. For a non-package submodule "a.b.c", it is "a.b". For a
// top-level non-package module, it is "".
static int CPyImport_SetModulePackage(PyObject *modobj, PyObject *module_name,
Py_ssize_t is_package) {
PyObject *pkg = NULL;
int rc = PyObject_GetOptionalAttrString(modobj, "__package__", &pkg);
if (rc < 0) {
return -1;
}
if (pkg != NULL && pkg != Py_None) {
Py_DECREF(pkg);
return 0;
}
Py_XDECREF(pkg);

PyObject *package_name = NULL;
if (is_package) {
package_name = module_name;
Py_INCREF(package_name);
} else {
Py_ssize_t name_len = PyUnicode_GetLength(module_name);
if (name_len < 0) {
return -1;
}
Py_ssize_t dot = PyUnicode_FindChar(module_name, '.', 0, name_len, -1);
if (dot >= 0) {
package_name = PyUnicode_Substring(module_name, 0, dot);
} else {
package_name = PyUnicode_FromString("");
}
}
if (package_name == NULL) {
return -1;
}
rc = PyObject_SetAttrString(modobj, "__package__", package_name);
Py_DECREF(package_name);
return rc;
}

// Derive and set __file__ on modobj from the shared library path, module name,
// and extension suffix. Returns 0 on success, -1 on error.
static int CPyImport_SetModuleFile(PyObject *modobj, PyObject *module_name,
Expand Down Expand Up @@ -1509,47 +1550,7 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
goto fail;
}

// Set __package__ before executing the module body so it is available
// during module initialization. For a package, __package__ is the module
// name itself. For a non-package submodule "a.b.c", it is "a.b". For a
// top-level non-package module, it is "".
{
PyObject *pkg = NULL;
if (PyObject_GetOptionalAttrString(modobj, "__package__", &pkg) < 0) {
goto fail;
}
if (pkg == NULL || pkg == Py_None) {
Py_XDECREF(pkg);
PyObject *package_name;
if (is_package) {
package_name = module_name;
Py_INCREF(package_name);
} else if (dot >= 0) {
package_name = PyUnicode_Substring(module_name, 0, dot);
} else {
package_name = PyUnicode_FromString("");
if (package_name == NULL) {
CPyError_OutOfMemory();
}
}
if (PyObject_SetAttrString(modobj, "__package__", package_name) < 0) {
Py_DECREF(package_name);
goto fail;
}
Py_DECREF(package_name);
} else {
Py_DECREF(pkg);
}
}

if (CPyImport_SetModuleFile(modobj, module_name, shared_lib_file, ext_suffix,
is_package) < 0) {
goto fail;
}
if (is_package && CPyImport_SetModulePath(modobj) < 0) {
goto fail;
}
if (CPyImport_SetModuleSpec(modobj, module_name, is_package) < 0) {
if (CPyImport_SetDunderAttrs(modobj, module_name, shared_lib_file, ext_suffix, is_package) < 0) {
goto fail;
}

Expand Down Expand Up @@ -1577,10 +1578,34 @@ PyObject *CPyImport_ImportNative(PyObject *module_name,
PyErr_Restore(exc_type, exc_val, exc_tb);
Py_XDECREF(parent_module);
Py_XDECREF(child_name);
Py_DECREF(modobj);
Py_CLEAR(*module_static);
return NULL;
}

int CPyImport_SetDunderAttrs(PyObject *module, PyObject *module_name, PyObject *shared_lib_file,
PyObject *ext_suffix, Py_ssize_t is_package)
{
int res = CPyImport_SetModulePackage(module, module_name, is_package);
if (res < 0) {
return res;
}

res = CPyImport_SetModuleFile(module, module_name, shared_lib_file, ext_suffix,
is_package);
if (res < 0) {
return res;
}

if (is_package) {
res = CPyImport_SetModulePath(module);
if (res < 0) {
return res;
}
}

return CPyImport_SetModuleSpec(module, module_name, is_package);
}

#if CPY_3_14_FEATURES

#include "internal/pycore_object.h"
Expand Down
32 changes: 32 additions & 0 deletions mypyc/test-data/run-multimodule.test
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,38 @@ globals()['A'] = None
[file driver.py]
import other_main

[case testNonNativeImportInPackageFile]
# The import is really non-native only in separate compilation mode where __init__.py and
# other_cache.py are in different libraries and the import uses the standard Python procedure.
# Python imports are resolved using __path__ and __spec__ from the package file so this checks
# that they are set up correctly.
[file other/__init__.py]
from other.other_cache import Cache

x = 1
[file other/other_cache.py]
class Cache:
pass

[file driver.py]
import other

[case testRelativeImportInPackageFile]
# Relative imports from a compiled package __init__ depend on package metadata being
# available while the package module body is executing.
[file other/__init__.py]
assert __package__ == "other"
from .other_cache import Cache

x = 1
[file other/other_cache.py]
class Cache:
pass

[file driver.py]
import other
assert other.Cache.__name__ == "Cache"

[case testMultiModuleSameNames]
# Use same names in both modules
import other
Expand Down
Loading