Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
d1a3f2c
done for feature generator
pat-alt Nov 8, 2023
5a20ed2
formatting
pat-alt Nov 8, 2023
f1b9f44
Simplify methods
RaunoArike Nov 23, 2023
1dc1876
Move the variance kwarg
RaunoArike Nov 24, 2023
961f2fe
Small fixes
RaunoArike Nov 24, 2023
1678158
Merge branch 'main' into 358-simplify-the-core-strct
RaunoArike Nov 24, 2023
2c3ca7d
Move files out of generate_counterfactual
RaunoArike Nov 24, 2023
906fcce
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Nov 24, 2023
abe3fb0
Add further improvements
RaunoArike Nov 25, 2023
7299479
Small fixes
RaunoArike Nov 25, 2023
d6dcc90
Merge branch 'main' into 343-overload-the-update-method-preferred-if-…
RaunoArike Nov 25, 2023
463b7d7
Minor fixes
RaunoArike Nov 25, 2023
4fbf2c8
Small fixes
RaunoArike Nov 25, 2023
3cd5ae4
Create structs for convergence types
RaunoArike Nov 27, 2023
e516cc6
Fix formatting
RaunoArike Nov 27, 2023
24b6e79
Minor fix
RaunoArike Nov 27, 2023
59400a2
Merge branch 'main' into 358-simplify-the-core-strct
RaunoArike Nov 27, 2023
cef0524
Fix benchmarking of feature tweak
RaunoArike Nov 27, 2023
d069ffa
Add error test for feature tweak
RaunoArike Nov 27, 2023
7511d68
Rewrite convergence logic
RaunoArike Nov 28, 2023
6bbc668
Small fix
RaunoArike Nov 28, 2023
9638643
Fix bugs
RaunoArike Nov 28, 2023
9a6684a
Update src/convergence/invalidation_rate.jl
RaunoArike Nov 28, 2023
cf9e2c6
Update src/convergence/decision_threshold.jl
RaunoArike Nov 28, 2023
df71519
Minor changes
RaunoArike Nov 28, 2023
403b9a7
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Nov 28, 2023
0d649e3
Merge branch 'main' into 343-overload-the-update-method-preferred-if-…
RaunoArike Nov 28, 2023
93184bf
Minor fix
RaunoArike Nov 28, 2023
1d798d6
Bugfix
RaunoArike Nov 29, 2023
af70576
Merge branch '343-overload-the-update-method-preferred-if-feasible' i…
RaunoArike Nov 29, 2023
2a1fe2a
Small fix
RaunoArike Nov 29, 2023
9e44507
Small fixes
RaunoArike Nov 29, 2023
7f4e086
Small fixes
RaunoArike Nov 29, 2023
73f5161
Merge branch '343-overload-the-update-method-preferred-if-feasible' i…
RaunoArike Nov 29, 2023
491f086
Fix failing tests
RaunoArike Nov 29, 2023
618f89e
Fix formatting
RaunoArike Nov 29, 2023
40cca87
Merge branch 'main' into 358-simplify-the-core-strct
RaunoArike Nov 29, 2023
1a717d2
Minor fixes
RaunoArike Nov 30, 2023
e6f234c
Small improvements
RaunoArike Nov 30, 2023
45d6a99
Update src/convergence/invalidation_rate.jl
RaunoArike Nov 30, 2023
6a42b10
Update src/generators/gradient_based/loss.jl
RaunoArike Nov 30, 2023
b6e58f5
Minor additions
RaunoArike Nov 30, 2023
36227c2
Resolve circular dependency
RaunoArike Nov 30, 2023
e99f69c
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Nov 30, 2023
49fa6b2
Fix failing tests
RaunoArike Nov 30, 2023
9be1220
Update test/generators/probe.jl
RaunoArike Dec 1, 2023
1dfacf0
Add missing method
RaunoArike Dec 1, 2023
a47daff
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Dec 1, 2023
2648df3
Small fix
RaunoArike Dec 1, 2023
5a84574
Small fix
RaunoArike Dec 1, 2023
08cd3c4
Merge branch 'main' into 358-simplify-the-core-strct
RaunoArike Dec 4, 2023
bf75071
Remove redundant generate_perturbations method
RaunoArike Dec 5, 2023
cb37729
small addition to doc string
pat-alt Dec 6, 2023
1f0f904
Minor fixes
RaunoArike Dec 6, 2023
0f2560e
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Dec 6, 2023
436baea
maybe this way?
pat-alt Dec 6, 2023
5add2a2
go
pat-alt Dec 6, 2023
27433ea
huh
pat-alt Dec 6, 2023
bd557ea
tryin to errors
pat-alt Dec 6, 2023
b77e8a8
Update test/generators/probe.jl
RaunoArike Dec 6, 2023
e6fcaa0
Update src/generators/gradient_based/probe.jl
RaunoArike Dec 6, 2023
14efbf0
error with probe
pat-alt Dec 6, 2023
1adddb7
Fix comment
RaunoArike Dec 6, 2023
b84a381
Merge branch '358-simplify-the-core-strct' of https://github.com/Juli…
RaunoArike Dec 6, 2023
db2962b
uff
pat-alt Dec 6, 2023
3ba24fb
returning warning in case InvalidationRateConvergence not used with P…
pat-alt Dec 6, 2023
aad6c08
bleh
pat-alt Dec 6, 2023
36607e5
tryin with a try catch
pat-alt Dec 7, 2023
a1d4066
trying again
pat-alt Dec 7, 2023
cde70ae
namespacing issue
pat-alt Dec 7, 2023
e7835da
another attempt without try/catch
pat-alt Dec 7, 2023
ba74aca
small fix
pat-alt Dec 7, 2023
f1539cb
hm
pat-alt Dec 7, 2023
3c5cc82
uh
pat-alt Dec 7, 2023
6d0c869
uh
pat-alt Dec 7, 2023
bb79e6d
s
pat-alt Dec 7, 2023
4efb266
bam
pat-alt Dec 7, 2023
ab91387
come on
pat-alt Dec 7, 2023
0150597
omg
pat-alt Dec 7, 2023
df3e2f2
bloody hell
pat-alt Dec 7, 2023
f8b72de
shoot me aljeblieft
pat-alt Dec 7, 2023
ff833ee
as if
pat-alt Dec 7, 2023
10f94dd
now pls
pat-alt Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/CounterfactualExplanations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export AbstractCounterfactualExplanation
export AbstractFittedModel
export AbstractGenerator
export AbstractParallelizer
export AbstractConvergence

