Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 54 additions & 7 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,6 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {

// skip primal return
if (val == Activity::enzyme_constnoneed ||
val == Activity::enzyme_activenoneed ||
val == Activity::enzyme_dupnoneed) {
newRetActivityArgs.push_back(iattr);
continue;
Expand All @@ -636,15 +635,35 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {

switch (val) {
case Activity::enzyme_active:
if (!res.use_empty()) {
outs_args.push_back(res);
out_ty.push_back(res.getType());
newRetActivityArgs.push_back(iattr);
} else {
if (res.use_empty()) {
changed = true;
auto new_activenn = ActivityAttr::get(rewriter.getContext(),
Activity::enzyme_activenoneed);
newRetActivityArgs.push_back(new_activenn);
} else {
int in_idx = 0;
for (auto act : inpActivity) {
auto v = cast<ActivityAttr>(act).getValue();
in_idx +=
(v == Activity::enzyme_dup || v == Activity::enzyme_dupnoneed)
? 2
: 1;
}
in_idx += out_idx;
auto dres = uop.getInputs()[in_idx];

if (matchPattern(dres, m_Zero()) ||
matchPattern(dres, m_AnyZeroFloat())) {
changed = true;
auto new_const = ActivityAttr::get(rewriter.getContext(),
Activity::enzyme_const);
newRetActivityArgs.push_back(new_const);
} else {
newRetActivityArgs.push_back(iattr);
}

outs_args.push_back(res);
out_ty.push_back(res.getType());
}
break;

Expand All @@ -668,7 +687,31 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
newRetActivityArgs.push_back(iattr);
break;

case Activity::enzyme_activenoneed:
case Activity::enzyme_activenoneed: {
int in_idx = 0;
for (auto act : inpActivity) {
auto v = cast<ActivityAttr>(act).getValue();
in_idx +=
(v == Activity::enzyme_dup || v == Activity::enzyme_dupnoneed)
? 2
: 1;
}
in_idx += out_idx;

auto dres = uop.getInputs()[in_idx];

if (matchPattern(dres, m_Zero()) ||
matchPattern(dres, m_AnyZeroFloat())) {
changed = true;
auto new_constnn = ActivityAttr::get(rewriter.getContext(),
Activity::enzyme_constnoneed);
newRetActivityArgs.push_back(new_constnn);
} else {
newRetActivityArgs.push_back(iattr);
}

continue;
}
case Activity::enzyme_constnoneed:
case Activity::enzyme_dupnoneed:
break;
Expand Down Expand Up @@ -763,6 +806,10 @@ class ReverseRetOpt final : public OpRewritePattern<AutoDiffOp> {
} else if (new_val == Activity::enzyme_constnoneed &&
old_val == Activity::enzyme_const) {
++oldIdx; // skip const primal
} else if (new_val == Activity::enzyme_const &&
old_val == Activity::enzyme_active) {
uop.getOutputs()[oldIdx++].replaceAllUsesWith(
newOp.getOutputs()[newIdx++]);
}
}
}
Expand Down
16 changes: 16 additions & 0 deletions enzyme/test/MLIR/ReverseMode/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,20 @@ module {
// CHECK: enzyme.autodiff @square2(%arg0, %arg1, %arg2, %arg3){{.*}}activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>]{{.*}}ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_activenoneed>]{{.*}}
return %cst : f32
}

// Test 5: active -> const for ret_activity (iff derivative is 0)
func.func @test5(%x: f32, %y: f32, %dr0: f32) -> (f32,f32,f32,f32) {
%cst = arith.constant 0.0000e+00 : f32
%r:4 = enzyme.autodiff @square2(%x,%y,%dr0,%cst) { activity=[#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>] } : (f32,f32,f32,f32) -> (f32,f32,f32,f32)
// CHECK: %{{.*}} = enzyme.autodiff @square2(%arg0, %arg1, %arg2, %cst){{.*}}activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>]{{.*}}ret_activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]{{.*}}
return %r#0,%r#1,%r#2,%r#3 : f32,f32,f32,f32
}

// Test 6: active -> activenoneed/const -> constnoneed for ret_activity
func.func @test6(%x: f32, %y: f32, %dr0: f32) -> (f32,f32,f32) {
%cst = arith.constant 0.0000e+00 : f32
%r:4 = enzyme.autodiff @square2(%x,%y,%dr0,%cst) { activity=[#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>] } : (f32,f32,f32,f32) -> (f32,f32,f32,f32)
// CHECK: %{{.*}} = enzyme.autodiff @square2(%arg0, %arg1, %arg2, %cst){{.*}}activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>]{{.*}}ret_activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_constnoneed>]{{.*}}
return %r#0,%r#2,%r#3 : f32,f32,f32
}
}
Loading