Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
138 commits
Select commit Hold shift + click to select a range
d1a3f2c
done for feature generator
pat-alt Nov 8, 2023
5a20ed2
formatting
pat-alt Nov 8, 2023
f1b9f44
Simplify methods
RaunoArike Nov 23, 2023
1dc1876
Move the variance kwarg
RaunoArike Nov 24, 2023
961f2fe
Small fixes
RaunoArike Nov 24, 2023
1678158
Merge branch 'main' into 358-simplify-the-core-strct
RaunoArike Nov 24, 2023
2c3ca7d
Move files out of generate_counterfactual
RaunoArike Nov 24, 2023
906fcce
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Nov 24, 2023
abe3fb0
Add further improvements
RaunoArike Nov 25, 2023
7299479
Small fixes
RaunoArike Nov 25, 2023
d6dcc90
Merge branch 'main' into 343-overload-the-update-method-preferred-if-…
RaunoArike Nov 25, 2023
463b7d7
Minor fixes
RaunoArike Nov 25, 2023
4fbf2c8
Small fixes
RaunoArike Nov 25, 2023
3cd5ae4
Create structs for convergence types
RaunoArike Nov 27, 2023
e516cc6
Fix formatting
RaunoArike Nov 27, 2023
24b6e79
Minor fix
RaunoArike Nov 27, 2023
59400a2
Merge branch 'main' into 358-simplify-the-core-strct
RaunoArike Nov 27, 2023
cef0524
Fix benchmarking of feature tweak
RaunoArike Nov 27, 2023
d069ffa
Add error test for feature tweak
RaunoArike Nov 27, 2023
7511d68
Rewrite convergence logic
RaunoArike Nov 28, 2023
6bbc668
Small fix
RaunoArike Nov 28, 2023
9638643
Fix bugs
RaunoArike Nov 28, 2023
9a6684a
Update src/convergence/invalidation_rate.jl
RaunoArike Nov 28, 2023
cf9e2c6
Update src/convergence/decision_threshold.jl
RaunoArike Nov 28, 2023
df71519
Minor changes
RaunoArike Nov 28, 2023
403b9a7
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Nov 28, 2023
0d649e3
Merge branch 'main' into 343-overload-the-update-method-preferred-if-…
RaunoArike Nov 28, 2023
93184bf
Minor fix
RaunoArike Nov 28, 2023
1d798d6
Bugfix
RaunoArike Nov 29, 2023
af70576
Merge branch '343-overload-the-update-method-preferred-if-feasible' i…
RaunoArike Nov 29, 2023
2a1fe2a
Small fix
RaunoArike Nov 29, 2023
9e44507
Small fixes
RaunoArike Nov 29, 2023
7f4e086
Small fixes
RaunoArike Nov 29, 2023
73f5161
Merge branch '343-overload-the-update-method-preferred-if-feasible' i…
RaunoArike Nov 29, 2023
491f086
Fix failing tests
RaunoArike Nov 29, 2023
618f89e
Fix formatting
RaunoArike Nov 29, 2023
40cca87
Merge branch 'main' into 358-simplify-the-core-strct
RaunoArike Nov 29, 2023
1a717d2
Minor fixes
RaunoArike Nov 30, 2023
e6f234c
Small improvements
RaunoArike Nov 30, 2023
45d6a99
Update src/convergence/invalidation_rate.jl
RaunoArike Nov 30, 2023
6a42b10
Update src/generators/gradient_based/loss.jl
RaunoArike Nov 30, 2023
b6e58f5
Minor additions
RaunoArike Nov 30, 2023
36227c2
Resolve circular dependency
RaunoArike Nov 30, 2023
e99f69c
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Nov 30, 2023
49fa6b2
Fix failing tests
RaunoArike Nov 30, 2023
9be1220
Update test/generators/probe.jl
RaunoArike Dec 1, 2023
1dfacf0
Add missing method
RaunoArike Dec 1, 2023
a47daff
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Dec 1, 2023
2648df3
Small fix
RaunoArike Dec 1, 2023
5a84574
Small fix
RaunoArike Dec 1, 2023
08cd3c4
Merge branch 'main' into 358-simplify-the-core-strct
RaunoArike Dec 4, 2023
bf75071
Remove redundant generate_perturbations method
RaunoArike Dec 5, 2023
cb37729
small addition to doc string
pat-alt Dec 6, 2023
1f0f904
Minor fixes
RaunoArike Dec 6, 2023
0f2560e
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Dec 6, 2023
436baea
maybe this way?
pat-alt Dec 6, 2023
5add2a2
go
pat-alt Dec 6, 2023
27433ea
huh
pat-alt Dec 6, 2023
bd557ea
tryin to errors
pat-alt Dec 6, 2023
b77e8a8
Update test/generators/probe.jl
RaunoArike Dec 6, 2023
e6fcaa0
Update src/generators/gradient_based/probe.jl
RaunoArike Dec 6, 2023
14efbf0
error with probe
pat-alt Dec 6, 2023
1adddb7
Fix comment
RaunoArike Dec 6, 2023
b84a381
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Dec 6, 2023
db2962b
uff
pat-alt Dec 6, 2023
3ba24fb
returning warning in case InvalidationRateConvergence not used with P…
pat-alt Dec 6, 2023
aad6c08
bleh
pat-alt Dec 6, 2023
36607e5
tryin with a try catch
pat-alt Dec 7, 2023
a1d4066
trying again
pat-alt Dec 7, 2023
cde70ae
namespacing issue
pat-alt Dec 7, 2023
e7835da
another attempt without try/catch
pat-alt Dec 7, 2023
ba74aca
small fix
pat-alt Dec 7, 2023
f1539cb
hm
pat-alt Dec 7, 2023
3c5cc82
uh
pat-alt Dec 7, 2023
6d0c869
uh
pat-alt Dec 7, 2023
bb79e6d
s
pat-alt Dec 7, 2023
4efb266
bam
pat-alt Dec 7, 2023
ab91387
come on
pat-alt Dec 7, 2023
0150597
omg
pat-alt Dec 7, 2023
df3e2f2
bloody hell
pat-alt Dec 7, 2023
f8b72de
shoot me aljeblieft
pat-alt Dec 7, 2023
ff833ee
as if
pat-alt Dec 7, 2023
10f94dd
now pls
pat-alt Dec 7, 2023
addbbec
maybe this way?
pat-alt Dec 6, 2023
80f23d2
go
pat-alt Dec 6, 2023
1abc592
huh
pat-alt Dec 6, 2023
ed30593
tryin to errors
pat-alt Dec 6, 2023
a41e1ce
Update test/generators/probe.jl
RaunoArike Dec 6, 2023
893ff67
Update src/generators/gradient_based/probe.jl
RaunoArike Dec 6, 2023
96968ae
error with probe
pat-alt Dec 6, 2023
a7a1211
uff
pat-alt Dec 6, 2023
b4986c7
returning warning in case InvalidationRateConvergence not used with P…
pat-alt Dec 6, 2023
338167d
bleh
pat-alt Dec 6, 2023
6c23034
tryin with a try catch
pat-alt Dec 7, 2023
d95841c
trying again
pat-alt Dec 7, 2023
fdd4759
namespacing issue
pat-alt Dec 7, 2023
4344cb4
another attempt without try/catch
pat-alt Dec 7, 2023
e86c613
small fix
pat-alt Dec 7, 2023
c1ef286
hm
pat-alt Dec 7, 2023
612fb94
uh
pat-alt Dec 7, 2023
16844b7
uh
pat-alt Dec 7, 2023
c61765d
s
pat-alt Dec 7, 2023
bd0f42c
bam
pat-alt Dec 7, 2023
56d7279
come on
pat-alt Dec 7, 2023
a8c8139
omg
pat-alt Dec 7, 2023
903033e
bloody hell
pat-alt Dec 7, 2023
a07455d
shoot me aljeblieft
pat-alt Dec 7, 2023
d07cdad
as if
pat-alt Dec 7, 2023
f58db86
now pls
pat-alt Dec 7, 2023
196004b
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
pat-alt Dec 8, 2023
c159f1a
Revert "now pls"
pat-alt Dec 8, 2023
e7dcc80
Revert "as if"
pat-alt Dec 8, 2023
338c360
Revert "shoot me aljeblieft"
pat-alt Dec 8, 2023
6d9234c
Revert "bloody hell"
pat-alt Dec 8, 2023
a8ac346
Revert "omg"
pat-alt Dec 8, 2023
ea14380
Revert "come on"
pat-alt Dec 8, 2023
83a6789
Revert "bam"
pat-alt Dec 8, 2023
785520f
Revert "s"
pat-alt Dec 8, 2023
bcab248
Revert "uh"
pat-alt Dec 8, 2023
f2263f3
Revert "uh"
pat-alt Dec 8, 2023
86a8ee5
Revert "hm"
pat-alt Dec 8, 2023
a303056
Revert "small fix"
pat-alt Dec 8, 2023
93fefc9
Revert "another attempt without try/catch"
pat-alt Dec 8, 2023
9af1968
Revert "namespacing issue"
pat-alt Dec 8, 2023
1e248b1
Revert "trying again"
pat-alt Dec 8, 2023
cc3ceb6
Revert "tryin with a try catch"
pat-alt Dec 8, 2023
bf4f6fa
Revert "bleh"
pat-alt Dec 8, 2023
263b886
Revert "returning warning in case InvalidationRateConvergence not use…
pat-alt Dec 8, 2023
a4b5945
Revert "uff"
pat-alt Dec 8, 2023
83a3500
Revert "error with probe"
pat-alt Dec 8, 2023
ae3c599
Revert "Update src/generators/gradient_based/probe.jl"
pat-alt Dec 8, 2023
e2112c4
Revert "Update test/generators/probe.jl"
pat-alt Dec 8, 2023
99e71f9
Revert "tryin to errors"
pat-alt Dec 8, 2023
e3b84b7
Revert "huh"
pat-alt Dec 8, 2023
c8b68ea
Revert "go"
pat-alt Dec 8, 2023
7e1adc0
Revert "maybe this way?"
pat-alt Dec 8, 2023
0067cf5
unmerged file
pat-alt Dec 8, 2023
661fa7e
Revert "unmerged file"
pat-alt Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/CounterfactualExplanations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ export Generator, @objective, @threshold
###
include("counterfactuals/Counterfactuals.jl")
export CounterfactualExplanation
export generate_counterfactual
export initialize!, update!
export total_steps, converged, terminated, path, target_probs
export animate_path
Expand All @@ -93,9 +94,6 @@ export animate_path
include("data/Data.jl")
using .Data

