Skip to content

Commit 5d274e3

Browse files
committed
Handle if stmt and nil
Enough to parse the writeFile test case
1 parent 9c91f0e commit 5d274e3

File tree

4 files changed

+71
-15
lines changed

4 files changed

+71
-15
lines changed

examples/codemods/identity.luau

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ codemod.apply("examples/json.luau", identityTransform)
1414
codemod.apply("examples/main.luau", identityTransform)
1515
codemod.apply("examples/parsing.luau", identityTransform)
1616
codemod.apply("examples/time_example.luau", identityTransform)
17+
codemod.apply("examples/writeFile.luau", identityTransform)

luau/src/luau.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,7 @@ struct AstSerialize : public Luau::AstVisitor
472472

473473
void serialize(Luau::AstExprConstantNil* node)
474474
{
475-
lua_rawcheckstack(L, 2);
476-
lua_createtable(L, 0, preambleSize);
477-
475+
serializeToken(node->location.begin, "nil", preambleSize);
478476
serializeNodePreamble(node, "nil");
479477
}
480478

@@ -825,29 +823,41 @@ struct AstSerialize : public Luau::AstVisitor
825823

826824
serializeNodePreamble(node, "conditional");
827825

826+
serializeToken(node->location.begin, "if");
827+
lua_setfield(L, -2, "if");
828+
828829
node->condition->visit(this);
829830
lua_setfield(L, -2, "condition");
830831

832+
serializeToken(node->thenLocation->begin, "then");
833+
lua_setfield(L, -2, "then");
834+
831835
node->thenbody->visit(this);
832836
lua_setfield(L, -2, "consequent");
833837

834838
if (node->elsebody)
839+
{
840+
LUAU_ASSERT(node->elseLocation);
841+
serializeToken(node->elseLocation->begin, "else");
842+
lua_setfield(L, -2, "else");
843+
835844
node->elsebody->visit(this);
836-
else
837-
lua_pushnil(L);
838-
lua_setfield(L, -2, "antecedent");
845+
lua_setfield(L, -2, "antecedent");
839846

840-
if (node->thenLocation)
841-
serialize(*node->thenLocation);
847+
serializeToken(node->elsebody->location.end, "end");
848+
lua_setfield(L, -2, "end");
849+
}
842850
else
851+
{
843852
lua_pushnil(L);
844-
lua_setfield(L, -2, "thenLocation");
853+
lua_setfield(L, -2, "else");
845854

846-
if (node->elseLocation)
847-
serialize(*node->elseLocation);
848-
else
849855
lua_pushnil(L);
850-
lua_setfield(L, -2, "elseLocation");
856+
lua_setfield(L, -2, "antecedent");
857+
858+
serializeToken(node->thenbody->location.end, "end");
859+
lua_setfield(L, -2, "end");
860+
}
851861
}
852862

853863
void serializeStat(Luau::AstStatWhile* node)

std/codemod/ast_types.luau

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ export type AstLocal = {
4343
name: Token<string>,
4444
}
4545

46+
export type AstExprConstantNil = Token<"nil"> & { tag: "nil" }
47+
4648
export type AstExprConstantBool = Token<"true" | "false"> & { tag: "boolean", value: boolean }
4749

4850
export type AstExprConstantNumber = Token<string> & { tag: "number", value: number }
@@ -118,6 +120,7 @@ export type AstExprBinary = {
118120
}
119121

120122
export type AstExpr =
123+
| AstExprConstantNil
121124
| AstExprConstantBool
122125
| AstExprConstantNumber
123126
| AstExprConstantString
@@ -134,6 +137,17 @@ export type AstStatBlock = {
134137
statements: { AstStat },
135138
}
136139

