Skip to content

Use nograd keyword for functions too #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/grad/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using .ReverseDiff
ReverseDiff.@grad function (ev::Eval)(args...)
Z = ev.fwd(ReverseDiff.value.(args)...)
Z, Δ -> begin
isnothing(ev.rev) && error("no gradient definition here!")
ev.rev===nothing && throw("No gradient definition found! Running `@tullio` with keyword `verbose=true` may print the reason")
ev.rev(ReverseDiff.value(Δ), Z, ReverseDiff.value.(args)...)
end
end
2 changes: 1 addition & 1 deletion src/grad/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using .Tracker
Tracker.@grad function (ev::Eval)(args...)
Z = ev.fwd(Tracker.data.(args)...)
Z, Δ -> begin
isnothing(ev.rev) && error("no gradient definition here!")
ev.rev===nothing && throw("No gradient definition found! Running `@tullio` with keyword `verbose=true` may print the reason")
tuple(ev.rev(Tracker.data(Δ), Z, Tracker.data.(args)...)...)
end
end
2 changes: 1 addition & 1 deletion src/grad/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using .Zygote
Zygote.@adjoint function (ev::Eval)(args...)
Z = ev.fwd(args...)
Z, Δ -> begin
isnothing(ev.rev) && error("no gradient definition here!")
ev.rev===nothing && throw("No gradient definition found! Running `@tullio` with keyword `verbose=true` may print the reason")
tuple(nothing, ev.rev(Δ, Z, args...)...)
end
end
Expand Down
14 changes: 13 additions & 1 deletion src/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ function parse_options(exs...)
)
expr = nothing
nograd = Symbol[]
safe = Symbol[]
ranges = Tuple[]
for ex in exs
# Actual options:
Expand All @@ -150,6 +151,16 @@ function parse_options(exs...)
throw("this accepts nograd=A or nograd=(A,B,C)")
end

# Safe keyword
elseif isexpr(ex, :(=)) && ex.args[1] == :safe
if ex.args[2] isa Symbol
push!(safe, ex.args[2])
elseif isexpr(ex.args[2], :tuple)
append!(safe, ex.args[2].args)
else
throw("this accepts safe=i or safe=(i,j,k)")
end

# Ranges specified outside:
elseif isexpr(ex, :call) && ex.args[1] in [:in, :∈]
push!(ranges, (ex.args[2], ex.args[3]))
Expand Down Expand Up @@ -190,6 +201,7 @@ function parse_options(exs...)
avx=opts[:avx],
cuda=opts[:cuda],
nograd=nograd,
safe=safe,
), ranges, expr
end

Expand Down Expand Up @@ -588,7 +600,7 @@ detectunsafe(expr, list, store) = MacroTools_postwalk(expr) do ex
MacroTools_postwalk(i) do x
@capture_(x, B_[inner__]) || return x
# Now we have found an array which indexes another one, mark its indices unsafe
append!(list, filter(j -> j isa Symbol, inner))
append!(list, setdiff(filter(j -> j isa Symbol, inner), store.safe))
unique!(list)
# and don't compute a gradient for the inner array
B isa Symbol && push!(store.nograd, B)
Expand Down
47 changes: 25 additions & 22 deletions src/symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using DiffRules

const _CSE = Ref(true)

function insert_symbolic_gradient(axislist, store)

dZ = Symbol(DEL, ZED)
Expand Down Expand Up @@ -32,12 +34,12 @@ function insert_symbolic_gradient(axislist, store)

inbody, prebody = [], []
for (dt, t) in unique(targets)
drdt = leibnitz(store.right, t)
drdt = leibnitz(store.right, t, store.nograd)
deltar = if store.finaliser == :identity
simplitimes(simpliconj(drdt), :($dZ[$(store.leftraw...)]))
else
rhs = :($ZED[$(store.leftraw...)])
dldr = leibfinal(store.finaliser, rhs)
dldr = leibfinal(store.finaliser, rhs, store.nograd)
simplitimes(simpliconj(drdt), simpliconj(dldr), :($dZ[$(store.leftraw...)]))
end
if store.redfun == :+
Expand All @@ -50,7 +52,7 @@ function insert_symbolic_gradient(axislist, store)
end
end
store.verbose>0 && @info "symbolic gradients" inbody
ex_body = commonsubex(quote $(inbody...) end)
ex_body = _CSE[] ? commonsubex(quote $(inbody...) end) : quote $(inbody...) end

