Skip to content

Commit 445ea7b

Browse files
JohnnyMorganzaatxe
andauthored
CST: Handle generic definitions (#268)
This PR adds the generic definition nodes into the function and type alias definitions. This involves serializing AstGenericType / AstGenericTypePack. Generic definitions can be punctuated, and are separated by commas. However, there is only 1 list of overall comma positions in the CST node. Hence we implement a way to split the commas so that we can separate `generics` and `genericPacks`. Note that this means the last element of `generics` in the punctuated list may still have a comma. --------- Co-authored-by: ariel <[email protected]>
1 parent 95b1ffd commit 445ea7b

File tree

6 files changed

+204
-4
lines changed

6 files changed

+204
-4
lines changed

batteries/syntax/ast_types.luau

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ export type AstExprIndexExpr = {
103103
}
104104

105105
export type AstFunctionBody = {
106-
-- TODO: generics
106+
openGenerics: Token<"<">?,
107+
generics: Punctuated<AstGenericType>?,
108+
genericPacks: Punctuated<AstGenericTypePack>?,
109+
closeGenerics: Token<">">?,
107110
openParens: Token<"(">,
108111
-- TODO: parameters
109112
parameters: Punctuated<Token<string>>,
@@ -330,6 +333,10 @@ export type AstStatTypeAlias = {
330333
["export"]: Token<"export">?,
331334
typeToken: Token<"type">,
332335
name: Token,
336+
openGenerics: Token<"<">?,
337+
generics: Punctuated<AstGenericType>?,
338+
genericPacks: Punctuated<AstGenericTypePack>?,
339+
closeGenerics: Token<">">?,
333340
equals: Token<"=">,
334341
type: AstType,
335342
}
@@ -362,6 +369,21 @@ export type AstStat =
362369
| AstStatTypeAlias
363370
| AstStatTypeFunction
364371

372+
export type AstGenericType = {
373+
tag: "generic",
374+
name: Token<string>,
375+
equals: Token<"=">?,
376+
default: AstType?,
377+
}
378+
379+
export type AstGenericTypePack = {
380+
tag: "genericpack",
381+
name: Token<string>,
382+
ellipsis: Token<"...">,
383+
equals: Token<"=">?,
384+
default: AstTypePack?,
385+
}
386+
365387
export type AstTypeReference = {
366388
tag: "reference",
367389
prefix: Token<string>?,

batteries/syntax/visitor.luau

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,46 @@ local function visitCompoundAssign(node: T.AstStatCompoundAssign, visitor: Visit
260260
end
261261
end
262262

263+
local function visitGeneric(node: T.AstGenericType, visitor: Visitor)
264+
visitToken(node.name, visitor)
265+
if node.equals then
266+
visitToken(node.equals, visitor)
267+
end
268+
if node.default then
269+
visitType(node.default, visitor)
270+
end
271+
end
272+
273+
local function visitGenericPack(node: T.AstGenericTypePack, visitor: Visitor)
274+
visitToken(node.name, visitor)
275+
visitToken(node.ellipsis, visitor)
276+
if node.equals then
277+
visitToken(node.equals, visitor)
278+
end
279+
if node.default then
280+
visitTypePack(node.default, visitor)
281+
end
282+
end
283+
263284
local function visitTypeAlias(node: T.AstStatTypeAlias, visitor: Visitor)
264285
if visitor.visitTypeAlias(node) then
265286
if node.export then
266287
visitToken(node.export, visitor)
267288
end
268289
visitToken(node.typeToken, visitor)
269290
visitToken(node.name, visitor)
291+
if node.openGenerics then
292+
visitToken(node.openGenerics, visitor)
293+
end
294+
if node.generics then
295+
visitPunctuated(node.generics, visitor, visitGeneric)
296+
end
297+
if node.genericPacks then
298+
visitPunctuated(node.genericPacks, visitor, visitGenericPack)
299+
end
300+
if node.closeGenerics then
301+
visitToken(node.closeGenerics, visitor)
302+
end
270303
visitToken(node.equals, visitor)
271304
visitType(node.type, visitor)
272305
end
@@ -343,6 +376,18 @@ local function visitBinary(node: T.AstExprBinary, visitor: Visitor)
343376
end
344377

345378
local function visitFunctionBody(node: T.AstFunctionBody, visitor: Visitor)
379+
if node.openGenerics then
380+
visitToken(node.openGenerics, visitor)
381+
end
382+
if node.generics then
383+
visitPunctuated(node.generics, visitor, visitGeneric)
384+
end
385+
if node.genericPacks then
386+
visitPunctuated(node.genericPacks, visitor, visitGenericPack)
387+
end
388+
if node.closeGenerics then
389+
visitToken(node.closeGenerics, visitor)
390+
end
346391
visitToken(node.openParens, visitor)
347392
visitPunctuated(node.parameters, visitor, visitLocal)
348393
visitToken(node.closeParens, visitor)

luau/src/luau.cpp

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,29 @@ struct AstSerialize : public Luau::AstVisitor
842842
const auto* cstNode = lookupCstNode<Luau::CstExprFunction>(node);
843843

844844
lua_rawcheckstack(L, 3);
845-
lua_createtable(L, 0, 7);
845+
lua_createtable(L, 0, 9);
846+
847+
if (node->generics.size > 0 || node->genericPacks.size > 0)
848+
{
849+
if (cstNode)
850+
serializeToken(cstNode->openGenericsPosition, "<");
851+
else
852+
lua_pushnil(L);
853+
lua_setfield(L, -2, "openGenerics");
854+
855+
auto commas = cstNode ? cstNode->genericsCommaPositions : Luau::AstArray<Luau::Position>{};
856+
serializePunctuated(node->generics, commas, ",");
857+
lua_setfield(L, -2, "generics");
858+
859+
serializePunctuated(node->genericPacks, splitArray(commas, node->generics.size), ",");
860+
lua_setfield(L, -2, "genericPacks");
861+
862+
if (cstNode)
863+
serializeToken(cstNode->closeGenericsPosition, ">");
864+
else
865+
lua_pushnil(L);
866+
lua_setfield(L, -2, "closeGenerics");
867+
}
846868

847869
if (node->self)
848870
serialize(node->self, /* createToken= */ false);
@@ -859,7 +881,7 @@ struct AstSerialize : public Luau::AstVisitor
859881
serializePunctuated(node->args, cstNode ? cstNode->argsCommaPositions : Luau::AstArray<Luau::Position>{}, ",");
860882
lua_setfield(L, -2, "parameters");
861883

862-
// TODO: generics, return types, etc.
884+
// TODO: return types
863885

864886
if (node->vararg)
865887
serialize(node->varargLocation);
@@ -1479,6 +1501,13 @@ struct AstSerialize : public Luau::AstVisitor
14791501
lua_setfield(L, -2, "body");
14801502
}
14811503

1504+
static Luau::AstArray<Luau::Position> splitArray(Luau::AstArray<Luau::Position> arr, size_t index)
1505+
{
1506+
if (arr.size < index)
1507+
return arr;
1508+
return {arr.data + index, arr.size - index};
1509+
}
1510+
14821511
void serializeStat(Luau::AstStatTypeAlias* node)
14831512
{
14841513
lua_rawcheckstack(L, 2);
@@ -1500,7 +1529,27 @@ struct AstSerialize : public Luau::AstVisitor
15001529
serializeToken(node->nameLocation.begin, node->name.value);
15011530
lua_setfield(L, -2, "name");
15021531

1503-
// TODO: generics
1532+
if (node->generics.size > 0 || node->genericPacks.size > 0)
1533+
{
1534+
if (cstNode)
1535+
serializeToken(cstNode->genericsOpenPosition, "<");
1536+
else
1537+
lua_pushnil(L);
1538+
lua_setfield(L, -2, "openGenerics");
1539+
1540+
auto commas = cstNode ? cstNode->genericsCommaPositions : Luau::AstArray<Luau::Position>{};
1541+
serializePunctuated(node->generics, commas, ",");
1542+
lua_setfield(L, -2, "generics");
1543+
1544+
serializePunctuated(node->genericPacks, splitArray(commas, node->generics.size), ",");
1545+
lua_setfield(L, -2, "genericPacks");
1546+
1547+
if (cstNode)
1548+
serializeToken(cstNode->genericsClosePosition, ">");
1549+
else
1550+
lua_pushnil(L);
1551+
lua_setfield(L, -2, "closeGenerics");
1552+
}
15041553

15051554
if (cstNode)
15061555
{
@@ -1950,6 +1999,62 @@ struct AstSerialize : public Luau::AstVisitor
19501999
lua_setfield(L, -2, "closeParens");
19512000
}
19522001

2002+
void serializeType(Luau::AstGenericType* node)
2003+
{
2004+
lua_rawcheckstack(L, 2);
2005+
lua_createtable(L, 0, preambleSize + 3);
2006+
2007+
serializeNodePreamble(node, "generic");
2008+
2009+
const auto cstNode = lookupCstNode<Luau::CstGenericType>(node);
2010+
2011+
serializeToken(node->location.begin, node->name.value);
2012+
lua_setfield(L, -2, "name");
2013+
2014+
if (node->defaultValue && cstNode)
2015+
serializeToken(*cstNode->defaultEqualsPosition, "=");
2016+
else
2017+
lua_pushnil(L);
2018+
lua_setfield(L, -2, "equals");
2019+
2020+
if (node->defaultValue)
2021+
node->defaultValue->visit(this);
2022+
else
2023+
lua_pushnil(L);
2024+
lua_setfield(L, -2, "default");
2025+
}
2026+
2027+
void serializeType(Luau::AstGenericTypePack* node)
2028+
{
2029+
lua_rawcheckstack(L, 2);
2030+
lua_createtable(L, 0, preambleSize + 3);
2031+
2032+
serializeNodePreamble(node, "generic");
2033+
2034+
const auto cstNode = lookupCstNode<Luau::CstGenericTypePack>(node);
2035+
2036+
serializeToken(node->location.begin, node->name.value);
2037+
lua_setfield(L, -2, "name");
2038+
2039+
if (cstNode)
2040+
serializeToken(cstNode->ellipsisPosition, "...");
2041+
else
2042+
lua_pushnil(L);
2043+
lua_setfield(L, -2, "ellipsis");
2044+
2045+
if (node->defaultValue && cstNode)
2046+
serializeToken(*cstNode->defaultEqualsPosition, "=");
2047+
else
2048+
lua_pushnil(L);
2049+
lua_setfield(L, -2, "equals");
2050+
2051+
if (node->defaultValue)
2052+
node->defaultValue->visit(this);
2053+
else
2054+
lua_pushnil(L);
2055+
lua_setfield(L, -2, "default");
2056+
}
2057+
19532058
void serializeType(Luau::AstTypeError* node)
19542059
{
19552060
// TODO: types
@@ -2348,6 +2453,18 @@ struct AstSerialize : public Luau::AstVisitor
23482453
serializeType(node);
23492454
return false;
23502455
}
2456+
2457+
bool visit(Luau::AstGenericType* node) override
2458+
{
2459+
serializeType(node);
2460+
return false;
2461+
}
2462+
2463+
bool visit(Luau::AstGenericTypePack* node) override
2464+
{
2465+
serializeType(node);
2466+
return false;
2467+
}
23512468
};
23522469

23532470
int luau_parse(lua_State* L)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
function a<A>() end
2+
3+
function b<A...>() end
4+
5+
function c<A, B>() end
6+
7+
function d<A..., B...>() end
8+
9+
function e<A, B, C..., D...>() end
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
type A<X> = string
2+
type B<X...> = string
3+
type C<X, Y...> = string
4+
type D<X = string> = string
5+
type E<X = string, Y... = ...string> = string

tests/testAstSerializer.spec.luau

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ local function test_roundtrippableAst()
151151
"tests/astSerializerTests/compound-types-3.luau",
152152
"tests/astSerializerTests/function-declaration-1.luau",
153153
"tests/astSerializerTests/function-declaration-2.luau",
154+
"tests/astSerializerTests/function-declaration-3.luau",
154155
"tests/astSerializerTests/generic-for-loop-1.luau",
155156
"tests/astSerializerTests/if-expression-1.luau",
156157
"tests/astSerializerTests/if-expression-2.luau",
@@ -161,6 +162,7 @@ local function test_roundtrippableAst()
161162
"tests/astSerializerTests/numeric-for-loop-1.luau",
162163
"tests/astSerializerTests/table-1.luau",
163164
"tests/astSerializerTests/table-2.luau",
165+
"tests/astSerializerTests/type-alias-1.luau",
164166
"tests/astSerializerTests/type-assertion-1.luau",
165167
"tests/astSerializerTests/type-function-1.luau",
166168
"tests/astSerializerTests/type-singletons-1.luau",

0 commit comments

Comments
 (0)