3434import org .openrewrite .scala .marker .SObject ;
3535import org .openrewrite .scala .marker .TypeProjection ;
3636import org .openrewrite .scala .marker .ScalaForLoop ;
37- import org .openrewrite .scala .marker .AsInstanceOfPrefix ;
3837import org .openrewrite .scala .marker .TypeAscription ;
3938import org .openrewrite .scala .marker .UnderscorePlaceholderLambda ;
4039import org .openrewrite .scala .tree .S ;
@@ -200,12 +199,18 @@ public J visitTry(J.Try tryable, PrintOutputCapture<P> p) {
200199 p .append ("try" );
201200 visit (tryable .getBody (), p );
202201 if (!tryable .getCatches ().isEmpty ()) {
203- J .Try .Catch firstCatch = tryable .getCatches ().get (0 );
204- visitSpace (firstCatch .getPrefix (), Space .Location .CATCH_PREFIX , p );
205- p .append ("catch {" );
206- for (J .Try .Catch aCatch : tryable .getCatches ()) {
202+ // Print catch block with cases from AST whitespace
203+ for (int i = 0 ; i < tryable .getCatches ().size (); i ++) {
204+ J .Try .Catch aCatch = tryable .getCatches ().get (i );
205+ if (i == 0 ) {
206+ // First catch — prefix is the space before "catch"
207+ visitSpace (aCatch .getPrefix (), Space .Location .CATCH_PREFIX , p );
208+ p .append ("catch {" );
209+ }
210+ // Print case with AST whitespace
207211 J .VariableDeclarations varDecl = aCatch .getParameter ().getTree ();
208- p .append ("\n case" );
212+ visitSpace (varDecl .getPrefix (), Space .Location .VARIABLE_DECLARATIONS_PREFIX , p );
213+ p .append ("case" );
209214 if (!varDecl .getVariables ().isEmpty ()) {
210215 visit (varDecl .getVariables ().get (0 ).getName (), p );
211216 }
@@ -218,7 +223,10 @@ public J visitTry(J.Try tryable, PrintOutputCapture<P> p) {
218223 visit (stmt , p );
219224 }
220225 }
221- p .append ("\n }" );
226+ // Close catch block — use end space from last catch body if available
227+ J .Try .Catch lastCatch = tryable .getCatches ().get (tryable .getCatches ().size () - 1 );
228+ visitSpace (lastCatch .getBody ().getEnd (), Space .Location .BLOCK_END , p );
229+ p .append ("}" );
222230 }
223231 if (tryable .getPadding ().getFinally () != null ) {
224232 visitSpace (tryable .getPadding ().getFinally ().getBefore (), Space .Location .TRY_FINALLY , p );
@@ -249,29 +257,21 @@ public J visitSwitch(J.Switch switch_, PrintOutputCapture<P> p) {
249257 public J visitCase (J .Case case_ , PrintOutputCapture <P > p ) {
250258 beforeSyntax (case_ , Space .Location .CASE_PREFIX , p );
251259 p .append ("case" );
252- List <JRightPadded <J >> labels = case_ .getPadding ().getCaseLabels ().getPadding ().getElements ();
253- for (JRightPadded <J > label : labels ) {
254- visit (label .getElement (), p );
255- visitSpace (label .getAfter (), Space .Location .CASE_LABEL , p );
260+ for (J label : case_ .getCaseLabels ()) {
261+ visit (label , p );
256262 }
257263 if (case_ .getGuard () != null ) {
258- p .append ("if" );
264+ p .append (" if" );
259265 visit (case_ .getGuard (), p );
260266 }
261267 p .append (" =>" );
262268 if (case_ .getPadding ().getBody () != null ) {
263269 J bodyElement = case_ .getPadding ().getBody ().getElement ();
264- // For Scala match cases, unwrap blocks to avoid printing { }
270+ // For Scala match cases, unwrap single-statement blocks to avoid { }
265271 if (bodyElement instanceof J .Block ) {
266272 J .Block block = (J .Block ) bodyElement ;
267- visitSpace (block .getPrefix (), Space .Location .BLOCK_PREFIX , p );
268- List <JRightPadded <Statement >> paddedStatements = block .getPadding ().getStatements ();
269- for (int i = 0 ; i < paddedStatements .size (); i ++) {
270- JRightPadded <Statement > paddedStmt = paddedStatements .get (i );
271- visit (paddedStmt .getElement (), p );
272- if (i < paddedStatements .size () - 1 ) {
273- visitSpace (paddedStmt .getAfter (), Space .Location .BLOCK_STATEMENT_SUFFIX , p );
274- }
273+ for (Statement stmt : block .getStatements ()) {
274+ visit (stmt , p );
275275 }
276276 } else {
277277 visit (bodyElement , p );
@@ -401,7 +401,19 @@ public J visitMethodDeclaration(J.MethodDeclaration method, PrintOutputCapture<P
401401 if (!procedureSyntax ) {
402402 p .append (" =" );
403403 }
404- visit (actualBody , p );
404+ // If body is OmitBraces block with single statement, print just the statement
405+ if (actualBody instanceof J .Block ) {
406+ J .Block bodyBlock = (J .Block ) actualBody ;
407+ boolean omit = bodyBlock .getMarkers ().findFirst (
408+ org .openrewrite .scala .marker .OmitBraces .class ).isPresent ();
409+ if (omit && bodyBlock .getStatements ().size () == 1 ) {
410+ visit (bodyBlock .getStatements ().get (0 ), p );
411+ } else {
412+ visit (actualBody , p );
413+ }
414+ } else {
415+ visit (actualBody , p );
416+ }
405417 }
406418 } else if (method .getBody () != null ) {
407419 // Normal method body
@@ -481,10 +493,15 @@ public J visit(@Nullable Tree tree, PrintOutputCapture<P> p) {
481493 return visitWildcard ((S .Wildcard ) tree , p );
482494 } else if (tree instanceof S .TuplePattern ) {
483495 return visitTuplePattern ((S .TuplePattern ) tree , p );
484- } else if (tree instanceof S .BlockExpression ) {
485- return visitBlockExpression ((S .BlockExpression ) tree , p );
486- } else if (tree instanceof S .ExpressionStatement ) {
487- return visitExpressionStatement ((S .ExpressionStatement ) tree , p );
496+ } else if (tree instanceof S .StatementExpression ) {
497+ // Transparent — visit the inner statement
498+ return visit (((S .StatementExpression ) tree ).getStatement (), p );
499+ } else if (tree instanceof S .TypeAscription ) {
500+ return visitTypeAscription ((S .TypeAscription ) tree , p );
501+ } else if (tree instanceof S .TypeAlias ) {
502+ return visitTypeAlias ((S .TypeAlias ) tree , p );
503+ } else if (tree instanceof S .PatternDefinition ) {
504+ return visitPatternDefinition ((S .PatternDefinition ) tree , p );
488505 }
489506 return super .visit (tree , p );
490507 }
@@ -794,8 +811,13 @@ private void visitTypeParameters(@Nullable JContainer<J.TypeParameter> typeParam
794811
795812 @ Override
796813 public J visitBlock (J .Block block , PrintOutputCapture <P > p ) {
797- // Check if this block has the OmitBraces marker (for objects without body)
814+ // OmitBraces blocks print statements without { } — used for braceless bodies,
815+ // synthetic lambda body blocks, and expression-position blocks
798816 if (block .getMarkers ().findFirst (org .openrewrite .scala .marker .OmitBraces .class ).isPresent ()) {
817+ beforeSyntax (block , Space .Location .BLOCK_PREFIX , p );
818+ visitStatements (block .getPadding ().getStatements (), JRightPadded .Location .BLOCK_STATEMENT , p );
819+ visitSpace (block .getEnd (), Space .Location .BLOCK_END , p );
820+ afterSyntax (block , p );
799821 return block ;
800822 }
801823 // Scala 3 braceless (indentation-based) blocks use `:` instead of `{}`
@@ -865,26 +887,20 @@ public J visitForEachLoop(J.ForEachLoop forEachLoop, PrintOutputCapture<P> p) {
865887 return super .visitForEachLoop (forEachLoop , p );
866888 }
867889
890+ public J visitTypeAscription (S .TypeAscription typeAscription , PrintOutputCapture <P > p ) {
891+ beforeSyntax (typeAscription , Space .Location .LANGUAGE_EXTENSION , p );
892+ visit (typeAscription .getExpression (), p );
893+ p .append (':' );
894+ visit (typeAscription .getTypeTree (), p );
895+ afterSyntax (typeAscription , p );
896+ return typeAscription ;
897+ }
898+
868899 @ Override
869900 public J visitTypeCast (J .TypeCast typeCast , PrintOutputCapture <P > p ) {
870- if (typeCast .getMarkers ().findFirst (TypeAscription .class ).isPresent ()) {
871- // Scala type ascription: expr: Type
872- beforeSyntax (typeCast , Space .Location .TYPE_CAST_PREFIX , p );
873- visit (typeCast .getExpression (), p );
874- if (typeCast .getClazz () instanceof J .ControlParentheses ) {
875- J .ControlParentheses <?> controlParens = (J .ControlParentheses <?>) typeCast .getClazz ();
876- visitSpace (controlParens .getPrefix (), Space .Location .CONTROL_PARENTHESES_PREFIX , p );
877- p .append (':' );
878- visitRightPadded (controlParens .getPadding ().getTree (), JRightPadded .Location .PARENTHESES , "" , p );
879- }
880- afterSyntax (typeCast , p );
881- return typeCast ;
882- }
883- // Existing asInstanceOf handling
901+ // asInstanceOf handling
884902 beforeSyntax (typeCast , Space .Location .TYPE_CAST_PREFIX , p );
885903 visit (typeCast .getExpression (), p );
886- typeCast .getMarkers ().findFirst (AsInstanceOfPrefix .class )
887- .ifPresent (sp -> visitSpace (sp .getPrefix (), Space .Location .LANGUAGE_EXTENSION , p ));
888904 p .append (".asInstanceOf" );
889905 if (typeCast .getClazz () instanceof J .ControlParentheses ) {
890906 J .ControlParentheses <?> controlParens = (J .ControlParentheses <?>) typeCast .getClazz ();
@@ -967,8 +983,12 @@ public J visitVariable(J.VariableDeclarations.NamedVariable variable, PrintOutpu
967983 // In Scala, type annotation comes after the name
968984 J .VariableDeclarations parent = getCursor ().getParentOrThrow ().getValue ();
969985 if (parent .getTypeExpression () != null ) {
986+ // Print space before colon if present (e.g., `given IntSchema : SchemaFor[Int]`)
987+ // Stored in varargs field (repurposed, unused in Scala)
988+ if (parent .getVarargs () != null ) {
989+ visitSpace (parent .getVarargs (), Space .Location .VARARGS , p );
990+ }
970991 p .append (":" );
971- // The type expression should have the space after colon in its prefix
972992 visit (parent .getTypeExpression (), p );
973993
974994 // If there's an initializer, use visitLeftPadded to handle it properly
@@ -1111,7 +1131,8 @@ public J visitMethodInvocation(J.MethodInvocation method, PrintOutputCapture<P>
11111131 visit (method .getName (), p );
11121132 }
11131133
1114- // Print the block argument directly (no parentheses)
1134+ // Print the block argument — it's typically an S.StatementExpression(J.Block)
1135+ // The J.Block contains the lambda. visitBlock prints the { } braces.
11151136 if (method .getArguments () != null ) {
11161137 for (Expression arg : method .getArguments ()) {
11171138 visit (arg , p );
@@ -1229,16 +1250,17 @@ public J visitWildcard(S.Wildcard wildcard, PrintOutputCapture<P> p) {
12291250 return wildcard ;
12301251 }
12311252
1232- public J visitExpressionStatement (S .ExpressionStatement expressionStatement , PrintOutputCapture <P > p ) {
1233- visit (expressionStatement .getExpression (), p );
1234- return expressionStatement ;
1253+ public J visitTypeAlias (S .TypeAlias typeAlias , PrintOutputCapture <P > p ) {
1254+ beforeSyntax (typeAlias , Space .Location .LANGUAGE_EXTENSION , p );
1255+ p .append (typeAlias .getText ());
1256+ afterSyntax (typeAlias , p );
1257+ return typeAlias ;
12351258 }
12361259
1237- public J visitBlockExpression (S .BlockExpression blockExpression , PrintOutputCapture <P > p ) {
1238- beforeSyntax (blockExpression , Space .Location .LANGUAGE_EXTENSION , p );
1239- // Simply visit the contained block - it will print itself with braces
1240- visit (blockExpression .getBlock (), p );
1241- afterSyntax (blockExpression , p );
1242- return blockExpression ;
1260+ public J visitPatternDefinition (S .PatternDefinition patDef , PrintOutputCapture <P > p ) {
1261+ beforeSyntax (patDef , Space .Location .LANGUAGE_EXTENSION , p );
1262+ p .append (patDef .getText ());
1263+ afterSyntax (patDef , p );
1264+ return patDef ;
12431265 }
12441266}
0 commit comments