Skip to content

Commit c2f93bf

Browse files
github-actions[bot]CompatHelper Juliasethaxen
authored
CompatHelper: bump compat for PosteriorStats to 0.4, (keep existing compat) (#50)
* CompatHelper: bump compat for PosteriorStats to 0.4, (keep existing compat) * Only define and test WAIC conversions if it exists * Increment patch number --------- Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: Seth Axen <[email protected]>
1 parent 27ec12b commit c2f93bf

File tree

3 files changed

+31
-28
lines changed

3 files changed

+31
-28
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArviZPythonPlots"
22
uuid = "4a6e88f0-2c8e-11ee-0601-e94153f0eada"
33
authors = ["Seth Axen <[email protected]>"]
4-
version = "0.1.11"
4+
version = "0.1.12"
55

66
[deps]
77
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
@@ -22,7 +22,7 @@ DimensionalData = "0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29"
2222
InferenceObjects = "0.4.13"
2323
Markdown = "1"
2424
OrderedCollections = "1"
25-
PosteriorStats = "0.3"
25+
PosteriorStats = "0.3, 0.4"
2626
PythonCall = "0.9"
2727
PythonPlot = "1"
2828
Random = "1"

src/conversions.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,26 @@ function PythonCall.Py(d::PSISLOOResult)
2020
return arviz.stats.ELPDData(; data, index)
2121
end
2222

23-
function PythonCall.Py(d::WAICResult)
24-
estimates = elpd_estimates(d)
25-
pointwise = elpd_estimates(d; pointwise=true)
26-
ds = convert_to_dataset((waic_i=pointwise.elpd,))
27-
pyds = PythonCall.Py(ds)
28-
entries = (
29-
elpd_waic=estimates.elpd,
30-
se=estimates.se_elpd,
31-
p_waic=estimates.p,
32-
n_samples="unknown",
33-
n_data_points=length(pointwise.elpd),
34-
warning=false,
35-
waic_i=pyds.waic_i,
36-
scale="log",
37-
)
38-
data = pylist(values(entries))
39-
index = pylist(map(pystr, keys(entries)))
40-
return arviz.stats.ELPDData(; data, index)
23+
@static if isdefined(PosteriorStats, :WAICResult)
24+
function PythonCall.Py(d::WAICResult)
25+
estimates = elpd_estimates(d)
26+
pointwise = elpd_estimates(d; pointwise=true)
27+
ds = convert_to_dataset((waic_i=pointwise.elpd,))
28+
pyds = PythonCall.Py(ds)
29+
entries = (
30+
elpd_waic=estimates.elpd,
31+
se=estimates.se_elpd,
32+
p_waic=estimates.p,
33+
n_samples="unknown",
34+
n_data_points=length(pointwise.elpd),
35+
warning=false,
36+
waic_i=pyds.waic_i,
37+
scale="log",
38+
)
39+
data = pylist(values(entries))
40+
index = pylist(map(pystr, keys(entries)))
41+
return arviz.stats.ELPDData(; data, index)
42+
end
4143
end
4244

4345
function rekey(nt::NamedTuple, old_new_keys::Pair...)
@@ -48,14 +50,16 @@ end
4850
function PythonCall.Py(mc::ModelComparisonResult)
4951
table = Tables.columntable(mc)
5052
se_pairs = (:se_elpd => :se, :se_elpd_diff => :dse)
51-
est_pairs = if eltype(mc.elpd_result) <: PSISLOOResult
52-
(:elpd => :elpd_loo, :p => :p_loo)
53-
elseif eltype(mc.elpd_result) <: WAICResult
54-
(:elpd => :elpd_waic, :p => :p_waic)
55-
end
53+
est_pairs = _estimate_name_map(eltype(mc.elpd_result))
5654
nrows = Tables.rowcount(table)
5755
new_cols = (warning=fill(false, nrows), scale=fill("log", nrows))
5856
table_new = merge(rekey(table, est_pairs..., se_pairs...), new_cols)
5957
pdf = topandas(Val(:DataFrame), table_new; index_name="name")
6058
return pdf
6159
end
60+
61+
_estimate_name_map(::Type{<:PSISLOOResult}) = (:elpd => :elpd_loo, :p => :p_loo)
62+
63+
@static if isdefined(PosteriorStats, :WAICResult)
64+
_estimate_name_map(::Type{<:WAICResult}) = (:elpd => :elpd_waic, :p => :p_waic)
65+
end

test/test_conversions.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ using Test
2525
@test pyconvert(Array{Float64}, py_loo_result.pareto_k.values)
2626
pyconvert(Array{Float64}, loo_py_result.pareto_k.values) rtol = 1e-1
2727
end
28-
@testset "WAICResult" begin
28+
isdefined(PosteriorStats, :WAICResult) && @testset "WAICResult" begin
2929
idata = load_example_data("centered_eight")
3030
waic_result = waic(idata)
3131
waic_py_result = ArviZPythonPlots.arviz.waic(idata; pointwise=true)
@@ -36,8 +36,7 @@ using Test
3636
)
3737
@test pyconvert(Float64, py_waic_result.elpd_waic)
3838
pyconvert(Float64, waic_py_result.elpd_waic) rtol = 1e-3
39-
@test pyconvert(Float64, py_waic_result.se) pyconvert(Float64, waic_py_result.se) rtol =
40-
1e-1
39+
@test pyconvert(Float64, py_waic_result.se) pyconvert(Float64, waic_py_result.se) rtol = 1e-1
4140
@test pyconvert(Float64, py_waic_result.p_waic)
4241
pyconvert(Float64, waic_py_result.p_waic) rtol = 1e-3
4342
@test pyconvert(Array{Float64}, py_waic_result.waic_i.values)

0 commit comments

Comments
 (0)