Skip to content

Commit 7d42ab2

Browse files
authored
**BREAKING** change of sampleplot behaviour and defaults (#213)
This PR makes two breaking changes to the `sampleplot` plotting utility: 1. **Most default settings have been removed.** - `sampleplot` no longer imposes a default seriescolor. **How to obtain the previous color:** manually pass `seriescolor = "red"`. - The markers have been removed as they seemed more confusing than helpful (overwhelmingly large in the legend, almost invisibly tiny on the plot). The `linealpha` default has been increased to 0.35 to keep visual appearance the same. **How to obtain the previous markers:** If you actually liked the tiny markers, [check the diff](https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/pull/213/files#diff-59728b925dcbfab15dd2def5624b3bd8b3a6a609570dc56379f69c8ee0bdef92L124-L127) for the previous default settings that you now need to pass manually. 2. **Multiple samples (`samples` kwarg > 1) are now plotted as a single series.** This means, for example, that passing a string as `label` only adds one legend entry for all samples, not one _per_ sample. **How to obtain previous behaviour:** If you actually want each sample to be a different series (for example, to give them different colors), you now need to call `sampleplot` multiple times instead. Additionally, this PR contains the following non-breaking minor changes: - minor fix: error message for check of `ribbon_scale` in `plot` recipe - internal code cleanup (`pop!` for custom-defined kwarg)
1 parent 4423ebc commit 7d42ab2

File tree

7 files changed

+50
-60
lines changed

7 files changed

+50
-60
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AbstractGPs"
22
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
33
authors = ["JuliaGaussianProcesses Team"]
4-
version = "0.4.0"
4+
version = "0.5.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
55

66
[compat]
7-
AbstractGPs = "0.4"
7+
AbstractGPs = "0.4, 0.5"
88
Documenter = "0.27"

examples/regression-1d/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1313
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1414

1515
[compat]
16-
AbstractGPs = "0.4"
16+
AbstractGPs = "0.4, 0.5"
1717
AdvancedHMC = "0.2"
1818
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
1919
DynamicHMC = "2.2, 3.1"

examples/regression-1d/script.jl

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,19 @@ mean(logpdf(gp_posterior(x_train, y_train, p)(x_test), y_test) for p in samples)
212212
# We sample 5 functions from each posterior GP given by the final 100 samples of kernel
213213
# parameters.
214214

215-
plt = scatter(
216-
x_train,
217-
y_train;
218-
xlim=(0, 1),
219-
xlabel="x",
220-
ylabel="y",
221-
title="posterior (AdvancedHMC)",
222-
label="Train Data",
223-
)
224-
for p in samples[(end - 100):end]
225-
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p); samples=5)
215+
plt = plot(; xlim=(0, 1), xlabel="x", ylabel="y", title="posterior (AdvancedHMC)")
216+
for (i, p) in enumerate(samples[(end - 100):end])
217+
sampleplot!(
218+
plt,
219+
0:0.02:1,
220+
gp_posterior(x_train, y_train, p);
221+
samples=5,
222+
seriescolor="red",
223+
label=(i == 1 ? "samples" : nothing),
224+
)
226225
end
227-
scatter!(plt, x_test, y_test; label="Test Data")
226+
scatter!(plt, x_train, y_train; label="Train Data", markercolor=1)
227+
scatter!(plt, x_test, y_test; label="Test Data", markercolor=2)
228228
plt
229229

230230
# #### DynamicHMC
@@ -290,18 +290,11 @@ mean(logpdf(gp_posterior(x_train, y_train, p)(x_test), y_test) for p in samples)
290290
# We sample a function from the posterior GP for the final 100 samples of kernel
291291
# parameters.
292292

293-
plt = scatter(
294-
x_train,
295-
y_train;
296-
xlim=(0, 1),
297-
xlabel="x",
298-
ylabel="y",
299-
title="posterior (DynamicHMC)",
300-
label="Train Data",
301-
)
293+
plt = plot(; xlim=(0, 1), xlabel="x", ylabel="y", title="posterior (DynamicHMC)")
294+
scatter!(plt, x_train, y_train; label="Train Data")
302295
scatter!(plt, x_test, y_test; label="Test Data")
303296
for p in samples[(end - 100):end]
304-
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p))
297+
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p); seriescolor="red")
305298
end
306299
plt
307300

@@ -349,18 +342,13 @@ mean(logpdf(gp_posterior(x_train, y_train, p)(x_test), y_test) for p in samples)
349342
# We sample a function from the posterior GP for the final 100 samples of kernel
350343
# parameters.
351344

352-
plt = scatter(
353-
x_train,
354-
y_train;
355-
xlim=(0, 1),
356-
xlabel="x",
357-
ylabel="y",
358-
title="posterior (EllipticalSliceSampling)",
359-
label="Train Data",
345+
plt = plot(;
346+
xlim=(0, 1), xlabel="x", ylabel="y", title="posterior (EllipticalSliceSampling)"
360347
)
348+
scatter!(plt, x_train, y_train; label="Train Data")
361349
scatter!(plt, x_test, y_test; label="Test Data")
362350
for p in samples[(end - 100):end]
363-
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p))
351+
sampleplot!(plt, 0:0.02:1, gp_posterior(x_train, y_train, p); seriescolor="red")
364352
end
365353
plt
366354

