Skip to content

Commit 022f9d2

Browse files
committed
Scala parser: support parameterless methods, annotated methods, remove while fallback
- Parameterless methods (def name: Type) now map to J.MethodDeclaration with OmitBraces marker on empty param container — printer omits () - Annotated methods (@deprecated def ...) now map to J.MethodDeclaration with annotations extracted and def-modifier carrying annotation-to-def spacing - Remove while loop fallback — while loops in method bodies handled by parser - Remove AppliedTypeTree fallback (already fixed in #7260) - 3 remaining fallbacks: procedure syntax, nested braces, function type params - 5 test failures: 3 complex annotations (with args), 2 while body block whitespace
1 parent 9258932 commit 022f9d2

3 files changed

Lines changed: 162 additions & 103 deletions

File tree

rewrite-scala/src/main/java/org/openrewrite/scala/ScalaPrinter.java

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -287,24 +287,36 @@ public J visitCase(J.Case case_, PrintOutputCapture<P> p) {
287287
public J visitMethodDeclaration(J.MethodDeclaration method, PrintOutputCapture<P> p) {
288288
beforeSyntax(method, Space.Location.METHOD_DECLARATION_PREFIX, p);
289289
visit(method.getLeadingAnnotations(), p);
290+
boolean defAlreadyPrinted = false;
290291
for (J.Modifier m : method.getModifiers()) {
291-
visit(m, p);
292+
if ("def".equals(m.getKeyword()) && m.getType() == J.Modifier.Type.LanguageExtension) {
293+
visitSpace(m.getPrefix(), Space.Location.MODIFIER_PREFIX, p);
294+
p.append("def");
295+
defAlreadyPrinted = true;
296+
} else {
297+
visit(m, p);
298+
}
292299
}
293-
294-
if (!method.getModifiers().isEmpty()) {
295-
p.append(" ");
300+
if (!defAlreadyPrinted) {
301+
if (!method.getModifiers().isEmpty()) {
302+
p.append(" ");
303+
}
304+
p.append("def");
296305
}
297-
p.append("def");
298306
visit(method.getName(), p);
299307

300308
if (method.getPadding().getTypeParameters() != null) {
301309
visit(method.getPadding().getTypeParameters(), p);
302310
}
303311

304-
// Print parameters (name: Type)
312+
// Print parameters — skip parens for parameterless methods (marked with OmitBraces)
305313
JContainer<Statement> params = method.getPadding().getParameters();
306-
visitSpace(params.getBefore(), Space.Location.METHOD_DECLARATION_PARAMETERS, p);
307-
p.append('(');
314+
boolean hasParens = !params.getMarkers().findFirst(
315+
org.openrewrite.scala.marker.OmitBraces.class).isPresent();
316+
if (hasParens) {
317+
visitSpace(params.getBefore(), Space.Location.METHOD_DECLARATION_PARAMETERS, p);
318+
p.append('(');
319+
}
308320
List<JRightPadded<Statement>> paramList = params.getPadding().getElements();
309321
for (int i = 0; i < paramList.size(); i++) {
310322
JRightPadded<Statement> param = paramList.get(i);
@@ -342,7 +354,9 @@ public J visitMethodDeclaration(J.MethodDeclaration method, PrintOutputCapture<P
342354
p.append(',');
343355
}
344356
}
345-
p.append(')');
357+
if (hasParens) {
358+
p.append(')');
359+
}
346360

347361
if (method.getReturnTypeExpression() != null) {
348362
p.append(':');

rewrite-scala/src/main/scala/org/openrewrite/scala/internal/ScalaTreeVisitor.scala

Lines changed: 127 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -322,26 +322,25 @@ class ScalaTreeVisitor(
322322
// Check if this is an annotation pattern (will be handled specially when called from visitClassDef)
323323
// Annotations look like Apply(Select(New(...), <init>), args) with @ in source
324324
// Constructor calls look the same but have "new" in source
325-
val isAnnotationPattern = app.fun match {
326-
case sel: Trees.Select[?] if sel.name.toString == "<init>" =>
327-
sel.qualifier match {
328-
case newNode: Trees.New[?] =>
329-
// Check if the source has @ before the type (annotation) or "new" (constructor)
330-
if (app.span.exists) {
331-
val adjustedStart = Math.max(0, app.span.start - offsetAdjustment)
332-
val adjustedEnd = Math.max(0, app.span.end - offsetAdjustment)
333-
if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
334-
val sourceText = source.substring(adjustedStart, adjustedEnd)
335-
sourceText.trim.startsWith("@")
336-
} else {
337-
false
338-
}
339-
} else {
340-
false
341-
}
342-
case _ => false
343-
}
344-
case _ => false
325+
val isAnnotationPattern = {
326+
def checkAnnotation(fun: Trees.Tree[?]): Boolean = fun match {
327+
case sel: Trees.Select[?] if sel.name.toString == "<init>" =>
328+
sel.qualifier match {
329+
case _: Trees.New[?] =>
330+
if (app.span.exists) {
331+
val adjustedStart = Math.max(0, app.span.start - offsetAdjustment)
332+
val adjustedEnd = Math.max(0, app.span.end - offsetAdjustment)
333+
if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
334+
source.substring(adjustedStart, adjustedEnd).trim.startsWith("@")
335+
} else false
336+
} else false
337+
case _ => false
338+
}
339+
// Handle TypeApply wrapper for annotations with type params like @throws[Exception]
340+
case ta: Trees.TypeApply[?] => checkAnnotation(ta.fun)
341+
case _ => false
342+
}
343+
checkAnnotation(app.fun)
345344
}
346345

347346
if (isAnnotationPattern) {
@@ -406,19 +405,24 @@ class ScalaTreeVisitor(
406405

407406

408407
// Extract the annotation type and arguments
409-
val (annotationType, args) = app.fun match {
408+
// Handle both direct Select(New(...), <init>) and TypeApply(Select(New(...), <init>), types)
409+
def extractAnnotationType(fun: Trees.Tree[?]): Option[Trees.Ident[?]] = fun match {
410410
case sel: Trees.Select[?] if sel.name.toString == "<init>" =>
411411
sel.qualifier match {
412-
case newTree: Trees.New[?] =>
413-
val typeIdent = newTree.tpt match {
414-
case id: Trees.Ident[?] => id
415-
case _ => return visitUnknown(app)
416-
}
417-
(typeIdent, app.args)
418-
case _ => return visitUnknown(app)
412+
case newTree: Trees.New[?] => newTree.tpt match {
413+
case id: Trees.Ident[?] => Some(id)
414+
case _ => None
415+
}
416+
case _ => None
419417
}
420-
case _ => return visitUnknown(app)
418+
case ta: Trees.TypeApply[?] => extractAnnotationType(ta.fun)
419+
case _ => None
421420
}
421+
val annotationType = extractAnnotationType(app.fun) match {
422+
case Some(id) => id
423+
case None => return visitUnknown(app)
424+
}
425+
val args = app.args
422426

423427
// Create the annotation type
424428
val annotTypeTree = new J.Identifier(
@@ -455,7 +459,20 @@ class ScalaTreeVisitor(
455459
case e: Expression => e
456460
case _ => visitUnknown(arg)
457461
}
458-
argList.add(JRightPadded.build(expr.asInstanceOf[Expression]).withAfter(Space.EMPTY))
462+
// Skip trailing comma between arguments
463+
val afterSpace = if (i < args.size - 1) {
464+
updateCursor(arg.span.end)
465+
if (cursor < source.length) {
466+
val after = source.substring(cursor, Math.min(cursor + 20, source.length))
467+
val commaIdx = after.indexOf(',')
468+
if (commaIdx >= 0) {
469+
val sp = Space.format(after.substring(0, commaIdx))
470+
cursor = cursor + commaIdx + 1
471+
sp
472+
} else Space.EMPTY
473+
} else Space.EMPTY
474+
} else Space.EMPTY
475+
argList.add(JRightPadded.build(expr.asInstanceOf[Expression]).withAfter(afterSpace))
459476
}
460477
JContainer.build(
461478
Space.EMPTY,
@@ -3521,37 +3538,36 @@ class ScalaTreeVisitor(
35213538
*/
35223539

35233540
private def visitBlock(block: Trees.Block[?]): J.Block = {
3524-
val prefix = extractPrefix(block.span)
3525-
3526-
// Move cursor past the opening brace — but only if the block starts with one
35273541
val blockStart = Math.max(0, block.span.start - offsetAdjustment)
35283542
val blockEndAdj = Math.max(0, block.span.end - offsetAdjustment)
3529-
val blockStartsWithBrace = blockStart < source.length && source.charAt(blockStart) == '{'
3530-
if (blockStartsWithBrace) {
3543+
3544+
// Find the opening brace — it may be at blockStart, before blockStart (while/for body),
3545+
// or between cursor and the first child
3546+
val savedCursorBeforePrefix = cursor
3547+
val prefix = extractPrefix(block.span) // advances cursor to blockStart
3548+
3549+
// Now find and advance past '{'
3550+
if (blockStart < source.length && source.charAt(blockStart) == '{') {
3551+
// Brace at block span start
35313552
cursor = blockStart + 1
3532-
} else {
3533-
// Check if there's a brace between cursor and the first statement
3534-
val firstChildStart = if (block.stats.nonEmpty) {
3535-
Math.max(0, block.stats.head.span.start - offsetAdjustment)
3536-
} else if (!block.expr.isEmpty) {
3537-
Math.max(0, block.expr.span.start - offsetAdjustment)
3538-
} else blockStart
3539-
3540-
if (cursor < firstChildStart && firstChildStart <= source.length) {
3541-
val between = source.substring(cursor, firstChildStart)
3542-
val braceIdx = between.indexOf('{')
3543-
if (braceIdx >= 0) {
3544-
cursor = cursor + braceIdx + 1
3545-
}
3553+
} else if (savedCursorBeforePrefix < blockStart) {
3554+
// Check if there's a brace BEFORE the block span (e.g., while body)
3555+
val beforeSpan = source.substring(savedCursorBeforePrefix, blockStart)
3556+
val braceIdx = beforeSpan.indexOf('{')
3557+
if (braceIdx >= 0) {
3558+
// The brace is before the block span — cursor should be past it
3559+
cursor = savedCursorBeforePrefix + braceIdx + 1
35463560
}
35473561
}
35483562

35493563
val statements = new util.ArrayList[JRightPadded[Statement]]()
35503564

3565+
35513566
// Visit all statements in the block
35523567
for (i <- block.stats.indices) {
35533568
val stat = block.stats(i)
3554-
visitTree(stat) match {
3569+
val visitResult = visitTree(stat)
3570+
visitResult match {
35553571
case null => // Skip null statements (e.g., package declarations)
35563572
case stmt: Statement =>
35573573
// Extract trailing space after this statement
@@ -4482,58 +4498,43 @@ class ScalaTreeVisitor(
44824498
}
44834499

44844500
private def visitDefDef(dd: Trees.DefDef[?]): J = {
4485-
// Fall back to J.Unknown for cases we can't handle yet
4486-
val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
4487-
if (hasAnnotations) {
4488-
return visitUnknown(dd)
4489-
}
4490-
4491-
// Check for cases we need to fall back to J.Unknown
44924501
val adjustedStart = Math.max(0, dd.span.start - offsetAdjustment)
44934502
val adjustedEnd = Math.max(0, dd.span.end - offsetAdjustment)
4503+
4504+
// Detect procedure syntax (no = before body) and nested braces — compiler flattens these
44944505
if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
44954506
val defSource = source.substring(adjustedStart, adjustedEnd)
4496-
// Procedure syntax (no = before body)
44974507
val braceIdx = defSource.indexOf('{')
44984508
val equalsIdx = defSource.indexOf('=')
44994509
if (braceIdx >= 0 && (equalsIdx < 0 || equalsIdx > braceIdx)) {
4500-
return visitUnknown(dd)
4510+
return visitUnknown(dd) // Procedure syntax
45014511
}
4502-
// Nested braces after = (compiler flattens these, breaking round-trip)
45034512
if (equalsIdx >= 0 && braceIdx >= 0) {
45044513
val afterFirstBrace = defSource.indexOf('{', braceIdx + 1)
4505-
if (afterFirstBrace >= 0) {
4506-
val between = defSource.substring(braceIdx + 1, afterFirstBrace).trim()
4507-
if (between.isEmpty) {
4508-
return visitUnknown(dd)
4509-
}
4514+
if (afterFirstBrace >= 0 && defSource.substring(braceIdx + 1, afterFirstBrace).trim().isEmpty) {
4515+
return visitUnknown(dd) // Nested braces
45104516
}
45114517
}
4512-
// Fall back for methods with while/for loops (block whitespace issues)
4513-
if (defSource.contains("while ") || defSource.contains("while(")) {
4514-
return visitUnknown(dd)
4515-
}
45164518
}
45174519

4518-
// Fall back for parameterless methods like `def name: Type` (no parens in source)
4520+
// Detect parameterless methods (def name: Type = ...) — no parens in source
4521+
var hasParensInSource = true
45194522
if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
45204523
val defSource = source.substring(adjustedStart, adjustedEnd)
45214524
val defIdx = defSource.indexOf("def ")
45224525
if (defIdx >= 0) {
45234526
val afterDef = defSource.substring(defIdx + 4)
4524-
// Find end of name (first non-alphanumeric char)
45254527
val nameEnd = afterDef.indexWhere(c => !c.isLetterOrDigit && c != '_')
45264528
if (nameEnd >= 0) {
45274529
val afterName = afterDef.substring(nameEnd).trim()
4528-
// If the next meaningful char after name is : or = (not (), the method has no parens
45294530
if (afterName.startsWith(":") || afterName.startsWith("=")) {
4530-
return visitUnknown(dd)
4531+
hasParensInSource = false
45314532
}
45324533
}
45334534
}
45344535
}
45354536

4536-
// Fall back for methods with function types or annotated parameters
4537+
// Fall back for function types (Int => Int) and annotated parameters (@unchecked)
45374538
val hasUnsupportedParams = dd.paramss.exists(_.exists {
45384539
case vd: Trees.ValDef[?] =>
45394540
vd.tpt.isInstanceOf[untpd.Function] ||
@@ -4546,18 +4547,58 @@ class ScalaTreeVisitor(
45464547

45474548
val savedCursor = cursor
45484549
try {
4549-
visitDefDefImpl(dd)
4550+
visitDefDefImpl(dd, hasParensInSource)
45504551
} catch {
45514552
case _: Exception =>
45524553
cursor = savedCursor
45534554
visitUnknown(dd)
45544555
}
45554556
}
45564557

4557-
private def visitDefDefImpl(dd: Trees.DefDef[?]): J.MethodDeclaration = {
4558+
private def visitDefDefImpl(dd: Trees.DefDef[?], hasParensInSource: Boolean = true): J.MethodDeclaration = {
45584559
val leadingAnnotations = new util.ArrayList[J.Annotation]()
4560+
val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
45594561
val prefix = extractPrefix(dd.span)
45604562

4563+
if (hasAnnotations) {
4564+
for (annot <- dd.mods.annotations) {
4565+
// First try the proper AST visitor
4566+
val savedAnnotCursor = cursor
4567+
val annotResult = try { visitTree(annot) } catch { case _: Exception => cursor = savedAnnotCursor; null }
4568+
4569+
annotResult match {
4570+
case ann: J.Annotation => leadingAnnotations.add(ann)
4571+
case _ =>
4572+
// Fallback: build annotation from source text
4573+
cursor = savedAnnotCursor // Reset cursor — visitTree may have advanced it incorrectly
4574+
if (annot.span.exists) {
4575+
val annotStart = Math.max(0, annot.span.start - offsetAdjustment)
4576+
val annotEnd = Math.max(0, annot.span.end - offsetAdjustment)
4577+
// Find the @ before the annotation span (it's in the source but might not be in the span)
4578+
val searchStart = Math.max(0, annotStart - 5)
4579+
val atIdx = if (searchStart < source.length) source.lastIndexOf('@', annotStart) else -1
4580+
val fullStart = if (atIdx >= 0 && atIdx >= searchStart) atIdx else annotStart
4581+
4582+
// Prefix = space from cursor to @ position
4583+
val annotPrefix = if (cursor < fullStart && fullStart <= source.length) {
4584+
Space.format(source.substring(cursor, fullStart))
4585+
} else Space.EMPTY
4586+
4587+
// Source = everything from @ to annotation end
4588+
val annotSource = if (fullStart < annotEnd && annotEnd <= source.length) {
4589+
source.substring(fullStart, annotEnd)
4590+
} else ""
4591+
val annotName = if (annotSource.startsWith("@")) annotSource.substring(1) else annotSource
4592+
4593+
val annotId = new J.Identifier(Tree.randomId(), Space.EMPTY, Markers.EMPTY,
4594+
Collections.emptyList(), annotName, null, null)
4595+
leadingAnnotations.add(new J.Annotation(Tree.randomId(), annotPrefix, Markers.EMPTY, annotId, null))
4596+
cursor = annotEnd
4597+
}
4598+
}
4599+
}
4600+
}
4601+
45614602
val adjustedEnd = Math.max(0, dd.span.end - offsetAdjustment)
45624603
var modifierText = ""
45634604
var defIndex = -1
@@ -4573,6 +4614,12 @@ class ScalaTreeVisitor(
45734614

45744615
val (modifiers, _) = extractModifiersFromText(dd.mods, modifierText)
45754616

4617+
// If annotations but no modifiers, capture the space between annotations and "def"
4618+
if (hasAnnotations && modifiers.isEmpty() && modifierText.nonEmpty && modifierText.trim().isEmpty) {
4619+
modifiers.add(new J.Modifier(Tree.randomId(), Space.format(modifierText), Markers.EMPTY,
4620+
"def", J.Modifier.Type.LanguageExtension, Collections.emptyList()))
4621+
}
4622+
45764623
val defKeywordPos = if (defIndex >= 0) cursor + defIndex + "def".length else cursor
45774624
cursor = defKeywordPos
45784625

@@ -4686,7 +4733,7 @@ class ScalaTreeVisitor(
46864733
}
46874734

46884735
JContainer.build(parenSpace, jParams, Markers.EMPTY)
4689-
} else if (valueParamLists.nonEmpty) {
4736+
} else if (valueParamLists.nonEmpty && hasParensInSource) {
46904737
// Empty parameter list ()
46914738
val searchEnd = Math.min(cursor + 50, source.length)
46924739
val searchText = source.substring(cursor, searchEnd)
@@ -4700,7 +4747,9 @@ class ScalaTreeVisitor(
47004747
}
47014748
JContainer.build(parenSpace, new util.ArrayList[JRightPadded[Statement]](), Markers.EMPTY)
47024749
} else {
4703-
JContainer.empty[Statement]()
4750+
// Parameterless method — mark so printer omits ()
4751+
JContainer.build(Space.EMPTY, new util.ArrayList[JRightPadded[Statement]](),
4752+
Markers.build(Collections.singletonList(new org.openrewrite.scala.marker.OmitBraces(Tree.randomId()))))
47044753
}
47054754

47064755
// Handle return type `: ReturnType` — only if explicitly written in source

0 commit comments

Comments
 (0)