@@ -322,26 +322,25 @@ class ScalaTreeVisitor(
322322 // Check if this is an annotation pattern (will be handled specially when called from visitClassDef)
323323 // Annotations look like Apply(Select(New(...), <init>), args) with @ in source
324324 // Constructor calls look the same but have "new" in source
325- val isAnnotationPattern = app.fun match {
326- case sel : Trees .Select [? ] if sel.name.toString == " <init>" =>
327- sel.qualifier match {
328- case newNode : Trees .New [? ] =>
329- // Check if the source has @ before the type (annotation) or "new" (constructor)
330- if (app.span.exists) {
331- val adjustedStart = Math .max(0 , app.span.start - offsetAdjustment)
332- val adjustedEnd = Math .max(0 , app.span.end - offsetAdjustment)
333- if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
334- val sourceText = source.substring(adjustedStart, adjustedEnd)
335- sourceText.trim.startsWith(" @" )
336- } else {
337- false
338- }
339- } else {
340- false
341- }
342- case _ => false
343- }
344- case _ => false
325+ val isAnnotationPattern = {
326+ def checkAnnotation (fun : Trees .Tree [? ]): Boolean = fun match {
327+ case sel : Trees .Select [? ] if sel.name.toString == " <init>" =>
328+ sel.qualifier match {
329+ case _ : Trees .New [? ] =>
330+ if (app.span.exists) {
331+ val adjustedStart = Math .max(0 , app.span.start - offsetAdjustment)
332+ val adjustedEnd = Math .max(0 , app.span.end - offsetAdjustment)
333+ if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
334+ source.substring(adjustedStart, adjustedEnd).trim.startsWith(" @" )
335+ } else false
336+ } else false
337+ case _ => false
338+ }
339+ // Handle TypeApply wrapper for annotations with type params like @throws[Exception]
340+ case ta : Trees .TypeApply [? ] => checkAnnotation(ta.fun)
341+ case _ => false
342+ }
343+ checkAnnotation(app.fun)
345344 }
346345
347346 if (isAnnotationPattern) {
@@ -406,19 +405,24 @@ class ScalaTreeVisitor(
406405
407406
408407 // Extract the annotation type and arguments
409- val (annotationType, args) = app.fun match {
408+ // Handle both direct Select(New(...), <init>) and TypeApply(Select(New(...), <init>), types)
409+ def extractAnnotationType (fun : Trees .Tree [? ]): Option [Trees .Ident [? ]] = fun match {
410410 case sel : Trees .Select [? ] if sel.name.toString == " <init>" =>
411411 sel.qualifier match {
412- case newTree : Trees .New [? ] =>
413- val typeIdent = newTree.tpt match {
414- case id : Trees .Ident [? ] => id
415- case _ => return visitUnknown(app)
416- }
417- (typeIdent, app.args)
418- case _ => return visitUnknown(app)
412+ case newTree : Trees .New [? ] => newTree.tpt match {
413+ case id : Trees .Ident [? ] => Some (id)
414+ case _ => None
415+ }
416+ case _ => None
419417 }
420- case _ => return visitUnknown(app)
418+ case ta : Trees .TypeApply [? ] => extractAnnotationType(ta.fun)
419+ case _ => None
420+ }
421+ val annotationType = extractAnnotationType(app.fun) match {
422+ case Some (id) => id
423+ case None => return visitUnknown(app)
421424 }
425+ val args = app.args
422426
423427 // Create the annotation type
424428 val annotTypeTree = new J .Identifier (
@@ -455,7 +459,20 @@ class ScalaTreeVisitor(
455459 case e : Expression => e
456460 case _ => visitUnknown(arg)
457461 }
458- argList.add(JRightPadded .build(expr.asInstanceOf [Expression ]).withAfter(Space .EMPTY ))
462+ // Skip trailing comma between arguments
463+ val afterSpace = if (i < args.size - 1 ) {
464+ updateCursor(arg.span.end)
465+ if (cursor < source.length) {
466+ val after = source.substring(cursor, Math .min(cursor + 20 , source.length))
467+ val commaIdx = after.indexOf(',' )
468+ if (commaIdx >= 0 ) {
469+ val sp = Space .format(after.substring(0 , commaIdx))
470+ cursor = cursor + commaIdx + 1
471+ sp
472+ } else Space .EMPTY
473+ } else Space .EMPTY
474+ } else Space .EMPTY
475+ argList.add(JRightPadded .build(expr.asInstanceOf [Expression ]).withAfter(afterSpace))
459476 }
460477 JContainer .build(
461478 Space .EMPTY ,
@@ -3521,28 +3538,25 @@ class ScalaTreeVisitor(
35213538 */
35223539
35233540 private def visitBlock (block : Trees .Block [? ]): J .Block = {
3524- val prefix = extractPrefix(block.span)
3525-
3526- // Move cursor past the opening brace — but only if the block starts with one
35273541 val blockStart = Math .max(0 , block.span.start - offsetAdjustment)
35283542 val blockEndAdj = Math .max(0 , block.span.end - offsetAdjustment)
3529- val blockStartsWithBrace = blockStart < source.length && source.charAt(blockStart) == '{'
3530- if (blockStartsWithBrace) {
3543+
3544+ // Find the opening brace — it may be at blockStart, before blockStart (while/for body),
3545+ // or between cursor and the first child
3546+ val savedCursorBeforePrefix = cursor
3547+ val prefix = extractPrefix(block.span) // advances cursor to blockStart
3548+
3549+ // Now find and advance past '{'
3550+ if (blockStart < source.length && source.charAt(blockStart) == '{' ) {
3551+ // Brace at block span start
35313552 cursor = blockStart + 1
3532- } else {
3533- // Check if there's a brace between cursor and the first statement
3534- val firstChildStart = if (block.stats.nonEmpty) {
3535- Math .max(0 , block.stats.head.span.start - offsetAdjustment)
3536- } else if (! block.expr.isEmpty) {
3537- Math .max(0 , block.expr.span.start - offsetAdjustment)
3538- } else blockStart
3539-
3540- if (cursor < firstChildStart && firstChildStart <= source.length) {
3541- val between = source.substring(cursor, firstChildStart)
3542- val braceIdx = between.indexOf('{' )
3543- if (braceIdx >= 0 ) {
3544- cursor = cursor + braceIdx + 1
3545- }
3553+ } else if (savedCursorBeforePrefix < blockStart) {
3554+ // Check if there's a brace BEFORE the block span (e.g., while body)
3555+ val beforeSpan = source.substring(savedCursorBeforePrefix, blockStart)
3556+ val braceIdx = beforeSpan.indexOf('{' )
3557+ if (braceIdx >= 0 ) {
3558+ // The brace is before the block span — cursor should be past it
3559+ cursor = savedCursorBeforePrefix + braceIdx + 1
35463560 }
35473561 }
35483562
@@ -4482,58 +4496,43 @@ class ScalaTreeVisitor(
44824496 }
44834497
44844498 private def visitDefDef (dd : Trees .DefDef [? ]): J = {
4485- // Fall back to J.Unknown for cases we can't handle yet
4486- val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
4487- if (hasAnnotations) {
4488- return visitUnknown(dd)
4489- }
4490-
4491- // Check for cases we need to fall back to J.Unknown
44924499 val adjustedStart = Math .max(0 , dd.span.start - offsetAdjustment)
44934500 val adjustedEnd = Math .max(0 , dd.span.end - offsetAdjustment)
4501+
4502+ // Detect procedure syntax (no = before body) and nested braces — compiler flattens these
44944503 if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
44954504 val defSource = source.substring(adjustedStart, adjustedEnd)
4496- // Procedure syntax (no = before body)
44974505 val braceIdx = defSource.indexOf('{' )
44984506 val equalsIdx = defSource.indexOf('=' )
44994507 if (braceIdx >= 0 && (equalsIdx < 0 || equalsIdx > braceIdx)) {
4500- return visitUnknown(dd)
4508+ return visitUnknown(dd) // Procedure syntax
45014509 }
4502- // Nested braces after = (compiler flattens these, breaking round-trip)
45034510 if (equalsIdx >= 0 && braceIdx >= 0 ) {
45044511 val afterFirstBrace = defSource.indexOf('{' , braceIdx + 1 )
4505- if (afterFirstBrace >= 0 ) {
4506- val between = defSource.substring(braceIdx + 1 , afterFirstBrace).trim()
4507- if (between.isEmpty) {
4508- return visitUnknown(dd)
4509- }
4512+ if (afterFirstBrace >= 0 && defSource.substring(braceIdx + 1 , afterFirstBrace).trim().isEmpty) {
4513+ return visitUnknown(dd) // Nested braces
45104514 }
45114515 }
4512- // Fall back for methods with while/for loops (block whitespace issues)
4513- if (defSource.contains(" while " ) || defSource.contains(" while(" )) {
4514- return visitUnknown(dd)
4515- }
45164516 }
45174517
4518- // Fall back for parameterless methods like `def name: Type` (no parens in source)
4518+ // Detect parameterless methods (def name: Type = ...) — no parens in source
4519+ var hasParensInSource = true
45194520 if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
45204521 val defSource = source.substring(adjustedStart, adjustedEnd)
45214522 val defIdx = defSource.indexOf(" def " )
45224523 if (defIdx >= 0 ) {
45234524 val afterDef = defSource.substring(defIdx + 4 )
4524- // Find end of name (first non-alphanumeric char)
45254525 val nameEnd = afterDef.indexWhere(c => ! c.isLetterOrDigit && c != '_' )
45264526 if (nameEnd >= 0 ) {
45274527 val afterName = afterDef.substring(nameEnd).trim()
4528- // If the next meaningful char after name is : or = (not (), the method has no parens
45294528 if (afterName.startsWith(" :" ) || afterName.startsWith(" =" )) {
4530- return visitUnknown(dd)
4529+ hasParensInSource = false
45314530 }
45324531 }
45334532 }
45344533 }
45354534
4536- // Fall back for methods with function types or annotated parameters
4535+ // Fall back for function types (Int => Int) and annotated parameters (@unchecked)
45374536 val hasUnsupportedParams = dd.paramss.exists(_.exists {
45384537 case vd : Trees .ValDef [? ] =>
45394538 vd.tpt.isInstanceOf [untpd.Function ] ||
@@ -4546,18 +4545,58 @@ class ScalaTreeVisitor(
45464545
45474546 val savedCursor = cursor
45484547 try {
4549- visitDefDefImpl(dd)
4548+ visitDefDefImpl(dd, hasParensInSource )
45504549 } catch {
45514550 case _ : Exception =>
45524551 cursor = savedCursor
45534552 visitUnknown(dd)
45544553 }
45554554 }
45564555
4557- private def visitDefDefImpl (dd : Trees .DefDef [? ]): J .MethodDeclaration = {
4556+ private def visitDefDefImpl (dd : Trees .DefDef [? ], hasParensInSource : Boolean = true ): J .MethodDeclaration = {
45584557 val leadingAnnotations = new util.ArrayList [J .Annotation ]()
4558+ val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
45594559 val prefix = extractPrefix(dd.span)
45604560
4561+ if (hasAnnotations) {
4562+ for (annot <- dd.mods.annotations) {
4563+ // First try the proper AST visitor
4564+ val savedAnnotCursor = cursor
4565+ val annotResult = try { visitTree(annot) } catch { case _ : Exception => cursor = savedAnnotCursor; null }
4566+
4567+ annotResult match {
4568+ case ann : J .Annotation => leadingAnnotations.add(ann)
4569+ case _ =>
4570+ // Fallback: build annotation from source text
4571+ cursor = savedAnnotCursor // Reset cursor — visitTree may have advanced it incorrectly
4572+ if (annot.span.exists) {
4573+ val annotStart = Math .max(0 , annot.span.start - offsetAdjustment)
4574+ val annotEnd = Math .max(0 , annot.span.end - offsetAdjustment)
4575+ // Find the @ before the annotation span (it's in the source but might not be in the span)
4576+ val searchStart = Math .max(0 , annotStart - 5 )
4577+ val atIdx = if (searchStart < source.length) source.lastIndexOf('@' , annotStart) else - 1
4578+ val fullStart = if (atIdx >= 0 && atIdx >= searchStart) atIdx else annotStart
4579+
4580+ // Prefix = space from cursor to @ position
4581+ val annotPrefix = if (cursor < fullStart && fullStart <= source.length) {
4582+ Space .format(source.substring(cursor, fullStart))
4583+ } else Space .EMPTY
4584+
4585+ // Source = everything from @ to annotation end
4586+ val annotSource = if (fullStart < annotEnd && annotEnd <= source.length) {
4587+ source.substring(fullStart, annotEnd)
4588+ } else " "
4589+ val annotName = if (annotSource.startsWith(" @" )) annotSource.substring(1 ) else annotSource
4590+
4591+ val annotId = new J .Identifier (Tree .randomId(), Space .EMPTY , Markers .EMPTY ,
4592+ Collections .emptyList(), annotName, null , null )
4593+ leadingAnnotations.add(new J .Annotation (Tree .randomId(), annotPrefix, Markers .EMPTY , annotId, null ))
4594+ cursor = annotEnd
4595+ }
4596+ }
4597+ }
4598+ }
4599+
45614600 val adjustedEnd = Math .max(0 , dd.span.end - offsetAdjustment)
45624601 var modifierText = " "
45634602 var defIndex = - 1
@@ -4573,6 +4612,12 @@ class ScalaTreeVisitor(
45734612
45744613 val (modifiers, _) = extractModifiersFromText(dd.mods, modifierText)
45754614
4615+ // If annotations but no modifiers, capture the space between annotations and "def"
4616+ if (hasAnnotations && modifiers.isEmpty() && modifierText.nonEmpty && modifierText.trim().isEmpty) {
4617+ modifiers.add(new J .Modifier (Tree .randomId(), Space .format(modifierText), Markers .EMPTY ,
4618+ " def" , J .Modifier .Type .LanguageExtension , Collections .emptyList()))
4619+ }
4620+
45764621 val defKeywordPos = if (defIndex >= 0 ) cursor + defIndex + " def" .length else cursor
45774622 cursor = defKeywordPos
45784623
@@ -4686,7 +4731,7 @@ class ScalaTreeVisitor(
46864731 }
46874732
46884733 JContainer .build(parenSpace, jParams, Markers .EMPTY )
4689- } else if (valueParamLists.nonEmpty) {
4734+ } else if (valueParamLists.nonEmpty && hasParensInSource ) {
46904735 // Empty parameter list ()
46914736 val searchEnd = Math .min(cursor + 50 , source.length)
46924737 val searchText = source.substring(cursor, searchEnd)
@@ -4700,7 +4745,9 @@ class ScalaTreeVisitor(
47004745 }
47014746 JContainer .build(parenSpace, new util.ArrayList [JRightPadded [Statement ]](), Markers .EMPTY )
47024747 } else {
4703- JContainer .empty[Statement ]()
4748+ // Parameterless method — mark so printer omits ()
4749+ JContainer .build(Space .EMPTY , new util.ArrayList [JRightPadded [Statement ]](),
4750+ Markers .build(Collections .singletonList(new org.openrewrite.scala.marker.OmitBraces (Tree .randomId()))))
47044751 }
47054752
47064753 // Handle return type `: ReturnType` — only if explicitly written in source
0 commit comments