Skip to content

Commit 08bf630

Browse files
authored
Refactor SteadystateProblem initialization (AMICI-dev#2844)
Remove `SteadystateProblem::initializeForwardProblem` and move initialization to `ForwardProblem`. This simplifies things since there we don't need to guess whether we are performing pre- or post-equilibration.
1 parent b112466 commit 08bf630

File tree

3 files changed

+20
-36
lines changed

3 files changed

+20
-36
lines changed

include/amici/steadystateproblem.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ class SteadystateProblem {
232232
*
233233
* Tries to determine the steady state of the ODE system and computes
234234
* steady state sensitivities if requested.
235+
* Expects that solver, model, and ws_ are already initialized.
235236
*
236237
* @param solver The solver instance
237238
* @param model The model instance
@@ -469,17 +470,6 @@ class SteadystateProblem {
469470
*/
470471
void runSteadystateSimulationBwd(Solver const& solver, Model& model);
471472

472-
/**
473-
* @brief Initialize forward computation
474-
* @param it Index of the current output time point.
475-
* @param solver pointer to the solver object
476-
* @param model pointer to the model object
477-
* @param t0 Initial time for the steady state simulation.
478-
*/
479-
void initializeForwardProblem(
480-
int it, Solver const& solver, Model& model, realtype t0
481-
);
482-
483473
/**
484474
* @brief Update member variables to indicate that state_.x has been
485475
* updated and xdot_, delta_, etc. need to be recomputed.

src/forwardproblem.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,18 @@ void ForwardProblem::handlePreequilibration() {
175175

176176
preeq_problem_.emplace(&ws_, *solver, *model);
177177
auto t0 = std::isnan(model->t0Preeq()) ? model->t0() : model->t0Preeq();
178+
179+
// The solver was not run before, set up everything.
180+
// TODO: For pre-equilibration in combination with adjoint sensitivities,
181+
// we will need to use a separate solver instance because we still need the
182+
// forward solver for each period for backward integration.
183+
auto roots_found = std::vector<int>(model->ne, 0);
184+
model->initialize(
185+
t0, ws_.x, ws_.dx, ws_.sx, ws_.sdx,
186+
solver->getSensitivityOrder() >= SensitivityOrder::first, roots_found
187+
);
188+
solver->setup(t0, model, ws_.x, ws_.dx, ws_.sx, ws_.sdx);
189+
178190
preeq_problem_->workSteadyStateProblem(*solver, *model, -1, t0);
179191

180192
ws_.x = preeq_problem_->getState();
@@ -270,6 +282,10 @@ void ForwardProblem::handlePostequilibration() {
270282
posteq_problem_.emplace(&ws_, *solver, *model);
271283
auto it = getCurrentTimeIteration();
272284
auto t0 = it < 1 ? model->t0() : model->getTimepoint(it - 1);
285+
286+
// The solver was run before, extract current state from solver.
287+
solver->writeSolution(ws_.t, ws_.x, ws_.dx, ws_.sx);
288+
Expects(t0 == ws_.t);
273289
posteq_problem_->workSteadyStateProblem(*solver, *model, it, t0);
274290
}
275291
}

src/steadystateproblem.cpp

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,11 @@ void SteadystateProblem::workSteadyStateProblem(
184184
);
185185
}
186186

187-
initializeForwardProblem(it, solver, model, t0);
188-
189187
// Compute steady state, track computation time
190188
CpuTimer cpu_timer;
189+
ws_->t = t0;
190+
flagUpdatedState();
191+
newton_solver_.reinitialize();
191192
findSteadyState(solver, model, it, t0);
192193

193194
// Check whether state sensitivities still need to be computed.
@@ -433,29 +434,6 @@ SteadyStateStatus SteadystateProblem::findSteadyStateBySimulation(
433434
}
434435
}
435436

436-
void SteadystateProblem::initializeForwardProblem(
437-
int const it, Solver const& solver, Model& model, realtype const t0
438-
) {
439-
newton_solver_.reinitialize();
440-
441-
// Process solver handling for pre- or postequilibration.
442-
if (it == -1) {
443-
// The solver was not run before, set up everything.
444-
auto roots_found = std::vector<int>(model.ne, 0);
445-
model.initialize(
446-
t0, ws_->x, ws_->dx, ws_->sx, ws_->sdx,
447-
solver.getSensitivityOrder() >= SensitivityOrder::first, roots_found
448-
);
449-
solver.setup(t0, &model, ws_->x, ws_->dx, ws_->sx, ws_->sdx);
450-
} else {
451-
// The solver was run before, extract current state from solver.
452-
solver.writeSolution(ws_->t, ws_->x, ws_->dx, ws_->sx);
453-
}
454-
455-
ws_->t = t0;
456-
flagUpdatedState();
457-
}
458-
459437
void SteadystateProblem::computeSteadyStateQuadrature(
460438
Solver const& solver, Model& model, realtype t0
461439
) {

0 commit comments

Comments
 (0)