Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 40 additions & 48 deletions parser/prism/Translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,22 @@ template <typename... Tail> bool hasExpr(const std::unique_ptr<parser::Node> &he
return hasExpr(head) && hasExpr(tail...);
}

// Helper to extract desugared expression or return EmptyTree if node is null.
ExpressionPtr takeDesugaredExprOrEmptyTree(const std::unique_ptr<parser::Node> &node) {
if (node == nullptr) {
return MK::EmptyTree();
}
ENFORCE(node->hasDesugaredExpr(), "Node has no desugared expression");
return node->takeDesugaredExpr();
}

// Helper template to convert nodes to any store type with takeDesugaredExpr or EmptyTree for nulls.
// This is used to convert a NodeVec to the store type argument for nodes including `Send`, `InsSeq`.
template <typename StoreType> StoreType nodeVecToStore(const sorbet::parser::NodeVec &nodes) {
StoreType store;
store.reserve(nodes.size());
for (const auto &node : nodes) {
store.emplace_back(node ? node->takeDesugaredExpr() : sorbet::ast::MK::EmptyTree());
store.emplace_back(takeDesugaredExprOrEmptyTree(node));
}
return store;
}
Expand Down Expand Up @@ -1110,7 +1119,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
ExpressionPtr breakArgs;
if (arguments.size() == 1) {
auto &first = arguments[0];
breakArgs = first == nullptr ? MK::EmptyTree() : first->takeDesugaredExpr();
breakArgs = takeDesugaredExprOrEmptyTree(first);
} else {
auto args = nodeVecToStore<ast::Array::ENTRY_store>(arguments);
auto arrayLocation = parser.translateLocation(breakNode->arguments->base.location);
Expand Down Expand Up @@ -1630,8 +1639,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
blockExpr = desugarSymbolProc(symbol);
} else {
auto blockLoc = translateLoc(prismBlock->location);
auto blockBodyExpr =
blockBody == nullptr ? MK::EmptyTree() : blockBody->takeDesugaredExpr();
auto blockBodyExpr = takeDesugaredExprOrEmptyTree(blockBody);
blockExpr = MK::Block(blockLoc, move(blockBodyExpr), move(blockParamsStore));
}

Expand Down Expand Up @@ -1748,7 +1756,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
blockExpr = desugarSymbolProc(symbol);
} else {
auto blockLoc = translateLoc(prismBlock->location);
auto blockBodyExpr = blockBody == nullptr ? MK::EmptyTree() : blockBody->takeDesugaredExpr();
auto blockBodyExpr = takeDesugaredExprOrEmptyTree(blockBody);
blockExpr = MK::Block(blockLoc, move(blockBodyExpr), move(blockParamsStore));
}

Expand Down Expand Up @@ -1829,7 +1837,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
}

// Start with the else clause as the final else
ExpressionPtr resultExpr = elseClause == nullptr ? MK::EmptyTree() : elseClause->takeDesugaredExpr();
ExpressionPtr resultExpr = takeDesugaredExprOrEmptyTree(elseClause);

