-
-
Notifications
You must be signed in to change notification settings - Fork 634
Added an option for multiple initial conditions in IDAKLU solver #4981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
23e50f9
fe86f2b
8b07bbd
1b3e547
b09df51
320197e
2c4833c
e86adc1
32242e1
d261930
72ef8a4
72b8ed4
25e8941
22028e0
6ac742d
c41f945
a154289
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -823,7 +823,44 @@ def _demote_64_to_32(self, x: pybamm.EvaluatorJax): | |
def supports_parallel_solve(self): | ||
return True | ||
|
||
def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): | ||
def _apply_solver_initial_conditions(self, model, initial_conditions): | ||
""" | ||
Apply custom initial conditions to a model by overriding model.y0. | ||
|
||
Parameters | ||
---------- | ||
model : pybamm.BaseModel | ||
A model with a precomputed y0 vector. | ||
initial_conditions : dict or numpy.ndarray | ||
Either a mapping from variable names to values (scalar or array), | ||
or a flat numpy array matching the length of model.y0. | ||
""" | ||
if isinstance(initial_conditions, dict): | ||
y0_np = ( | ||
model.y0.full() if isinstance(model.y0, casadi.DM) else model.y0.copy() | ||
) | ||
|
||
for var_name, value in initial_conditions.items(): | ||
found = False | ||
for symbol, slice_info in model.y_slices.items(): | ||
if symbol.name == var_name: | ||
var_slice = slice_info[0] | ||
y0_np[var_slice] = value | ||
found = True | ||
break | ||
if not found: | ||
raise ValueError(f"Variable '{var_name}' not found in model") | ||
|
||
model.y0 = casadi.DM(y0_np) | ||
|
||
elif isinstance(initial_conditions, np.ndarray): | ||
model.y0 = casadi.DM(initial_conditions) | ||
else: | ||
raise TypeError("Initial conditions must be dict or numpy array") | ||
|
||
def _integrate( | ||
self, model, t_eval, inputs_list=None, t_interp=None, initial_conditions=None | ||
): | ||
""" | ||
Solve a DAE model defined by residuals with initial conditions y0. | ||
|
||
|
@@ -838,6 +875,13 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): | |
t_interp : None, list or ndarray, optional | ||
The times (in seconds) at which to interpolate the solution. Defaults to `None`, | ||
which returns the adaptive time-stepping times. | ||
initial_conditions : dict, numpy.ndarray, or list, optional | ||
Override the model’s default `y0`. Can be: | ||
|
||
- a dict mapping variable names → values | ||
- a 1D array of length `n_states` | ||
- a list of such overrides (one per parallel solve) | ||
|
||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: I mentioned it on this one, but I have seen this block several times There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for telling, I was not aware of it, will be extra careful next time, I've fixed it here |
||
if not ( | ||
model.convert_to_format == "casadi" | ||
|
@@ -862,11 +906,36 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): | |
else: | ||
inputs = np.array([[]] * len(inputs_list)) | ||
|
||
# stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states) | ||
# note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support | ||
# different initial conditions for different inputs (see https://github.com/pybamm-team/PyBaMM/pull/4260). For now we just repeat the same initial conditions for each input | ||
y0full = np.vstack([model.y0full] * len(inputs_list)) | ||
ydot0full = np.vstack([model.ydot0full] * len(inputs_list)) | ||
if initial_conditions is not None: | ||
if isinstance(initial_conditions, list): | ||
if len(initial_conditions) != len(inputs_list): | ||
raise ValueError( | ||
"Number of initial conditions must match number of input sets" | ||
) | ||
|
||
y0_list = [] | ||
|
||
model_copy = model.new_copy() | ||
for ic in initial_conditions: | ||
self._apply_solver_initial_conditions(model_copy, ic) | ||
y0_list.append(model_copy.y0.full().flatten()) | ||
|
||
y0full = np.vstack(y0_list) | ||
ydot0full = np.zeros_like(y0full) | ||
|
||
else: | ||
self._apply_solver_initial_conditions(model, initial_conditions) | ||
|
||
y0_np = model.y0.full() | ||
|
||
y0full = np.vstack([y0_np for _ in range(len(inputs_list))]) | ||
ydot0full = np.zeros_like(y0full) | ||
else: | ||
# stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states) | ||
# note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support | ||
# different initial conditions for different inputs (see https://github.com/pybamm-team/PyBaMM/pull/4260). For now we just repeat the same initial conditions for each input | ||
y0full = np.vstack([model.y0full] * len(inputs_list)) | ||
ydot0full = np.vstack([model.ydot0full] * len(inputs_list)) | ||
|
||
atol = getattr(model, "atol", self.atol) | ||
atol = self._check_atol_type(atol, y0full.size) | ||
|
Uh oh!
There was an error while loading. Please reload this page.