140+
export type AstStatIf = {
141+
tag: "conditional",
142+
["if"]: Token<"if">,
143+
condition: AstExpr,
144+
["then"]: Token<"then">,
145+
consequent: AstStatBlock,
146+
["else"]: Token<"else">, -- TODO: this could be elseif!
147+
antecedent: AstStatBlock,
148+
["end"]: Token<"end">,
149+
}
150+
137151
export type AstStatReturn = {
138152
tag: "return",
139153
["return"]: Token<"return">,
@@ -153,6 +167,6 @@ export type AstStatLocal = {
153167
values: Punctuated<AstExpr>,
154168
}
155169

156-
export type AstStat = AstStatBlock | AstStatReturn | AstStatExpr | AstStatLocal
170+
export type AstStat = AstStatBlock | AstStatIf | AstStatReturn | AstStatExpr | AstStatLocal
157171

158172
return {}

std/codemod/visitor.luau

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ local T = require("./ast_types")
44

55
type Visitor = {
66
visitBlock: (T.AstStatBlock) -> boolean,
7+
visitIf: (T.AstStatIf) -> boolean,
78
visitReturn: (T.AstStatReturn) -> boolean,
89
visitLocalDeclaration: (T.AstStatLocal) -> boolean,
910

@@ -16,6 +17,7 @@ type Visitor = {
1617
visitIndexName: (T.AstExprIndexName) -> boolean,
1718

1819
visitToken: (T.Token) -> boolean,
20+
visitNil: (T.AstExprConstantNil) -> boolean,
1921
visitString: (T.AstExprConstantString) -> boolean,
2022
visitBoolean: (T.AstExprConstantBool) -> boolean,
2123
visitNumber: (T.AstExprConstantNumber) -> boolean,
@@ -28,6 +30,7 @@ end
2830

2931
local defaultVisitor: Visitor = {
3032
visitBlock = alwaysVisit :: any,
33+
visitIf = alwaysVisit :: any,
3134
visitReturn = alwaysVisit :: any,
3235
visitLocalDeclaration = alwaysVisit :: any,
3336

@@ -40,6 +43,7 @@ local defaultVisitor: Visitor = {
4043
visitIndexName = alwaysVisit :: any,
4144

4245
visitToken = alwaysVisit :: any,
46+
visitNil = alwaysVisit :: any,
4347
visitString = alwaysVisit :: any,
4448
visitBoolean = alwaysVisit :: any,
4549
visitNumber = alwaysVisit :: any,
@@ -77,6 +81,23 @@ local function visitBlock(block: T.AstStatBlock, visitor: Visitor)
7781
end
7882
end
7983

84+
local function visitIf(node: T.AstStatIf, visitor: Visitor)
85+
if visitor.visitIf(node) then
86+
visitToken(node["if"], visitor)
87+
visitExpression(node.condition, visitor)
88+
visitToken(node["then"], visitor)
89+
visitBlock(node.consequent, visitor)
90+
-- TODO: special handling for elseif?
91+
if node["else"] then
92+
visitToken(node["else"], visitor)
93+
end
94+
if node.antecedent then
95+
visitBlock(node.antecedent, visitor)
96+
end
97+
visitToken(node["end"], visitor)
98+
end
99+
end
100+
80101
local function visitReturn(node: T.AstStatReturn, visitor: Visitor)
81102
if visitor.visitReturn(node) then
82103
visitToken(node["return"], visitor)
@@ -101,6 +122,12 @@ local function visitString(node: T.AstExprConstantString, visitor: Visitor)
101122
end
102123
end
103124

125+
local function visitNil(node: T.AstExprConstantNil, visitor: Visitor)
126+
if visitor.visitNil(node) then
127+
visitToken(node, visitor)
128+
end
129+
end
130+
104131
local function visitBoolean(node: T.AstExprConstantBool, visitor: Visitor)
105132
if visitor.visitBoolean(node) then
106133
visitToken(node, visitor)
@@ -189,7 +216,9 @@ local function visitIndexName(node: T.AstExprIndexName, visitor: Visitor)
189216
end
190217

191218
function visitExpression(expression: T.AstExpr, visitor: Visitor)
192-
if expression.tag == "boolean" then
219+
if expression.tag == "nil" then
220+
visitNil(expression, visitor)
221+
elseif expression.tag == "boolean" then
193222
visitBoolean(expression, visitor)
194223
elseif expression.tag == "number" then
195224
visitNumber(expression, visitor)
@@ -215,6 +244,8 @@ end
215244
function visitStatement(statement: T.AstStat, visitor: Visitor)
216245
if statement.tag == "block" then
217246
visitBlock(statement, visitor)
247+
elseif statement.tag == "conditional" then
248+
visitIf(statement, visitor)
218249
elseif statement.tag == "expression" then
219250
visitExpression(statement.expression, visitor)
220251
elseif statement.tag == "local" then

0 commit comments

Comments
 (0)