Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions luau/src/luau.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ struct AstSerialize : public Luau::AstVisitor

// absolute index for the table where we're storing locals
int localTableIndex;
// reference to previously serialized token
int lastTokenRef = LUA_NOREF;

AstSerialize(lua_State* L, std::string_view source, Luau::CstNodeMap cstNodeMap, std::vector<Luau::Comment> commentLocations)
: L(L)
Expand Down Expand Up @@ -253,6 +255,26 @@ struct AstSerialize : public Luau::AstVisitor
return result;
}

// Splits a list of trivia into trailing trivia for the previos token, and leading trivia for the next token
// The trailing trivia consists of all trivia up to and including the first '\n' character seen
static std::pair<std::vector<Trivia>, std::vector<Trivia>> splitTrivia(std::vector<Trivia> trivia)
{
size_t i = 0;
for (i = 0; i < trivia.size(); i++)
{
if (trivia[i].kind == Trivia::Whitespace && trivia[i].text.find('\n') != std::string::npos)
break;
}

if (i == trivia.size())
return {trivia, {}};

auto middleIter(trivia.begin());
std::advance(middleIter, i + 1);

return {std::vector<Trivia>(trivia.begin(), middleIter), std::vector<Trivia>(middleIter, trivia.end())};
}

void serialize(Luau::Position position)
{
lua_rawcheckstack(L, 2);
Expand Down Expand Up @@ -407,9 +429,27 @@ struct AstSerialize : public Luau::AstVisitor
lua_rawcheckstack(L, 2);
lua_createtable(L, 0, nrec + 3);

// TODO: split up into leading / trailing trivia
const auto leadingTrivia = extractTrivia(position);
serializeTrivia(leadingTrivia);
const auto trivia = extractTrivia(position);
if (lastTokenRef != LUA_NOREF)
{
const auto [trailingTrivia, leadingTrivia] = splitTrivia(trivia);

lua_getref(L, lastTokenRef);
LUAU_ASSERT(lua_istable(L, -1));

serializeTrivia(trailingTrivia);
lua_setfield(L, -2, "trailingTrivia");
lua_pop(L, 1);
lua_unref(L, lastTokenRef);
lastTokenRef = LUA_NOREF;

serializeTrivia(leadingTrivia);
}
else
{
serializeTrivia(trivia);
}
LUAU_ASSERT(lua_istable(L, -2));
lua_setfield(L, -2, "leadingTrivia");

serialize(position);
Expand All @@ -422,6 +462,8 @@ struct AstSerialize : public Luau::AstVisitor
lua_rawcheckstack(L, 2);
lua_createtable(L, 0, 0);
lua_setfield(L, -2, "trailingTrivia");

lastTokenRef = lua_ref(L, -1);
}

void serializeLocals(Luau::AstArray<Luau::AstLocal*>& locals, size_t nrec = 0)
Expand Down
24 changes: 24 additions & 0 deletions tests/testAstSerializer.spec.luau
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,29 @@ local function test_tokenizeWhitespace()
assert(token.leadingTrivia[3].text == "\n")
end

local function test_triviaSplitBetweenLeadingAndTrailing()
local block = luau.parse("local x = 'test' -- comment\n" .. "-- comment 2\nlocal y = 'value'").root
assert(#block.statements == 2)

local firstStmt = block.statements[1]
assert(firstStmt.tag == "local")

local trailingToken = firstStmt.values[1].node
assert(trailingToken.tag == "string")
assert(#trailingToken.trailingTrivia == 3)
assert(trailingToken.trailingTrivia[1].text == " ")
assert(trailingToken.trailingTrivia[2].text == "-- comment")
assert(trailingToken.trailingTrivia[3].text == "\n")

local secondStmt = block.statements[2]
assert(secondStmt.tag == "local")

local leadingToken = secondStmt["local"]
assert(#leadingToken.leadingTrivia == 2)
assert(leadingToken.leadingTrivia[1].text == "-- comment 2")
assert(leadingToken.leadingTrivia[2].text == "\n")
end

local function test_roundtrippableAst()
local files = {
"examples/a.luau",
Expand All @@ -116,4 +139,5 @@ test_tokenContainsLeadingNewline()
test_tokenContainsLeadingSingleLineComment()
test_tokenContainsLeadingBlockComment()
test_tokenizeWhitespace()
test_triviaSplitBetweenLeadingAndTrailing()
test_roundtrippableAst()