Skip to content

Commit a872d7c

Browse files
tlienartablaom
andauthored
Fixes for quantile regression (#148)
* fix a doc typo * fixes following discussion around #147 --------- Co-authored-by: Anthony D. Blaom <[email protected]>
1 parent d310ff8 commit a872d7c

File tree

13 files changed

+127
-39
lines changed

13 files changed

+127
-39
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ jobs:
4141
${{ runner.os }}-
4242
- uses: julia-actions/julia-buildpkg@v1
4343
- uses: julia-actions/julia-runtest@v1
44+
env:
45+
RUN_COMPARISONS: "false"
4446
- uses: julia-actions/julia-processcoverage@v1
4547
- uses: codecov/codecov-action@v1
4648
with:

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJLinearModels"
22
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
33
authors = ["Thibaut Lienart <[email protected]>"]
4-
version = "0.9.1"
4+
version = "0.9.2"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/fit/solvers.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@ Newton solver. This is a full Hessian solver and should be avoided for
4646
`newton_options` are the [options of Newton's method](https://julianlsolvers.github.io/Optim.jl/stable/#algo/newton/)
4747
4848
## Example
49+
4950
```julia
5051
using MLJLinearModels, Optim
5152
52-
solver = MLJLinearModels.Newton(optim_options = Optim.Options(time_limit = 20),
53-
newton_options = (linesearch = Optim.LineSearches.HagerZhang()),))
53+
solver = MLJLinearModels.Newton(
54+
optim_options = Optim.Options(time_limit = 20),
55+
newton_options = (linesearch = Optim.LineSearches.HagerZhang()),)
56+
)
5457
```
5558
"""
5659
@with_kw struct Newton{O,S} <: Solver
@@ -70,13 +73,15 @@ generally be preferred for larger scale cases.
7073
`newtoncg_options` are the [options of Krylov Trust Region method](https://github.com/JuliaNLSolvers/Optim.jl/blob/master/src/multivariate/solvers/second_order/krylov_trust_region.jl)
7174
7275
## Example
76+
7377
```julia
7478
using MLJLinearModels, Optim
7579
76-
solver = MLJLinearModels.Newton(optim_options = Optim.Options(time_limit = 20),
77-
newtoncg_options = (eta = 0.2,))
80+
solver = MLJLinearModels.NewtonCG(
81+
optim_options = Optim.Options(time_limit = 20),
82+
newtoncg_options = (eta = 0.2,)
83+
)
7884
```
79-
8085
"""
8186
@with_kw struct NewtonCG{O,S} <: Solver
8287
optim_options::O = Optim.Options(f_tol=1e-4)
@@ -95,8 +100,10 @@ LBFGS quasi-Newton solver. See [the wikipedia entry](https://en.wikipedia.org/wi
95100
```julia
96101
using MLJLinearModels, Optim
97102
98-
solver = MLJLinearModels.Newton(optim_options = Optim.Options(time_limit = 20),
99-
lbfgs_options = (linesearch = Optim.LineSearches.HagerZhang()),))
103+
solver = MLJLinearModels.LBFGS(
104+
optim_options = Optim.Options(time_limit = 20),
105+
lbfgs_options = (linesearch = Optim.LineSearches.HagerZhang()),)
106+
)
100107
```
101108
"""
102109
@with_kw struct LBFGS{O,S} <: Solver

src/loss-penalty/robust.jl

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,46 @@ export RobustLoss,
33
BisquareRho, Bisquare, LogisticRho, Logistic,
44
FairRho, Fair, TalwarRho, Talwar, QuantileRho, Quantile
55

6+
#=
7+
In the non-penalised case:
8+
9+
β⋆ = arg min ∑ ρ(yᵢ - ⟨xᵢ, β⟩)
10+
11+
where ρ is a weighing function such as, for instance, the pinball loss for
12+
the quantile regression.
13+
14+
It is useful to define the following quantities:
15+
16+
ψ(r) = ρ'(r) (first derivative)
17+
ϕ(r) = ψ'(r) (second derivative)
18+
ω(r) = ψ(r)/r (weighing function used for IWLS), a threshold can be passed
19+
to clip weights
20+
21+
Some refs:
22+
- https://josephsalmon.eu/enseignement/UW/STAT593/QuantileRegression.pdf
23+
=#
24+
625
abstract type RobustRho end
726

8-
abstract type RobustRho1P{δ} <: RobustRho end # one parameter
27+
# robust rho with only one parameter
28+
abstract type RobustRho1P{δ} <: RobustRho end
929

1030
struct RobustLoss{ρ} <: AtomicLoss where ρ <: RobustRho
1131
rho:
1232
end
1333

14-
(rl::RobustLoss)(x::AVR, y::AVR) = rl(x .- y)
15-
(rl::RobustLoss)(r::AVR) = rl.rho(r)
34+
(rl::RobustLoss)(::AVR, y::AVR) = rl(y .- )
35+
(rl::RobustLoss)(r::AVR) = rl.rho(r)
1636

17-
# ψ(r) = ρ'(r) (first derivative)
18-
# ω(r) = ψ(r)/r (weighing function) a thresh can be passed to clip weights
19-
# ϕ(r) = ψ'(r) (second derivative)
2037

2138
"""
2239
$TYPEDEF
2340
2441
Huber weighing of the residuals corresponding to
2542
2643
``ρ(z) = z²/2`` if `|z|≤δ` and `ρ(z)=δ(|z|-δ/2)` otherwise.
44+
45+
Note: symmetric weighing.
2746
"""
2847
struct HuberRho{δ} <: RobustRho1P{δ}
2948
HuberRho::Real=1.0; delta::Real=δ) = new{delta}()
@@ -33,7 +52,7 @@ Huber(δ::Real=1.0; delta::Real=δ) = HuberRho(delta)
3352
(::HuberRho{δ})(r::AVR) where δ = begin
3453
ar = abs.(r)
3554
w = ar .<= δ
36-
return sum( r.^2/2 .* w .+ δ .* (ar .- δ/2) .* .!w )
55+
return sum( @. ifelse(w, r^2/2, δ * (ar - δ/2) ) )
3756
end
3857

3958
ψ(::Type{HuberRho{δ}} ) where δ = (r, w) -> r * w + δ * sign(r) * (1.0 - w)
@@ -47,6 +66,8 @@ $TYPEDEF
4766
Andrews weighing of the residuals corresponding to
4867
4968
``ρ(z) = -cos(πz/δ)/(π/δ)²`` if `|z|≤δ` and `ρ(δ)` otherwise.
69+
70+
Note: symmetric weighing.
5071
"""
5172
struct AndrewsRho{δ} <: RobustRho1P{δ}
5273
AndrewsRho::Real=1.0; delta::Real=δ) = new{delta}()
@@ -58,7 +79,7 @@ Andrews(δ::Real=1.0; delta::Real=δ) = AndrewsRho(delta)
5879
w = ar .<= δ
5980
c = π/δ
6081
κ =/π)^2
61-
return sum( -cos.(c .* r) .* κ .* w .+ κ .* .!w )
82+
return sum( @. ifelse(w, -cos(c * r) * κ, κ) )
6283
end
6384

6485
# Note, sinc(x) = sin(πx)/πx, well defined everywhere
@@ -74,6 +95,8 @@ $TYPEDEF
7495
Bisquare weighing of the residuals corresponding to
7596
7697
``ρ(z) = δ²/6 (1-(1-(z/δ)²)³)`` if `|z|≤δ` and `δ²/6` otherwise.
98+
99+
Note: symmetric weighing.
77100
"""
78101
struct BisquareRho{δ} <: RobustRho1P{δ}
79102
BisquareRho::Real=1.0; delta::Real=δ) = new{delta}()
@@ -84,7 +107,7 @@ Bisquare(δ::Real=1.0; delta::Real=δ) = BisquareRho(delta)
84107
ar = abs.(r)
85108
w = ar .<= δ
86109
κ = δ^2/6
87-
return sum( κ * (1.0 .- (1.0 .- (r ./ δ).^2).^3) .* w + κ .* .!w )
110+
return sum( @. ifelse(w, κ * (1 - (1 - (r / δ)^2)^3), κ) )
88111
end
89112

90113
ψ(::Type{BisquareRho{δ}} ) where δ = (r, w) -> w * r * (1.0 - (r / δ)^2)^2
@@ -97,14 +120,16 @@ $TYPEDEF
97120
Logistic weighing of the residuals corresponding to
98121
99122
``ρ(z) = δ² log(cosh(z/δ))``
123+
124+
Note: symmetric weighing.
100125
"""
101126
struct LogisticRho{δ} <: RobustRho1P{δ}
102127
LogisticRho::Real=1.0; delta::Real=δ) = new{delta}()
103128
end
104129
Logistic::Real=1.0; delta::Real=δ) = LogisticRho(delta)
105130