// Build the if ladder backwards from the last "in" to the first
for (auto it = inNodes.rbegin(); it != inNodes.rend(); ++it) {
Expand All @@ -1841,7 +1849,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
auto matchExpr = MK::RaiseUnimplemented(patternLoc);

// The body is the statements from the "in" clause
auto thenExpr = inPattern->body != nullptr ? inPattern->body->takeDesugaredExpr() : MK::EmptyTree();
auto thenExpr = takeDesugaredExprOrEmptyTree(inPattern->body);

// Collect pattern variable assignments from the pattern
ast::InsSeq::STATS_store vars;
Expand Down Expand Up @@ -1910,7 +1918,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon

ast::Send::ARGS_store args;
args.reserve(2 + whenNodes.size() + totalPatterns); // +2 is for the predicate and the patterns count
args.emplace_back(predicate == nullptr ? MK::EmptyTree() : predicate->takeDesugaredExpr());
args.emplace_back(takeDesugaredExprOrEmptyTree(predicate));
args.emplace_back(MK::Int(locZeroLen, totalPatterns));

for (auto &whenNodePtr : whenNodes) {
Expand All @@ -1919,20 +1927,18 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
// Each pattern node already has a desugared expression (populated by translateMulti +
// NodeWithExpr). Consume them now; the wrapper's placeholder expression is intentionally ignored.
for (auto &patternNode : whenNodeWrapped->patterns) {
args.emplace_back(patternNode == nullptr ? MK::EmptyTree() : patternNode->takeDesugaredExpr());
args.emplace_back(takeDesugaredExprOrEmptyTree(patternNode));
}
}

for (auto &whenNodePtr : whenNodes) {
auto whenNodeWrapped = parser::NodeWithExpr::cast_node<parser::When>(whenNodePtr.get());
ENFORCE(whenNodeWrapped != nullptr, "case without a when?");
// The body node also carries a real expression once translateStatements has run.
auto bodyExpr =
whenNodeWrapped->body == nullptr ? MK::EmptyTree() : whenNodeWrapped->body->takeDesugaredExpr();
args.emplace_back(move(bodyExpr));
args.emplace_back(takeDesugaredExprOrEmptyTree(whenNodeWrapped->body));
}

args.emplace_back(elseClause == nullptr ? MK::EmptyTree() : elseClause->takeDesugaredExpr());
args.emplace_back(takeDesugaredExprOrEmptyTree(elseClause));

// Desugar to `::Magic.caseWhen(predicate, num_patterns, patterns..., bodies..., else)`
auto expr = MK::Send(location, MK::Magic(locZeroLen), core::Names::caseWhen(), locZeroLen, args.size(),
Expand All @@ -1955,15 +1961,15 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon

// The if/else ladder for the entire case statement, starting with the else clause as the final `else` when
// building backwards
ExpressionPtr resultExpr = elseClause == nullptr ? MK::EmptyTree() : elseClause->takeDesugaredExpr();
ExpressionPtr resultExpr = takeDesugaredExprOrEmptyTree(elseClause);

for (auto it = whenNodes.rbegin(); it != whenNodes.rend(); ++it) {
auto whenNodeWrapped = parser::NodeWithExpr::cast_node<parser::When>(it->get());
ENFORCE(whenNodeWrapped != nullptr, "case without a when?");

ExpressionPtr patternsResult; // the if/else ladder for this when clause's patterns
for (auto &patternNode : whenNodeWrapped->patterns) {
auto patternExpr = patternNode == nullptr ? MK::EmptyTree() : patternNode->takeDesugaredExpr();
auto patternExpr = takeDesugaredExprOrEmptyTree(patternNode);
auto patternLoc = patternExpr.loc();

ExpressionPtr testExpr;
Expand Down Expand Up @@ -1998,8 +2004,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
}
}

auto thenExpr =
whenNodeWrapped->body != nullptr ? whenNodeWrapped->body->takeDesugaredExpr() : MK::EmptyTree();
auto thenExpr = takeDesugaredExprOrEmptyTree(whenNodeWrapped->body);
resultExpr = MK::If(whenNodeWrapped->loc, move(patternsResult), move(thenExpr), move(resultExpr));
}

Expand Down Expand Up @@ -2232,7 +2237,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
}
}

auto methodBody = body == nullptr ? MK::EmptyTree() : body->takeDesugaredExpr();
auto methodBody = takeDesugaredExprOrEmptyTree(body);

auto methodExpr = MK::Method(location, declLoc, name, move(paramsStore), move(methodBody));

Expand Down Expand Up @@ -2413,7 +2418,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
canProvideNiceDesugar = parser::NodeWithExpr::isa_node<parser::LVarLhs>(variable.get());
}

