Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions rewrite-python/rewrite/src/rewrite/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
# Set REWRITE_PYTHON_VERSION to "2" or "2.7" to parse Python 2 code
_python_version = os.environ.get("REWRITE_PYTHON_VERSION", "3")

# Set via --recipe-install-dir; an InstallRecipes RPC for a not-yet-importable
# package pip installs it here before activating.
_recipe_install_dir: Optional[Path] = None


def _next_request_id() -> int:
"""Generate a unique request ID for outgoing requests."""
Expand Down Expand Up @@ -627,12 +631,42 @@ def _get_marketplace():
return _marketplace


def _is_package_installed(package_name: str, version: Optional[str]) -> bool:
try:
import importlib.metadata
installed = importlib.metadata.version(package_name)
except Exception:
return False
return version is None or installed == version


def _pip_install_recipe_package(package_name: str, version: Optional[str], target_dir: Path) -> None:
import importlib
import subprocess

target_dir.mkdir(parents=True, exist_ok=True)
spec = f"{package_name}=={version}" if version else package_name
cmd = [sys.executable, "-m", "pip", "install", "--target", str(target_dir), spec]
logger.info(f"pip install: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(
f"pip install failed for {spec} (target={target_dir}):\n{result.stderr}"
)

target_str = str(target_dir.resolve())
if target_str not in sys.path:
sys.path.insert(0, target_str)
importlib.invalidate_caches()


def handle_install_recipes(params: dict) -> dict:
"""Handle an InstallRecipes RPC request.

Activates a recipe package in the marketplace. The package should already be
installed by the caller (e.g., via pip install --target). This handler discovers
and activates the package's recipes.
Activates a recipe package in the marketplace. When `--recipe-install-dir`
is configured, a package spec that isn't already installed is pip-installed
into that directory before activation; otherwise the package must have been
installed by the caller.

Args:
params: Dict containing either:
Expand Down Expand Up @@ -675,16 +709,17 @@ def handle_install_recipes(params: dict) -> dict:
recipes_added = len(added)

elif isinstance(recipes, dict):
# Package spec with name and optional version - package should already be installed
package_name = recipes.get('packageName')
version = recipes.get('version')

if not package_name:
raise ValueError("Package name is required")

if _recipe_install_dir is not None and not _is_package_installed(package_name, version):
_pip_install_recipe_package(package_name, version, _recipe_install_dir)

logger.info(f"Activating recipes package: {package_name}")

# Get the installed version
try:
import importlib.metadata
installed_version = importlib.metadata.version(package_name)
Expand Down Expand Up @@ -1640,8 +1675,13 @@ def main():
parser.add_argument('--log-file', help='Log file path')
parser.add_argument('--metrics-csv', help='Metrics CSV output path')
parser.add_argument('--trace-rpc-messages', action='store_true', help='Enable RPC message tracing')
parser.add_argument('--recipe-install-dir', help='Directory where recipe pip packages are installed')
args = parser.parse_args()

if args.recipe_install_dir:
global _recipe_install_dir
_recipe_install_dir = Path(args.recipe_install_dir)

if args.log_file:
file_handler = logging.FileHandler(args.log_file)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
Expand Down
31 changes: 31 additions & 0 deletions rewrite-python/rewrite/tests/rpc/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,37 @@ def fake_parse_python_source(source, path="<unknown>", relative_to=None, ty_clie
assert (tmp_path / "pkg" / "__init__.py").read_text(encoding="utf-8") == ""


def test_pip_install_recipe_package_shape(tmp_path, monkeypatch):
import rewrite.rpc.server as server
import subprocess

install_dir = tmp_path / "recipes"
captured = {}

class FakeCompletedProcess:
returncode = 0
stdout = ""
stderr = ""

def fake_run(cmd, capture_output=False, text=False):
captured["cmd"] = cmd
return FakeCompletedProcess()

monkeypatch.setattr(subprocess, "run", fake_run)

server._pip_install_recipe_package(
"openrewrite-recipes-python", "1.2.3", install_dir
)

assert install_dir.exists()
assert captured["cmd"][1:] == [
"-m", "pip", "install",
"--target", str(install_dir),
"openrewrite-recipes-python==1.2.3",
]
assert str(install_dir.resolve()) in __import__("sys").path


def test_recipe_descriptor_to_dict_emits_all_collection_keys():
from rewrite.recipe import RecipeDescriptor
from rewrite.rpc.server import _recipe_descriptor_to_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,8 @@ public PythonRewriteRpc get() {
"-m", "rewrite.rpc.server",
log == null ? null : "--log-file=" + log.toAbsolutePath().normalize(),
metricsCsv == null ? null : "--metrics-csv=" + metricsCsv.toAbsolutePath().normalize(),
traceRpcMessages ? "--trace-rpc-messages" : null
traceRpcMessages ? "--trace-rpc-messages" : null,
recipeInstallDir == null ? null : "--recipe-install-dir=" + recipeInstallDir.toAbsolutePath().normalize()
);

String[] cmdArr = cmd.filter(Objects::nonNull).toArray(String[]::new);
Expand Down