Skip to content

Commit f9c9f9a

Browse files
CST: Handle parsing expressions directly (#232)
To help make writing tests easier, we want to support returning a CST when parsing expressions directly. This PR adds CST support to the parseexpr function (as well as adding a wrapper around it in the batteries syntax package). We extend the visitor set up to allow visiting an overall expression. We also add "end" visitors. For now, this is only for expression and blocks (as necessary for codemod tooling), however we may want to add "end" visitors for other nodes in the future. The printer logic is also modified to support printing expressions directly.
1 parent dc85cd3 commit f9c9f9a

File tree

4 files changed

+81
-41
lines changed

4 files changed

+81
-41
lines changed

batteries/syntax/parser.luau

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ local function parse(source: string): T.AstStatBlock
88
return luau.parse(source).root
99
end
1010

11+
local function parseexpr(source: string): T.AstExpr
12+
return luau.parseexpr(source)
13+
end
14+
1115
return {
1216
parse = parse,
17+
parseexpr = parseexpr,
1318
}

batteries/syntax/printer.luau

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,31 +46,47 @@ local function printString(expr: T.AstExprConstantString): string
4646
return result
4747
end
4848

49-
--- Returns a string representation of an AstStatBlock
50-
local function printBlock(block: T.AstStatBlock): string
51-
local printer = visitor.createVisitor()
52-
local result = ""
49+
type PrintVisitor = visitor.Visitor & {
50+
result: string,
51+
}
52+
53+
local function printVisitor()
54+
local printer = visitor.createVisitor() :: PrintVisitor
55+
printer.result = ""
5356

5457
printer.visitToken = function(node: T.Token)
55-
result ..= printToken(node)
58+
printer.result ..= printToken(node)
5659
return false
5760
end
5861

5962
printer.visitString = function(node: T.AstExprConstantString)
60-
result ..= printString(node)
63+
printer.result ..= printString(node)
6164
return false
6265
end
6366

6467
printer.visitTypeString = function(node: T.AstTypeSingletonString)
65-
result ..= printString(node)
68+
printer.result ..= printString(node)
6669
return false
6770
end
6871

72+
return printer
73+
end
74+
75+
--- Returns a string representation of an AstStatBlock
76+
local function printBlock(block: T.AstStatBlock): string
77+
local printer = printVisitor()
6978
visitor.visitBlock(block, printer)
79+
return printer.result
80+
end
7081

71-
return result
82+
--- Returns a string representation of an AstExpr
83+
function printExpr(block: T.AstExpr): string
84+
local printer = printVisitor()
85+
visitor.visitExpression(block, printer)
86+
return printer.result
7287
end
7388

7489
return {
7590
print = printBlock,
91+
printexpr = printExpr,
7692
}

batteries/syntax/visitor.luau

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
local T = require("./ast_types")
44

5-
type Visitor = {
5+
export type Visitor = {
66
visitBlock: (T.AstStatBlock) -> boolean,
7+
visitBlockEnd: (T.AstStatBlock) -> (),
78
visitIf: (T.AstStatIf) -> boolean,
89
visitWhile: (T.AstStatWhile) -> boolean,
910
visitRepeat: (T.AstStatRepeat) -> boolean,
@@ -17,6 +18,8 @@ type Visitor = {
1718
visitLocalFunction: (T.AstStatLocalFunction) -> boolean,
1819
visitTypeAlias: (T.AstStatTypeAlias) -> boolean,
1920

21+
visitExpression: (T.AstExpr) -> boolean,
22+
visitExpressionEnd: (T.AstExpr) -> (),
2023
visitLocalReference: (T.AstExprLocal) -> boolean,
2124
visitGlobal: (T.AstExprGlobal) -> boolean,
2225
visitCall: (T.AstExprCall) -> boolean,
@@ -50,6 +53,7 @@ end
5053

5154
local defaultVisitor: Visitor = {
5255
visitBlock = alwaysVisit :: any,
56+
visitBlockEnd = alwaysVisit :: any,
5357
visitIf = alwaysVisit :: any,
5458
visitWhile = alwaysVisit :: any,
5559
visitRepeat = alwaysVisit :: any,
@@ -63,6 +67,8 @@ local defaultVisitor: Visitor = {
6367
visitLocalFunction = alwaysVisit :: any,
6468
visitTypeAlias = alwaysVisit :: any,
6569

70+
visitExpression = alwaysVisit :: any,
71+
visitExpressionEnd = alwaysVisit :: any,
6672
visitLocalReference = alwaysVisit :: any,
6773
visitGlobal = alwaysVisit :: any,
6874
visitCall = alwaysVisit :: any,
@@ -118,6 +124,8 @@ local function visitBlock(block: T.AstStatBlock, visitor: Visitor)
118124
for _, statement in block.statements do
119125
visitStatement(statement, visitor)
120126
end
127+
128+
visitor.visitBlockEnd(block)
121129
end
122130
end
123131

@@ -447,38 +455,42 @@ local function visitTypeGroup(node: T.AstTypeGroup, visitor: Visitor)
447455
end
448456

449457
function visitExpression(expression: T.AstExpr, visitor: Visitor)
450-
if expression.tag == "nil" then
451-
visitNil(expression, visitor)
452-
elseif expression.tag == "boolean" then
453-
visitBoolean(expression, visitor)
454-
elseif expression.tag == "number" then
455-
visitNumber(expression, visitor)
456-
elseif expression.tag == "string" then
457-
visitString(expression, visitor)
458-
elseif expression.tag == "local" then
459-
visitLocalReference(expression, visitor)
460-
elseif expression.tag == "global" then
461-
visitGlobal(expression, visitor)
462-
elseif expression.tag == "vararg" then
463-
visitVarargs(expression, visitor)
464-
elseif expression.tag == "call" then
465-
visitCall(expression, visitor)
466-
elseif expression.tag == "unary" then
467-
visitUnary(expression, visitor)
468-
elseif expression.tag == "binary" then
469-
visitBinary(expression, visitor)
470-
elseif expression.tag == "function" then
471-
visitAnonymousFunction(expression, visitor)
472-
elseif expression.tag == "table" then
473-
visitTable(expression, visitor)
474-
elseif expression.tag == "indexname" then
475-
visitIndexName(expression, visitor)
476-
elseif expression.tag == "index" then
477-
visitIndexExpr(expression, visitor)
478-
elseif expression.tag == "group" then
479-
visitGroup(expression, visitor)
480-
else
481-
exhaustiveMatch(expression.tag)
458+
if visitor.visitExpression(expression) then
459+
if expression.tag == "nil" then
460+
visitNil(expression, visitor)
461+
elseif expression.tag == "boolean" then
462+
visitBoolean(expression, visitor)
463+
elseif expression.tag == "number" then
464+
visitNumber(expression, visitor)
465+
elseif expression.tag == "string" then
466+
visitString(expression, visitor)
467+
elseif expression.tag == "local" then
468+
visitLocalReference(expression, visitor)
469+
elseif expression.tag == "global" then
470+
visitGlobal(expression, visitor)
471+
elseif expression.tag == "vararg" then
472+
visitVarargs(expression, visitor)
473+
elseif expression.tag == "call" then
474+
visitCall(expression, visitor)
475+
elseif expression.tag == "unary" then
476+
visitUnary(expression, visitor)
477+
elseif expression.tag == "binary" then
478+
visitBinary(expression, visitor)
479+
elseif expression.tag == "function" then
480+
visitAnonymousFunction(expression, visitor)
481+
elseif expression.tag == "table" then
482+
visitTable(expression, visitor)
483+
elseif expression.tag == "indexname" then
484+
visitIndexName(expression, visitor)
485+
elseif expression.tag == "index" then
486+
visitIndexExpr(expression, visitor)
487+
elseif expression.tag == "group" then
488+
visitGroup(expression, visitor)
489+
else
490+
exhaustiveMatch(expression.tag)
491+
end
492+
493+
visitor.visitExpressionEnd(expression)
482494
end
483495
end
484496

@@ -543,4 +555,7 @@ end
543555
return {
544556
createVisitor = createVisitor,
545557
visitBlock = visitBlock,
558+
visitStatement = visitStatement,
559+
visitExpression = visitExpression,
560+
visitType = visitType,
546561
}

luau/src/luau.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,16 @@ struct ExprResult
5757

5858
static ExprResult parseExpr(std::string& source)
5959
{
60+
// TODO: this is very bad, fix it!
61+
FFlag::LuauStoreCSTData2.value = true;
62+
6063
auto allocator = std::make_shared<Luau::Allocator>();
6164
auto names = std::make_shared<Luau::AstNameTable>(*allocator);
6265

6366
Luau::ParseOptions options;
6467
options.captureComments = true;
6568
options.allowDeclarationSyntax = false;
69+
options.storeCstData = true;
6670

6771
auto parseResult = Luau::Parser::parseExpr(source.data(), source.size(), *names, *allocator, options);
6872

0 commit comments

Comments
 (0)