Skip to content

Commit 251dd8a

Browse files
committed
Scala parser: support parameterless methods, annotated methods, remove while fallback
- Parameterless methods (def name: Type) now map to J.MethodDeclaration with OmitBraces marker on empty param container — printer omits () - Annotated methods (@deprecated def ...) now map to J.MethodDeclaration with annotations extracted and def-modifier carrying annotation-to-def spacing - Remove while loop fallback — while loops in method bodies handled by parser - Remove AppliedTypeTree fallback (already fixed in #7260) - 3 remaining fallbacks: procedure syntax, nested braces, function type params - 5 test failures: 3 complex annotations (with args), 2 while body block whitespace
1 parent 9258932 commit 251dd8a

2 files changed

Lines changed: 54 additions & 37 deletions

File tree

rewrite-scala/src/main/java/org/openrewrite/scala/ScalaPrinter.java

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,24 +287,36 @@ public J visitCase(J.Case case_, PrintOutputCapture<P> p) {
287287
public J visitMethodDeclaration(J.MethodDeclaration method, PrintOutputCapture<P> p) {
288288
beforeSyntax(method, Space.Location.METHOD_DECLARATION_PREFIX, p);
289289
visit(method.getLeadingAnnotations(), p);
290+
boolean defAlreadyPrinted = false;
290291
for (J.Modifier m : method.getModifiers()) {
291-
visit(m, p);
292+
if ("def".equals(m.getKeyword()) && m.getType() == J.Modifier.Type.LanguageExtension) {
293+
visitSpace(m.getPrefix(), Space.Location.MODIFIER_PREFIX, p);
294+
p.append("def");
295+
defAlreadyPrinted = true;
296+
} else {
297+
visit(m, p);
298+
}
292299
}
293-
294-
if (!method.getModifiers().isEmpty()) {
295-
p.append(" ");
300+
if (!defAlreadyPrinted) {
301+
if (!method.getModifiers().isEmpty()) {
302+
p.append(" ");
303+
}
304+
p.append("def");
296305
}
297-
p.append("def");
298306
visit(method.getName(), p);
299307

300308
if (method.getPadding().getTypeParameters() != null) {
301309
visit(method.getPadding().getTypeParameters(), p);
302310
}
303311

304-
// Print parameters (name: Type)
312+
// Print parameters — skip parens for parameterless methods (marked with OmitBraces)
305313
JContainer<Statement> params = method.getPadding().getParameters();
306-
visitSpace(params.getBefore(), Space.Location.METHOD_DECLARATION_PARAMETERS, p);
307-
p.append('(');
314+
boolean hasParens = !params.getMarkers().findFirst(
315+
org.openrewrite.scala.marker.OmitBraces.class).isPresent();
316+
if (hasParens) {
317+
visitSpace(params.getBefore(), Space.Location.METHOD_DECLARATION_PARAMETERS, p);
318+
p.append('(');
319+
}
308320
List<JRightPadded<Statement>> paramList = params.getPadding().getElements();
309321
for (int i = 0; i < paramList.size(); i++) {
310322
JRightPadded<Statement> param = paramList.get(i);
@@ -342,7 +354,9 @@ public J visitMethodDeclaration(J.MethodDeclaration method, PrintOutputCapture<P
342354
p.append(',');
343355
}
344356
}
345-
p.append(')');
357+
if (hasParens) {
358+
p.append(')');
359+
}
346360

347361
if (method.getReturnTypeExpression() != null) {
348362
p.append(':');

rewrite-scala/src/main/scala/org/openrewrite/scala/internal/ScalaTreeVisitor.scala

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4482,58 +4482,43 @@ class ScalaTreeVisitor(
44824482
}
44834483

44844484
private def visitDefDef(dd: Trees.DefDef[?]): J = {
4485-
// Fall back to J.Unknown for cases we can't handle yet
4486-
val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
4487-
if (hasAnnotations) {
4488-
return visitUnknown(dd)
4489-
}
4490-
4491-
// Check for cases we need to fall back to J.Unknown
44924485
val adjustedStart = Math.max(0, dd.span.start - offsetAdjustment)
44934486
val adjustedEnd = Math.max(0, dd.span.end - offsetAdjustment)
4487+
4488+
// Detect procedure syntax (no = before body) and nested braces — compiler flattens these
44944489
if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
44954490
val defSource = source.substring(adjustedStart, adjustedEnd)
4496-
// Procedure syntax (no = before body)
44974491
val braceIdx = defSource.indexOf('{')
44984492
val equalsIdx = defSource.indexOf('=')
44994493
if (braceIdx >= 0 && (equalsIdx < 0 || equalsIdx > braceIdx)) {
4500-
return visitUnknown(dd)
4494+
return visitUnknown(dd) // Procedure syntax
45014495
}
4502-
// Nested braces after = (compiler flattens these, breaking round-trip)
45034496
if (equalsIdx >= 0 && braceIdx >= 0) {
45044497
val afterFirstBrace = defSource.indexOf('{', braceIdx + 1)
4505-
if (afterFirstBrace >= 0) {
4506-
val between = defSource.substring(braceIdx + 1, afterFirstBrace).trim()
4507-
if (between.isEmpty) {
4508-
return visitUnknown(dd)
4509-
}
4498+
if (afterFirstBrace >= 0 && defSource.substring(braceIdx + 1, afterFirstBrace).trim().isEmpty) {
4499+
return visitUnknown(dd) // Nested braces
45104500
}
45114501
}
4512-
// Fall back for methods with while/for loops (block whitespace issues)
4513-
if (defSource.contains("while ") || defSource.contains("while(")) {
4514-
return visitUnknown(dd)
4515-
}
45164502
}
45174503

4518-
// Fall back for parameterless methods like `def name: Type` (no parens in source)
4504+
// Detect parameterless methods (def name: Type = ...) — no parens in source
4505+
var hasParensInSource = true
45194506
if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
45204507
val defSource = source.substring(adjustedStart, adjustedEnd)
45214508
val defIdx = defSource.indexOf("def ")
45224509
if (defIdx >= 0) {
45234510
val afterDef = defSource.substring(defIdx + 4)
4524-
// Find end of name (first non-alphanumeric char)
45254511
val nameEnd = afterDef.indexWhere(c => !c.isLetterOrDigit && c != '_')
45264512
if (nameEnd >= 0) {
45274513
val afterName = afterDef.substring(nameEnd).trim()
4528-
// If the next meaningful char after name is : or = (not (), the method has no parens
45294514
if (afterName.startsWith(":") || afterName.startsWith("=")) {
4530-
return visitUnknown(dd)
4515+
hasParensInSource = false
45314516
}
45324517
}
45334518
}
45344519
}
45354520

4536-
// Fall back for methods with function types or annotated parameters
4521+
// Fall back for function types (Int => Int) and annotated parameters (@unchecked)
45374522
val hasUnsupportedParams = dd.paramss.exists(_.exists {
45384523
case vd: Trees.ValDef[?] =>
45394524
vd.tpt.isInstanceOf[untpd.Function] ||
@@ -4546,18 +4531,28 @@ class ScalaTreeVisitor(
45464531

45474532
val savedCursor = cursor
45484533
try {
4549-
visitDefDefImpl(dd)
4534+
visitDefDefImpl(dd, hasParensInSource)
45504535
} catch {
45514536
case _: Exception =>
45524537
cursor = savedCursor
45534538
visitUnknown(dd)
45544539
}
45554540
}
45564541

4557-
private def visitDefDefImpl(dd: Trees.DefDef[?]): J.MethodDeclaration = {
4542+
private def visitDefDefImpl(dd: Trees.DefDef[?], hasParensInSource: Boolean = true): J.MethodDeclaration = {
45584543
val leadingAnnotations = new util.ArrayList[J.Annotation]()
4544+
val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
45594545
val prefix = extractPrefix(dd.span)
45604546

4547+
if (hasAnnotations) {
4548+
for (annot <- dd.mods.annotations) {
4549+
visitTree(annot) match {
4550+
case ann: J.Annotation => leadingAnnotations.add(ann)
4551+
case _ =>
4552+
}
4553+
}
4554+
}
4555+
45614556
val adjustedEnd = Math.max(0, dd.span.end - offsetAdjustment)
45624557
var modifierText = ""
45634558
var defIndex = -1
@@ -4573,6 +4568,12 @@ class ScalaTreeVisitor(
45734568

45744569
val (modifiers, _) = extractModifiersFromText(dd.mods, modifierText)
45754570

4571+
// If annotations but no modifiers, capture the space between annotations and "def"
4572+
if (hasAnnotations && modifiers.isEmpty() && modifierText.nonEmpty && modifierText.trim().isEmpty) {
4573+
modifiers.add(new J.Modifier(Tree.randomId(), Space.format(modifierText), Markers.EMPTY,
4574+
"def", J.Modifier.Type.LanguageExtension, Collections.emptyList()))
4575+
}
4576+
45764577
val defKeywordPos = if (defIndex >= 0) cursor + defIndex + "def".length else cursor
45774578
cursor = defKeywordPos
45784579

@@ -4686,7 +4687,7 @@ class ScalaTreeVisitor(
46864687
}
46874688

46884689
JContainer.build(parenSpace, jParams, Markers.EMPTY)
4689-
} else if (valueParamLists.nonEmpty) {
4690+
} else if (valueParamLists.nonEmpty && hasParensInSource) {
46904691
// Empty parameter list ()
46914692
val searchEnd = Math.min(cursor + 50, source.length)
46924693
val searchText = source.substring(cursor, searchEnd)
@@ -4700,7 +4701,9 @@ class ScalaTreeVisitor(
47004701
}
47014702
JContainer.build(parenSpace, new util.ArrayList[JRightPadded[Statement]](), Markers.EMPTY)
47024703
} else {
4703-
JContainer.empty[Statement]()
4704+
// Parameterless method — mark so printer omits ()
4705+
JContainer.build(Space.EMPTY, new util.ArrayList[JRightPadded[Statement]](),
4706+
Markers.build(Collections.singletonList(new org.openrewrite.scala.marker.OmitBraces(Tree.randomId()))))
47044707
}
47054708

47064709
// Handle return type `: ReturnType` — only if explicitly written in source

0 commit comments

Comments
 (0)