Skip to content

Commit 60f67c5

Browse files
authored
Add generic types to JavaTemplate (#5280)
* Add generic support to template parameters * Add generics to JavaTemplate * Fix kotlin recipe
1 parent f87d1ad commit 60f67c5

24 files changed

Lines changed: 846 additions & 451 deletions

rewrite-groovy/src/main/java/org/openrewrite/groovy/GroovyTemplate.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.openrewrite.java.tree.J;
2525
import org.openrewrite.java.tree.JavaCoordinates;
2626

27+
import java.util.Collections;
2728
import java.util.HashSet;
2829
import java.util.Set;
2930
import java.util.function.Consumer;
@@ -32,6 +33,7 @@ public class GroovyTemplate extends JavaTemplate {
3233
private GroovyTemplate(boolean contextSensitive, GroovyParser.Builder parser, String code, Set<String> imports, Consumer<String> onAfterVariableSubstitution, Consumer<String> onBeforeParseTemplate) {
3334
super(
3435
code,
36+
Collections.emptySet(),
3537
onAfterVariableSubstitution,
3638
new GroovyTemplateParser(
3739
contextSensitive,

rewrite-groovy/src/main/java/org/openrewrite/groovy/internal/template/GroovySubstitutions.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
import org.openrewrite.java.internal.template.Substitutions;
1919

20+
import java.util.Collections;
21+
2022
public class GroovySubstitutions extends Substitutions {
2123
public GroovySubstitutions(String code, Object[] parameters) {
22-
super(code, parameters);
24+
super(code, Collections.emptySet(), parameters);
2325
}
2426
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
* <p>
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* <p>
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
* <p>
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.openrewrite.java;
17+
18+
import org.junit.jupiter.api.Test;
19+
import org.openrewrite.ExecutionContext;
20+
import org.openrewrite.java.tree.Expression;
21+
import org.openrewrite.java.tree.J;
22+
import org.openrewrite.test.RewriteTest;
23+
24+
import java.util.Objects;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
import static org.openrewrite.java.Assertions.java;
28+
import static org.openrewrite.test.RewriteTest.toRecipe;
29+
30+
class JavaTemplateGenericsTest implements RewriteTest {
31+
32+
@Test
33+
void genericTypes() {
34+
JavaTemplate invalidPrintf = JavaTemplate.builder("System.out.printf(#{any(T)})")
35+
.genericTypes("T")
36+
.build();
37+
JavaTemplate invalidSort = JavaTemplate.builder("java.util.Collections.sort(#{any(java.util.List<T>)}, #{any(C)})")
38+
.genericTypes("T", "C extends java.util.Comparator<?>")
39+
.build();
40+
JavaTemplate validPrintf = JavaTemplate.builder("System.out.printf(#{any(T)})")
41+
.genericTypes("T extends String")
42+
.build();
43+
JavaTemplate validSort = JavaTemplate.builder("java.util.Collections.sort(#{any(java.util.List<T>)}, #{any(C)})")
44+
.genericTypes("T", "C extends java.util.Comparator<? super T>")
45+
.build();
46+
47+
rewriteRun(
48+
spec -> spec.recipe(toRecipe(() -> new JavaVisitor<>() {
49+
@Override
50+
public J visitVariableDeclarations(J.VariableDeclarations multiVariable, ExecutionContext executionContext) {
51+
J.VariableDeclarations.NamedVariable variable = multiVariable.getVariables().get(0);
52+
if ("o".equals(variable.getSimpleName())) {
53+
Expression exp = Objects.requireNonNull(variable.getInitializer());
54+
J.MethodInvocation res1 = invalidPrintf.apply(getCursor(), multiVariable.getCoordinates().replace(), exp);
55+
assertThat(res1.getMethodType()).isNull();
56+
J.MethodInvocation res2 = invalidSort.apply(getCursor(), multiVariable.getCoordinates().replace(), exp, exp);
57+
assertThat(res2.getMethodType()).isNull();
58+
J.MethodInvocation res3 = validPrintf.apply(getCursor(), multiVariable.getCoordinates().replace(), exp);
59+
assertThat(res3.getMethodType()).isNotNull();
60+
J.MethodInvocation res4 = validSort.apply(getCursor(), multiVariable.getCoordinates().replace(), exp, exp);
61+
assertThat(res4.getMethodType()).isNotNull();
62+
return res3;
63+
}
64+
return super.visitVariableDeclarations(multiVariable, executionContext);
65+
}
66+
})),
67+
java(
68+
"""
69+
class Test {
70+
void test() {
71+
Object o = any();
72+
}
73+
static native <T> T any();
74+
}
75+
""",
76+
"""
77+
class Test {
78+
void test() {
79+
System.out.printf(any());
80+
}
81+
static native <T> T any();
82+
}
83+
"""
84+
)
85+
);
86+
}
87+
}

rewrite-java/src/main/antlr/TemplateParameterLexer.g4

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@ RBRACK : '>';
1010
WILDCARD : '?';
1111
LSBRACK : '[';
1212
RSBRACK : ']';
13+
AND : '&';
1314

14-
Variance
15-
: 'extends'
16-
| 'super'
17-
;
15+
Extends : 'extends';
16+
Super : 'super';
1817

1918
FullyQualifiedName
2019
: 'boolean'

rewrite-java/src/main/antlr/TemplateParameterParser.g4

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ matcherPattern
77
| parameterName
88
;
99

10+
genericPattern
11+
: genericName (Extends (type AND)* type)?
12+
;
13+
1014
typedPattern
1115
: (parameterName COLON)? patternType
1216
;
@@ -25,7 +29,8 @@ typeParameter
2529
;
2630

2731
variance
28-
: WILDCARD Variance
32+
: WILDCARD Extends
33+
| WILDCARD Super
2934
;
3035

3136
typeArray
@@ -36,6 +41,10 @@ parameterName
3641
: Identifier
3742
;
3843

44+
genericName
45+
: Identifier
46+
;
47+
3948
typeName
4049
: FullyQualifiedName
4150
| Identifier

rewrite-java/src/main/java/org/openrewrite/java/JavaTemplate.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.openrewrite.java;
1717

18+
import java.util.Collections;
1819
import lombok.Getter;
1920
import lombok.Value;
2021
import lombok.experimental.NonFinal;
@@ -80,21 +81,24 @@ protected static Path getTemplateClasspathDir() {
8081

8182
@Getter
8283
private final String code;
84+
@Getter
85+
private final Set<String> genericTypes;
8386

8487
private final Consumer<String> onAfterVariableSubstitution;
8588
private final JavaTemplateParser templateParser;
8689

8790
private JavaTemplate(boolean contextSensitive, JavaParser.Builder<?, ?> parser, String code, Set<String> imports,
88-
Consumer<String> onAfterVariableSubstitution, Consumer<String> onBeforeParseTemplate) {
89-
this(code, onAfterVariableSubstitution, new JavaTemplateParser(contextSensitive, augmentClasspath(parser), onAfterVariableSubstitution, onBeforeParseTemplate, imports));
91+
Set<String> genericTypes, Consumer<String> onAfterVariableSubstitution, Consumer<String> onBeforeParseTemplate) {
92+
this(code, genericTypes, onAfterVariableSubstitution, new JavaTemplateParser(contextSensitive, augmentClasspath(parser), onAfterVariableSubstitution, onBeforeParseTemplate, imports));
9093
}
9194

9295
private static JavaParser.Builder<?, ?> augmentClasspath(JavaParser.Builder<?, ?> parserBuilder) {
9396
return parserBuilder.addClasspathEntry(getTemplateClasspathDir());
9497
}
9598

96-
protected JavaTemplate(String code, Consumer<String> onAfterVariableSubstitution, JavaTemplateParser templateParser) {
99+
protected JavaTemplate(String code, Set<String> genericTypes, Consumer<String> onAfterVariableSubstitution, JavaTemplateParser templateParser) {
97100
this.code = code;
101+
this.genericTypes = genericTypes;
98102
this.onAfterVariableSubstitution = onAfterVariableSubstitution;
99103
this.templateParser = templateParser;
100104
}
@@ -121,7 +125,7 @@ public <J2 extends J> J2 apply(Cursor scope, JavaCoordinates coordinates, Object
121125
}
122126

123127
protected Substitutions substitutions(Object[] parameters) {
124-
return new Substitutions(code, parameters);
128+
return new Substitutions(code, genericTypes, parameters);
125129
}
126130

127131
@Incubating(since = "8.0.0")
@@ -174,6 +178,7 @@ public static class Builder {
174178

175179
private final String code;
176180
private final Set<String> imports = new HashSet<>();
181+
private final Set<String> genericTypes = new HashSet<>();
177182

178183
private boolean contextSensitive;
179184

@@ -223,6 +228,11 @@ public Builder staticImports(String... fullyQualifiedMemberTypeNames) {
223228
return this;
224229
}
225230

231+
public Builder genericTypes(String... genericTypes) {
232+
Collections.addAll(this.genericTypes, genericTypes);
233+
return this;
234+
}
235+
226236
private void validateImport(String typeName) {
227237
if (StringUtils.isBlank(typeName)) {
228238
throw new IllegalArgumentException("Imports must not be blank");
@@ -249,7 +259,7 @@ public Builder doBeforeParseTemplate(Consumer<String> beforeParseTemplate) {
249259
}
250260

251261
public JavaTemplate build() {
252-
return new JavaTemplate(contextSensitive, parser.clone(), code, imports,
262+
return new JavaTemplate(contextSensitive, parser.clone(), code, imports, genericTypes,
253263
onAfterVariableSubstitution, onBeforeParseTemplate);
254264
}
255265
}

rewrite-java/src/main/java/org/openrewrite/java/JavaTemplateSemanticallyEqual.java

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ static TemplateMatchResult matchesTemplate(JavaTemplate template, Cursor input)
5151
throw new IllegalArgumentException("Only expressions and statements can be matched against a template: " + input.getClass());
5252
}
5353

54-
J[] parameters = createTemplateParameters(template.getCode());
54+
J[] parameters = createTemplateParameters(template.getCode(), template.getGenericTypes());
5555
try {
5656
J templateTree = template.apply(input, coordinates, (Object[]) parameters);
5757
return matchTemplate(templateTree, input);
@@ -61,10 +61,11 @@ static TemplateMatchResult matchesTemplate(JavaTemplate template, Cursor input)
6161
}
6262
}
6363

64-
private static J[] createTemplateParameters(String code) {
64+
private static J[] createTemplateParameters(String code, Set<String> genericTypes) {
6565
PropertyPlaceholderHelper propertyPlaceholderHelper = new PropertyPlaceholderHelper(
6666
"#{", "}", null);
6767

68+
Map<String, JavaType.GenericTypeVariable> generics = TypeParameter.parseGenericTypes(genericTypes);
6869
List<J> parameters = new ArrayList<>();
6970
String substituted = code;
7071
Map<String, String> typedPatternByName = new HashMap<>();
@@ -73,24 +74,7 @@ private static J[] createTemplateParameters(String code) {
7374
substituted = propertyPlaceholderHelper.replacePlaceholders(substituted, key -> {
7475
String s;
7576
if (!key.isEmpty()) {
76-
BaseErrorListener errorListener = new BaseErrorListener() {
77-
@Override
78-
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol,
79-
int line, int charPositionInLine, String msg, RecognitionException e) {
80-
throw new IllegalArgumentException(
81-
String.format("Syntax error at line %d:%d %s.", line, charPositionInLine, msg), e);
82-
}
83-
};
84-
85-
TemplateParameterLexer lexer = new TemplateParameterLexer(CharStreams.fromString(key));
86-
lexer.removeErrorListeners();
87-
lexer.addErrorListener(errorListener);
88-
89-
TemplateParameterParser parser = new TemplateParameterParser(new CommonTokenStream(lexer));
90-
parser.removeErrorListeners();
91-
parser.addErrorListener(errorListener);
92-
93-
TemplateParameterParser.MatcherPatternContext ctx = parser.matcherPattern();
77+
TemplateParameterParser.MatcherPatternContext ctx = TypeParameter.parser(key).matcherPattern();
9478
if (ctx.typedPattern() == null) {
9579
String paramName = ctx.parameterName().Identifier().getText();
9680
s = typedPatternByName.get(paramName);
@@ -99,7 +83,7 @@ public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol,
9983
}
10084
} else {
10185
TypedPatternContext typedPattern = ctx.typedPattern();
102-
JavaType type = typedParameter(key, typedPattern);
86+
JavaType type = typedParameter(key, typedPattern, generics);
10387
s = TypeUtils.toString(type);
10488

10589
String name = null;
@@ -126,12 +110,12 @@ public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol,
126110
return parameters.toArray(new J[0]);
127111
}
128112

129-
private static JavaType typedParameter(String key, TypedPatternContext typedPattern) {
113+
private static JavaType typedParameter(String key, TypedPatternContext typedPattern, Map<String, JavaType.GenericTypeVariable> generics) {
130114
String matcherName = typedPattern.patternType().matcherName().Identifier().getText();
131115
if ("any".equals(matcherName)) {
132-
return TypeParameter.toFullyQualifiedName(typedPattern.patternType().type());
116+
return TypeParameter.toJavaType(typedPattern.patternType().type(), generics);
133117
} else if ("anyArray".equals(matcherName)) {
134-
return new JavaType.Array(null, TypeParameter.toFullyQualifiedName(typedPattern.patternType().type()), null);
118+
return new JavaType.Array(null, TypeParameter.toJavaType(typedPattern.patternType().type(), generics), null);
135119
} else {
136120
throw new IllegalArgumentException("Unsupported template matcher '" + key + "'");
137121
}

0 commit comments

Comments
 (0)