Skip to content

Commit e0f99a7

Browse files
committed
Handle untpd.ParsedTry in Scala 3 parser
The Scala 3 compiler uses two different AST representations for try/catch/finally: `Trees.Try` (typed/desugared) and `untpd.ParsedTry` (untyped/parsed). PR #7260 added handling for `Trees.Try` but `ParsedTry` was falling through to `visitUnknown` since it extends `Trees.Tree` directly and not `Trees.Try`. Add `visitParsedTry`/`visitParsedTryImpl` to handle `ParsedTry` by extracting cases from its `handler` (a `Match` tree) and mapping them to `J.Try` AST nodes. Update `ScalaPrinter` to use AST-stored spacing for catch block formatting.
1 parent deca4de commit e0f99a7

3 files changed

Lines changed: 183 additions & 2 deletions

File tree

rewrite-scala/src/main/java/org/openrewrite/scala/ScalaPrinter.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ public J visitTry(J.Try tryable, PrintOutputCapture<P> p) {
204204
p.append("catch {");
205205
for (J.Try.Catch aCatch : tryable.getCatches()) {
206206
J.VariableDeclarations varDecl = aCatch.getParameter().getTree();
207-
p.append("\n case");
207+
visitSpace(aCatch.getParameter().getPrefix(), Space.Location.CONTROL_PARENTHESES_PREFIX, p);
208+
p.append("case");
208209
if (!varDecl.getVariables().isEmpty()) {
209210
visit(varDecl.getVariables().get(0).getName(), p);
210211
}
@@ -217,7 +218,8 @@ public J visitTry(J.Try tryable, PrintOutputCapture<P> p) {
217218
visit(stmt, p);
218219
}
219220
}
220-
p.append("\n}");
221+
visitSpace(firstCatch.getBody().getEnd(), Space.Location.BLOCK_END, p);
222+
p.append("}");
221223
}
222224
if (tryable.getPadding().getFinally() != null) {
223225
visitSpace(tryable.getPadding().getFinally().getBefore(), Space.Location.TRY_FINALLY, p);

rewrite-scala/src/main/scala/org/openrewrite/scala/internal/ScalaTreeVisitor.scala

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class ScalaTreeVisitor(
204204
case func: untpd.Function => visitFunction(func)
205205
case typed: Trees.Typed[?] => visitTyped(typed)
206206
case tuple: untpd.Tuple => visitTuple(tuple)
207+
case parsedTry: untpd.ParsedTry => visitParsedTry(parsedTry)
207208
case tryTree: Trees.Try[?] => visitTryTree(tryTree)
208209
case matchTree: Trees.Match[?] => visitMatchTree(matchTree)
209210
case thisTree: Trees.This[?] => visitThis(thisTree)
@@ -5257,6 +5258,126 @@ class ScalaTreeVisitor(
52575258
new J.Try(Tree.randomId(), prefix, Markers.EMPTY, null, body, catches, finallyBlock)
52585259
}
52595260

5261+
5262+
private def visitParsedTry(parsedTry: untpd.ParsedTry): J = {
5263+
val savedCursor = cursor
5264+
try { visitParsedTryImpl(parsedTry) } catch { case _: Exception => cursor = savedCursor; visitUnknown(parsedTry) }
5265+
}
5266+
5267+
private def visitParsedTryImpl(parsedTry: untpd.ParsedTry): J.Try = {
5268+
val prefix = extractPrefix(parsedTry.span)
5269+
val tryStart = Math.max(0, parsedTry.span.start - offsetAdjustment)
5270+
if (tryStart >= cursor && tryStart + 3 <= source.length) cursor = tryStart + 3
5271+
5272+
val body = visitTree(parsedTry.expr) match {
5273+
case block: J.Block => block
5274+
case expr: Expression =>
5275+
val stmts = new util.ArrayList[JRightPadded[Statement]]()
5276+
stmts.add(JRightPadded.build(expr.asInstanceOf[Statement]))
5277+
new J.Block(Tree.randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build(false), stmts, Space.EMPTY)
5278+
case _ => return visitUnknown(parsedTry).asInstanceOf[J.Try]
5279+
}
5280+
5281+
val cases: List[Trees.CaseDef[?]] = parsedTry.handler match {
5282+
case m: Trees.Match[?] => m.cases.asInstanceOf[List[Trees.CaseDef[?]]]
5283+
case _ => scala.collection.immutable.Nil
5284+
}
5285+
5286+
val catches = new util.ArrayList[J.Try.Catch]()
5287+
if (!parsedTry.handler.isEmpty && parsedTry.handler.span.exists) {
5288+
val catchSearch = if (cursor < source.length) source.substring(cursor, Math.min(cursor + 50, source.length)) else ""
5289+
val catchIdx = catchSearch.indexOf("catch")
5290+
val catchPrefix = if (catchIdx > 0) Space.format(catchSearch.substring(0, catchIdx)) else Space.EMPTY
5291+
if (catchIdx >= 0) cursor = cursor + catchIdx + 5
5292+
val braceSearch = if (cursor < source.length) source.substring(cursor, Math.min(cursor + 20, source.length)) else ""
5293+
val braceIdx = braceSearch.indexOf('{')
5294+
if (braceIdx >= 0) cursor = cursor + braceIdx + 1
5295+
5296+
for (caseDef <- cases) {
5297+
// Extract space before "case" keyword (e.g., newline + indentation)
5298+
val casePrefixSpace = extractPrefix(caseDef.span)
5299+
val caseStart = Math.max(0, caseDef.span.start - offsetAdjustment)
5300+
if (caseStart >= cursor) cursor = caseStart
5301+
val caseSearch = if (cursor < source.length) source.substring(cursor, Math.min(cursor + 20, source.length)) else ""
5302+
val caseKwIdx = caseSearch.indexOf("case")
5303+
if (caseKwIdx >= 0) cursor = cursor + caseKwIdx + 4
5304+
5305+
val arrowSearch = if (cursor < source.length) source.substring(cursor, Math.min(cursor + 200, source.length)) else ""
5306+
val arrowIdx = arrowSearch.indexOf("=>")
5307+
5308+
val paramName = caseDef.pat match {
5309+
case bind: Trees.Bind[?] => bind.name.toString
5310+
case typed: Trees.Typed[?] => typed.expr match {
5311+
case id: Trees.Ident[?] => id.name.toString
5312+
case _ => "_"
5313+
}
5314+
case _ => extractSource(caseDef.pat.span)
5315+
}
5316+
val paramType = caseDef.pat match {
5317+
case bind: Trees.Bind[?] => bind.body match {
5318+
case typed: Trees.Typed[?] => visitTree(typed.tpt) match { case tt: TypeTree => tt; case id: J.Identifier => id; case _ => null }
5319+
case _ => null
5320+
}
5321+
case typed: Trees.Typed[?] =>
5322+
{ val colonSearch = source.indexOf(':', Math.max(0, typed.expr.span.end - offsetAdjustment)); if (colonSearch >= 0) cursor = colonSearch + 1 }
5323+
visitTree(typed.tpt) match { case tt: TypeTree => tt; case id: J.Identifier => id; case _ => null }
5324+
case _ => null
5325+
}
5326+
updateCursor(caseDef.pat.span.end)
5327+
if (arrowIdx >= 0) { val a = source.indexOf("=>", cursor); if (a >= 0) cursor = a + 2 }
5328+
5329+
val paramId = new J.Identifier(Tree.randomId(), Space.format(" "), Markers.EMPTY, Collections.emptyList(), paramName, null, null)
5330+
val namedVar = new J.VariableDeclarations.NamedVariable(Tree.randomId(), Space.EMPTY, Markers.EMPTY, paramId, Collections.emptyList(), null, null)
5331+
val varDecl = new J.VariableDeclarations(Tree.randomId(), Space.EMPTY, Markers.EMPTY,
5332+
Collections.emptyList(), Collections.emptyList(), paramType, null, Collections.emptyList(),
5333+
Collections.singletonList(JRightPadded.build(namedVar)))
5334+
val controlParens = new J.ControlParentheses[J.VariableDeclarations](Tree.randomId(), casePrefixSpace, Markers.EMPTY, JRightPadded.build(varDecl))
5335+
5336+
val caseBody = visitTree(caseDef.body) match {
5337+
case block: J.Block => block
5338+
case expr: Expression =>
5339+
val s = new util.ArrayList[JRightPadded[Statement]](); s.add(JRightPadded.build(expr.asInstanceOf[Statement]))
5340+
new J.Block(Tree.randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build(false), s, Space.EMPTY)
5341+
case stmt: Statement =>
5342+
val s = new util.ArrayList[JRightPadded[Statement]](); s.add(JRightPadded.build(stmt))
5343+
new J.Block(Tree.randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build(false), s, Space.EMPTY)
5344+
case _ => new J.Block(Tree.randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build(false), new util.ArrayList(), Space.EMPTY)
5345+
}
5346+
updateCursor(caseDef.span.end)
5347+
catches.add(new J.Try.Catch(Tree.randomId(), catchPrefix, Markers.EMPTY, controlParens, caseBody))
5348+
}
5349+
// Extract closing brace prefix and store it in the first catch's body end space
5350+
val closeBracePrefix = if (cursor < source.length) {
5351+
val r = source.substring(cursor, Math.min(cursor + 50, source.length))
5352+
val ci = r.indexOf('}')
5353+
if (ci > 0) { val space = Space.format(r.substring(0, ci)); cursor = cursor + ci + 1; space }
5354+
else if (ci == 0) { cursor = cursor + 1; Space.EMPTY }
5355+
else Space.EMPTY
5356+
} else Space.EMPTY
5357+
if (!catches.isEmpty) {
5358+
val firstCatch = catches.get(0)
5359+
val updatedBody = firstCatch.getBody.withEnd(closeBracePrefix)
5360+
catches.set(0, firstCatch.withBody(updatedBody))
5361+
}
5362+
}
5363+
5364+
val finallyBlock: JLeftPadded[J.Block] = if (!parsedTry.finalizer.isEmpty && parsedTry.finalizer.span.exists) {
5365+
val fs = if (cursor < source.length) source.substring(cursor, Math.min(cursor + 50, source.length)) else ""
5366+
val fi = fs.indexOf("finally"); val fSpace = if (fi > 0) Space.format(fs.substring(0, fi)) else Space.EMPTY
5367+
if (fi >= 0) cursor = cursor + fi + 7
5368+
val fb = visitTree(parsedTry.finalizer) match {
5369+
case block: J.Block => block
5370+
case expr: Expression =>
5371+
val s = new util.ArrayList[JRightPadded[Statement]](); s.add(JRightPadded.build(expr.asInstanceOf[Statement]))
5372+
new J.Block(Tree.randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build(false), s, Space.EMPTY)
5373+
case _ => null
5374+
}
5375+
if (fb != null) JLeftPadded.build(fb).withBefore(fSpace) else null
5376+
} else null
5377+
5378+
updateCursor(parsedTry.span.end)
5379+
new J.Try(Tree.randomId(), prefix, Markers.EMPTY, null, body, catches, finallyBlock)
5380+
}
52605381
private def visitMatchTree(matchTree: Trees.Match[?]): J = {
52615382
val savedCursor = cursor
52625383
try { visitMatchImpl(matchTree) } catch { case _: Exception => cursor = savedCursor; visitUnknown(matchTree) }
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright 2025 the original author or authors.
3+
* <p>
4+
* Licensed under the Moderne Source Available License (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
* <p>
8+
* https://docs.moderne.io/licensing/moderne-source-available-license
9+
* <p>
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.openrewrite.scala.tree;
17+
18+
import org.junit.jupiter.api.Test;
19+
import org.openrewrite.test.RewriteTest;
20+
21+
import static org.openrewrite.scala.Assertions.scala;
22+
23+
class ParsedTryTest implements RewriteTest {
24+
25+
@Test
26+
void tryFinallyWithThrow() {
27+
rewriteRun(
28+
scala(
29+
"""
30+
object Test {
31+
try {
32+
println("risky")
33+
} finally {
34+
throw new RuntimeException("fail")
35+
}
36+
}
37+
"""
38+
)
39+
);
40+
}
41+
42+
@Test
43+
void tryCatchThrowable() {
44+
rewriteRun(
45+
scala(
46+
"""
47+
object Test {
48+
try {
49+
println("risky")
50+
} catch {
51+
case e: Throwable => println("caught")
52+
}
53+
}
54+
"""
55+
)
56+
);
57+
}
58+
}

0 commit comments

Comments
 (0)