@@ -4482,58 +4482,43 @@ class ScalaTreeVisitor(
44824482 }
44834483
44844484 private def visitDefDef (dd : Trees .DefDef [? ]): J = {
4485- // Fall back to J.Unknown for cases we can't handle yet
4486- val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
4487- if (hasAnnotations) {
4488- return visitUnknown(dd)
4489- }
4490-
4491- // Check for cases we need to fall back to J.Unknown
44924485 val adjustedStart = Math .max(0 , dd.span.start - offsetAdjustment)
44934486 val adjustedEnd = Math .max(0 , dd.span.end - offsetAdjustment)
4487+
4488+ // Detect procedure syntax (no = before body) and nested braces — compiler flattens these
44944489 if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
44954490 val defSource = source.substring(adjustedStart, adjustedEnd)
4496- // Procedure syntax (no = before body)
44974491 val braceIdx = defSource.indexOf('{' )
44984492 val equalsIdx = defSource.indexOf('=' )
44994493 if (braceIdx >= 0 && (equalsIdx < 0 || equalsIdx > braceIdx)) {
4500- return visitUnknown(dd)
4494+ return visitUnknown(dd) // Procedure syntax
45014495 }
4502- // Nested braces after = (compiler flattens these, breaking round-trip)
45034496 if (equalsIdx >= 0 && braceIdx >= 0 ) {
45044497 val afterFirstBrace = defSource.indexOf('{' , braceIdx + 1 )
4505- if (afterFirstBrace >= 0 ) {
4506- val between = defSource.substring(braceIdx + 1 , afterFirstBrace).trim()
4507- if (between.isEmpty) {
4508- return visitUnknown(dd)
4509- }
4498+ if (afterFirstBrace >= 0 && defSource.substring(braceIdx + 1 , afterFirstBrace).trim().isEmpty) {
4499+ return visitUnknown(dd) // Nested braces
45104500 }
45114501 }
4512- // Fall back for methods with while/for loops (block whitespace issues)
4513- if (defSource.contains(" while " ) || defSource.contains(" while(" )) {
4514- return visitUnknown(dd)
4515- }
45164502 }
45174503
4518- // Fall back for parameterless methods like `def name: Type` (no parens in source)
4504+ // Detect parameterless methods (def name: Type = ...) — no parens in source
4505+ var hasParensInSource = true
45194506 if (adjustedStart < adjustedEnd && adjustedEnd <= source.length) {
45204507 val defSource = source.substring(adjustedStart, adjustedEnd)
45214508 val defIdx = defSource.indexOf(" def " )
45224509 if (defIdx >= 0 ) {
45234510 val afterDef = defSource.substring(defIdx + 4 )
4524- // Find end of name (first non-alphanumeric char)
45254511 val nameEnd = afterDef.indexWhere(c => ! c.isLetterOrDigit && c != '_' )
45264512 if (nameEnd >= 0 ) {
45274513 val afterName = afterDef.substring(nameEnd).trim()
4528- // If the next meaningful char after name is : or = (not (), the method has no parens
45294514 if (afterName.startsWith(" :" ) || afterName.startsWith(" =" )) {
4530- return visitUnknown(dd)
4515+ hasParensInSource = false
45314516 }
45324517 }
45334518 }
45344519 }
45354520
4536- // Fall back for methods with function types or annotated parameters
4521+ // Fall back for function types (Int => Int) and annotated parameters (@unchecked)
45374522 val hasUnsupportedParams = dd.paramss.exists(_.exists {
45384523 case vd : Trees .ValDef [? ] =>
45394524 vd.tpt.isInstanceOf [untpd.Function ] ||
@@ -4546,18 +4531,28 @@ class ScalaTreeVisitor(
45464531
45474532 val savedCursor = cursor
45484533 try {
4549- visitDefDefImpl(dd)
4534+ visitDefDefImpl(dd, hasParensInSource )
45504535 } catch {
45514536 case _ : Exception =>
45524537 cursor = savedCursor
45534538 visitUnknown(dd)
45544539 }
45554540 }
45564541
4557- private def visitDefDefImpl (dd : Trees .DefDef [? ]): J .MethodDeclaration = {
4542+ private def visitDefDefImpl (dd : Trees .DefDef [? ], hasParensInSource : Boolean = true ): J .MethodDeclaration = {
45584543 val leadingAnnotations = new util.ArrayList [J .Annotation ]()
4544+ val hasAnnotations = dd.mods != null && dd.mods.annotations.nonEmpty
45594545 val prefix = extractPrefix(dd.span)
45604546
4547+ if (hasAnnotations) {
4548+ for (annot <- dd.mods.annotations) {
4549+ visitTree(annot) match {
4550+ case ann : J .Annotation => leadingAnnotations.add(ann)
4551+ case _ =>
4552+ }
4553+ }
4554+ }
4555+
45614556 val adjustedEnd = Math .max(0 , dd.span.end - offsetAdjustment)
45624557 var modifierText = " "
45634558 var defIndex = - 1
@@ -4573,6 +4568,12 @@ class ScalaTreeVisitor(
45734568
45744569 val (modifiers, _) = extractModifiersFromText(dd.mods, modifierText)
45754570
4571+ // If annotations but no modifiers, capture the space between annotations and "def"
4572+ if (hasAnnotations && modifiers.isEmpty() && modifierText.nonEmpty && modifierText.trim().isEmpty) {
4573+ modifiers.add(new J .Modifier (Tree .randomId(), Space .format(modifierText), Markers .EMPTY ,
4574+ " def" , J .Modifier .Type .LanguageExtension , Collections .emptyList()))
4575+ }
4576+
45764577 val defKeywordPos = if (defIndex >= 0 ) cursor + defIndex + " def" .length else cursor
45774578 cursor = defKeywordPos
45784579
@@ -4686,7 +4687,7 @@ class ScalaTreeVisitor(
46864687 }
46874688
46884689 JContainer .build(parenSpace, jParams, Markers .EMPTY )
4689- } else if (valueParamLists.nonEmpty) {
4690+ } else if (valueParamLists.nonEmpty && hasParensInSource ) {
46904691 // Empty parameter list ()
46914692 val searchEnd = Math .min(cursor + 50 , source.length)
46924693 val searchText = source.substring(cursor, searchEnd)
@@ -4700,7 +4701,9 @@ class ScalaTreeVisitor(
47004701 }
47014702 JContainer .build(parenSpace, new util.ArrayList [JRightPadded [Statement ]](), Markers .EMPTY )
47024703 } else {
4703- JContainer .empty[Statement ]()
4704+ // Parameterless method — mark so printer omits ()
4705+ JContainer .build(Space .EMPTY , new util.ArrayList [JRightPadded [Statement ]](),
4706+ Markers .build(Collections .singletonList(new org.openrewrite.scala.marker.OmitBraces (Tree .randomId()))))
47044707 }
47054708
47064709 // Handle return type `: ReturnType` — only if explicitly written in source
0 commit comments