include("generate_counterfactual/generate_counterfactual.jl")
export generate_counterfactual

include("evaluation/Evaluation.jl")
using .Evaluation

Expand Down
3 changes: 3 additions & 0 deletions src/counterfactuals/Counterfactuals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ using .Models
using ChainRulesCore
using Flux
using MLUtils
using MultivariateStats
using Statistics
using StatsBase

include("core_struct.jl")
include("convergence.jl")
include("encodings.jl")
include("generate_counterfactual.jl")
include("info_extraction.jl")
include("initialisation.jl")
include("latent_space_mappings.jl")
include("path_tracking.jl")
include("printing.jl")
include("search.jl")
include("utils.jl")
include("vectorised.jl")
87 changes: 57 additions & 30 deletions src/counterfactuals/convergence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,73 @@
A convenience method to determine if the counterfactual search has terminated.
"""
function terminated(ce::CounterfactualExplanation)
if ce.M isa Models.TreeModel
return in_target_class(ce)
end
return converged(ce) || steps_exhausted(ce)
end

"""
in_target_class(ce::CounterfactualExplanation)
converged(ce::CounterfactualExplanation)

Check if the counterfactual is in the target class.
A convenience method to determine if the counterfactual search has converged.
The search is considered to have converged only if the counterfactual is valid.
"""
function in_target_class(ce::CounterfactualExplanation)
return Models.predict_label(ce.M, ce.data, decode_state(ce))[1] == ce.target
function converged(ce::CounterfactualExplanation)
return converged(ce, Val(ce.convergence[:converge_when]))
end

"""
converged(ce::CounterfactualExplanation)
converged(ce::CounterfactualExplanation, ::Val{:decision_threshold})