auto bodyExpr = body ? body->takeDesugaredExpr() : MK::EmptyTree();
auto bodyExpr = takeDesugaredExprOrEmptyTree(body);
auto collectionExpr = collection->takeDesugaredExpr();
auto locZeroLen = location.copyWithZeroLength();
ast::MethodDef::PARAMS_store params;
Expand Down Expand Up @@ -2988,7 +2993,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
ExpressionPtr nextArgs;
if (arguments.size() == 1) {
auto &first = arguments[0];
nextArgs = first == nullptr ? MK::EmptyTree() : first->takeDesugaredExpr();
nextArgs = takeDesugaredExprOrEmptyTree(first);
} else {
auto args = nodeVecToStore<ast::Array::ENTRY_store>(arguments);
auto arrayLocation = parser.translateLocation(nextNode->arguments->base.location);
Expand Down Expand Up @@ -3148,8 +3153,8 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
auto recv = MK::Magic(location);
auto locZeroLen = core::LocOffsets{location.beginPos(), location.beginPos()};

auto fromExpr = left ? left->takeDesugaredExpr() : MK::EmptyTree();
auto toExpr = right ? right->takeDesugaredExpr() : MK::EmptyTree();
auto fromExpr = takeDesugaredExprOrEmptyTree(left);
auto toExpr = takeDesugaredExprOrEmptyTree(right);

auto excludeEndExpr = isExclusive ? MK::True(location) : MK::False(location);

Expand Down Expand Up @@ -3280,7 +3285,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
ExpressionPtr returnArgs;
if (returnValues.size() == 1) {
auto &first = returnValues[0];
returnArgs = first == nullptr ? MK::EmptyTree() : first->takeDesugaredExpr();
returnArgs = takeDesugaredExprOrEmptyTree(first);
} else {
auto args = nodeVecToStore<ast::Array::ENTRY_store>(std::move(returnValues));
auto arrayLocation = parser.translateLocation(returnNode->arguments->base.location);
Expand Down Expand Up @@ -3487,7 +3492,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
}
} else {
auto cond = predicate->takeDesugaredExpr();
auto body = statements ? statements->takeDesugaredExpr() : MK::EmptyTree();
auto body = takeDesugaredExprOrEmptyTree(statements);
if (beginModifier) {
auto breaker =
MK::If(location, std::move(cond), MK::Break(location, MK::EmptyTree()), MK::EmptyTree());
Expand Down Expand Up @@ -3529,7 +3534,7 @@ unique_ptr<parser::Node> Translator::translate(pm_node_t *node, bool preserveCon
}
} else {
auto cond = predicate->takeDesugaredExpr();
auto body = statements ? statements->takeDesugaredExpr() : MK::EmptyTree();
auto body = takeDesugaredExprOrEmptyTree(statements);
if (beginModifier) {
// TODO using bang (aka !) is not semantically correct because it can be overridden by the user.
auto negatedCond =
Expand Down Expand Up @@ -4926,11 +4931,7 @@ unique_ptr<parser::Node> Translator::translateRescue(pm_begin_node *parentBeginN
// Regular local variable
varExpr = var->takeDesugaredExpr();

if (rescueBody != nullptr) {
rescueBodyExpr = rescueBody->takeDesugaredExpr();
} else {
rescueBodyExpr = ast::MK::EmptyTree();
}
rescueBodyExpr = takeDesugaredExprOrEmptyTree(rescueBody);
} else if (isReference) {
// Non-local reference (lvalue exception variables like @ex, @@ex, $ex)
// Create a temp variable and wrap the body
Expand All @@ -4945,7 +4946,7 @@ unique_ptr<parser::Node> Translator::translateRescue(pm_begin_node *parentBeginN
ast::InsSeq::STATS_store stats;
stats.emplace_back(move(assignExpr));

auto bodyExpr = rescueBody != nullptr ? rescueBody->takeDesugaredExpr() : ast::MK::EmptyTree();
auto bodyExpr = takeDesugaredExprOrEmptyTree(rescueBody);
rescueBodyExpr = ast::MK::InsSeq(varLoc, move(stats), move(bodyExpr));
} else {
// For bare rescue clauses with no variable, create a <rescueTemp> variable
Expand All @@ -4957,7 +4958,7 @@ unique_ptr<parser::Node> Translator::translateRescue(pm_begin_node *parentBeginN
: rescueKeywordLoc;
varExpr = ast::MK::Local(syntheticVarLoc, rescueTemp);

rescueBodyExpr = rescueBody != nullptr ? rescueBody->takeDesugaredExpr() : ast::MK::EmptyTree();
rescueBodyExpr = takeDesugaredExprOrEmptyTree(rescueBody);
}

auto rescueCaseExpr = ast::make_expression<ast::RescueCase>(resbodyLoc, move(astExceptions), move(varExpr),
Expand Down Expand Up @@ -5028,12 +5029,7 @@ unique_ptr<parser::Node> Translator::translateRescue(pm_begin_node *parentBeginN
}

// Build the ast::Rescue expression
ast::ExpressionPtr bodyExpr;
if (bodyNode != nullptr) {
bodyExpr = bodyNode->takeDesugaredExpr();
} else {
bodyExpr = ast::MK::EmptyTree();
}
auto bodyExpr = takeDesugaredExprOrEmptyTree(bodyNode);

// Extract RescueCase expressions from each Resbody node
ast::Rescue::RESCUE_CASE_store rescueCases;
Expand All @@ -5047,8 +5043,7 @@ unique_ptr<parser::Node> Translator::translateRescue(pm_begin_node *parentBeginN
}

// Extract the else expression
ast::ExpressionPtr elseExpr;
elseExpr = (elseNode != nullptr) ? elseNode->takeDesugaredExpr() : ast::MK::EmptyTree();
auto elseExpr = takeDesugaredExprOrEmptyTree(elseNode);

// Build the ast::Rescue expression (ensure is EmptyTree since this is translateRescue, not translateEnsure)
auto rescueExpr = ast::make_expression<ast::Rescue>(rescueLoc, move(bodyExpr), move(rescueCases), move(elseExpr),
Expand Down Expand Up @@ -5127,7 +5122,7 @@ NodeVec Translator::translateEnsure(pm_begin_node *beginNode) {
auto rescue = ast::cast_tree<ast::Rescue>(bodyExpr);
ENFORCE(rescue != nullptr, "translatedRescue should be a Rescue node");

rescue->ensure = ensureBody != nullptr ? ensureBody->takeDesugaredExpr() : ast::MK::EmptyTree();
rescue->ensure = takeDesugaredExprOrEmptyTree(ensureBody);

translatedEnsure =
make_node_with_expr<parser::Ensure>(move(bodyExpr), loc, move(translatedRescue), move(ensureBody));
Expand Down Expand Up @@ -5158,11 +5153,8 @@ NodeVec Translator::translateEnsure(pm_begin_node *beginNode) {
} else {
// Build ast::Rescue expression with ensure field set
// When there's no rescue clause, create a new Rescue with empty rescue cases
ast::ExpressionPtr bodyExpr;
bodyExpr = (bodyNode != nullptr) ? bodyNode->takeDesugaredExpr() : ast::MK::EmptyTree();

ast::ExpressionPtr ensureExpr =
(ensureBody != nullptr) ? ensureBody->takeDesugaredExpr() : ast::MK::EmptyTree();
auto bodyExpr = takeDesugaredExprOrEmptyTree(bodyNode);
auto ensureExpr = takeDesugaredExprOrEmptyTree(ensureBody);

// Create ast::Rescue with empty rescue cases
ast::Rescue::RESCUE_CASE_store emptyCases;
Expand Down Expand Up @@ -5250,8 +5242,8 @@ unique_ptr<parser::Node> Translator::translateIfNode(core::LocOffsets location,
}

auto condExpr = predicate->takeDesugaredExpr();
auto thenExpr = ifTrue ? ifTrue->takeDesugaredExpr() : MK::EmptyTree();
auto elseExpr = ifFalse ? ifFalse->takeDesugaredExpr() : MK::EmptyTree();
auto thenExpr = takeDesugaredExprOrEmptyTree(ifTrue);
auto elseExpr = takeDesugaredExprOrEmptyTree(ifFalse);
auto ifNode = MK::If(location, move(condExpr), move(thenExpr), move(elseExpr));
return make_node_with_expr<parser::If>(move(ifNode), location, move(predicate), move(ifTrue), move(ifFalse));
}
Expand Down
Loading