src/util/plotting.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
length(x) == length(gp.x) ||
55
throw(DimensionMismatch("length of `x` and `gp.x` has to be equal"))
66
scale::Float64 = pop!(plotattributes, :ribbon_scale, 1.0)
7-
scale > 0.0 || error("`bandwidth` keyword argument must be non-negative")
7+
scale >= 0.0 || error("`ribbon_scale` keyword argument must be non-negative")
88

99
# compute marginals
1010
μ, σ2 = mean_and_var(gp)
@@ -82,16 +82,19 @@ Plot samples from the projection `f` of a Gaussian process versus `x`.
8282
Make sure to load [Plots.jl](https://github.com/JuliaPlots/Plots.jl) before you use
8383
this function.
8484
85+
When plotting multiple samples, these are treated as a _single_ series (i.e.,
86+
only a single entry will be added to the legend when providing a `label`).
87+
8588
# Example
8689
8790
```julia
8891
using Plots
8992
9093
gp = GP(SqExponentialKernel())
91-
sampleplot(gp(rand(5)); samples=10, markersize=5)
94+
sampleplot(gp(rand(5)); samples=10, linealpha=1.0)
9295
```
93-
The given example plots 10 samples from the projection of the GP `gp`. The `markersize` is modified
94-
from default of 0.5 to 5.
96+
The given example plots 10 samples from the projection of the GP `gp`.
97+
The `linealpha` is modified from default of 0.35 to 1.
9598
9699
---
97100
sampleplot(x::AbstractVector, gp::AbstractGP; samples=1, kwargs...)
@@ -115,18 +118,15 @@ SamplePlot((f,)::Tuple{<:FiniteGP}) = SamplePlot((f.x, f))
115118
SamplePlot((x, gp)::Tuple{<:AbstractVector,<:AbstractGP}) = SamplePlot((gp(x, 1e-9),))
116119

117120
@recipe function f(sp::SamplePlot)
118-
nsamples::Int = get(plotattributes, :samples, 1)
121+
nsamples::Int = pop!(plotattributes, :samples, 1)
119122
samples = rand(sp.f, nsamples)
120123

124+
flat_x = repeat(vcat(sp.x, NaN), nsamples)
125+
flat_f = vec(vcat(samples, fill(NaN, 1, nsamples)))
126+
121127
# Set default attributes
122-
seriestype --> :line
123-
linealpha --> 0.2
124-
markershape --> :circle
125-
markerstrokewidth --> 0.0
126-
markersize --> 0.5
127-
markeralpha --> 0.3
128-
seriescolor --> "red"
128+
linealpha --> 0.35
129129
label --> ""
130130

131-
return sp.x, samples
131+
return flat_x, flat_f
132132
end

test/deprecations.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
gp = f(x, 0.1)
55

66
plt = @test_deprecated sampleplot(gp, 10)
7-
@test plt.n == 10
7+
@test plt.n == 1
88

99
@test_deprecated sampleplot!(gp, 4)
10-
@test plt.n == 14
10+
@test plt.n == 2
1111

1212
@test_deprecated sampleplot!(Plots.current(), gp, 3)
13-
@test plt.n == 17
13+
@test plt.n == 3
1414
end

test/util/plotting.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66
z = rand(10)
77
plt1 = sampleplot(z, gp)
88
@test plt1.n == 1
9-
@test plt1.series_list[1].plotattributes[:x] == sort(z)
9+
@test isequal(plt1.series_list[1].plotattributes[:x], vcat(z, NaN))
1010

11-
plt2 = sampleplot(gp; samples=10)
12-
@test plt2.n == 10
13-
sort_x = sort(x)
14-
@test all(series.plotattributes[:x] == sort_x for series in plt2.series_list)
11+
plt2 = sampleplot(gp; samples=3)
12+
@test plt2.n == 1
13+
plt2_x = plt2.series_list[1].plotattributes[:x]
14+
plt2_y = plt2.series_list[1].plotattributes[:y]
15+
@test isequal(plt2_x, vcat(x, NaN, x, NaN, x, NaN))
16+
@test length(plt2_y) == length(plt2_x)
17+
@test isnan(plt2_y[length(z) + 1]) && isnan(plt2_y[2length(z) + 2])
1518

16-
z = rand(7)
17-
plt3 = sampleplot(z, f; samples=8)
18-
@test plt3.n == 8
19-
sort_z = sort(z)
20-
@test all(series.plotattributes[:x] == sort_z for series in plt3.series_list)
19+
z3 = rand(7)
20+
plt3 = sampleplot(z3, f; samples=2)
21+
@test plt3.n == 1
22+
@test isequal(plt3.series_list[1].plotattributes[:x], vcat(z3, NaN, z3, NaN))
2123

2224
# Check recipe dispatches for `FiniteGP`s
2325
rec = RecipesBase.apply_recipe(Dict{Symbol,Any}(), gp)

0 commit comments

Comments
 (0)