A convenience method to determine if the counterfactual search has converged. The search is considered to have converged only if the counterfactual is valid.
Checks if the counterfactual search has converged when the convergence criterion is the decision threshold.
"""
function converged(ce::CounterfactualExplanation)
if ce.generator isa GrowingSpheresGenerator || ce.generator isa FeatureTweakGenerator
conv = ce.search[:converged]
elseif ce.convergence[:converge_when] == :decision_threshold
conv = threshold_reached(ce)
elseif ce.convergence[:converge_when] == :generator_conditions
conv = threshold_reached(ce) && Generators.conditions_satisfied(ce.generator, ce)
elseif ce.convergence[:converge_when] == :max_iter
conv = false
elseif ce.convergence[:converge_when] == :invalidation_rate
ir = Generators.invalidation_rate(ce)
# gets the label from an array, not sure why it is an array though.
label = predict_label(ce.M, ce.data, ce.x′)[1]
conv = label == ce.target && ce.params[:invalidation_rate] > ir
elseif (ce.convergence[:converge_when] == :early_stopping)
conv = steps_exhausted(ce)
else
@error "Convergence criterion not recognized."
end

return conv
function converged(ce::CounterfactualExplanation, ::Val{:decision_threshold})
return threshold_reached(ce)
end

"""
converged(ce::CounterfactualExplanation, ::Val{:generator_conditions})

