Skip to content

Commit 7767c5f

Browse files
authored
Add PythonDependencyFile trait for polymorphic dependency management (#7274)
1 parent 83bd388 commit 7767c5f

25 files changed

Lines changed: 3152 additions & 848 deletions

rewrite-python/src/main/java/org/openrewrite/python/AddDependency.java

Lines changed: 58 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,20 @@
1919
import lombok.Value;
2020
import org.jspecify.annotations.Nullable;
2121
import org.openrewrite.*;
22-
import org.openrewrite.marker.Markers;
2322
import org.openrewrite.python.internal.PyProjectHelper;
2423
import org.openrewrite.python.internal.PythonDependencyExecutionContextView;
25-
import org.openrewrite.python.marker.PythonResolutionResult;
26-
import org.openrewrite.toml.TomlIsoVisitor;
27-
import org.openrewrite.toml.tree.Space;
24+
import org.openrewrite.python.trait.PythonDependencyFile;
2825
import org.openrewrite.toml.tree.Toml;
29-
import org.openrewrite.toml.tree.TomlRightPadded;
30-
import org.openrewrite.toml.tree.TomlType;
3126

32-
import java.util.*;
33-
34-
import static org.openrewrite.Tree.randomId;
27+
import java.nio.file.Path;
28+
import java.util.Collections;
29+
import java.util.HashSet;
30+
import java.util.Map;
31+
import java.util.Set;
3532

3633
/**
37-
* Add a dependency to the {@code [project].dependencies} array in pyproject.toml.
34+
* Add a dependency to a Python project. Supports {@code pyproject.toml}
35+
* (with scope and group targeting), {@code requirements.txt}, and {@code Pipfile}.
3836
* When uv is available, the uv.lock file is regenerated to reflect the change.
3937
*/
4038
@EqualsAndHashCode(callSuper = false)
@@ -54,9 +52,9 @@ public class AddDependency extends ScanningRecipe<AddDependency.Accumulator> {
5452
String version;
5553

5654
@Option(displayName = "Scope",
57-
description = "The dependency scope to add to. Defaults to `project.dependencies`.",
58-
valid = {"project.dependencies", "project.optional-dependencies", "dependency-groups",
59-
"tool.uv.constraint-dependencies", "tool.uv.override-dependencies"},
55+
description = "The dependency scope to add to. For pyproject.toml this targets a specific TOML section. " +
56+
"For requirements files, `null` matches all files, empty string matches only `requirements.txt`, " +
57+
"and a value like `dev` matches `requirements-dev.txt`. Defaults to `project.dependencies`.",
6058
example = "project.dependencies",
6159
required = false)
6260
@Nullable
@@ -90,12 +88,13 @@ public String getInstanceNameSuffix() {
9088

9189
@Override
9290
public String getDescription() {
93-
return "Add a dependency to the `[project].dependencies` array in `pyproject.toml`. " +
91+
return "Add a dependency to a Python project. Supports `pyproject.toml` " +
92+
"(with scope/group targeting), `requirements.txt`, and `Pipfile`. " +
9493
"When `uv` is available, the `uv.lock` file is regenerated.";
9594
}
9695

9796
static class Accumulator {
98-
final Set<String> projectsToUpdate = new HashSet<>();
97+
final Set<Path> projectsToUpdate = new HashSet<>();
9998
}
10099

101100
@Override
@@ -105,141 +104,70 @@ public Accumulator getInitialValue(ExecutionContext ctx) {
105104

106105
@Override
107106
public TreeVisitor<?, ExecutionContext> getScanner(Accumulator acc) {
108-
return new TomlIsoVisitor<ExecutionContext>() {
109-
@Override
110-
public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) {
111-
String sourcePath = document.getSourcePath().toString();
112-
113-
if (sourcePath.endsWith("uv.lock")) {
114-
PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put(
115-
PyProjectHelper.correspondingPyprojectPath(sourcePath),
116-
document.printAll());
117-
return document;
118-
}
107+
return new TreeVisitor<Tree, ExecutionContext>() {
108+
final PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher();
119109

120-
if (!sourcePath.endsWith("pyproject.toml")) {
121-
return document;
110+
@Override
111+
public Tree preVisit(Tree tree, ExecutionContext ctx) {
112+
stopAfterPreVisit();
113+
if (!(tree instanceof SourceFile)) {
114+
return tree;
122115
}
123-
Optional<PythonResolutionResult> resolution = document.getMarkers()
124-
.findFirst(PythonResolutionResult.class);
125-
if (!resolution.isPresent()) {
126-
return document;
116+
SourceFile sourceFile = (SourceFile) tree;
117+
if (tree instanceof Toml.Document && sourceFile.getSourcePath().endsWith("uv.lock")) {
118+
PythonDependencyExecutionContextView.view(ctx).getExistingLockContents().put(
119+
PyProjectHelper.correspondingPyprojectPath(sourceFile.getSourcePath()),
120+
((Toml.Document) tree).printAll());
121+
return tree;
127122
}
128-
129-
PythonResolutionResult marker = resolution.get();
130-
131-
// Check if the dependency already exists in the target scope
132-
if (PyProjectHelper.findDependencyInScope(marker, packageName, scope, groupName) != null) {
133-
return document;
123+
PythonDependencyFile trait = matcher.get(getCursor()).orElse(null);
124+
if (trait != null && PyProjectHelper.findDependencyInScope(trait.getMarker(), packageName, scope, groupName) == null) {
125+
acc.projectsToUpdate.add(sourceFile.getSourcePath());
134126
}
135-
136-
acc.projectsToUpdate.add(sourcePath);
137-
return document;
127+
return tree;
138128
}
139129
};
140130
}
141131

142132
@Override
143133
public TreeVisitor<?, ExecutionContext> getVisitor(Accumulator acc) {
144-
return new TomlIsoVisitor<ExecutionContext>() {
145-
@Override
146-
public Toml.Document visitDocument(Toml.Document document, ExecutionContext ctx) {
147-
String sourcePath = document.getSourcePath().toString();
134+
if (acc.projectsToUpdate.isEmpty()) {
135+
return TreeVisitor.noop();
136+
}
137+
return new TreeVisitor<Tree, ExecutionContext>() {
138+
final PythonDependencyFile.Matcher matcher = new PythonDependencyFile.Matcher();
148139

149-
if (sourcePath.endsWith("pyproject.toml") && acc.projectsToUpdate.contains(sourcePath)) {
150-
return addDependencyToPyproject(document, ctx, acc);
140+
@Override
141+
public Tree preVisit(Tree tree, ExecutionContext ctx) {
142+
stopAfterPreVisit();
143+
if (!(tree instanceof SourceFile)) {
144+
return tree;
145+
}
146+
SourceFile sourceFile = (SourceFile) tree;
147+
Path sourcePath = sourceFile.getSourcePath();
148+
149+
if (acc.projectsToUpdate.contains(sourcePath)) {
150+
PythonDependencyFile trait = matcher.get(getCursor()).orElse(null);
151+
if (trait != null) {
152+
String ver = version != null ? version : "";
153+
Map<String, String> additions = Collections.singletonMap(packageName, ver);
154+
PythonDependencyFile updated = trait.withAddedDependencies(additions, scope, groupName);
155+
if (updated.getTree() != tree) {
156+
return updated.afterModification(ctx);
157+
}
158+
}
151159
}
152160

153-
if (sourcePath.endsWith("uv.lock")) {
154-
Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock(document, ctx);
161+
if (tree instanceof Toml.Document && sourcePath.endsWith("uv.lock")) {
162+
Toml.Document updatedLock = PyProjectHelper.maybeUpdateUvLock((Toml.Document) tree, ctx);
155163
if (updatedLock != null) {
156164
return updatedLock;
157165
}
158166
}
159167

160-
return document;
168+
return tree;
161169
}
162170
};
163171
}
164172

165-
private Toml.Document addDependencyToPyproject(Toml.Document document, ExecutionContext ctx, Accumulator acc) {
166-
String pep508 = version != null ? packageName + PyProjectHelper.normalizeVersionConstraint(version) : packageName;
167-
168-
Toml.Document updated = (Toml.Document) new TomlIsoVisitor<ExecutionContext>() {
169-
@Override
170-
public Toml.Array visitArray(Toml.Array array, ExecutionContext ctx) {
171-
Toml.Array a = super.visitArray(array, ctx);
172-
173-
if (!PyProjectHelper.isInsideDependencyArray(getCursor(), scope, groupName)) {
174-
return a;
175-
}
176-
177-
Toml.Literal newLiteral = new Toml.Literal(
178-
randomId(),
179-
Space.EMPTY,
180-
Markers.EMPTY,
181-
TomlType.Primitive.String,
182-
"\"" + pep508 + "\"",
183-
pep508
184-
);
185-
186-
List<TomlRightPadded<Toml>> existingPadded = a.getPadding().getValues();
187-
List<TomlRightPadded<Toml>> newPadded = new ArrayList<>();
188-
189-
// An empty TOML array [] is represented as a single Toml.Empty element
190-
boolean isEmpty = existingPadded.size() == 1 &&
191-
existingPadded.get(0).getElement() instanceof Toml.Empty;
192-
if (existingPadded.isEmpty() || isEmpty) {
193-
newPadded.add(new TomlRightPadded<>(newLiteral, Space.EMPTY, Markers.EMPTY));
194-
} else {
195-
// Check if the last element is Toml.Empty (trailing comma marker)
196-
TomlRightPadded<Toml> lastPadded = existingPadded.get(existingPadded.size() - 1);
197-
boolean hasTrailingComma = lastPadded.getElement() instanceof Toml.Empty;
198-
199-
if (hasTrailingComma) {
200-
// Insert before the Empty element. The Empty's position
201-
// stores the whitespace before ']'.
202-
// Find the last real element to copy its prefix formatting
203-
int lastRealIdx = existingPadded.size() - 2;
204-
Toml lastRealElement = existingPadded.get(lastRealIdx).getElement();
205-
Toml.Literal formattedLiteral = newLiteral.withPrefix(lastRealElement.getPrefix());
206-
207-
// Copy all existing elements up to (not including) the Empty
208-
for (int i = 0; i <= lastRealIdx; i++) {
209-
newPadded.add(existingPadded.get(i));
210-
}
211-
// Add new literal with empty after (comma added by printer)
212-
newPadded.add(new TomlRightPadded<>(formattedLiteral, Space.EMPTY, Markers.EMPTY));
213-
// Keep the Empty element for trailing comma + closing bracket whitespace
214-
newPadded.add(lastPadded);
215-
} else {
216-
// No trailing comma — the last real element's after has the space before ']'
217-
Toml lastElement = lastPadded.getElement();
218-
// For multi-line arrays, use same prefix; for inline, use single space
219-
Space newPrefix = lastElement.getPrefix().getWhitespace().contains("\n")
220-
? lastElement.getPrefix()
221-
: Space.SINGLE_SPACE;
222-
Toml.Literal formattedLiteral = newLiteral.withPrefix(newPrefix);
223-
224-
// Copy all existing elements but set last one's after to empty
225-
for (int i = 0; i < existingPadded.size() - 1; i++) {
226-
newPadded.add(existingPadded.get(i));
227-
}
228-
newPadded.add(lastPadded.withAfter(Space.EMPTY));
229-
// New element gets the after from the old last element
230-
newPadded.add(new TomlRightPadded<>(formattedLiteral, lastPadded.getAfter(), Markers.EMPTY));
231-
}
232-
}
233-
234-
return a.getPadding().withValues(newPadded);
235-
}
236-
}.visitNonNull(document, ctx);
237-
238-
if (updated != document) {
239-
updated = PyProjectHelper.regenerateLockAndRefreshMarker(updated, ctx);
240-
}
241-
242-
return updated;
243-
}
244-
245173
}

rewrite-python/src/main/java/org/openrewrite/python/Assertions.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,45 @@ public static SourceSpecs setupCfg(@Nullable String before,
237237
return text;
238238
}
239239

240+
public static SourceSpecs pipfile(@Language("toml") @Nullable String before) {
241+
return pipfile(before, s -> {
242+
});
243+
}
244+
245+
public static SourceSpecs pipfile(@Language("toml") @Nullable String before,
246+
Consumer<SourceSpec<Toml.Document>> spec) {
247+
SourceSpec<Toml.Document> toml = new SourceSpec<>(
248+
Toml.Document.class, null, PipfileParser.builder(), before,
249+
SourceSpec.ValidateSource.noop,
250+
ctx -> {
251+
}
252+
);
253+
toml.path("Pipfile");
254+
spec.accept(toml);
255+
return toml;
256+
}
257+
258+
public static SourceSpecs pipfile(@Language("toml") @Nullable String before,
259+
@Language("toml") @Nullable String after) {
260+
return pipfile(before, after, s -> {
261+
});
262+
}
263+
264+
public static SourceSpecs pipfile(@Language("toml") @Nullable String before,
265+
@Language("toml") @Nullable String after,
266+
Consumer<SourceSpec<Toml.Document>> spec) {
267+
SourceSpec<Toml.Document> toml = new SourceSpec<>(
268+
Toml.Document.class, null, PipfileParser.builder(), before,
269+
SourceSpec.ValidateSource.noop,
270+
ctx -> {
271+
}
272+
);
273+
toml.path("Pipfile");
274+
toml.after(s -> after);
275+
spec.accept(toml);
276+
return toml;
277+
}
278+
240279
public static SourceSpecs python(@Language("py") @Nullable String before) {
241280
return python(before, s -> {
242281
});

0 commit comments

Comments
 (0)