Skip to content

Commit b997df0

Browse files
CST: Improve serialization of functions, handle AstTypeFunction (#275)
This PR introduces support for serializing an AstTypeFunction. The most interesting part of this is the parameters, where the name (and hence colon) is optional. For this, we introduce a type `AstTypeFunctionParameter` to contain this information. We also serialize the vararg annotation independently. As part of this, we also extend the serialization on normal functions to include attributes, varargs, and the return type. For now we serialize a placeholder for the `:` token (in functions and for AstLocal too), which will get fixed after we sync to Luau 0.673.
1 parent 445ea7b commit b997df0

File tree

8 files changed

+362
-32
lines changed

8 files changed

+362
-32
lines changed

batteries/syntax/ast_types.luau

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ export type Punctuated<T, Separator = ","> = { Pair<T, Separator> }
4343

4444
export type AstLocal = {
4545
name: Token<string>,
46+
colon: Token<":">?,
47+
annotation: AstType?,
4648
}
4749

4850
export type AstExprGroup = {
@@ -108,16 +110,20 @@ export type AstFunctionBody = {
108110
genericPacks: Punctuated<AstGenericTypePack>?,
109111
closeGenerics: Token<">">?,
110112
openParens: Token<"(">,
111-
-- TODO: parameters
112-
parameters: Punctuated<Token<string>>,
113+
parameters: Punctuated<AstLocal>,
114+
vararg: Token<"...">?,
115+
varargColon: Token<":">?,
116+
varargAnnotation: AstTypePack?,
113117
closeParens: Token<")">,
114-
-- TODO: return type
118+
returnSpecifier: Token<":">?,
119+
returnAnnotation: AstTypePack?,
115120
body: AstStatBlock,
116121
["end"]: Token<"end">,
117122
}
118123

119124
export type AstExprAnonymousFunction = {
120125
tag: "function",
126+
attributes: { AstAttribute },
121127
["function"]: Token<"function">,
122128
body: AstFunctionBody,
123129
}
@@ -313,15 +319,19 @@ export type AstStatCompoundAssign = {
313319
value: AstExpr,
314320
}
315321

322+
export type AstAttribute = Token<"@checked" | "@native" | "@deprecated"> & { tag: "attribute" }
323+
316324
export type AstStatFunction = {
317325
tag: "function",
326+
attributes: { AstAttribute },
318327
["function"]: Token<"function">,
319328
name: AstExpr,
320329
body: AstFunctionBody,
321330
}
322331

323332
export type AstStatLocalFunction = {
324333
tag: "localfunction",
334+
attributes: { AstAttribute },
325335
["local"]: Token<"local">,
326336
["function"]: Token<"function">,
327337
name: AstLocal,
@@ -477,6 +487,26 @@ export type AstTypeTable = {
477487
closeBrace: Token<"}">,
478488
}
479489

490+
export type AstTypeFunctionParameter = {
491+
name: Token?,
492+
colon: Token<":">?,
493+
type: AstType,
494+
}
495+
496+
export type AstTypeFunction = {
497+
tag: "function",
498+
openGenerics: Token<"<">?,
499+
generics: Punctuated<AstGenericType>?,
500+
genericPacks: Punctuated<AstGenericTypePack>?,
501+
closeGenerics: Token<">">?,
502+
openParens: Token<"(">,
503+
parameters: Punctuated<AstTypeFunctionParameter>,
504+
vararg: AstTypePack?,
505+
closeParens: Token<")">,
506+
returnArrow: Token<"->">,
507+
returnTypes: AstTypePack,
508+
}
509+
480510
export type AstType =
481511
| AstTypeReference
482512
| AstTypeSingletonBool
@@ -488,13 +518,14 @@ export type AstType =
488518
| AstTypeOptional
489519
| AstTypeArray
490520
| AstTypeTable
521+
| AstTypeFunction
491522

492523
export type AstTypePackExplicit = {
493524
tag: "explicit",
494-
openParens: Token<"(">,
525+
openParens: Token<"(">?,
495526
types: Punctuated<AstType>,
496527
tailType: AstTypePack?,
497-
closeParens: Token<")">,
528+
closeParens: Token<")">?,
498529
}
499530

500531
export type AstTypePackGeneric = {
@@ -505,7 +536,8 @@ export type AstTypePackGeneric = {
505536

506537
export type AstTypePackVariadic = {
507538
tag: "variadic",
508-
ellipsis: Token<"...">,
539+
--- May be nil when present as the vararg annotation in a function body
540+
ellipsis: Token<"...">?,
509541
type: AstType,
510542
}
511543

batteries/syntax/visitor.luau

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ export type Visitor = {
4545
visitTypeIntersection: (T.AstTypeIntersection) -> boolean,
4646
visitTypeArray: (T.AstTypeArray) -> boolean,
4747
visitTypeTable: (T.AstTypeTable) -> boolean,
48+
visitTypeFunction: (T.AstTypeFunction) -> boolean,
4849

4950
visitTypePackExplicit: (T.AstTypePackExplicit) -> boolean,
5051
visitTypePackGeneric: (T.AstTypePackGeneric) -> boolean,
@@ -106,6 +107,7 @@ local defaultVisitor: Visitor = {
106107
visitTypeIntersection = alwaysVisit :: any,
107108
visitTypeArray = alwaysVisit :: any,
108109
visitTypeTable = alwaysVisit :: any,
110+
visitTypeFunction = alwaysVisit,
109111

110112
visitTypePackExplicit = alwaysVisit,
111113
visitTypePackGeneric = alwaysVisit,
@@ -140,6 +142,12 @@ end
140142
local function visitLocal(node: T.AstLocal, visitor: Visitor)
141143
if visitor.visitLocal(node) then
142144
visitToken(node.name, visitor)
145+
if node.colon then
146+
visitToken(node.colon, visitor)
147+
end
148+
if node.annotation then
149+
visitType(node.annotation, visitor)
150+
end
143151
end
144152
end
145153

@@ -390,20 +398,45 @@ local function visitFunctionBody(node: T.AstFunctionBody, visitor: Visitor)
390398
end
391399
visitToken(node.openParens, visitor)
392400
visitPunctuated(node.parameters, visitor, visitLocal)
401+
if node.vararg then
402+
visitToken(node.vararg, visitor)
403+
end
404+
if node.varargColon then
405+
visitToken(node.varargColon, visitor)
406+
end
407+
if node.varargAnnotation then
408+
visitTypePack(node.varargAnnotation, visitor)
409+
end
393410
visitToken(node.closeParens, visitor)
411+
if node.returnSpecifier then
412+
visitToken(node.returnSpecifier, visitor)
413+
end
414+
if node.returnAnnotation then
415+
visitTypePack(node.returnAnnotation, visitor)
416+
end
394417
visitBlock(node.body, visitor)
395418
visitToken(node["end"], visitor)
396419
end
397420

421+
local function visitAttribute(node: T.AstAttribute, visitor)
422+
visitToken(node, visitor)
423+
end
424+
398425
local function visitAnonymousFunction(node: T.AstExprAnonymousFunction, visitor: Visitor)
399426
if visitor.visitAnonymousFunction(node) then
427+
for _, attribute in node.attributes do
428+
visitAttribute(attribute, visitor)
429+
end
400430
visitToken(node["function"], visitor)
401431
visitFunctionBody(node.body, visitor)
402432
end
403433
end
404434

405435
local function visitFunction(node: T.AstStatFunction, visitor: Visitor)
406436
if visitor.visitFunction(node) then
437+
for _, attribute in node.attributes do
438+
visitAttribute(attribute, visitor)
439+
end
407440
visitToken(node["function"], visitor)
408441
visitExpression(node.name, visitor)
409442
visitFunctionBody(node.body, visitor)
@@ -412,6 +445,9 @@ end
412445

413446
local function visitLocalFunction(node: T.AstStatLocalFunction, visitor: Visitor)
414447
if visitor.visitLocalFunction(node) then
448+
for _, attribute in node.attributes do
449+
visitAttribute(attribute, visitor)
450+
end
415451
visitToken(node["local"], visitor)
416452
visitToken(node["function"], visitor)
417453
visitLocal(node.name, visitor)
@@ -641,14 +677,53 @@ local function visitTypeTable(node: T.AstTypeTable, visitor: Visitor)
641677
end
642678
end
643679

680+
local function visitTypeFunctionParameter(node: T.AstTypeFunctionParameter, visitor)
681+
if node.name then
682+
visitToken(node.name, visitor)
683+
end
684+
if node.colon then
685+
visitToken(node.colon, visitor)
686+
end
687+
visitType(node.type, visitor)
688+
end
689+
690+
local function visitTypeFunction(node: T.AstTypeFunction, visitor: Visitor)
691+
if visitor.visitTypeFunction(node) then
692+
if node.openGenerics then
693+
visitToken(node.openGenerics, visitor)
694+
end
695+
if node.generics then
696+
visitPunctuated(node.generics, visitor, visitGeneric)
697+
end
698+
if node.genericPacks then
699+
visitPunctuated(node.genericPacks, visitor, visitGenericPack)
700+
end
701+
if node.closeGenerics then
702+
visitToken(node.closeGenerics, visitor)
703+
end
704+
visitToken(node.openParens, visitor)
705+
visitPunctuated(node.parameters, visitor, visitTypeFunctionParameter)
706+
if node.vararg then
707+
visitTypePack(node.vararg, visitor)
708+
end
709+
visitToken(node.closeParens, visitor)
710+
visitToken(node.returnArrow, visitor)
711+
visitTypePack(node.returnTypes, visitor)
712+
end
713+
end
714+
644715
local function visitTypePackExplicit(node: T.AstTypePackExplicit, visitor: Visitor)
645716
if visitor.visitTypePackExplicit(node) then
646-
visitToken(node.openParens, visitor)
717+
if node.openParens then
718+
visitToken(node.openParens, visitor)
719+
end
647720
visitPunctuated(node.types, visitor, visitType)
648721
if node.tailType then
649722
visitTypePack(node.tailType, visitor)
650723
end
651-
visitToken(node.closeParens, visitor)
724+
if node.closeParens then
725+
visitToken(node.closeParens, visitor)
726+
end
652727
end
653728
end
654729

@@ -661,7 +736,9 @@ end
661736

662737
local function visitTypePackVariadic(node: T.AstTypePackVariadic, visitor: Visitor)
663738
if visitor.visitTypePackVariadic(node) then
664-
visitToken(node.ellipsis, visitor)
739+
if node.ellipsis then
740+
visitToken(node.ellipsis, visitor)
741+
end
665742
visitType(node.type, visitor)
666743
end
667744
end
@@ -773,6 +850,8 @@ function visitType(type: T.AstType, visitor: Visitor)
773850
visitTypeArray(type, visitor)
774851
elseif type.tag == "table" then
775852
visitTypeTable(type, visitor)
853+
elseif type.tag == "function" then
854+
visitTypeFunction(type, visitor)
776855
else
777856
exhaustiveMatch(type.tag)
778857
end

0 commit comments

Comments
 (0)