Checks if the counterfactual search has converged when the convergence criterion is generator_conditions.
"""
function converged(ce::CounterfactualExplanation, ::Val{:generator_conditions})
return threshold_reached(ce) && Generators.conditions_satisfied(ce.generator, ce)
end

"""
converged(ce::CounterfactualExplanation, ::Val{:max_iter})

Checks if the counterfactual search has converged when the convergence criterion is maximum iterations.
"""
function converged(ce::CounterfactualExplanation, ::Val{:max_iter})
return false
end

"""
converged(ce::CounterfactualExplanation, ::Val{:invalidation_rate})

Checks if the counterfactual search has converged when the convergence criterion is invalidation rate.
"""
function converged(ce::CounterfactualExplanation, ::Val{:invalidation_rate})
ir = Generators.invalidation_rate(ce)
label = predict_label(ce.M, ce.data, ce.x′)[1]
return label == ce.target && ce.generator.invalidation_rate > ir
end

"""
converged(ce::CounterfactualExplanation, ::Val{:early_stopping})

Checks if the counterfactual search has converged when the convergence criterion is early stopping.
"""
function converged(ce::CounterfactualExplanation, ::Val{:early_stopping})
return steps_exhausted(ce)
end

"""
converged(ce::CounterfactualExplanation, ::Val{sym})

