Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions crates/cairo-lang-sierra-generator/src/local_variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,13 @@ pub fn analyze_ap_changes<'db>(

let mut ctx = analysis.analyzer;
let peeled_used_after_revoke: OrderedHashSet<_> =
ctx.used_after_revoke.iter().map(|var| ctx.peel_aliases(var)).copied().collect();
ctx.used_after_revoke.iter().map(|var| ctx.peel_aliases(var).variable_id).collect();
// Any used after revoke variable that might be revoked should be a local.
let locals: OrderedHashSet<VariableId> = ctx
.used_after_revoke
.iter()
.filter(|var| ctx.might_be_revoked(&peeled_used_after_revoke, var))
.map(|var| ctx.peel_aliases(var))
.cloned()
.map(|var| ctx.peel_aliases(var).variable_id)
.collect();
let mut need_ap_alignment = OrderedHashSet::default();
if !root_info.known_ap_change {
Expand Down Expand Up @@ -115,6 +114,11 @@ struct CalledBlockInfo {
introduced_vars: Vec<VariableId>,
}

struct VarSource {
variable_id: VariableId,
allow_const: bool,
}

/// Context for the find_local_variables logic.
struct FindLocalsContext<'db, 'a> {
db: &'db dyn Database,
Expand All @@ -127,9 +131,9 @@ struct FindLocalsContext<'db, 'a> {
constants: UnorderedHashSet<VariableId>,
/// A mapping of variables which are the same in the context of finding locals.
/// I.e. if `aliases[var_id]` is local than var_id is also local.
aliases: UnorderedHashMap<VariableId, VariableId>,
aliases: UnorderedHashMap<VariableId, VarSource>,
/// A mapping from partial param variables to the containing variable.
partial_param_parents: UnorderedHashMap<VariableId, VariableId>,
partial_param_parents: UnorderedHashMap<VariableId, VarSource>,
}

pub type LoweredDemand = Demand<VariableId, ()>;
Expand Down Expand Up @@ -241,11 +245,13 @@ struct BranchInfo {

impl<'db, 'a> FindLocalsContext<'db, 'a> {
/// Given a variable that might be an alias follow aliases until we get the original variable.
pub fn peel_aliases(&'a self, mut var: &'a VariableId) -> &'a VariableId {
pub fn peel_aliases(&'a self, mut var: &'a VariableId) -> VarSource {
let mut allow_const = true;
while let Some(alias) = self.aliases.get(var) {
var = alias;
var = &alias.variable_id;
allow_const &= alias.allow_const;
}
var
VarSource { variable_id: *var, allow_const }
}

/// Return true if the variable might be revoked by ap changes.
Expand All @@ -258,19 +264,31 @@ impl<'db, 'a> FindLocalsContext<'db, 'a> {
peeled_used_after_revoke: &OrderedHashSet<VariableId>,
var: &VariableId,
) -> bool {
if self.constants.contains(var) {
let mut peeled = self.peel_aliases(var);
if self.constants.contains(&peeled.variable_id) {
return false;
}
let mut peeled = self.peel_aliases(var);
if self.non_ap_based.contains(peeled) {
if self.non_ap_based.contains(&peeled.variable_id) {
return false;
}
// In the case of partial params, we check if one of its ancestors is a local variable, or
// will be used after the revoke, and thus will be used as a local variable. If that
// is the case, then 'var' can not be revoked.
while let Some(parent) = self.partial_param_parents.get(peeled) {
peeled = self.peel_aliases(parent);
if self.non_ap_based.contains(peeled) || peeled_used_after_revoke.contains(peeled) {

let mut allow_const = peeled.allow_const;
while let Some(parent) = self.partial_param_parents.get(&peeled.variable_id) {
allow_const &= parent.allow_const;
peeled = self.peel_aliases(&parent.variable_id);
allow_const &= peeled.allow_const;
if self.non_ap_based.contains(&peeled.variable_id) {
return false;
}
// If the variable parent-peel chain ends in a const, but the allow_const flag is false
// (i.e. the chain went through a libfunc that doesn't allow consts),
// then the variable can be revoked as the saved const will be ap based.
if peeled_used_after_revoke.contains(&peeled.variable_id)
&& (!self.constants.contains(&peeled.variable_id) || allow_const)
{
return false;
}
}
Expand Down Expand Up @@ -309,10 +327,22 @@ impl<'db, 'a> FindLocalsContext<'db, 'a> {
for (var, output_info) in zip_eq(output_vars.iter(), var_output_infos.iter()) {
match output_info.ref_info {
OutputVarReferenceInfo::SameAsParam { param_idx } => {
self.aliases.insert(*var, input_vars[param_idx].var_id);
self.aliases.insert(
*var,
VarSource {
variable_id: input_vars[param_idx].var_id,
allow_const: _params_signatures[param_idx].allow_const,
},
);
}
OutputVarReferenceInfo::PartialParam { param_idx } => {
self.partial_param_parents.insert(*var, input_vars[param_idx].var_id);
self.partial_param_parents.insert(
*var,
VarSource {
variable_id: input_vars[param_idx].var_id,
allow_const: _params_signatures[param_idx].allow_const,
},
);
}
OutputVarReferenceInfo::Deferred(DeferredOutputKind::Const)
| OutputVarReferenceInfo::NewLocalVar
Expand Down Expand Up @@ -387,12 +417,21 @@ impl<'db, 'a> FindLocalsContext<'db, 'a> {
)
}
lowering::Statement::Snapshot(statement_snapshot) => {
self.aliases.insert(statement_snapshot.original(), statement_snapshot.input.var_id);
self.aliases.insert(statement_snapshot.snapshot(), statement_snapshot.input.var_id);
self.aliases.insert(
statement_snapshot.original(),
VarSource { variable_id: statement_snapshot.input.var_id, allow_const: true },
);
self.aliases.insert(
statement_snapshot.snapshot(),
VarSource { variable_id: statement_snapshot.input.var_id, allow_const: true },
);
BranchInfo { known_ap_change: true }
}
lowering::Statement::Desnap(statement_desnap) => {
self.aliases.insert(statement_desnap.output, statement_desnap.input.var_id);
self.aliases.insert(
statement_desnap.output,
VarSource { variable_id: statement_desnap.input.var_id, allow_const: true },
);
BranchInfo { known_ap_change: true }
}
};
Expand Down
Loading