ex_pre, ex_post = if store.redfun == :* # then nonzero LHS are handled already, but harder cases here:
product_grad(prebody, store)
Expand Down Expand Up @@ -84,16 +86,16 @@ function insert_symbolic_gradient(axislist, store)

end

leibfinal(fun::Symbol, res) =
leibfinal(fun::Symbol, res, no=()) =
if fun == :log
:(exp(-$res)) # this exp gets done at every element :(
# :(inv(exp($res)))
else
_leibfinal(:($fun($RHS)), res)
_leibfinal(:($fun($RHS)), res, no)
end

_leibfinal(out, res) = begin
grad1 = leibnitz(out, RHS)
_leibfinal(out, res, no) = begin
grad1 = leibnitz(out, RHS, no)
grad2 = MacroTools_postwalk(grad1) do ex
# @show ex ex == out
ex == out ? res : ex
Expand All @@ -103,13 +105,13 @@ _leibfinal(out, res) = begin
end
end

leibfinal(ex::Expr, res) = begin
leibfinal(ex::Expr, res, no=()) = begin
if ex.head == :call && ex.args[1] isa Expr &&
ex.args[1].head == :(->) && ex.args[1].args[1] == RHS # then it came from underscores
inner = ex.args[1].args[2]
if inner isa Expr && inner.head == :block
lines = filter(a -> !(a isa LineNumberNode), inner.args)
length(lines) == 1 && return _leinfinal(first(lines), res)
length(lines) == 1 && return _leibfinal(first(lines), res, no) # not tested!
end
end
throw("couldn't understand finaliser")
Expand Down Expand Up @@ -191,9 +193,9 @@ symbwalk(targets, store) = ex -> begin
return ex
end

leibnitz(s::Number, target) = 0
leibnitz(s::Symbol, target) = s == target ? 1 : 0
leibnitz(ex::Expr, target) = begin
leibnitz(s::Number, target, no=()) = 0
leibnitz(s::Symbol, target, no=()) = s == target ? 1 : 0
leibnitz(ex::Expr, target, no=()) = begin
ex == target && return 1
@capture_(ex, B_[ijk__]) && return 0
if ex.head == Symbol("'")
Expand All @@ -202,34 +204,35 @@ leibnitz(ex::Expr, target) = begin
end
ex.head == :call || throw("expected a functionn call, got $ex.")
fun = ex.args[1]
fun in no && return 0
if fun == :log # catch log(a*b) and especially log(a/b)
arg = ex.args[2]
if arg isa Expr && arg.args[1] == :* && length(arg.args) == 3
newex = :(log($(arg.args[2])) + log($(arg.args[3])))
return leibnitz(newex, target)
return leibnitz(newex, target, no)
elseif arg isa Expr && arg.args[1] == :/
newex = :(log($(arg.args[2])) - log($(arg.args[3])))
return leibnitz(newex, target)
return leibnitz(newex, target, no)
end
end
if length(ex.args) == 2 # one-arg function
fx = mydiffrule(fun, ex.args[2])
dx = leibnitz(ex.args[2], target)
dx = leibnitz(ex.args[2], target, no)
return simplitimes(fx, dx)
elseif length(ex.args) == 3 # two-arg function
fx, fy = mydiffrule(fun, ex.args[2:end]...)
dx = leibnitz(ex.args[2], target)
dy = leibnitz(ex.args[3], target)
dx = leibnitz(ex.args[2], target, no)
dy = leibnitz(ex.args[3], target, no)
return simpliplus(simplitimes(fx, dx), simplitimes(fy, dy))
elseif fun in [:+, :*]
fun == :* && return leibnitz(:(*($(ex.args[2]), *($(ex.args[3:end]...)))), target)
dxs = [leibnitz(x, target) for x in ex.args[2:end]]
fun == :* && return leibnitz(:(*($(ex.args[2]), *($(ex.args[3:end]...)))), target, no)
dxs = [leibnitz(x, target, no) for x in ex.args[2:end]]
fun == :+ && return simpliplus(dxs...)
elseif length(ex.args) == 4 # three-arg function such as ifelse
fx, fy, fz = mydiffrule(fun, ex.args[2:end]...)
dx = leibnitz(ex.args[2], target)
dy = leibnitz(ex.args[3], target)
dz = leibnitz(ex.args[4], target)
dx = leibnitz(ex.args[2], target, no)
dy = leibnitz(ex.args[3], target, no)
dz = leibnitz(ex.args[4], target, no)
return simpliplus(simplitimes(fx, dx), simplitimes(fy, dy), simplitimes(fz, dz))
end
throw("don't know how to handle $ex.")
Expand Down