Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
/Manifest.toml
/docs/Manifest.toml
/docs/build/
LocalPreferences.toml
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Version [1.1.3] - 2025-01-08

- Improved the way objects are shared and distributed across process to avoid memory issues. [#21]

## Version [1.1.2] - 2024-11-18

### Changed
Expand Down
36 changes: 18 additions & 18 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
authors = ["Patrick Altmeyer <[email protected]>"]
name = "TaijaParallel"
uuid = "bf1c2c22-5e42-4e78-8b6b-92e6c673eeb0"
authors = ["Patrick Altmeyer <[email protected]>"]
version = "1.1.2"

[deps]
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
TaijaBase = "10284c91-9f28-4c9a-abbf-ee43576dfff6"

[weakdeps]
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"

[extensions]
MPIExt = "MPI"
version = "1.1.3"

[compat]
Aqua = "0.8"
Expand All @@ -33,10 +17,26 @@ TaijaBase = "1"
Test = "1"
julia = "1.10"

[deps]
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
TaijaBase = "10284c91-9f28-4c9a-abbf-ee43576dfff6"

[extensions]
MPIExt = "MPI"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "MPI", "Test"]

[weakdeps]
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
3 changes: 2 additions & 1 deletion ext/MPIExt/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function TaijaBase.parallelize(
if parallelizer.rank == 0 && verbose
# Generating counterfactuals with progress bar:
output = []
@showprogress desc = "Evaluating counterfactuals ..." for x in zip(
@showprogress desc = "Evaluating counterfactuals using MPI ..." for x in zip(
eachcol(worker_chunk)...,
)
with_logger(NullLogger()) do
Expand All @@ -70,6 +70,7 @@ function TaijaBase.parallelize(
second_parallelizer,
f,
eachcol(worker_chunk)...;
verbose=verbose,
kwargs...,
)
end
Expand Down
15 changes: 9 additions & 6 deletions ext/MPIExt/generate_counterfactual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ function TaijaBase.parallelize(
# Extract positional arguments:
counterfactuals = args[1] |> x -> TaijaBase.vectorize_collection(x)
target = args[2] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))
data = args[3] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))
@assert !isa(args[3], AbstractArray) "Cannot generate counterfactuals for mutliple datasets in parallel."
data = args[3]
M = args[4] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))
generator = args[5] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))

