Skip to content

Commit a302511

Browse files
authored
improve return_type handling (#240)
- improve `ReturnTypeCallInfo` printing (see the example below) - don't error when trying to descend into a failed `return_type` site > before ```julia julia> only_ints(::Integer) = 1; julia> descend(; optimize=false) do t1 = Base._return_type(only_ints, Tuple{Int}) # successful `return_type` t2 = Base._return_type(only_ints, Tuple{Float64}) # failed `return_type` t1, t2 end (::var"#3#4")() in Main at REPL[3]:2 Variables #self#::Core.Const(var"#3#4"()) t2::Type{Union{}} t1::Type{Int64} │ ─ %-1 = invoke #3()::Core.Const((Int64, Union{})) @ REPL[3]:2 within `#3` 1 ─ %1 = Base._return_type::Core.Const(Core.Compiler.return_type) │ %2 = Core.apply_type(Main.Tuple, Main.Int)::Core.Const(Tuple{Int64}) │ (t1 = (%1)(Main.only_ints, %2))::Core.Const(Int64) │ @ REPL[3]:3 within `#3` │ %4 = Base._return_type::Core.Const(Core.Compiler.return_type) │ %5 = Core.apply_type(Main.Tuple, Main.Float64)::Core.Const(Tuple{Float64}) │ (t2 = (%4)(Main.only_ints, %5))::Core.Const(Union{}) │ @ REPL[3]:4 within `#3` │ %7 = Core.tuple(t1::Core.Const(Int64), t2::Core.Const(Union{}))::Core.Const((Int64, Union{})) └── return %7 Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark. Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [i]nlining costs, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native. Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code Actions: [E]dit source code, [R]evise and redisplay Advanced: dump [P]arams cache. • %3 = return_type < only_ints(::Int64)::Core.Const(Int64) > %6 = return_type < → return_type(::typeof(only_ints),::Type{Tuple{Float64}})::Core.Const(Union{}) > ↩ ``` > this commit ```julia julia> only_ints(::Integer) = 1; julia> descend(; optimize=false) do t1 = Base._return_type(only_ints, Tuple{Int}) # successful `return_type` t2 = Base._return_type(only_ints, Tuple{Float64}) # failed `return_type` t1, t2 end (::var"#3#4")() in Main at REPL[3]:2 Variables #self#::Core.Const(var"#3#4"()) t2::Type{Union{}} t1::Type{Int64} │ ─ %-1 = invoke #3()::Core.Const((Int64, Union{})) @ REPL[3]:2 within `#3` 1 ─ %1 = Base._return_type::Core.Const(Core.Compiler.return_type) │ %2 = Core.apply_type(Main.Tuple, Main.Int)::Core.Const(Tuple{Int64}) │ (t1 = (%1)(Main.only_ints, %2))::Core.Const(Int64) │ @ REPL[3]:3 within `#3` │ %4 = Base._return_type::Core.Const(Core.Compiler.return_type) │ %5 = Core.apply_type(Main.Tuple, Main.Float64)::Core.Const(Tuple{Float64}) │ (t2 = (%4)(Main.only_ints, %5))::Core.Const(Union{}) │ @ REPL[3]:4 within `#3` │ %7 = Core.tuple(t1::Core.Const(Int64), t2::Core.Const(Union{}))::Core.Const((Int64, Union{})) └── return %7 Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark. Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [i]nlining costs, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native. Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code Actions: [E]dit source code, [R]evise and redisplay Advanced: dump [P]arams cache. • %3 = return_type < only_ints(::Int64)::Int64 > %6 = return_type < only_ints(::Float64)::Union{} > ↩ ```
1 parent 5b8423a commit a302511

File tree

5 files changed

+88
-58
lines changed

5 files changed

+88
-58
lines changed

src/Cthulhu.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using REPL: REPL, AbstractTerminal
99

1010
using Core: MethodInstance
1111
const Compiler = Core.Compiler
12-
using Core.Compiler: MethodMatch, LimitedAccuracy, ignorelimited
12+
import Core.Compiler: MethodMatch, LimitedAccuracy, ignorelimited, specialize_method
1313
import Base: unwrapva, isvarargtype
1414
const mapany = Base.mapany
1515

@@ -462,7 +462,7 @@ end
462462

463463
function get_specialization(@nospecialize(TT))
464464
match = Base._which(TT)
465-
mi = Core.Compiler.specialize_method(match)
465+
mi = specialize_method(match)
466466
return mi
467467
end
468468

src/callsite.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Unicode
22

3-
abstract type CallInfo; end
3+
abstract type CallInfo end
44

55
# Call could be resolved to a singular MI
66
struct MICallInfo <: CallInfo
@@ -97,10 +97,10 @@ get_rt(tci::OCCallInfo) = get_rt(tci.ci)
9797

9898
# Special handling for ReturnTypeCall
9999
struct ReturnTypeCallInfo <: CallInfo
100-
called_mi::CallInfo
100+
vmi::CallInfo # virtualized method call
101101
end
102-
get_mi(rtci::ReturnTypeCallInfo) = get_mi(rtci.called_mi)
103-
get_rt(rtci::ReturnTypeCallInfo) = get_rt(rtci.called_mi)
102+
get_mi((; vmi)::ReturnTypeCallInfo) = isa(vmi, FailedCallInfo) ? nothing : get_mi(vmi)
103+
get_rt((; vmi)::ReturnTypeCallInfo) = Type{isa(vmi, FailedCallInfo) ? Union{} : widenconst(get_rt(vmi))}
104104

105105
struct ConstPropCallInfo <: CallInfo
106106
mi::CallInfo
@@ -232,7 +232,7 @@ function show_callinfo(limiter, mici::MICallInfo)
232232
mi = mici.mi
233233
tt = (Base.unwrap_unionall(mi.specTypes)::DataType).parameters[2:end]
234234
name = (mi.def::Method).name
235-
rt = mici.rt
235+
rt = get_rt(mici)
236236
__show_limited(limiter, name, tt, rt)
237237
end
238238

@@ -247,8 +247,7 @@ function show_callinfo(limiter, ci::Union{MultiCallInfo, FailedCallInfo, Generat
247247
else
248248
name = "→ (::$(nameof(ft)))"
249249
end
250-
rt = ci.rt
251-
__show_limited(limiter, name::String, tt, rt)
250+
__show_limited(limiter, name::String, tt, get_rt(ci))
252251
end
253252

254253
function show_callinfo(limiter, (; argtypes, rt)::PureCallInfo)
@@ -262,7 +261,19 @@ function show_callinfo(limiter, ci::ConstPropCallInfo)
262261
# XXX: The first argument could be const-overriden too
263262
name = ci.result.linfo.def.name
264263
tt = ci.result.argtypes[2:end]
265-
__show_limited(limiter, name, tt, (ignorewrappers(ci.mi)::MICallInfo).rt)
264+
__show_limited(limiter, name, tt, get_rt(ignorewrappers(ci.mi)::MICallInfo))
265+
end
266+
267+
function show_callinfo(limiter, (; vmi)::ReturnTypeCallInfo)
268+
if isa(vmi, FailedCallInfo)
269+
ft = Base.tuple_type_head(vmi.sig)
270+
f = Compiler.singleton_type(ft)
271+
name = isnothing(f) ? "unknown" : string(f)
272+
tt = Base.tuple_type_tail(vmi.sig).parameters
273+
__show_limited(limiter, name, tt, vmi.rt)
274+
else
275+
show_callinfo(limiter, vmi)
276+
end
266277
end
267278

268279
function Base.show(io::IO, c::Callsite)
@@ -306,7 +317,7 @@ function Base.show(io::IO, c::Callsite)
306317
print(limiter, " >")
307318
elseif isa(info, ReturnTypeCallInfo)
308319
print(limiter, " = return_type < ")
309-
show_callinfo(limiter, info.called_mi)
320+
show_callinfo(limiter, info)
310321
print(limiter, " >")
311322
elseif isa(info, CuCallInfo)
312323
print(limiter, " = cucall < ")
@@ -342,7 +353,7 @@ is_callsite(info::WrappedCallInfo, mi::MethodInstance) = is_callsite(get_wrapped
342353
is_callsite(info::ConstPropCallInfo, mi::MethodInstance) = is_callsite(info.mi, mi)
343354
is_callsite(info::TaskCallInfo, mi::MethodInstance) = is_callsite(info.ci, mi)
344355
is_callsite(info::InvokeCallInfo, mi::MethodInstance) = is_callsite(info.ci, mi)
345-
is_callsite(info::ReturnTypeCallInfo, mi::MethodInstance) = is_callsite(info.called_mi, mi)
356+
is_callsite(info::ReturnTypeCallInfo, mi::MethodInstance) = is_callsite(info.vmi, mi)
346357
is_callsite(info::CuCallInfo, mi::MethodInstance) = is_callsite(info.cumi, mi)
347358
function is_callsite(info::MultiCallInfo, mi::MethodInstance)
348359
for csi in info.callinfos

src/interpreter.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ function Compiler.finish(state::InferenceState, interp::CthulhuInterpreter)
7676
end
7777

7878
function Compiler.transform_result_for_cache(interp::CthulhuInterpreter, linfo::MethodInstance,
79-
valid_worlds::Core.Compiler.WorldRange, @nospecialize(inferred_result))
79+
valid_worlds::WorldRange, @nospecialize(inferred_result))
8080
if isa(inferred_result, OptimizationState)
8181
opt = inferred_result
8282
if isdefined(opt, :ir)

src/reflection.jl

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
using Base.Meta
66
import .Compiler: widenconst, argextype, Const, MethodMatchInfo,
77
UnionSplitApplyCallInfo, UnionSplitInfo, ConstCallInfo,
8-
MethodResultPure, ApplyCallInfo
8+
MethodResultPure, ApplyCallInfo,
9+
sptypes_from_meth_instance, argtypes_to_type
10+
import Base: may_invoke_generator
911

10-
const sptypes_from_meth_instance = Core.Compiler.sptypes_from_meth_instance
11-
const may_invoke_generator = Base.may_invoke_generator
1212
function code_for_method(method, metharg, methsp, world, preexisting=false)
1313
@static if VERSION v"1.8.0-DEV.369"
1414
# https://github.com/JuliaLang/julia/pull/41920
15-
Core.Compiler.specialize_method(method, metharg, methsp; preexisting)
15+
specialize_method(method, metharg, methsp; preexisting)
1616
else
17-
Core.Compiler.specialize_method(method, metharg, methsp, preexisting)
17+
specialize_method(method, metharg, methsp, preexisting)
1818
end
1919
end
2020

@@ -27,6 +27,8 @@ function transform(::Val{:CuFunction}, callsite, callexpr, CI, mi, slottypes; pa
2727
return Callsite(callsite.id, CuCallInfo(callinfo(Tuple{widenconst(ft), tt.val.parameters...}, Nothing, params)), callsite.head)
2828
end
2929

30+
const ArgTypes = Vector{Any}
31+
3032
function find_callsites(interp::CthulhuInterpreter, CI::Union{Core.CodeInfo, IRCode},
3133
stmt_info::Union{Vector, Nothing}, mi::Core.MethodInstance,
3234
slottypes::Vector{Any}, optimize::Bool=true;
@@ -47,28 +49,17 @@ function find_callsites(interp::CthulhuInterpreter, CI::Union{Core.CodeInfo, IRC
4749
if !optimize
4850
args = (ignorelhs(c)::Expr).args
4951
end
50-
types = mapany(function (@nospecialize(arg),)
51-
t = argextype(arg, CI, sptypes, slottypes)
52-
return widenconst(ignorelimited(t))
53-
end, args)
54-
was_return_type = false
55-
if isa(info, Core.Compiler.ReturnTypeCallInfo)
56-
info = info.info
57-
was_return_type = true
58-
end
59-
60-
callinfos = process_info(interp, info, types, rt, optimize)
61-
if !was_return_type && isempty(callinfos)
62-
continue
63-
end
52+
argtypes = mapany(function (@nospecialize(arg),)
53+
t = argextype(arg, CI, sptypes, slottypes)
54+
return widenconst(ignorelimited(t))
55+
end, args)
56+
callinfos = process_info(interp, info, argtypes, rt, optimize)
57+
isempty(callinfos) && continue
6458
callsite = let
6559
if length(callinfos) == 1
6660
callinfo = callinfos[1]
6761
else
68-
callinfo = MultiCallInfo(Compiler.argtypes_to_type(types), rt, callinfos)
69-
end
70-
if was_return_type
71-
callinfo = ReturnTypeCallInfo(callinfo)
62+
callinfo = MultiCallInfo(argtypes_to_type(argtypes), rt, callinfos)
7263
end
7364
Callsite(id, callinfo, c.head)
7465
end
@@ -113,10 +104,9 @@ function find_callsites(interp::CthulhuInterpreter, CI::Union{Core.CodeInfo, IRC
113104
return callsites
114105
end
115106

116-
function process_info(interp, @nospecialize(info), types, @nospecialize(rt), optimize::Bool)
107+
function process_info(interp, @nospecialize(info), argtypes::ArgTypes, @nospecialize(rt), optimize::Bool)
117108
is_cached(@nospecialize(key)) = haskey(optimize ? interp.opt : interp.unopt, key)
118-
process_recursive(@nospecialize(newinfo)) = process_info(interp, newinfo, types, rt, optimize)
119-
109+
process_recursive(@nospecialize(newinfo)) = process_info(interp, newinfo, argtypes, rt, optimize)
120110

121111
if isa(info, MethodMatchInfo)
122112
if info.results === missing
@@ -125,12 +115,12 @@ function process_info(interp, @nospecialize(info), types, @nospecialize(rt), opt
125115

126116
matches = info.results.matches
127117
return mapany(matches) do match::Core.MethodMatch
128-
mi = Core.Compiler.specialize_method(match)
118+
mi = specialize_method(match)
129119
mici = MICallInfo(mi, rt)
130120
return is_cached(mi) ? mici : UncachedCallInfo(mici)
131121
end
132122
elseif isa(info, MethodResultPure)
133-
return Any[PureCallInfo(types, rt)]
123+
return Any[PureCallInfo(argtypes, rt)]
134124
elseif isa(info, UnionSplitInfo)
135125
return mapreduce(process_recursive, vcat, info.matches; init=[])::Vector{Any}
136126
elseif isa(info, UnionSplitApplyCallInfo)
@@ -158,16 +148,33 @@ function process_info(interp, @nospecialize(info), types, @nospecialize(rt), opt
158148
elseif (@static isdefined(Compiler, :OpaqueClosureCreateInfo) && true) && isa(info, Compiler.OpaqueClosureCreateInfo)
159149
# TODO: Add ability to descend into OCs at creation site
160150
return []
151+
elseif isa(info, Compiler.ReturnTypeCallInfo)
152+
newargtypes = argtypes[2:end]
153+
callinfos = process_info(interp, info.info, newargtypes, unwrapType(widenconst(rt)), optimize)
154+
if length(callinfos) == 1
155+
vmi = only(callinfos)
156+
else
157+
@assert isempty(callinfos)
158+
argt = unwrapType(widenconst(newargtypes[2]))::DataType
159+
sig = Tuple{widenconst(newargtypes[1]), argt.parameters...}
160+
vmi = FailedCallInfo(sig, Union{})
161+
end
162+
return Any[ReturnTypeCallInfo(vmi)]
161163
elseif info == false
162164
return []
163165
else
164-
@show CI
165-
@show c
166-
@show info
167-
error()
166+
@eval Main begin
167+
interp = $interp
168+
info = $info
169+
argtypes = $argtypes
170+
rt = $rt
171+
optimize = $optimize
172+
end
173+
error("inspect `Main.interp|info|argtypes|rt|optimize`")
168174
end
169175
end
170176

177+
unwrapType(@nospecialize t) = Compiler.isType(t) ? t.parameters[1] : t
171178

172179
ignorelhs(@nospecialize(x)) = isexpr(x, :(=)) ? last(x.args) : x
173180
function is_call_expr(x::Expr, optimize::Bool)

test/runtests.jl

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -234,22 +234,34 @@ struct SingletonPureCallable{N} end
234234
@test s == "SingletonPureCallable{1}()(::Float64)::Float64"
235235
end
236236

237-
# Failed return_type
238-
only_ints(::Integer) = 1
239-
return_type_failure(::T) where T = Base._return_type(only_ints, Tuple{T})
240-
let callsites = find_callsites_by_ftt(return_type_failure, Tuple{Float64}, optimize=false)
241-
@test length(callsites) == 1
242-
callinfo = callsites[1].info
243-
@test callinfo isa Cthulhu.ReturnTypeCallInfo
244-
callinfo = callinfo.called_mi
245-
@test callinfo isa Cthulhu.MultiCallInfo
237+
@testset "ReturnTypeCallInfo" begin
238+
only_ints(::Integer) = 1
239+
240+
callsites = find_callsites_by_ftt(; optimize=false) do
241+
t1 = Base._return_type(only_ints, Tuple{Int}) # successful `return_type`
242+
t2 = Base._return_type(only_ints, Tuple{Float64}) # failed `return_type`
243+
t1, t2
244+
end
245+
@test length(callsites) == 2
246+
callinfo1 = callsites[1].info
247+
@test callinfo1 isa Cthulhu.ReturnTypeCallInfo
248+
@test callinfo1.vmi isa Cthulhu.MICallInfo
246249
io = IOBuffer()
247-
Cthulhu.show_callinfo(io, callinfo)
248-
@test String(take!(io)) == "→ return_type(::typeof(only_ints),::Type{Tuple{Float64}})::Core.Const(Union{})"
250+
Cthulhu.show_callinfo(io, callinfo1)
251+
@test String(take!(io)) == "only_ints(::$Int)::$Int"
249252
io = IOBuffer()
250253
print(io, callsites[1])
251-
@test occursin("return_type < → return_type", String(take!(io)))
252-
@test length(callinfo.callinfos) == 0
254+
@test occursin("return_type < only_ints(::$Int)::$Int >", String(take!(io)))
255+
256+
callinfo2 = callsites[2].info
257+
@test callinfo2 isa Cthulhu.ReturnTypeCallInfo
258+
@test callinfo2.vmi isa Cthulhu.FailedCallInfo
259+
io = IOBuffer()
260+
Cthulhu.show_callinfo(io, callinfo2)
261+
@test String(take!(io)) == "only_ints(::Float64)::Union{}"
262+
io = IOBuffer()
263+
print(io, callsites[2])
264+
@test occursin("return_type < only_ints(::Float64)::Union{} >", String(take!(io)))
253265
end
254266

255267
# tasks

0 commit comments

Comments
 (0)