106131
(::LogisticRho{δ})(r::AVR) where δ = begin
107-
return sum( δ^2 .* log.(cosh.(r ./ δ)) )
132+
return sum( @. δ^2 * log(cosh(r / δ)) )
108133
end
109134

110135
# similar to sinc, to avoid NaNs if tanh(0)/0 (lim is 1.0)
@@ -121,15 +146,17 @@ $TYPEDEF
121146
Fair weighing of the residuals corresponding to
122147
123148
``ρ(z) = δ² (|z|/δ - log(1+|z|/δ))``
149+
150+
Note: symmetric weighing.
124151
"""
125152
struct FairRho{δ} <: RobustRho1P{δ}
126153
FairRho::Real=1.0; delta::Real=δ) = new{delta}()
127154
end
128155
Fair::Real=1.0; delta::Real=δ) = FairRho(delta)
129156

130157
(::FairRho{δ})(r::AVR) where δ = begin
131-
sr = abs.(r) ./ δ
132-
return sum( δ^2 .* (sr .- log1p.(sr)) )
158+
sr = @. abs(r) / δ
159+
return sum( @. δ^2 * (sr - log1p(sr)) )
133160
end
134161

135162
ψ(::Type{FairRho{δ}} ) where δ = (r, _) -> δ * r / (abs(r) + δ)
@@ -143,15 +170,17 @@ $TYPEDEF
143170
Talwar weighing of the residuals corresponding to
144171
145172
``ρ(z) = z²/2`` if `|z|≤δ` and `ρ(z)=ρ(δ)` otherwise.
173+
174+
Note: symmetric weighing.
146175
"""
147176
struct TalwarRho{δ} <: RobustRho1P{δ}
148177
TalwarRho::Real=1.0; delta::Real=δ) = new{delta}()
149178
end
150179
Talwar::Real=1.0; delta::Real=δ) = TalwarRho(delta)
151180

