@@ -322,26 +322,25 @@ class ScalaTreeVisitor(
322322 // Check if this is an annotation pattern (will be handled specially when called from visitClassDef)
323323 // Annotations look like Apply(Select(New(...), <init>), args) with @ in source
324324 // Constructor calls look the same but have "new" in source
325- val isAnnotationPattern = app.fun match {
326- case sel : Trees .Select [? ] if sel.name.toString == " <init>" =>
327- sel.qualifier match {
328- case newNode : Trees .New [? ] =>
329- // Check if the source has @ before the type (annotation) or "new" (constructor)
330- if (app.span.exists) {
331- val adjustedStart = Math .max(0 , app.span.start - offsetAdjustment)
332- val adjustedEnd = Math .max(0 , app.span.end - offsetAdjustment)
333- if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
334- val sourceText = source.substring(adjustedStart, adjustedEnd)
335- sourceText.trim.startsWith(" @" )
336- } else {
337- false
338- }
339- } else {
340- false
341- }
342- case _ => false
343- }
344- case _ => false
325+ val isAnnotationPattern = {
326+ def checkAnnotation (fun : Trees .Tree [? ]): Boolean = fun match {
327+ case sel : Trees .Select [? ] if sel.name.toString == " <init>" =>
328+ sel.qualifier match {
329+ case _ : Trees .New [? ] =>
330+ if (app.span.exists) {
331+ val adjustedStart = Math .max(0 , app.span.start - offsetAdjustment)
332+ val adjustedEnd = Math .max(0 , app.span.end - offsetAdjustment)
333+ if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
334+ source.substring(adjustedStart, adjustedEnd).trim.startsWith(" @" )
335+ } else false
336+ } else false
337+ case _ => false
338+ }
339+ // Handle TypeApply wrapper for annotations with type params like @throws[Exception]
340+ case ta : Trees .TypeApply [? ] => checkAnnotation(ta.fun)
341+ case _ => false
342+ }
343+ checkAnnotation(app.fun)
345344 }
346345
347346 if (isAnnotationPattern) {
@@ -406,19 +405,24 @@ class ScalaTreeVisitor(
406405
407406
408407 // Extract the annotation type and arguments
409- val (annotationType, args) = app.fun match {
408+ // Handle both direct Select(New(...), <init>) and TypeApply(Select(New(...), <init>), types)
409+ def extractAnnotationType (fun : Trees .Tree [? ]): Option [Trees .Ident [? ]] = fun match {
410410 case sel : Trees .Select [? ] if sel.name.toString == " <init>" =>
411411 sel.qualifier match {
412- case newTree : Trees .New [? ] =>
413- val typeIdent = newTree.tpt match {
414- case id : Trees .Ident [? ] => id
415- case _ => return visitUnknown(app)
416- }
417- (typeIdent, app.args)
418- case _ => return visitUnknown(app)
412+ case newTree : Trees .New [? ] => newTree.tpt match {
413+ case id : Trees .Ident [? ] => Some (id)
414+ case _ => None
415+ }
416+ case _ => None
419417 }
420- case _ => return visitUnknown(app)
418+ case ta : Trees .TypeApply [? ] => extractAnnotationType(ta.fun)
419+ case _ => None
421420 }
421+ val annotationType = extractAnnotationType(app.fun) match {
422+ case Some (id) => id
423+ case None => return visitUnknown(app)
424+ }
425+ val args = app.args
422426
423427 // Create the annotation type
424428 val annotTypeTree = new J .Identifier (
@@ -455,7 +459,20 @@ class ScalaTreeVisitor(
455459 case e : Expression => e
456460 case _ => visitUnknown(arg)
457461 }
458- argList.add(JRightPadded .build(expr.asInstanceOf [Expression ]).withAfter(Space .EMPTY ))
462+ // Skip trailing comma between arguments
463+ val afterSpace = if (i < args.size - 1 ) {
464+ updateCursor(arg.span.end)
465+ if (cursor < source.length) {
466+ val after = source.substring(cursor, Math .min(cursor + 20 , source.length))
467+ val commaIdx = after.indexOf(',' )
468+ if (commaIdx >= 0 ) {
469+ val sp = Space .format(after.substring(0 , commaIdx))
470+ cursor = cursor + commaIdx + 1
471+ sp
472+ } else Space .EMPTY
473+ } else Space .EMPTY
474+ } else Space .EMPTY
475+ argList.add(JRightPadded .build(expr.asInstanceOf [Expression ]).withAfter(afterSpace))
459476 }
460477 JContainer .build(
461478 Space .EMPTY ,
@@ -3521,37 +3538,36 @@ class ScalaTreeVisitor(
35213538 */
35223539
35233540 private def visitBlock (block : Trees .Block [? ]): J .Block = {
3524- val prefix = extractPrefix(block.span)
3525-
3526- // Move cursor past the opening brace — but only if the block starts with one
35273541 val blockStart = Math .max(0 , block.span.start - offsetAdjustment)
35283542 val blockEndAdj = Math .max(0 , block.span.end - offsetAdjustment)
3529- val blockStartsWithBrace = blockStart < source.length && source.charAt(blockStart) == '{'
3530- if (blockStartsWithBrace) {
3543+
3544+ // Find the opening brace — it may be at blockStart, before blockStart (while/for body),
3545+ // or between cursor and the first child
3546+ val savedCursorBeforePrefix = cursor
3547+ val prefix = extractPrefix(block.span) // advances cursor to blockStart
3548+
3549+ // Now find and advance past '{'
3550+ if (blockStart < source.length && source.charAt(blockStart) == '{' ) {
3551+ // Brace at block span start
35313552 cursor = blockStart + 1
3532- } else {
3533- // Check if there's a brace between cursor and the first statement
3534- val firstChildStart = if (block.stats.nonEmpty) {
3535- Math .max(0 , block.stats.head.span.start - offsetAdjustment)
3536- } else if (! block.expr.isEmpty) {
3537- Math .max(0 , block.expr.span.start - offsetAdjustment)
3538- } else blockStart
3539-
3540- if (cursor < firstChildStart && firstChildStart <= source.length) {
3541- val between = source.substring(cursor, firstChildStart)
3542- val braceIdx = between.indexOf('{' )
3543- if (braceIdx >= 0 ) {
3544- cursor = cursor + braceIdx + 1
3545- }
3553+ } else if (savedCursorBeforePrefix < blockStart) {
3554+ // Check if there's a brace BEFORE the block span (e.g., while body)
3555+ val beforeSpan = source.substring(savedCursorBeforePrefix, blockStart)
3556+ val braceIdx = beforeSpan.indexOf('{' )
3557+ if (braceIdx >= 0 ) {
3558+ // The brace is before the block span — cursor should be past it
3559+ cursor = savedCursorBeforePrefix + braceIdx + 1
35463560 }
35473561 }
35483562
35493563 val statements = new util.ArrayList [JRightPadded [Statement ]]()
35503564
3565+
35513566 // Visit all statements in the block
35523567 for (i <- block.stats.indices) {
35533568 val stat = block.stats(i)
3554- visitTree(stat) match {
3569+ val visitResult = visitTree(stat)
3570+ visitResult match {
35553571 case null => // Skip null statements (e.g., package declarations)
35563572 case stmt : Statement =>
35573573 // Extract trailing space after this statement
@@ -4482,58 +4498,43 @@ class ScalaTreeVisitor(
44824498 }
44834499
44844500 private def visitDefDef (dd : Trees .DefDef [? ]): J = {
4485- // Fall back to J.Unknown for cases we can't handle yet
4486- val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
4487- if (hasAnnotations) {
4488- return visitUnknown(dd)
4489- }
4490-
4491- // Check for cases we need to fall back to J.Unknown
44924501 val adjustedStart = Math .max(0 , dd.span.start - offsetAdjustment)
44934502 val adjustedEnd = Math .max(0 , dd.span.end - offsetAdjustment)
4503+
4504+ // Detect procedure syntax (no = before body) and nested braces — compiler flattens these
44944505 if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
44954506 val defSource = source.substring(adjustedStart, adjustedEnd)
4496- // Procedure syntax (no = before body)
44974507 val braceIdx = defSource.indexOf('{' )
44984508 val equalsIdx = defSource.indexOf('=' )
44994509 if (braceIdx >= 0 && (equalsIdx < 0 || equalsIdx > braceIdx)) {
4500- return visitUnknown(dd)
4510+ return visitUnknown(dd) // Procedure syntax
45014511 }
4502- // Nested braces after = (compiler flattens these, breaking round-trip)
45034512 if (equalsIdx >= 0 && braceIdx >= 0 ) {
45044513 val afterFirstBrace = defSource.indexOf('{' , braceIdx + 1 )
4505- if (afterFirstBrace >= 0 ) {
4506- val between = defSource.substring(braceIdx + 1 , afterFirstBrace).trim()
4507- if (between.isEmpty) {
4508- return visitUnknown(dd)
4509- }
4514+ if (afterFirstBrace >= 0 && defSource.substring(braceIdx + 1 , afterFirstBrace).trim().isEmpty) {
4515+ return visitUnknown(dd) // Nested braces
45104516 }
45114517 }
4512- // Fall back for methods with while/for loops (block whitespace issues)
4513- if (defSource.contains(" while " ) || defSource.contains(" while(" )) {
4514- return visitUnknown(dd)
4515- }
45164518 }
45174519
4518- // Fall back for parameterless methods like `def name: Type` (no parens in source)
4520+ // Detect parameterless methods (def name: Type = ...) — no parens in source
4521+ var hasParensInSource = true
45194522 if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
45204523 val defSource = source.substring(adjustedStart, adjustedEnd)
45214524 val defIdx = defSource.indexOf(" def " )
45224525 if (defIdx >= 0 ) {
45234526 val afterDef = defSource.substring(defIdx + 4 )
4524- // Find end of name (first non-alphanumeric char)
45254527 val nameEnd = afterDef.indexWhere(c => ! c.isLetterOrDigit && c != '_' )
45264528 if (nameEnd >= 0 ) {
45274529 val afterName = afterDef.substring(nameEnd).trim()
4528- // If the next meaningful char after name is : or = (not (), the method has no parens
45294530 if (afterName.startsWith(" :" ) || afterName.startsWith(" =" )) {
4530- return visitUnknown(dd)
4531+ hasParensInSource = false
45314532 }
45324533 }
45334534 }
45344535 }
45354536
4536- // Fall back for methods with function types or annotated parameters
4537+ // Fall back for function types (Int => Int) and annotated parameters (@unchecked)
45374538 val hasUnsupportedParams = dd.paramss.exists(_.exists {
45384539 case vd : Trees .ValDef [? ] =>
45394540 vd.tpt.isInstanceOf [untpd.Function ] ||
@@ -4546,18 +4547,58 @@ class ScalaTreeVisitor(
45464547
45474548 val savedCursor = cursor
45484549 try {
4549- visitDefDefImpl(dd)
4550+ visitDefDefImpl(dd, hasParensInSource )
45504551 } catch {
45514552 case _ : Exception =>
45524553 cursor = savedCursor
45534554 visitUnknown(dd)
45544555 }
45554556 }
45564557
4557- private def visitDefDefImpl (dd : Trees .DefDef [? ]): J .MethodDeclaration = {
4558+ private def visitDefDefImpl (dd : Trees .DefDef [? ], hasParensInSource : Boolean = true ): J .MethodDeclaration = {
45584559 val leadingAnnotations = new util.ArrayList [J .Annotation ]()
4560+ val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
45594561 val prefix = extractPrefix(dd.span)
45604562
4563+ if (hasAnnotations) {
4564+ for (annot <- dd.mods.annotations) {
4565+ // First try the proper AST visitor
4566+ val savedAnnotCursor = cursor
4567+ val annotResult = try { visitTree(annot) } catch { case _ : Exception => cursor = savedAnnotCursor; null }
4568+
4569+ annotResult match {
4570+ case ann : J .Annotation => leadingAnnotations.add(ann)
4571+ case _ =>
4572+ // Fallback: build annotation from source text
4573+ cursor = savedAnnotCursor // Reset cursor — visitTree may have advanced it incorrectly
4574+ if (annot.span.exists) {
4575+ val annotStart = Math .max(0 , annot.span.start - offsetAdjustment)
4576+ val annotEnd = Math .max(0 , annot.span.end - offsetAdjustment)
4577+ // Find the @ before the annotation span (it's in the source but might not be in the span)
4578+ val searchStart = Math .max(0 , annotStart - 5 )
4579+ val atIdx = if (searchStart < source.length) source.lastIndexOf('@' , annotStart) else - 1
4580+ val fullStart = if (atIdx >= 0 && atIdx >= searchStart) atIdx else annotStart
4581+
4582+ // Prefix = space from cursor to @ position
4583+ val annotPrefix = if (cursor < fullStart && fullStart <= source.length) {
4584+ Space .format(source.substring(cursor, fullStart))
4585+ } else Space .EMPTY
4586+
4587+ // Source = everything from @ to annotation end
4588+ val annotSource = if (fullStart < annotEnd && annotEnd <= source.length) {
4589+ source.substring(fullStart, annotEnd)
4590+ } else " "
4591+ val annotName = if (annotSource.startsWith(" @" )) annotSource.substring(1 ) else annotSource
4592+
4593+ val annotId = new J .Identifier (Tree .randomId(), Space .EMPTY , Markers .EMPTY ,
4594+ Collections .emptyList(), annotName, null , null )
4595+ leadingAnnotations.add(new J .Annotation (Tree .randomId(), annotPrefix, Markers .EMPTY , annotId, null ))
4596+ cursor = annotEnd
4597+ }
4598+ }
4599+ }
4600+ }
4601+
45614602 val adjustedEnd = Math .max(0 , dd.span.end - offsetAdjustment)
45624603 var modifierText = " "
45634604 var defIndex = - 1
@@ -4573,6 +4614,12 @@ class ScalaTreeVisitor(
45734614
45744615 val (modifiers, _) = extractModifiersFromText(dd.mods, modifierText)
45754616
4617+ // If annotations but no modifiers, capture the space between annotations and "def"
4618+ if (hasAnnotations && modifiers.isEmpty() && modifierText.nonEmpty && modifierText.trim().isEmpty) {
4619+ modifiers.add(new J .Modifier (Tree .randomId(), Space .format(modifierText), Markers .EMPTY ,
4620+ " def" , J .Modifier .Type .LanguageExtension , Collections .emptyList()))
4621+ }
4622+
45764623 val defKeywordPos = if (defIndex >= 0 ) cursor + defIndex + " def" .length else cursor
45774624 cursor = defKeywordPos
45784625
@@ -4686,7 +4733,7 @@ class ScalaTreeVisitor(
46864733 }
46874734
46884735 JContainer .build(parenSpace, jParams, Markers .EMPTY )
4689- } else if (valueParamLists.nonEmpty) {
4736+ } else if (valueParamLists.nonEmpty && hasParensInSource ) {
46904737 // Empty parameter list ()
46914738 val searchEnd = Math .min(cursor + 50 , source.length)
46924739 val searchText = source.substring(cursor, searchEnd)
@@ -4700,7 +4747,9 @@ class ScalaTreeVisitor(
47004747 }
47014748 JContainer .build(parenSpace, new util.ArrayList [JRightPadded [Statement ]](), Markers .EMPTY )
47024749 } else {
4703- JContainer .empty[Statement ]()
4750+ // Parameterless method — mark so printer omits ()
4751+ JContainer .build(Space .EMPTY , new util.ArrayList [JRightPadded [Statement ]](),
4752+ Markers .build(Collections .singletonList(new org.openrewrite.scala.marker.OmitBraces (Tree .randomId()))))
47044753 }
47054754
47064755 // Handle return type `: ReturnType` — only if explicitly written in source
0 commit comments