Skip to content

Added an option for multiple initial conditions in IDAKLU solver #4981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
- Fixed a bug in the `QuickPlot` which would return empty values for 1D variables at the beginning and end of a timespan. ([#4991](https://github.com/pybamm-team/PyBaMM/pull/4991))
- Fixed a bug in the `Exponential1DSubMesh` where the mesh was not being created correctly for non-zero minimum values. ([#4989](https://github.com/pybamm-team/PyBaMM/pull/4989))

## Features

- Added an option for multiple initial conditions in IDAKLU solver ([#4981](https://github.com/pybamm-team/PyBaMM/pull/4981))

## Breaking changes

- Remove sensitivity functionality for Casadi and Scipy solvers, only `pybamm.IDAKLU` solver can calculate sensitivities. ([#4975](https://github.com/pybamm-team/PyBaMM/pull/4975))
Expand Down
9 changes: 7 additions & 2 deletions src/pybamm/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def solve(
showprogress=False,
inputs=None,
t_interp=None,
initial_conditions=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -537,9 +538,13 @@ def solve(
pybamm.SolverWarning,
stacklevel=2,
)

self._solution = solver.solve(
self._built_model, t_eval, inputs=inputs, t_interp=t_interp, **kwargs
self._built_model,
t_eval,
inputs=inputs,
t_interp=t_interp,
**kwargs,
initial_conditions=initial_conditions,
)

elif self.operating_mode == "with experiment":
Expand Down
9 changes: 8 additions & 1 deletion src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def solve(
nproc=None,
calculate_sensitivities=False,
t_interp=None,
initial_conditions=None,
):
"""
Execute the solver setup and calculate the solution of the model at
Expand Down Expand Up @@ -680,7 +681,12 @@ def solve(
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to None.
Only valid for solvers that support intra-solve interpolation (`IDAKLUSolver`).

initial_conditions : dict, numpy.ndarray, or list, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put somewhere here that this is only for the idaklu solver

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and jax perhaps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah right,added that

Override the model’s default `y0`. Can be:
- a dict mapping variable names → values
- a 1D array of length `n_states`
- a list of such overrides (one per parallel solve)
Only valid for IDAKLU solver.
Returns
-------
:class:`pybamm.Solution` or list of :class:`pybamm.Solution` objects.
Expand Down Expand Up @@ -852,6 +858,7 @@ def solve(
t_eval[start_index:end_index],
model_inputs_list,
t_interp,
initial_conditions,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this arg work for the jax solver?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not now , I'll add support for jax in a seperate PR in future otherwise this would become too long for now

)
else:
ninputs = len(model_inputs_list)
Expand Down
79 changes: 73 additions & 6 deletions src/pybamm/solvers/idaklu_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,44 @@ def _demote_64_to_32(self, x: pybamm.EvaluatorJax):
def supports_parallel_solve(self):
return True

def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
def _apply_solver_initial_conditions(self, model, initial_conditions):
"""
Apply custom initial conditions to a model by overriding model.y0.

Parameters
----------
model : pybamm.BaseModel
A model with a precomputed y0 vector.
initial_conditions : dict or numpy.ndarray
Either a mapping from variable names to values (scalar or array),
or a flat numpy array matching the length of model.y0.
"""
if isinstance(initial_conditions, dict):
y0_np = (
model.y0.full() if isinstance(model.y0, casadi.DM) else model.y0.copy()
)

for var_name, value in initial_conditions.items():
found = False
for symbol, slice_info in model.y_slices.items():
if symbol.name == var_name:
var_slice = slice_info[0]
y0_np[var_slice] = value
found = True
break
if not found:
raise ValueError(f"Variable '{var_name}' not found in model")

model.y0 = casadi.DM(y0_np)

elif isinstance(initial_conditions, np.ndarray):
model.y0 = casadi.DM(initial_conditions)
else:
raise TypeError("Initial conditions must be dict or numpy array")

def _integrate(
self, model, t_eval, inputs_list=None, t_interp=None, initial_conditions=None
):
"""
Solve a DAE model defined by residuals with initial conditions y0.

Expand All @@ -838,6 +875,11 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
t_interp : None, list or ndarray, optional
The times (in seconds) at which to interpolate the solution. Defaults to `None`,
which returns the adaptive time-stepping times.
initial_conditions : dict, numpy.ndarray, or list, optional
Override the model’s default `y0`. Can be:
- a dict mapping variable names → values
- a 1D array of length `n_states`
- a list of such overrides (one per parallel solve)
"""
if not (
model.convert_to_format == "casadi"
Expand All @@ -862,11 +904,36 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None):
else:
inputs = np.array([[]] * len(inputs_list))

# stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states)
# note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support
# different initial conditions for different inputs (see https://github.com/pybamm-team/PyBaMM/pull/4260). For now we just repeat the same initial conditions for each input
y0full = np.vstack([model.y0full] * len(inputs_list))
ydot0full = np.vstack([model.ydot0full] * len(inputs_list))
if initial_conditions is not None:
if isinstance(initial_conditions, list):
if len(initial_conditions) != len(inputs_list):
raise ValueError(
"Number of initial conditions must match number of input sets"
)

y0_list = []

model_copy = model.new_copy()
for ic in initial_conditions:
self._apply_solver_initial_conditions(model_copy, ic)
y0_list.append(model_copy.y0.full().flatten())

y0full = np.vstack(y0_list)
ydot0full = np.zeros_like(y0full)

else:
self._apply_solver_initial_conditions(model, initial_conditions)

y0_np = model.y0.full()

y0full = np.vstack([y0_np for _ in range(len(inputs_list))])
ydot0full = np.zeros_like(y0full)
else:
# stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states)
# note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support
# different initial conditions for different inputs (see https://github.com/pybamm-team/PyBaMM/pull/4260). For now we just repeat the same initial conditions for each input
y0full = np.vstack([model.y0full] * len(inputs_list))
ydot0full = np.vstack([model.ydot0full] * len(inputs_list))

atol = getattr(model, "atol", self.atol)
atol = self._check_atol_type(atol, y0full.size)
Expand Down
8 changes: 7 additions & 1 deletion src/pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def supports_parallel_solve(self):
def requires_explicit_sensitivities(self):
return False

def _integrate(self, model, t_eval, inputs=None, t_interp=None):
def _integrate(
self, model, t_eval, inputs=None, t_interp=None, intial_conditions=None
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are not planning on supporting the jax solver, you need to raise an error if initial_conditions is not None

"""
Solve a model defined by dydt with initial conditions y0.

Expand All @@ -218,6 +220,10 @@ def _integrate(self, model, t_eval, inputs=None, t_interp=None):
various diagnostic messages.

"""
if intial_conditions is not None: # pragma: no cover
raise NotImplementedError(
"Setting initial conditions is not yet implemented for the JAX IDAKLU solver"
)
if isinstance(inputs, dict):
inputs = [inputs]
timer = pybamm.Timer()
Expand Down
66 changes: 66 additions & 0 deletions tests/integration/test_solvers/test_idaklu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pybamm
import numpy as np
import pytest


class TestIDAKLUSolver:
Expand Down Expand Up @@ -208,3 +209,68 @@ def test_with_experiments(self):
sols[0].cycles[-1]["Current [A]"].data,
sols[1].cycles[-1]["Current [A]"].data,
)

@pytest.mark.parametrize(
"model_cls, make_ics",
[
(pybamm.lithium_ion.SPM, lambda y0: [y0, 2 * y0]),
(
pybamm.lithium_ion.DFN,
lambda y0: [y0, y0 * (1 + 0.01 * np.ones_like(y0))],
),
],
)
def test_multiple_initial_conditions_against_independent_solves(
self, model_cls, make_ics
):
model = model_cls()
geom = model.default_geometry
pv = model.default_parameter_values
pv.process_model(model)
pv.process_geometry(geom)
mesh = pybamm.Mesh(geom, model.default_submesh_types, model.default_var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)

t_eval = np.array([0, 1])
solver = pybamm.IDAKLUSolver()

base_sol = solver.solve(model, t_eval)
y0_base = base_sol.y[:, 0]

ics = make_ics(y0_base)
inputs = [{}] * len(ics)

multi_sols = solver.solve(
model,
t_eval,
inputs=inputs,
initial_conditions=ics,
)
assert isinstance(multi_sols, list) and len(multi_sols) == 2

indep_sols = []
for ic in ics:
sol_indep = solver.solve(
model, t_eval, inputs=[{}], initial_conditions=[ic]
)
if isinstance(sol_indep, list):
sol_indep = sol_indep[0]
indep_sols.append(sol_indep)

if model_cls is pybamm.lithium_ion.SPM:
rtol, atol = 1e-8, 1e-10
else:
rtol, atol = 1e-6, 1e-8

for idx in (0, 1):
sol_vec = multi_sols[idx]
sol_ind = indep_sols[idx]

np.testing.assert_allclose(sol_vec.t, sol_ind.t, rtol=1e-12, atol=0)
np.testing.assert_allclose(sol_vec.y, sol_ind.y, rtol=rtol, atol=atol)

if model_cls is pybamm.lithium_ion.SPM:
np.testing.assert_allclose(
sol_vec.y[:, 0], ics[idx], rtol=1e-8, atol=1e-10
)
Loading
Loading