Skip to content

Commit 023b1c6

Browse files
committed
update to avi/typelattice, incoming type lattice overhaul
1 parent 5b8423a commit 023b1c6

File tree

5 files changed

+79
-31
lines changed

5 files changed

+79
-31
lines changed

src/Cthulhu.jl

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ using InteractiveUtils
77
using UUIDs
88
using REPL: REPL, AbstractTerminal
99

10-
using Core: MethodInstance
10+
import Core: MethodInstance, OpaqueClosure
1111
const Compiler = Core.Compiler
12-
using Core.Compiler: MethodMatch, LimitedAccuracy, ignorelimited
12+
import Core.Compiler: MethodMatch, LimitedAccuracy, widenconst, ignorelimited, Const,
13+
PartialStruct, InterConditional, PartialOpaque
1314
import Base: unwrapva, isvarargtype
15+
1416
const mapany = Base.mapany
1517

1618
# branch on https://github.com/JuliaLang/julia/pull/42125
@@ -21,6 +23,15 @@ else
2123
macro constprop(_, ex); esc(ex); end
2224
end
2325

26+
const IS_OVERHAULED = isdefined(Core.Compiler, :LatticeElement)
27+
@static if IS_OVERHAULED
28+
import Core.Compiler: ⊤, ⊥, LatticeElement, NativeType, LatticeElement, isConst,
29+
isLimitedAccuracy, Argtypes, SSAValueTypes
30+
else
31+
const Argtypes = Vector{Any}
32+
const SSAValueTypes = Vector{Any}
33+
end
34+
2435
Base.@kwdef mutable struct CthulhuConfig
2536
enable_highlighter::Bool = false
2637
highlighter::Cmd = `pygmentize -l`
@@ -167,15 +178,36 @@ const descend = descend_code_typed
167178

168179
descend(interp::CthulhuInterpreter, mi::MethodInstance; kwargs...) = _descend(interp, mi; iswarn=false, interruptexc=false, kwargs...)
169180

181+
@static if IS_OVERHAULED
182+
import Core.Compiler: ConditionalInfo
170183
function codeinst_rt(code::CodeInstance)
171184
rettype = code.rettype
172185
if isdefined(code, :rettype_const)
173186
rettype_const = code.rettype_const
174187
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
175-
return Core.PartialStruct(rettype, rettype_const)
176-
elseif isa(rettype_const, Core.PartialOpaque) && rettype <: Core.OpaqueClosure
188+
return PartialStruct(rettype, rettype_const)
189+
elseif isa(rettype_const, PartialOpaque) && rettype <: OpaqueClosure
177190
return rettype_const
178-
elseif isa(rettype_const, Core.InterConditional) && !(Core.InterConditional <: rettype)
191+
elseif isa(rettype_const, ConditionalInfo) && !(ConditionalInfo <: rettype)
192+
@assert rettype_const.inter
193+
return InterConditional(rettype_const.slot_id, rettype_const.vtype, rettype_const.elsetype), mi
194+
else
195+
return Const(rettype_const)
196+
end
197+
else
198+
return rettype
199+
end
200+
end
201+
else # @static if IS_OVERHAULED
202+
function codeinst_rt(code::CodeInstance)
203+
rettype = code.rettype
204+
if isdefined(code, :rettype_const)
205+
rettype_const = code.rettype_const
206+
if isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
207+
return PartialStruct(rettype, rettype_const)
208+
elseif isa(rettype_const, PartialOpaque) && rettype <: OpaqueClosure
209+
return rettype_const
210+
elseif isa(rettype_const, InterConditional) && !(InterConditional <: rettype)
179211
return rettype_const
180212
else
181213
return Const(rettype_const)
@@ -184,6 +216,7 @@ function codeinst_rt(code::CodeInstance)
184216
return rettype
185217
end
186218
end
219+
end # @static if IS_OVERHAULED
187220