# Break down into chunks:
args = zip(counterfactuals, target, data, M, generator)
args = zip(counterfactuals, target, M, generator)
if !isnothing(n_each)
chunks = chunk_obs(args, n_each, parallelizer.n_proc)
else
Expand All @@ -46,12 +47,13 @@ function TaijaBase.parallelize(
worker_chunk = TaijaParallel.split_obs(chunk, parallelizer.n_proc)
worker_chunk = MPI.scatter(worker_chunk, parallelizer.comm)
worker_chunk = stack(worker_chunk; dims = 1)
_x, _target, _M, _generator = eachcol(worker_chunk)
if !parallelizer.threaded
if parallelizer.rank == 0 && verbose
# Generating counterfactuals with progress bar:
output = []
@showprogress desc = "Generating counterfactuals ..." for x in zip(
eachcol(worker_chunk)...,
@showprogress desc = "Generating counterfactuals using MPI ..." for x in zip(
_x, _target, fill(data, length(_generator)), _M, _generator,
)
with_logger(NullLogger()) do
push!(output, f(x...; kwargs...))
Expand All @@ -60,7 +62,7 @@ function TaijaBase.parallelize(
else
# Generating counterfactuals without progress bar:
output = with_logger(NullLogger()) do
f.(eachcol(worker_chunk)...; kwargs...)
f.(_x, _target, data, _M, _generator; kwargs...)
end
end
else
Expand All @@ -69,7 +71,8 @@ function TaijaBase.parallelize(
output = TaijaBase.parallelize(
second_parallelizer,
f,
eachcol(worker_chunk)...;
_x, _target, data, _M, _generator;
verbose = verbose,
kwargs...,
)
end
Expand Down
2 changes: 1 addition & 1 deletion src/CounterfactualExplanations.jl/threads/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function TaijaBase.parallelize(
if verbose
prog = ProgressMeter.Progress(
length(counterfactuals);
desc = "Evaluating counterfactuals ...",
desc = "Evaluating counterfactuals using multi-threading ...",
showspeed = true,
color = :green,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ function TaijaBase.parallelize(
args = zip(counterfactuals, target, data, M, generator)

# Preallocate a vector for storing results in the original order
ces = Vector{CounterfactualExplanations.AbstractCounterfactualExplanation}(undef, length(args))
return_flattened = get(kwargs, :return_flattened, false)
if return_flattened
ces = Vector{CounterfactualExplanations.FlattenedCE}(undef, length(args))
else
ces = Vector{CounterfactualExplanations.CounterfactualExplanation}(undef, length(args))
end

# Verbosity setup:
if verbose
prog = ProgressMeter.Progress(
length(args);
desc="Generating counterfactuals ...",
desc="Generating counterfactuals using multi-threading ...",
showspeed=true,
color=:green,
)
Expand Down
3 changes: 0 additions & 3 deletions src/TaijaParallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ macro with_parallelizer(parallelizer, expr)
expr = expr.args[end]
end

Meta.show_sexpr(expr)
println("")

# Unpack arguments:
pllr = esc(parallelizer)
f = esc(expr.args[1])
Expand Down
14 changes: 11 additions & 3 deletions test/CounterfactualExplanations.jl/CounterfactualExplanations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ using MPIPreferences
nprocs_str = get(ENV, "JULIA_MPI_TEST_NPROCS", "")
nprocs = nprocs_str == "" ? clamp(Sys.CPU_THREADS, 2, 8) : parse(Int, nprocs_str)

@testset "Threads" begin
include("threads.jl")
end
# @testset "Threads" begin
# include("threads.jl")
# end

@testset "MPI" begin
n = nprocs # number of processes
Expand All @@ -15,3 +15,11 @@ end
end
@test true
end

# @testset "Comparison" begin
# n = nprocs
# mpiexec() do exe # MPI wrapper
# run(`$exe -n $n $(Base.julia_cmd()) CounterfactualExplanations.jl/comparison.jl`)
# end
# @test true
# end
44 changes: 44 additions & 0 deletions test/CounterfactualExplanations.jl/comparison.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using CounterfactualExplanations
using CounterfactualExplanations: counterfactual
using CounterfactualExplanations.DataPreprocessing: CounterfactualData
using CounterfactualExplanations.Convergence
using CounterfactualExplanations.Evaluation: benchmark
using CounterfactualExplanations.Models
using Logging
using TaijaData
using TaijaParallel
using Test

data = TaijaData.load_linearly_separable()
counterfactual_data = CounterfactualData(data[1], data[2])

M = fit_model(counterfactual_data, :MLP)
conv = DecisionThresholdConvergence(decision_threshold=0.95)
generator = GenericGenerator()
factual = 1
target = 2
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual), 1000)
xs = select_factual(counterfactual_data, chosen)

# No parallelizer
parallelizer = nothing
ces = @with_parallelizer parallelizer begin
generate_counterfactual(xs, target, counterfactual_data, M, generator; convergence=conv, initialization=:identity)
end

# Threads
parallelizer = ThreadsParallelizer()
ces_threads = @with_parallelizer parallelizer begin
generate_counterfactual(xs, target, counterfactual_data, M, generator; convergence=conv, initialization=:identity)
end

# MPI
using MPI
MPI.Init()
parallelizer = TaijaParallel.MPIParallelizer(MPI.COMM_WORLD)
ces_mpi = @with_parallelizer parallelizer begin
generate_counterfactual(xs, target, counterfactual_data, M, generator; convergence=conv, initialization=:identity)
end

@test all(counterfactual.(ces) .== counterfactual.(ces_threads))
@test all(counterfactual.(ces) .== counterfactual.(ces_mpi))
38 changes: 36 additions & 2 deletions test/CounterfactualExplanations.jl/mpi.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
using CounterfactualExplanations
using CounterfactualExplanations: counterfactual
using CounterfactualExplanations.DataPreprocessing: CounterfactualData
using CounterfactualExplanations.Convergence
using CounterfactualExplanations.Evaluation: benchmark
using CounterfactualExplanations.Models
using Logging
using TaijaData
using TaijaParallel
Expand All @@ -8,12 +12,42 @@ using Test
# Initialize MPI
using MPI
MPI.Init()
parallelizer = TaijaParallel.MPIParallelizer(MPI.COMM_WORLD)

data = TaijaData.load_linearly_separable()
counterfactual_data = CounterfactualData(data[1], data[2])
parallelizer = TaijaParallel.MPIParallelizer(MPI.COMM_WORLD)

# Select factuals:
M = fit_model(counterfactual_data, :MLP)
conv = MaxIterConvergence(10)
generator = GenericGenerator()
factual = 1
target = 2
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual), 1000)
xs = select_factual(counterfactual_data, chosen)
target = fill(2, length(xs))

ces = TaijaParallel.parallelize(
parallelizer,
CounterfactualExplanations.generate_counterfactual,
xs,
target,
counterfactual_data,
M,
generator;
convergence=conv,
initialization=:identity,
)

nsteps = (ce -> total_steps(ce)).(ces)
if MPI.Comm_rank(MPI.COMM_WORLD) == 0
println("Total steps: ", nsteps)
end
@test allequal(nsteps)

# Benchmark CE with MPI
with_logger(NullLogger()) do
bmk = benchmark(counterfactual_data; parallelizer = parallelizer)
bmk = benchmark(counterfactual_data; parallelizer=parallelizer)
end
MPI.Finalize()
@test MPI.Finalized()
2 changes: 2 additions & 0 deletions test/CounterfactualExplanations.jl/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ evals = @with_parallelizer parallelizer begin
evaluate(ces)
end

bmk = benchmark(counterfactual_data; convergence=:generator_conditions, parallelizer=parallelizer)

@test true
Loading
Loading