# Traits:
include("traits/traits.jl")
Expand Down Expand Up @@ -79,11 +80,16 @@ export generator_catalogue
export generate_perturbations, conditions_satisfied, mutability_constraints
export Generator, @objective, @threshold

### Convergence
include("convergence/Convergence.jl")
using .Convergence

### CounterfactualExplanation
# argmin
###
include("counterfactuals/Counterfactuals.jl")
export CounterfactualExplanation
export generate_counterfactual
export initialize!, update!
export total_steps, converged, terminated, path, target_probs
export animate_path
Expand All @@ -93,9 +99,6 @@ export animate_path
include("data/Data.jl")
using .Data

include("generate_counterfactual/generate_counterfactual.jl")
export generate_counterfactual

include("evaluation/Evaluation.jl")
using .Evaluation

Expand Down
3 changes: 3 additions & 0 deletions src/base_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ Base.broadcastable(gen::AbstractGenerator) = Ref(gen)

"An abstract type for parallelizers."
abstract type AbstractParallelizer end

"An abstract type that serves as the base type for convergence objects."
abstract type AbstractConvergence end
60 changes: 60 additions & 0 deletions src/convergence/Convergence.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module Convergence

using Distributions
using Flux
using LinearAlgebra
using ..CounterfactualExplanations
using ..Generators
using ..Models
using ..Objectives

include("decision_threshold.jl")
include("generator_conditions.jl")
include("invalidation_rate.jl")
include("max_iter.jl")

"""
convergence_catalogue

A dictionary containing all convergence criteria.
"""
const convergence_catalogue = Dict(
:decision_threshold => DecisionThresholdConvergence(),
:generator_conditions => GeneratorConditionsConvergence(),
:max_iter => MaxIterConvergence(),
:invalidation_rate => InvalidationRateConvergence(),
)

"""
get_convergence_type(convergence::AbstractConvergence)

Returns the convergence object.
"""
function get_convergence_type(convergence::AbstractConvergence)
return convergence
end