152181
(::TalwarRho{δ})(r::AVR) where δ = begin
153-
w = abs.(r) .<= δ
154-
return sum( r.^2 ./ 2 .* w .+ δ^2/2 .* .!w)
182+
w = @. abs(r) <= δ
183+
return sum( @. ifelse(w, r^2 / 2, δ^2/2) )
155184
end
156185

157186
ψ(::Type{TalwarRho{δ}} ) where δ = (r, w) -> w * r
@@ -164,7 +193,11 @@ $TYPEDEF
164193
165194
Quantile regression weighing of the residuals corresponding to
166195
167-
``ρ(z) = z(δ - 1(z<0))``
196+
``ρ(z) = -z(δ - 1(z>=0))``
197+
198+
Note: asymetric weighing, the "-" sign is because similar libraries like
199+
quantreg for instance define the residual as `y-Xθ` while we do the opposite
200+
(out of convenience for gradients etc).
168201
"""
169202
struct QuantileRho{δ} <: RobustRho1P{δ}
170203
QuantileRho::Real=1.0; delta::Real=δ) = new{delta}()
@@ -173,9 +206,9 @@ end
173206
Quantile::Real=1.0; delta::Real=δ) = QuantileRho(delta)
174207

175208
(::QuantileRho{δ})(r::AVR) where δ = begin
176-
return sum( r .*.- (r .<= 0.0)) )
209+
return sum( @. -r *- (r >= 0)) )
177210
end
178211

179-
ψ(::Type{QuantileRho{δ}} ) where δ = (r, _) -> (δ - (r <= 0.0))
180-
ω(::Type{QuantileRho{δ}}, τ) where δ = (r, _) -> (δ - (r <= 0.0)) / clip(r, τ)
212+
ψ(::Type{QuantileRho{δ}} ) where δ = (r, _) -> ((r >= 0.0) - δ)
213+
ω(::Type{QuantileRho{δ}}, τ) where δ = (r, _) -> ((r >= 0.0) - δ) / clip(-r, τ)
181214
ϕ(::Type{QuantileRho{δ}} ) where δ = (_, _) -> error("Newton(CG) not available for Quantile Reg.")

src/mlj/regressors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ See also [`ElasticNetRegressor`](@ref).
104104
"whether to scale the penalty with the number of observations."
105105
scale_penalty_with_samples::Bool = true
106106
"""any instance of `MLJLinearModels.Analytical`. Use `Analytical()` for
107-
Cholesky and `CG()=Analytical(iteration=true)` for conjugate-gradient.
107+
Cholesky and `CG()=Analytical(iterative=true)` for conjugate-gradient.
108108
If `solver = nothing` (default) then `Analytical()` is used. """
109109
solver::Option{Solver} = nothing
110110
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
23
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
34
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
45
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

test/benchmarks/elementary_functions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using MLJLinearModels
22
using BenchmarkTools, Random, LinearAlgebra
3-
DO_COMPARISONS = false; include("../testutils.jl")
3+
include("../testutils.jl")
44

55
n, p = 50_000, 500
66
((X, y, θ), (X1, y1, θ1)) = generate_continuous(n, p; seed=512, sparse=0.5)

test/benchmarks/logistic-multinomial.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using MLJLinearModels
22
using BenchmarkTools, Random, LinearAlgebra
3-
DO_COMPARISONS = false; include("../testutils.jl")
3+
include("../testutils.jl")
44

55
n, p = 50_000, 500
66
((X, y, θ), (X1, y1, θ1)) = generate_binary(n, p; seed=525)

test/benchmarks/ridge-lasso.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using MLJLinearModels, StableRNGs
22
using BenchmarkTools, Random, LinearAlgebra
3-
DO_COMPARISONS = false; include("../testutils.jl")
3+
include("../testutils.jl")
44

55
n, p = 50_000, 500
66
((X, y, θ), (X1, y1, θ1)) = generate_continuous(n, p; seed=512, sparse=0.5)

test/benchmarks/robust.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Follow up from issue #147 comparing quantreg more specifically.
2+
3+
if DO_COMPARISONS
4+
@testset "Comp-QR-147" begin
5+
using CSV, DataFrames
6+
7+
dataset = CSV.read(download("http://freakonometrics.free.fr/rent98_00.txt"), DataFrame)
8+
tau = 0.3
9+
10+
y = Vector(dataset[!,:rent_euro])
11+
X = Matrix(dataset[!,[:area, :yearc]])
12+
X1 = hcat(X[:,1], X[:, 2], ones(size(X, 1)))
13+
14+
qr = QuantileRegression(tau; penalty=:none)
15+
obj = objective(qr, X, y)
16+
17+
θ_lbfgs = fit(qr, X, y)
18+
@test isapprox(obj(θ_lbfgs), 226_639, rtol=1e-4)
19+
20+
# in this case QR with BR method does better
21+
θ_qr_br = rcopy(getproperty(QUANTREG, :rq_fit_br)(X1, y; tau=tau))[:coefficients]
22+
@test isapprox(obj(θ_qr_br), 207_551, rtol=1e-4)
23+
24+
# lasso doesn't
25+
θ_qr_lasso = rcopy(getproperty(QUANTREG, :rq_fit_lasso)(X1, y; tau=tau))[:coefficients]
26+
obj(θ_qr_lasso) # 229_172
27+
@test 228_000 obj(θ_qr_lasso) 231_000
28+
end
29+
end

test/fit/quantile.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,35 @@ y1a = outlify(y1, 0.1)
4646
θ_ls = fit(LinearRegression(), X, y1a)
4747
θ_lbfgs = fit(rr, X, y1a, solver=LBFGS())
4848
θ_iwls = fit(rr, X, y1a, solver=IWLSCG())
49-
θ_qr_br = rcopy(QUANTREG.rq_fit_br(X1, y1a))[:coefficients]
50-
θ_qr_fnb = rcopy(QUANTREG.rq_fit_fnb(X1, y1a))[:coefficients]
49+
θ_qr_br = rcopy(QUANTREG.rq_fit_br(X1, y1a, tau=δ))[:coefficients]
50+
θ_qr_fnb = rcopy(QUANTREG.rq_fit_fnb(X1, y1a, tau=δ))[:coefficients]
5151
# NOTE: we take θ_qr_br as reference point
5252
@test isapprox(J(θ_ls), 505.45286, rtol=1e-5)
5353
@test J(θ_qr_br) 409.570777 # <- ref value
5454
# Their IP algorithm essentially gives the same answer
5555
@test (J(θ_qr_fnb) - J(θ_qr_br)) 1e-10
5656
# Our algorithms are close enough
57-
@test isapprox(J(θ_lbfgs), 409.57154, rtol=1e-5)
57+
@test isapprox(J(θ_lbfgs), 409.57608, rtol=1e-5)
5858
@test isapprox(J(θ_iwls), 409.59, rtol=1e-4)
59+
60+
# Let's try this again but with a δ different from 0.5
61+
δ = 0.75
62+
rr = QuantileRegression(δ, lambda=0)
63+
J = objective(rr, X, y1a)
64+
65+
θ_lbfgs = fit(rr, X, y1a, solver=LBFGS())
66+
θ_qr_br = rcopy(QUANTREG.rq_fit_br(X1, y1a, tau=δ))[:coefficients]
67+
68+
@test isapprox(J(θ_qr_br), 404.6161, rtol=1e-4)
69+
@test isapprox(J(θ_lbfgs), 404.6195, rtol=1e-4)
5970
end
6071
end
6172

6273
###########################
6374
## With Sparsity penalty ##
6475
###########################
6576

66-
n, p = 500, 100
77+
Jn, p = 500, 100
6778
((X, y, θ), (X1, y1, θ1)) = generate_continuous(n, p; seed=51112, sparse=0.1)
6879
# pepper with outliers
6980
y1a = outlify(y1, 0.1)
@@ -90,9 +101,9 @@ y1a = outlify(y1, 0.1)
90101
θ_ls = X1 \ y1a
91102
θ_fista = fit(rr, X, y1a, solver=FISTA())
92103
θ_ista = fit(rr, X, y1a, solver=ISTA())
93-
θ_qr_lasso = rcopy(QUANTREG.rq_fit_lasso(X1, y1a))[:coefficients]
104+
θ_qr_lasso = rcopy(QUANTREG.rq_fit_lasso(X1, y1a, lambda=λ))[:coefficients]
94105
@test isapprox(J(θ_ls), 888.3748, rtol=1e-5)
95-
@test isapprox(J(θ_qr_lasso), 425.5, rtol=1e-3)
106+
@test isapprox(J(θ_qr_lasso), 425.2, rtol=1e-3)
96107
# Our algorithms are close enough
97108
@test isapprox(J(θ_fista), 425.0526, rtol=1e-5)
98109
@test isapprox(J(θ_ista), 425.4113, rtol=1e-5)
@@ -101,6 +112,6 @@ y1a = outlify(y1, 0.1)
101112
@test nnz(θ_fista) == 88
102113
@test nnz(θ_ista) == 82
103114
# in this case fista is best
104-
@test J(θ_fista) < J(θ_ista) < J(θ_qr_lasso)
115+
@test J(θ_fista) < J(θ_qr_lasso) < J(θ_ista)
105116
end
106117
end

test/loss-penalty/robust.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ y = randn(rng, 10)
55
r = x .- y
66

77
@testset "Robust Loss" begin
8-
δ = 0.5
8+
δ = 0.75
99
rlδ = RobustLoss(Huber(δ))
1010
@test rlδ isa RobustLoss{HuberRho{δ}}
1111
@test rlδ(r) == rlδ(x, y) == sum(ifelse(abs(rᵢ)δ, rᵢ^2/2, δ*(abs(rᵢ)-δ/2)) for rᵢ in r)

0 commit comments

Comments
 (0)