188221
# `@constprop :aggressive` here in order to make sure the constant propagation of `allow_no_codeinf`
189222
@constprop :aggressive function lookup(interp::CthulhuInterpreter, mi::MethodInstance, optimize::Bool; allow_no_codeinf::Bool=false)
@@ -193,7 +226,11 @@ end
193226
infos = interp.unopt[mi].stmt_infos
194227
slottypes = codeinf.slottypes
195228
if isnothing(slottypes)
196-
slottypes = Any[ Any for i = 1:length(codeinf.slotflags) ]
229+
@static if IS_OVERHAULED
230+
slottypes = LatticeElement[ ⊤ for i = 1:length(codeinf.slotflags) ]
231+
else
232+
slottypes = Any[ Any for i = 1:length(codeinf.slotflags) ]
233+
end
197234
end
198235
else
199236
codeinst = interp.opt[mi]
@@ -209,7 +246,11 @@ end
209246
# But with coverage on, the empty function body isn't empty due to :code_coverage_effect expressions.
210247
codeinf = nothing
211248
infos = []
212-
slottypes = Any[Base.unwrap_unionall(mi.specTypes).parameters...]
249+
@static if IS_OVERHAULED
250+
slottypes = AbstractLatticce[NativeType(t) for t in Base.unwrap_unionall(mi.specTypes).parameters]
251+
else
252+
slottypes = Any[Base.unwrap_unionall(mi.specTypes).parameters...]
253+
end
213254
else
214255
Core.eval(Main, quote
215256
interp = $interp
@@ -219,7 +260,7 @@ end
219260
error("couldn't find the source; inspect `Main.interp` and `Main.mi`")
220261
end
221262
end
222-
(codeinf, rt, infos, slottypes::Vector{Any})
263+
return codeinf, rt, infos, slottypes::Argtypes
223264
end
224265

225266
##

src/callsite.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ struct MICallInfo <: CallInfo
77
mi::MethodInstance
88
rt
99
function MICallInfo(mi::MethodInstance, @nospecialize(rt))
10-
if isa(rt, LimitedAccuracy)
10+
if @static IS_OVERHAULED ? isLimitedAccuracy(rt) : isa(rt, LimitedAccuracy)
1111
return LimitedCallInfo(new(mi, ignorelimited(rt)))
1212
else
1313
return new(mi, rt)
@@ -36,9 +36,9 @@ struct UncachedCallInfo <: WrappedCallInfo
3636
end
3737

3838
struct PureCallInfo <: CallInfo
39-
argtypes::Vector{Any}
39+
argtypes::Argtypes
4040
rt
41-
PureCallInfo(argtypes::Vector{Any}, @nospecialize(rt)) =
41+
PureCallInfo(argtypes::Argtypes, @nospecialize(rt)) =
4242
new(argtypes, rt)
4343
end
4444
get_mi(::PureCallInfo) = nothing
@@ -239,7 +239,7 @@ end
239239
function show_callinfo(limiter, ci::Union{MultiCallInfo, FailedCallInfo, GeneratedCallInfo})
240240
types = (ci.sig::DataType).parameters
241241
ft, tt... = types
242-
f = Compiler.singleton_type(ft)
242+
f = Compiler.singleton_type((@static IS_OVERHAULED ? LatticeElement : identity)(ft))
243243
if f !== nothing
244244
name = "$f"
245245
elseif ft isa Union

src/codeview.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ function cthulhu_typed(io::IO, debuginfo::Symbol,
119119
if isa(src, Core.CodeInfo)
120120
# we're working on pre-optimization state, need to ignore `LimitedAccuracy`
121121
src = copy(src)
122-
src.ssavaluetypes = Base.mapany(ignorelimited, src.ssavaluetypes::Vector{Any})
122+
@static if IS_OVERHAULED
123+
src.ssavaluetypes = SSAValueTypes([ignorelimited(t) for t in src.ssavaluetypes::SSAValueTypes])
124+
else
125+
src.ssavaluetypes = Base.mapany(ignorelimited, src.ssavaluetypes::SSAValueTypes)
126+
end
123127
src.rettype = ignorelimited(src.rettype)
124128

125129
if src.slotnames !== nothing
@@ -144,8 +148,9 @@ function cthulhu_typed(io::IO, debuginfo::Symbol,
144148
isa(mi, MethodInstance) || throw("`mi::MethodInstance` is required")
145149
code = src isa IRCode ? src.stmts.inst : src.code
146150
cst = Vector{Int}(undef, length(code))
151+
sptypes = Core.Compiler.sptypes_from_meth_instance(mi)
147152
params = Core.Compiler.OptimizationParams(Core.Compiler.NativeInterpreter())
148-
maxcost = Core.Compiler.statement_costs!(cst, code, src, Any[mi.sparam_vals...], false, params)
153+
maxcost = Core.Compiler.statement_costs!(cst, code, src, sptypes, false, params)
149154
nd = ndigits(maxcost)
150155
_lineprinter = lineprinter(src)
151156
function preprinter(io, linestart, idx)
@@ -195,8 +200,9 @@ function show_variables(io, src, slotnames)
195200
slottypes = src.slottypes
196201
for i = 1:length(slotnames)
197202
print(io, " ", slotnames[i])
198-
if isa(slottypes, Vector{Any})
199-
InteractiveUtils.warntype_type_printer(io, slottypes[i], true)
203+
if isa(slottypes, Argtypes)
204+
typ = (@static IS_OVERHAULED ? Core.Compiler.unwraptype : identity)(slottypes[i])
205+
InteractiveUtils.warntype_type_printer(io, typ, true)
200206
end
201207
println(io)
202208
end

src/interpreter.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ end
9090
@static if isdefined(Compiler, :is_stmt_inline)
9191
function Compiler.inlining_policy(
9292
interp::CthulhuInterpreter, @nospecialize(src), stmt_flag::UInt8,
93-
mi::MethodInstance, argtypes::Vector{Any})
94-
@assert isa(src, OptimizedSource) || isnothing(src)
93+
mi::MethodInstance, argtypes::Argtypes)
94+
@assert isa(src, OptimizedSource) || isnothing(src) "`inlining_policy(::CthulhuInterpreter, ...)` got unexpected source: `$(typeof(src))`"
9595
if isa(src, OptimizedSource)
9696
if Compiler.is_stmt_inline(stmt_flag) || src.isinlineable
9797
return src.ir
@@ -100,7 +100,7 @@ function Compiler.inlining_policy(
100100
# the default inlining policy may try additional effor to find the source in a local cache
101101
return Base.@invoke Compiler.inlining_policy(
102102
interp::AbstractInterpreter, nothing, stmt_flag::UInt8,
103-
mi::MethodInstance, argtypes::Vector{Any})
103+
mi::MethodInstance, argtypes::Argtypes)
104104
end
105105
return nothing
106106
end

src/reflection.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ function transform(::Val{:CuFunction}, callsite, callexpr, CI, mi, slottypes; pa
2323
sptypes = sptypes_from_meth_instance(mi)
2424
tt = argextype(callexpr.args[4], CI, sptypes, slottypes)
2525
ft = argextype(callexpr.args[3], CI, sptypes, slottypes)
26-
isa(tt, Const) || return callsite
26+
(@static IS_OVERHAULED ? isConst(tt) : isa(tt, Const)) || return callsite
2727
return Callsite(callsite.id, CuCallInfo(callinfo(Tuple{widenconst(ft), tt.val.parameters...}, Nothing, params)), callsite.head)
2828
end
2929

3030
function find_callsites(interp::CthulhuInterpreter, CI::Union{Core.CodeInfo, IRCode},
3131
stmt_info::Union{Vector, Nothing}, mi::Core.MethodInstance,
32-
slottypes::Vector{Any}, optimize::Bool=true;
32+
slottypes::Argtypes, optimize::Bool=true;
3333
params=current_params())
3434
sptypes = sptypes_from_meth_instance(mi)
3535
callsites = Callsite[]
@@ -47,25 +47,27 @@ function find_callsites(interp::CthulhuInterpreter, CI::Union{Core.CodeInfo, IRC
4747
if !optimize
4848
args = (ignorelhs(c)::Expr).args
4949
end
50-
types = mapany(function (@nospecialize(arg),)
51-
t = argextype(arg, CI, sptypes, slottypes)
52-
return widenconst(ignorelimited(t))
53-
end, args)
50+
argtypes = @static IS_OVERHAULED ?
51+
LatticeElement[argextype(arg, CI, sptypes, slottypes) for arg in args] :
52+
mapany(function (@nospecialize(arg),)
53+
t = argextype(arg, CI, sptypes, slottypes)
54+
return widenconst(ignorelimited(t))
55+
end, args)
5456
was_return_type = false
5557
if isa(info, Core.Compiler.ReturnTypeCallInfo)
5658
info = info.info
5759
was_return_type = true
5860
end
5961

60-
callinfos = process_info(interp, info, types, rt, optimize)
62+
callinfos = process_info(interp, info, argtypes, rt, optimize)
6163
if !was_return_type && isempty(callinfos)
6264
continue
6365
end
6466
callsite = let
6567
if length(callinfos) == 1
6668
callinfo = callinfos[1]
6769
else
68-
callinfo = MultiCallInfo(Compiler.argtypes_to_type(types), rt, callinfos)
70+
callinfo = MultiCallInfo(Compiler.argtypes_to_type(argtypes), rt, callinfos)
6971
end
7072
if was_return_type
7173
callinfo = ReturnTypeCallInfo(callinfo)
@@ -113,10 +115,9 @@ function find_callsites(interp::CthulhuInterpreter, CI::Union{Core.CodeInfo, IRC
113115
return callsites
114116
end
115117

116-
function process_info(interp, @nospecialize(info), types, @nospecialize(rt), optimize::Bool)
118+
function process_info(interp, @nospecialize(info), argtypes::Argtypes, @nospecialize(rt), optimize::Bool)
117119
is_cached(@nospecialize(key)) = haskey(optimize ? interp.opt : interp.unopt, key)
118-
process_recursive(@nospecialize(newinfo)) = process_info(interp, newinfo, types, rt, optimize)
119-
120+
process_recursive(@nospecialize(newinfo)) = process_info(interp, newinfo, argtypes, rt, optimize)
120121

121122
if isa(info, MethodMatchInfo)
122123
if info.results === missing
@@ -130,7 +131,7 @@ function process_info(interp, @nospecialize(info), types, @nospecialize(rt), opt
130131
return is_cached(mi) ? mici : UncachedCallInfo(mici)
131132
end
132133
elseif isa(info, MethodResultPure)
133-
return Any[PureCallInfo(types, rt)]
134+
return Any[PureCallInfo(argtypes, rt)]
134135
elseif isa(info, UnionSplitInfo)
135136
return mapreduce(process_recursive, vcat, info.matches; init=[])::Vector{Any}
136137
elseif isa(info, UnionSplitApplyCallInfo)

0 commit comments

Comments
 (0)