Skip to content

Added an option for multiple initial conditions in IDAKLU solver #4981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Features

- Generalise `pybamm.DiscreteTimeSum` to allow it to be embedded in other expressions ([#5044](https://github.com/pybamm-team/PyBaMM/pull/5044))
- Added an option for multiple initial conditions in IDAKLU solver ([#4981](https://github.com/pybamm-team/PyBaMM/pull/4981))

## Bug fixes

Expand Down
9 changes: 7 additions & 2 deletions src/pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def solve(
showprogress=False,
inputs=None,
t_interp=None,
initial_conditions=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -539,9 +540,13 @@ def solve(
pybamm.SolverWarning,
stacklevel=2,
)

self._solution = solver.solve(
self._built_model, t_eval, inputs=inputs, t_interp=t_interp, **kwargs
self._built_model,
t_eval,
inputs=inputs,
t_interp=t_interp,
**kwargs,
initial_conditions=initial_conditions,
)

elif self.operating_mode == "with experiment":
Expand Down
9 changes: 9 additions & 0 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def solve(
nproc=None,
calculate_sensitivities=False,
t_interp=None,
initial_conditions=None,
):
"""
Execute the solver setup and calculate the solution of the model at
Expand Down Expand Up @@ -680,7 +681,14 @@ def solve(
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to None.
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).
initial_conditions : dict, numpy.ndarray, or list, optional
Override the model’s default `y0`. Can be:

- a dict mapping variable names → values
- a 1D array of length `n_states`
- a list of such overrides (one per parallel solve)

Only valid for IDAKLU solver.
Returns
-------
:class:`pybamm.Solution` or list of :class:`pybamm.Solution` objects.
Expand Down Expand Up @@ -852,6 +860,7 @@ def solve(
t_eval[start_index:end_index],
model_inputs_list,
t_interp,
initial_conditions,
)
else:
ninputs = len(model_inputs_list)
Expand Down
81 changes: 75 additions & 6 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,44 @@ def _demote_64_to_32(self, x: pybamm.EvaluatorJax):
def supports_parallel_solve(self):
return True

def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
def _apply_solver_initial_conditions(self, model, initial_conditions):
"""
Apply custom initial conditions to a model by overriding model.y0.

Parameters
----------
model : pybamm.BaseModel
A model with a precomputed y0 vector.
initial_conditions : dict or numpy.ndarray
Either a mapping from variable names to values (scalar or array),
or a flat numpy array matching the length of model.y0.
"""
if isinstance(initial_conditions, dict):
y0_np = (
model.y0.full() if isinstance(model.y0, casadi.DM) else model.y0.copy()
)

for var_name, value in initial_conditions.items():
found = False
for symbol, slice_info in model.y_slices.items():
if symbol.name == var_name:
var_slice = slice_info[0]
y0_np[var_slice] = value
found = True
break
if not found:
raise ValueError(f"Variable '{var_name}' not found in model")

model.y0 = casadi.DM(y0_np)

elif isinstance(initial_conditions, np.ndarray):
model.y0 = casadi.DM(initial_conditions)
else:
raise TypeError("Initial conditions must be dict or numpy array")

def _integrate(
self, model, t_eval, inputs_list=None, t_interp=None, initial_conditions=None
):
"""
Solve a DAE model defined by residuals with initial conditions y0.

Expand All @@ -838,6 +875,13 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to `None`,
which returns the adaptive time-stepping times.
initial_conditions : dict, numpy.ndarray, or list, optional
Override the model’s default `y0`. Can be:

- a dict mapping variable names → values
- a 1D array of length `n_states`
- a list of such overrides (one per parallel solve)

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These blocks do not render correctly in sphinx. Usually the fix is to have black lines before and/or after the list. You can view the generated docs in either the RTD CI or by running locally when you fix this.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I mentioned it on this one, but I have seen this block several times

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for telling, I was not aware of it, will be extra careful next time, I've fixed it here

if not (
model.convert_to_format == "casadi"
Expand All @@ -862,11 +906,36 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
else:
inputs = np.array([[]] * len(inputs_list))

# stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states)
# note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support
# different initial conditions for different inputs (see https://github.com/pybamm-team/PyBaMM/pull/4260). For now we just repeat the same initial conditions for each input
y0full = np.vstack([model.y0full] * len(inputs_list))
ydot0full = np.vstack([model.ydot0full] * len(inputs_list))
if initial_conditions is not None:
if isinstance(initial_conditions, list):
if len(initial_conditions) != len(inputs_list):
raise ValueError(
"Number of initial conditions must match number of input sets"
)

y0_list = []

model_copy = model.new_copy()
for ic in initial_conditions:
self._apply_solver_initial_conditions(model_copy, ic)
y0_list.append(model_copy.y0.full().flatten())

y0full = np.vstack(y0_list)
ydot0full = np.zeros_like(y0full)

else:
self._apply_solver_initial_conditions(model, initial_conditions)

y0_np = model.y0.full()

y0full = np.vstack([y0_np for _ in range(len(inputs_list))])
ydot0full = np.zeros_like(y0full)
else:
# stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states)
# note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support
# different initial conditions for different inputs (see https://github.com/pybamm-team/PyBaMM/pull/4260). For now we just repeat the same initial conditions for each input
y0full = np.vstack([model.y0full] * len(inputs_list))
ydot0full = np.vstack([model.ydot0full] * len(inputs_list))

atol = getattr(model, "atol", self.atol)
atol = self._check_atol_type(atol, y0full.size)
Expand Down
8 changes: 7 additions & 1 deletion src/pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def supports_parallel_solve(self):
def requires_explicit_sensitivities(self):
return False

def _integrate(self, model, t_eval, inputs=None, t_interp=None):
def _integrate(
self, model, t_eval, inputs=None, t_interp=None, intial_conditions=None
):
"""
Solve a model defined by dydt with initial conditions y0.

Expand All @@ -220,6 +222,10 @@ def _integrate(self, model, t_eval, inputs=None, t_interp=None):
various diagnostic messages.

"""
if intial_conditions is not None: # pragma: no cover
raise NotImplementedError(
"Setting initial conditions is not yet implemented for the JAX IDAKLU solver"
)
if isinstance(inputs, dict):
inputs = [inputs]
timer = pybamm.Timer()
Expand Down
66 changes: 66 additions & 0 deletions tests/integration/test_solvers/test_idaklu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

import pybamm

Expand Down Expand Up @@ -209,3 +210,68 @@ def test_with_experiments(self):
sols[0].cycles[-1]["Current [A]"].data,
sols[1].cycles[-1]["Current [A]"].data,
)

@pytest.mark.parametrize(
"model_cls, make_ics",
[
(pybamm.lithium_ion.SPM, lambda y0: [y0, 2 * y0]),
(
pybamm.lithium_ion.DFN,
lambda y0: [y0, y0 * (1 + 0.01 * np.ones_like(y0))],
),
],
)
def test_multiple_initial_conditions_against_independent_solves(
self, model_cls, make_ics
):
model = model_cls()
geom = model.default_geometry
pv = model.default_parameter_values
pv.process_model(model)
pv.process_geometry(geom)
mesh = pybamm.Mesh(geom, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

t_eval = np.array([0, 1])
solver = pybamm.IDAKLUSolver()

base_sol = solver.solve(model, t_eval)
y0_base = base_sol.y[:, 0]

ics = make_ics(y0_base)
inputs = [{}] * len(ics)

multi_sols = solver.solve(
model,
t_eval,
inputs=inputs,
initial_conditions=ics,
)
assert isinstance(multi_sols, list) and len(multi_sols) == 2

indep_sols = []
for ic in ics:
sol_indep = solver.solve(
model, t_eval, inputs=[{}], initial_conditions=[ic]
)
if isinstance(sol_indep, list):
sol_indep = sol_indep[0]
indep_sols.append(sol_indep)

if model_cls is pybamm.lithium_ion.SPM:
rtol, atol = 1e-8, 1e-10
else:
rtol, atol = 1e-6, 1e-8

for idx in (0, 1):
sol_vec = multi_sols[idx]
sol_ind = indep_sols[idx]

np.testing.assert_allclose(sol_vec.t, sol_ind.t, rtol=1e-12, atol=0)
np.testing.assert_allclose(sol_vec.y, sol_ind.y, rtol=rtol, atol=atol)

if model_cls is pybamm.lithium_ion.SPM:
np.testing.assert_allclose(
sol_vec.y[:, 0], ics[idx], rtol=1e-8, atol=1e-10
)
Loading