Throws an error when the `converged()` method is called on an unrecognized convergence criterion.
"""
function converged(ce::CounterfactualExplanation, ::Val{sym}) where {sym}
@error "Convergence criterion not recognized: $sym"
end

"""
Expand All @@ -64,6 +90,7 @@ end
A convenience method that determines if the predefined threshold for the target class probability has been reached for a specific sample `x`.
"""
function threshold_reached(ce::CounterfactualExplanation, x::AbstractArray)
print(ce.convergence[:min_success_rate])
γ = ce.convergence[:decision_threshold]
success_rate = sum(target_probs(ce, x) .>= γ) / ce.num_counterfactuals
return success_rate > ce.convergence[:min_success_rate]
Expand Down
15 changes: 2 additions & 13 deletions src/counterfactuals/core_struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ mutable struct CounterfactualExplanation <: AbstractCounterfactualExplanation
data::DataPreprocessing.CounterfactualData
M::Models.AbstractFittedModel
generator::Generators.AbstractGenerator
generative_model_params::NamedTuple
params::Dict
search::Union{Dict,Nothing}
convergence::Dict
Expand All @@ -28,12 +27,8 @@ end
max_iter::Int = 100,
num_counterfactuals::Int = 1,
initialization::Symbol = :add_perturbation,
generative_model_params::NamedTuple = (;),
min_success_rate::AbstractFloat=0.99,
converge_when::Symbol=:decision_threshold,
invalidation_rate::AbstractFloat=0.5,
learning_rate::AbstractFloat=1.0,
variance::AbstractFloat=0.01,
)

Outer method to construct a `CounterfactualExplanation` structure.
Expand All @@ -46,15 +41,11 @@ function CounterfactualExplanation(
generator::Generators.AbstractGenerator;
num_counterfactuals::Int=1,
initialization::Symbol=:add_perturbation,
generative_model_params::NamedTuple=(;),
max_iter::Int=100,
decision_threshold::AbstractFloat=0.5,
gradient_tol::AbstractFloat=parameters[:τ],
min_success_rate::AbstractFloat=parameters[:min_success_rate],
converge_when::Symbol=:decision_threshold,
invalidation_rate::AbstractFloat=0.5,
learning_rate::AbstractFloat=1.0,
variance::AbstractFloat=0.01,
)

# Assertions:
Expand All @@ -67,6 +58,8 @@ function CounterfactualExplanation(
:invalidation_rate,
:early_stopping,
]
@assert !(converge_when == :invalidation_rate && isnothing(generator.invalidation_rate)) "The convergence criterion is invalidation rate but no invalidation rate has been provided."
@assert !(converge_when == :invalidation_rate && isnothing(generator.variance)) "The convergence criterion is invalidation rate but no variance has been provided."

# Factual:
x = typeof(x) == Int ? select_factual(data, x) : x
Expand All @@ -79,9 +72,6 @@ function CounterfactualExplanation(
:mutability => DataPreprocessing.mutability_constraints(data),
:latent_space => generator.latent_space,
:dim_reduction => generator.dim_reduction,
:invalidation_rate => invalidation_rate,
:learning_rate => learning_rate,
:variance => variance,
)
ids = findall(predict_label(M, data) .== target)
n_candidates = minimum([size(data.y, 2), 1000])
Expand All @@ -107,7 +97,6 @@ function CounterfactualExplanation(
data,
M,
deepcopy(generator),
generative_model_params,
params,
nothing,
convergence,
Expand Down
3 changes: 0 additions & 3 deletions src/counterfactuals/encodings.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
using MultivariateStats
using StatsBase

"""
encode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,12 @@ function generate_counterfactual(
generator::AbstractGenerator;
num_counterfactuals::Int=1,
initialization::Symbol=:add_perturbation,
generative_model_params::NamedTuple=(;),
max_iter::Int=100,
decision_threshold::AbstractFloat=0.5,
gradient_tol::AbstractFloat=parameters[:τ],
min_success_rate::AbstractFloat=parameters[:min_success_rate],
converge_when::Symbol=:decision_threshold,
timeout::Union{Nothing,Int}=nothing,
invalidation_rate::AbstractFloat=0.1,
learning_rate::AbstractFloat=1.0,
variance::AbstractFloat=0.01,
)
# Initialize:
ce = CounterfactualExplanation(
Expand All @@ -68,15 +64,11 @@ function generate_counterfactual(
generator;
num_counterfactuals=num_counterfactuals,
initialization=initialization,
generative_model_params=generative_model_params,
max_iter=max_iter,
min_success_rate=min_success_rate,
decision_threshold=decision_threshold,
gradient_tol=gradient_tol,
converge_when=converge_when,
invalidation_rate=invalidation_rate,
learning_rate=learning_rate,
variance=variance,
)

# Search:
Expand Down
12 changes: 0 additions & 12 deletions src/counterfactuals/initialisation.jl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now adds random perturbations even for counterfactuals using latent space search. That seems redundant but it also doesn't seem to hurt and is probably more consistent. So happy if you're all happy with this.

Also: Not in this PR, but for consistency should we treat initialization in the same way as convergence? What do you people think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the same way as that it is its own struct?

Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,7 @@ Initializes the starting point for the factual(s):
"""
function initialize_state(ce::CounterfactualExplanation)
@assert ce.initialization ∈ [:identity, :add_perturbation]

s′ = ce.s′
data = ce.data

# No perturbation:
if ce.initialization == :identity
return s′
end

# If latent space, initial point is random anyway:
if ce.params[:latent_space]
return s′
end
Comment on lines -13 to -23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain to me please, why is this deleted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function returns unmodified s′ by default, so we don't need if-statements for the cases where we want to return s′, we only need if-statements for the cases where we want to return something else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, now I noticed that :latent_space can be combined with the :add_perturbations initialization type. Since it doesn't matter for :latent_space whether perturbations are added or not, though, it still seems more consistent to always just do that if initialization is set to :add_perturbations.


# Add random perturbation following Slack (2021): https://arxiv.org/abs/2106.02666
if ce.initialization == :add_perturbation
Expand Down
2 changes: 1 addition & 1 deletion src/counterfactuals/latent_space_mappings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function map_to_latent(

if ce.params[:latent_space]
generative_model = DataPreprocessing.get_generative_model(
data; ce.generative_model_params...
data; generator.generative_model_params...
)
# map counterfactual to latent space: s′=z′∼p(z|x)
s′, _, _ = GenerativeModels.rand(generative_model.encoder, s′)
Expand Down
1 change: 0 additions & 1 deletion src/counterfactuals/path_tracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ end
Returns the counterfactual probabilities for each step of the search.
"""
function counterfactual_probability_path(ce::CounterfactualExplanation)
M = ce.M
p = map(X -> counterfactual_probability(ce, X), path(ce))
return p
end
Expand Down
2 changes: 0 additions & 2 deletions src/generate_counterfactual/generate_counterfactual.jl

This file was deleted.

37 changes: 14 additions & 23 deletions src/generators/gradient_based/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ mutable struct GradientBasedGenerator <: AbstractGradientBasedGenerator
latent_space::Bool
dim_reduction::Bool
opt::Flux.Optimise.AbstractOptimiser
invalidation_rate::Union{Nothing,AbstractFloat}
variance::Union{Nothing,AbstractFloat}
generative_model_params::NamedTuple
end

"""
Expand All @@ -25,6 +28,9 @@ end
λ::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing,
latent_space::Bool::false,
opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
invalidation_rate::AbstractFloat=nothing,
variance::AbstractFloat=nothing,
generative_model_params::NamedTuple=(;),
)

Default outer constructor for `GradientBasedGenerator`.
Expand All @@ -35,6 +41,9 @@ Default outer constructor for `GradientBasedGenerator`.
- `λ::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing`: The weight of the penalty function.
- `latent_space::Bool=false`: Whether to use the latent space of a generative model to generate counterfactuals.
- `opt::Flux.Optimise.AbstractOptimiser=Flux.Descent()`: The optimizer to use for the generator.
- `invalidation_rate::AbstractFloat=nothing`: The invalidation rate of the counterfactual explanation.
- `variance::AbstractFloat=nothing`: The variance term to be used when calculating the invalidation rate of the counterfactual explanation.
- `generative_model_params::NamedTuple`: The parameters of the generative model associated with the generator.

# Returns
- `generator::GradientBasedGenerator`: A gradient-based counterfactual generator.
Expand All @@ -46,34 +55,16 @@ function GradientBasedGenerator(;
latent_space::Bool=false,
dim_reduction::Bool=false,
opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
invalidation_rate::Union{Nothing,AbstractFloat}=nothing,
variance::Union{Nothing,AbstractFloat}=nothing,
generative_model_params::NamedTuple=(;),
)
@assert !(isnothing(λ) && !isnothing(penalty)) "Penalty function(s) provided but no penalty weight(s) provided."
@assert !(isnothing(λ) && !isnothing(penalty)) "Penalty weight(s) provided but no penalty function(s) provided."

if typeof(penalty) <: Vector
@assert length(λ) == length(penalty) || length(λ) == 1 "The number of penalty weights must match the number of penalty functions or be equal to one."
length(λ) == 1 && (λ = fill(λ[1], length(penalty))) # if only one penalty weight is provided, use it for all penalties
end
return GradientBasedGenerator(loss, penalty, λ, latent_space, dim_reduction, opt)
end

"""
Generator(;
loss::Union{Nothing,Function}=nothing,
penalty::Penalty=nothing,
λ::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing,
latent_space::Bool::false,
opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
)

An outer constructor that allows for more convenient creation of the `GradientBasedGenerator` type.
"""
function Generator(;
loss::Union{Nothing,Function}=nothing,
penalty::Penalty=nothing,
λ::Union{Nothing,AbstractFloat,Vector{<:AbstractFloat}}=nothing,
latent_space::Bool=false,
dim_reduction::Bool=false,
opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
)
return GradientBasedGenerator(loss, penalty, λ, latent_space, dim_reduction, opt)
return GradientBasedGenerator(loss, penalty, λ, latent_space, dim_reduction, opt, invalidation_rate, variance, generative_model_params)
end
Loading