diff --git a/src/grad/reverse.jl b/src/grad/reverse.jl index da8b335..9bddd93 100644 --- a/src/grad/reverse.jl +++ b/src/grad/reverse.jl @@ -8,7 +8,7 @@ using .ReverseDiff ReverseDiff.@grad function (ev::Eval)(args...) Z = ev.fwd(ReverseDiff.value.(args)...) Z, Δ -> begin - isnothing(ev.rev) && error("no gradient definition here!") + ev.rev===nothing && throw("No gradient definition found! Running `@tullio` with keyword `verbose=true` may print the reason") ev.rev(ReverseDiff.value(Δ), Z, ReverseDiff.value.(args)...) end end diff --git a/src/grad/tracker.jl b/src/grad/tracker.jl index 53e1618..723914f 100644 --- a/src/grad/tracker.jl +++ b/src/grad/tracker.jl @@ -8,7 +8,7 @@ using .Tracker Tracker.@grad function (ev::Eval)(args...) Z = ev.fwd(Tracker.data.(args)...) Z, Δ -> begin - isnothing(ev.rev) && error("no gradient definition here!") + ev.rev===nothing && throw("No gradient definition found! Running `@tullio` with keyword `verbose=true` may print the reason") tuple(ev.rev(Tracker.data(Δ), Z, Tracker.data.(args)...)...) end end diff --git a/src/grad/zygote.jl b/src/grad/zygote.jl index 82808f9..f719c79 100644 --- a/src/grad/zygote.jl +++ b/src/grad/zygote.jl @@ -4,7 +4,7 @@ using .Zygote Zygote.@adjoint function (ev::Eval)(args...) Z = ev.fwd(args...) Z, Δ -> begin - isnothing(ev.rev) && error("no gradient definition here!") + ev.rev===nothing && throw("No gradient definition found! Running `@tullio` with keyword `verbose=true` may print the reason") tuple(nothing, ev.rev(Δ, Z, args...)...) end end diff --git a/src/macro.jl b/src/macro.jl index 0f0aa0b..d49b92f 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -127,6 +127,7 @@ function parse_options(exs...) ) expr = nothing nograd = Symbol[] + safe = Symbol[] ranges = Tuple[] for ex in exs # Actual options: @@ -150,6 +151,16 @@ function parse_options(exs...) throw("this accepts nograd=A or nograd=(A,B,C)") end + # Safe keyword + elseif isexpr(ex, :(=)) && ex.args[1] == :safe + if ex.args[2] isa Symbol + push!(safe, ex.args[2]) + elseif isexpr(ex.args[2], :tuple) + append!(safe, ex.args[2].args) + else + throw("this accepts safe=i or safe=(i,j,k)") + end + # Ranges specified outside: elseif isexpr(ex, :call) && ex.args[1] in [:in, :∈] push!(ranges, (ex.args[2], ex.args[3])) @@ -190,6 +201,7 @@ function parse_options(exs...) avx=opts[:avx], cuda=opts[:cuda], nograd=nograd, + safe=safe, ), ranges, expr end @@ -588,7 +600,7 @@ detectunsafe(expr, list, store) = MacroTools_postwalk(expr) do ex MacroTools_postwalk(i) do x @capture_(x, B_[inner__]) || return x # Now we have found an array which indexes another one, mark its indices unsafe - append!(list, filter(j -> j isa Symbol, inner)) + append!(list, setdiff(filter(j -> j isa Symbol, inner), store.safe)) unique!(list) # and don't compute a gradient for the inner array B isa Symbol && push!(store.nograd, B) diff --git a/src/symbolic.jl b/src/symbolic.jl index a073756..1e26631 100644 --- a/src/symbolic.jl +++ b/src/symbolic.jl @@ -3,6 +3,8 @@ using DiffRules +const _CSE = Ref(true) + function insert_symbolic_gradient(axislist, store) dZ = Symbol(DEL, ZED) @@ -32,12 +34,12 @@ function insert_symbolic_gradient(axislist, store) inbody, prebody = [], [] for (dt, t) in unique(targets) - drdt = leibnitz(store.right, t) + drdt = leibnitz(store.right, t, store.nograd) deltar = if store.finaliser == :identity simplitimes(simpliconj(drdt), :($dZ[$(store.leftraw...)])) else rhs = :($ZED[$(store.leftraw...)]) - dldr = leibfinal(store.finaliser, rhs) + dldr = leibfinal(store.finaliser, rhs, store.nograd) simplitimes(simpliconj(drdt), simpliconj(dldr), :($dZ[$(store.leftraw...)])) end if store.redfun == :+ @@ -50,7 +52,7 @@ function insert_symbolic_gradient(axislist, store) end end store.verbose>0 && @info "symbolic gradients" inbody - ex_body = commonsubex(quote $(inbody...) end) + ex_body = _CSE[] ? commonsubex(quote $(inbody...) end) : quote $(inbody...) end ex_pre, ex_post = if store.redfun == :* # then nonzero LHS are handled already, but harder cases here: product_grad(prebody, store) @@ -84,16 +86,16 @@ function insert_symbolic_gradient(axislist, store) end -leibfinal(fun::Symbol, res) = +leibfinal(fun::Symbol, res, no=()) = if fun == :log :(exp(-$res)) # this exp gets done at every element :( # :(inv(exp($res))) else - _leibfinal(:($fun($RHS)), res) + _leibfinal(:($fun($RHS)), res, no) end -_leibfinal(out, res) = begin - grad1 = leibnitz(out, RHS) +_leibfinal(out, res, no) = begin + grad1 = leibnitz(out, RHS, no) grad2 = MacroTools_postwalk(grad1) do ex # @show ex ex == out ex == out ? res : ex @@ -103,13 +105,13 @@ _leibfinal(out, res) = begin end end -leibfinal(ex::Expr, res) = begin +leibfinal(ex::Expr, res, no=()) = begin if ex.head == :call && ex.args[1] isa Expr && ex.args[1].head == :(->) && ex.args[1].args[1] == RHS # then it came from underscores inner = ex.args[1].args[2] if inner isa Expr && inner.head == :block lines = filter(a -> !(a isa LineNumberNode), inner.args) - length(lines) == 1 && return _leinfinal(first(lines), res) + length(lines) == 1 && return _leibfinal(first(lines), res, no) # not tested! end end throw("couldn't understand finaliser") @@ -191,9 +193,9 @@ symbwalk(targets, store) = ex -> begin return ex end -leibnitz(s::Number, target) = 0 -leibnitz(s::Symbol, target) = s == target ? 1 : 0 -leibnitz(ex::Expr, target) = begin +leibnitz(s::Number, target, no=()) = 0 +leibnitz(s::Symbol, target, no=()) = s == target ? 1 : 0 +leibnitz(ex::Expr, target, no=()) = begin ex == target && return 1 @capture_(ex, B_[ijk__]) && return 0 if ex.head == Symbol("'") @@ -202,34 +204,35 @@ leibnitz(ex::Expr, target) = begin end ex.head == :call || throw("expected a functionn call, got $ex.") fun = ex.args[1] + fun in no && return 0 if fun == :log # catch log(a*b) and especially log(a/b) arg = ex.args[2] if arg isa Expr && arg.args[1] == :* && length(arg.args) == 3 newex = :(log($(arg.args[2])) + log($(arg.args[3]))) - return leibnitz(newex, target) + return leibnitz(newex, target, no) elseif arg isa Expr && arg.args[1] == :/ newex = :(log($(arg.args[2])) - log($(arg.args[3]))) - return leibnitz(newex, target) + return leibnitz(newex, target, no) end end if length(ex.args) == 2 # one-arg function fx = mydiffrule(fun, ex.args[2]) - dx = leibnitz(ex.args[2], target) + dx = leibnitz(ex.args[2], target, no) return simplitimes(fx, dx) elseif length(ex.args) == 3 # two-arg function fx, fy = mydiffrule(fun, ex.args[2:end]...) - dx = leibnitz(ex.args[2], target) - dy = leibnitz(ex.args[3], target) + dx = leibnitz(ex.args[2], target, no) + dy = leibnitz(ex.args[3], target, no) return simpliplus(simplitimes(fx, dx), simplitimes(fy, dy)) elseif fun in [:+, :*] - fun == :* && return leibnitz(:(*($(ex.args[2]), *($(ex.args[3:end]...)))), target) - dxs = [leibnitz(x, target) for x in ex.args[2:end]] + fun == :* && return leibnitz(:(*($(ex.args[2]), *($(ex.args[3:end]...)))), target, no) + dxs = [leibnitz(x, target, no) for x in ex.args[2:end]] fun == :+ && return simpliplus(dxs...) elseif length(ex.args) == 4 # three-arg function such as ifelse fx, fy, fz = mydiffrule(fun, ex.args[2:end]...) - dx = leibnitz(ex.args[2], target) - dy = leibnitz(ex.args[3], target) - dz = leibnitz(ex.args[4], target) + dx = leibnitz(ex.args[2], target, no) + dy = leibnitz(ex.args[3], target, no) + dz = leibnitz(ex.args[4], target, no) return simpliplus(simplitimes(fx, dx), simplitimes(fy, dy), simplitimes(fz, dz)) end throw("don't know how to handle $ex.")