3434import org .openrewrite .scala .marker .SObject ;
3535import org .openrewrite .scala .marker .TypeProjection ;
3636import org .openrewrite .scala .marker .ScalaForLoop ;
37- import org .openrewrite .scala .marker .AsInstanceOfPrefix ;
3837import org .openrewrite .scala .marker .TypeAscription ;
3938import org .openrewrite .scala .marker .UnderscorePlaceholderLambda ;
4039import org .openrewrite .scala .tree .S ;
@@ -200,12 +199,18 @@ public J visitTry(J.Try tryable, PrintOutputCapture<P> p) {
200199 p .append ("try" );
201200 visit (tryable .getBody (), p );
202201 if (!tryable .getCatches ().isEmpty ()) {
203- J .Try .Catch firstCatch = tryable .getCatches ().get (0 );
204- visitSpace (firstCatch .getPrefix (), Space .Location .CATCH_PREFIX , p );
205- p .append ("catch {" );
206- for (J .Try .Catch aCatch : tryable .getCatches ()) {
202+ // Print catch block with cases from AST whitespace
203+ for (int i = 0 ; i < tryable .getCatches ().size (); i ++) {
204+ J .Try .Catch aCatch = tryable .getCatches ().get (i );
205+ if (i == 0 ) {
206+ // First catch — prefix is the space before "catch"
207+ visitSpace (aCatch .getPrefix (), Space .Location .CATCH_PREFIX , p );
208+ p .append ("catch {" );
209+ }
210+ // Print case with AST whitespace
207211 J .VariableDeclarations varDecl = aCatch .getParameter ().getTree ();
208- p .append ("\n case" );
212+ visitSpace (varDecl .getPrefix (), Space .Location .VARIABLE_DECLARATIONS_PREFIX , p );
213+ p .append ("case" );
209214 if (!varDecl .getVariables ().isEmpty ()) {
210215 visit (varDecl .getVariables ().get (0 ).getName (), p );
211216 }
@@ -218,7 +223,10 @@ public J visitTry(J.Try tryable, PrintOutputCapture<P> p) {
218223 visit (stmt , p );
219224 }
220225 }
221- p .append ("\n }" );
226+ // Close catch block — use end space from last catch body if available
227+ J .Try .Catch lastCatch = tryable .getCatches ().get (tryable .getCatches ().size () - 1 );
228+ visitSpace (lastCatch .getBody ().getEnd (), Space .Location .BLOCK_END , p );
229+ p .append ("}" );
222230 }
223231 if (tryable .getPadding ().getFinally () != null ) {
224232 visitSpace (tryable .getPadding ().getFinally ().getBefore (), Space .Location .TRY_FINALLY , p );
@@ -249,33 +257,22 @@ public J visitSwitch(J.Switch switch_, PrintOutputCapture<P> p) {
249257 public J visitCase (J .Case case_ , PrintOutputCapture <P > p ) {
250258 beforeSyntax (case_ , Space .Location .CASE_PREFIX , p );
251259 p .append ("case" );
252- List <JRightPadded <J >> labels = case_ .getPadding ().getCaseLabels ().getPadding ().getElements ();
253- for (JRightPadded <J > label : labels ) {
254- visit (label .getElement (), p );
255- visitSpace (label .getAfter (), Space .Location .CASE_LABEL , p );
260+ List <JRightPadded <J >> labelPadding = case_ .getPadding ().getCaseLabels ().getPadding ().getElements ();
261+ for (int li = 0 ; li < labelPadding .size (); li ++) {
262+ JRightPadded <J > lp = labelPadding .get (li );
263+ visit (lp .getElement (), p );
264+ // The last label's after space is the space before "if" guard (if any)
265+ if (li == labelPadding .size () - 1 && case_ .getGuard () != null ) {
266+ visitSpace (lp .getAfter (), JRightPadded .Location .CASE .getAfterLocation (), p );
267+ }
256268 }
257269 if (case_ .getGuard () != null ) {
258270 p .append ("if" );
259271 visit (case_ .getGuard (), p );
260272 }
261273 p .append (" =>" );
262274 if (case_ .getPadding ().getBody () != null ) {
263- J bodyElement = case_ .getPadding ().getBody ().getElement ();
264- // For Scala match cases, unwrap blocks to avoid printing { }
265- if (bodyElement instanceof J .Block ) {
266- J .Block block = (J .Block ) bodyElement ;
267- visitSpace (block .getPrefix (), Space .Location .BLOCK_PREFIX , p );
268- List <JRightPadded <Statement >> paddedStatements = block .getPadding ().getStatements ();
269- for (int i = 0 ; i < paddedStatements .size (); i ++) {
270- JRightPadded <Statement > paddedStmt = paddedStatements .get (i );
271- visit (paddedStmt .getElement (), p );
272- if (i < paddedStatements .size () - 1 ) {
273- visitSpace (paddedStmt .getAfter (), Space .Location .BLOCK_STATEMENT_SUFFIX , p );
274- }
275- }
276- } else {
277- visit (bodyElement , p );
278- }
275+ visit (case_ .getPadding ().getBody ().getElement (), p );
279276 }
280277 afterSyntax (case_ , p );
281278 return case_ ;
@@ -401,7 +398,19 @@ public J visitMethodDeclaration(J.MethodDeclaration method, PrintOutputCapture<P
401398 if (!procedureSyntax ) {
402399 p .append (" =" );
403400 }
404- visit (actualBody , p );
401+ // If body is OmitBraces block with single statement, print just the statement
402+ if (actualBody instanceof J .Block ) {
403+ J .Block bodyBlock = (J .Block ) actualBody ;
404+ boolean omit = bodyBlock .getMarkers ().findFirst (
405+ org .openrewrite .scala .marker .OmitBraces .class ).isPresent ();
406+ if (omit && bodyBlock .getStatements ().size () == 1 ) {
407+ visit (bodyBlock .getStatements ().get (0 ), p );
408+ } else {
409+ visit (actualBody , p );
410+ }
411+ } else {
412+ visit (actualBody , p );
413+ }
405414 }
406415 } else if (method .getBody () != null ) {
407416 // Normal method body
@@ -481,10 +490,15 @@ public J visit(@Nullable Tree tree, PrintOutputCapture<P> p) {
481490 return visitWildcard ((S .Wildcard ) tree , p );
482491 } else if (tree instanceof S .TuplePattern ) {
483492 return visitTuplePattern ((S .TuplePattern ) tree , p );
484- } else if (tree instanceof S .BlockExpression ) {
485- return visitBlockExpression ((S .BlockExpression ) tree , p );
486- } else if (tree instanceof S .ExpressionStatement ) {
487- return visitExpressionStatement ((S .ExpressionStatement ) tree , p );
493+ } else if (tree instanceof S .StatementExpression ) {
494+ // Transparent — visit the inner statement
495+ return visit (((S .StatementExpression ) tree ).getStatement (), p );
496+ } else if (tree instanceof S .TypeAscription ) {
497+ return visitTypeAscription ((S .TypeAscription ) tree , p );
498+ } else if (tree instanceof S .TypeAlias ) {
499+ return visitTypeAlias ((S .TypeAlias ) tree , p );
500+ } else if (tree instanceof S .PatternDefinition ) {
501+ return visitPatternDefinition ((S .PatternDefinition ) tree , p );
488502 }
489503 return super .visit (tree , p );
490504 }
@@ -794,8 +808,13 @@ private void visitTypeParameters(@Nullable JContainer<J.TypeParameter> typeParam
794808
795809 @ Override
796810 public J visitBlock (J .Block block , PrintOutputCapture <P > p ) {
797- // Check if this block has the OmitBraces marker (for objects without body)
811+ // OmitBraces blocks print statements without { } — used for braceless bodies,
812+ // synthetic lambda body blocks, and expression-position blocks
798813 if (block .getMarkers ().findFirst (org .openrewrite .scala .marker .OmitBraces .class ).isPresent ()) {
814+ beforeSyntax (block , Space .Location .BLOCK_PREFIX , p );
815+ visitStatements (block .getPadding ().getStatements (), JRightPadded .Location .BLOCK_STATEMENT , p );
816+ visitSpace (block .getEnd (), Space .Location .BLOCK_END , p );
817+ afterSyntax (block , p );
799818 return block ;
800819 }
801820 // Scala 3 braceless (indentation-based) blocks use `:` instead of `{}`
@@ -865,26 +884,20 @@ public J visitForEachLoop(J.ForEachLoop forEachLoop, PrintOutputCapture<P> p) {
865884 return super .visitForEachLoop (forEachLoop , p );
866885 }
867886
887+ public J visitTypeAscription (S .TypeAscription typeAscription , PrintOutputCapture <P > p ) {
888+ beforeSyntax (typeAscription , Space .Location .LANGUAGE_EXTENSION , p );
889+ visit (typeAscription .getExpression (), p );
890+ p .append (':' );
891+ visit (typeAscription .getTypeTree (), p );
892+ afterSyntax (typeAscription , p );
893+ return typeAscription ;
894+ }
895+
868896 @ Override
869897 public J visitTypeCast (J .TypeCast typeCast , PrintOutputCapture <P > p ) {
870- if (typeCast .getMarkers ().findFirst (TypeAscription .class ).isPresent ()) {
871- // Scala type ascription: expr: Type
872- beforeSyntax (typeCast , Space .Location .TYPE_CAST_PREFIX , p );
873- visit (typeCast .getExpression (), p );
874- if (typeCast .getClazz () instanceof J .ControlParentheses ) {
875- J .ControlParentheses <?> controlParens = (J .ControlParentheses <?>) typeCast .getClazz ();
876- visitSpace (controlParens .getPrefix (), Space .Location .CONTROL_PARENTHESES_PREFIX , p );
877- p .append (':' );
878- visitRightPadded (controlParens .getPadding ().getTree (), JRightPadded .Location .PARENTHESES , "" , p );
879- }
880- afterSyntax (typeCast , p );
881- return typeCast ;
882- }
883- // Existing asInstanceOf handling
898+ // asInstanceOf handling
884899 beforeSyntax (typeCast , Space .Location .TYPE_CAST_PREFIX , p );
885900 visit (typeCast .getExpression (), p );
886- typeCast .getMarkers ().findFirst (AsInstanceOfPrefix .class )
887- .ifPresent (sp -> visitSpace (sp .getPrefix (), Space .Location .LANGUAGE_EXTENSION , p ));
888901 p .append (".asInstanceOf" );
889902 if (typeCast .getClazz () instanceof J .ControlParentheses ) {
890903 J .ControlParentheses <?> controlParens = (J .ControlParentheses <?>) typeCast .getClazz ();
@@ -967,8 +980,12 @@ public J visitVariable(J.VariableDeclarations.NamedVariable variable, PrintOutpu
967980 // In Scala, type annotation comes after the name
968981 J .VariableDeclarations parent = getCursor ().getParentOrThrow ().getValue ();
969982 if (parent .getTypeExpression () != null ) {
983+ // Print space before colon if present (e.g., `given IntSchema : SchemaFor[Int]`)
984+ // Stored in varargs field (repurposed, unused in Scala)
985+ if (parent .getVarargs () != null ) {
986+ visitSpace (parent .getVarargs (), Space .Location .VARARGS , p );
987+ }
970988 p .append (":" );
971- // The type expression should have the space after colon in its prefix
972989 visit (parent .getTypeExpression (), p );
973990
974991 // If there's an initializer, use visitLeftPadded to handle it properly
@@ -1111,7 +1128,8 @@ public J visitMethodInvocation(J.MethodInvocation method, PrintOutputCapture<P>
11111128 visit (method .getName (), p );
11121129 }
11131130
1114- // Print the block argument directly (no parentheses)
1131+ // Print the block argument — it's typically an S.StatementExpression(J.Block)
1132+ // The J.Block contains the lambda. visitBlock prints the { } braces.
11151133 if (method .getArguments () != null ) {
11161134 for (Expression arg : method .getArguments ()) {
11171135 visit (arg , p );
@@ -1229,16 +1247,17 @@ public J visitWildcard(S.Wildcard wildcard, PrintOutputCapture<P> p) {
12291247 return wildcard ;
12301248 }
12311249
1232- public J visitExpressionStatement (S .ExpressionStatement expressionStatement , PrintOutputCapture <P > p ) {
1233- visit (expressionStatement .getExpression (), p );
1234- return expressionStatement ;
1250+ public J visitTypeAlias (S .TypeAlias typeAlias , PrintOutputCapture <P > p ) {
1251+ beforeSyntax (typeAlias , Space .Location .LANGUAGE_EXTENSION , p );
1252+ p .append (typeAlias .getText ());
1253+ afterSyntax (typeAlias , p );
1254+ return typeAlias ;
12351255 }
12361256
1237- public J visitBlockExpression (S .BlockExpression blockExpression , PrintOutputCapture <P > p ) {
1238- beforeSyntax (blockExpression , Space .Location .LANGUAGE_EXTENSION , p );
1239- // Simply visit the contained block - it will print itself with braces
1240- visit (blockExpression .getBlock (), p );
1241- afterSyntax (blockExpression , p );
1242- return blockExpression ;
1257+ public J visitPatternDefinition (S .PatternDefinition patDef , PrintOutputCapture <P > p ) {
1258+ beforeSyntax (patDef , Space .Location .LANGUAGE_EXTENSION , p );
1259+ p .append (patDef .getText ());
1260+ afterSyntax (patDef , p );
1261+ return patDef ;
12431262 }
12441263}
0 commit comments