"""
get_convergence_type(convergence::Symbol)

Returns the convergence object from the dictionary of default convergence types.
"""
function get_convergence_type(convergence::Symbol)
return get(
convergence_catalogue,
convergence,
() -> error("Convergence criterion not recognized: $convergence."),
)
end

export convergence_catalogue
export converged
export get_convergence_type
export hinge_loss, invalidation_rate
export threshold_reached
export DecisionThresholdConvergence
export GeneratorConditionsConvergence
export InvalidationRateConvergence
export MaxIterConvergence

end
27 changes: 27 additions & 0 deletions src/convergence/decision_threshold.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Base.@kwdef struct DecisionThresholdConvergence <: AbstractConvergence
decision_threshold::AbstractFloat = 0.5
max_iter::Int = 100
min_success_rate::AbstractFloat = 0.75
end

"""
converged(convergence::DecisionThresholdConvergence, ce::CounterfactualExplanation)

Checks if the counterfactual search has converged when the convergence criterion is the decision threshold.
"""
function converged(
convergence::DecisionThresholdConvergence, ce::AbstractCounterfactualExplanation
)
return threshold_reached(ce)
end

"""
threshold_reached(ce::CounterfactualExplanation)

Determines if the predefined threshold for the target class probability has been reached.
"""
function threshold_reached(ce::AbstractCounterfactualExplanation)
γ = ce.convergence.decision_threshold
success_rate = sum(target_probs(ce) .>= γ) / ce.num_counterfactuals
return success_rate > ce.convergence.min_success_rate
end
17 changes: 17 additions & 0 deletions src/convergence/generator_conditions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
Base.@kwdef struct GeneratorConditionsConvergence <: AbstractConvergence
decision_threshold::AbstractFloat = 0.5
gradient_tol::AbstractFloat = 1e-2
max_iter::Int = 100
min_success_rate::AbstractFloat = 0.75
end

"""
converged(convergence::GeneratorConditionsConvergence, ce::CounterfactualExplanation)

Checks if the counterfactual search has converged when the convergence criterion is generator_conditions.
"""
function converged(
convergence::GeneratorConditionsConvergence, ce::AbstractCounterfactualExplanation
)
return threshold_reached(ce) && Generators.conditions_satisfied(ce.generator, ce)
end
18 changes: 18 additions & 0 deletions src/convergence/invalidation_rate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Base.@kwdef struct InvalidationRateConvergence <: AbstractConvergence
invalidation_rate::AbstractFloat = 0.1
max_iter::Int = 100
variance::AbstractFloat = 0.01
end

"""
converged(convergence::InvalidationRateConvergence, ce::CounterfactualExplanation)

Checks if the counterfactual search has converged when the convergence criterion is invalidation rate.
"""
function converged(
convergence::InvalidationRateConvergence, ce::AbstractCounterfactualExplanation
)
ir = Objectives.invalidation_rate(ce)
label = Models.predict_label(ce.M, ce.data, ce.x′)[1]
return label == ce.target && convergence.invalidation_rate > ir
end
12 changes: 12 additions & 0 deletions src/convergence/max_iter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Base.@kwdef struct MaxIterConvergence <: AbstractConvergence
max_iter::Int = 100
end

"""
converged(convergence::MaxIterConvergence, ce::CounterfactualExplanation)

Checks if the counterfactual search has converged when the convergence criterion is maximum iterations. This means the counterfactual search will not terminate until the maximum number of iterations has been reached independently of the other convergence criteria.
"""
function converged(convergence::MaxIterConvergence, ce::AbstractCounterfactualExplanation)
return ce.search[:iteration_count] == convergence.max_iter
end
7 changes: 5 additions & 2 deletions src/counterfactuals/Counterfactuals.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
using .Convergence
using .DataPreprocessing
using .GenerativeModels
using .Generators
using .Models
using ChainRulesCore
using Flux
using MLUtils
using MultivariateStats
using Statistics
using StatsBase

include("core_struct.jl")

include("convergence.jl")
include("encodings.jl")
include("generate_counterfactual.jl")
include("growing_spheres.jl")
include("info_extraction.jl")
include("initialisation.jl")
include("latent_space_mappings.jl")
include("path_tracking.jl")
include("printing.jl")
include("search.jl")
include("utils.jl")
include("vectorised.jl")
89 changes: 0 additions & 89 deletions src/counterfactuals/convergence.jl

This file was deleted.

Loading