Skip to content

Commit b112466

Browse files
authored
Refactor SteadyStateProblem / FwdSimWorkspace (AMICI-dev#2835)
Use `FwdSimWorkspace` in `SteadyStateProblem` instead of its own `SimulationParameters`. Required to implement event-handling during pre-equilibration (AMICI-dev#2775) later on.
1 parent 56fb8cc commit b112466

File tree

4 files changed

+66
-46
lines changed

4 files changed

+66
-46
lines changed

include/amici/forwardproblem.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ struct FwdSimWorkspace {
127127
, rootvals(gsl::narrow<decltype(rootvals)::size_type>(model->ne), 0.0)
128128

129129
{}
130+
/** current simulation time */
131+
realtype t{NAN};
130132

131133
/** state vector (dimension: nx_solver) */
132134
AmiVector x;

include/amici/steadystateproblem.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class ExpData;
1212
class Solver;
1313
class Model;
1414
class BackwardProblem;
15+
struct FwdSimWorkspace;
1516

1617
/**
1718
* @brief Computes the weighted root-mean-square norm.
@@ -218,10 +219,13 @@ class SteadystateProblem {
218219
/**
219220
* @brief Constructor
220221
*
222+
* @param ws Workspace for forward simulation
221223
* @param solver Solver instance
222224
* @param model Model instance
223225
*/
224-
explicit SteadystateProblem(Solver const& solver, Model& model);
226+
explicit SteadystateProblem(
227+
FwdSimWorkspace* ws, Solver const& solver, Model& model
228+
);
225229

226230
/**
227231
* @brief Compute the steady state in the forward case.
@@ -260,7 +264,7 @@ class SteadystateProblem {
260264
* @return stored SimulationState
261265
*/
262266
[[nodiscard]] SimulationState const& getFinalSimulationState() const {
263-
return state_;
267+
return final_state_;
264268
}
265269

266270
/**
@@ -274,14 +278,14 @@ class SteadystateProblem {
274278
* @brief Return state at steady state
275279
* @return x
276280
*/
277-
[[nodiscard]] AmiVector const& getState() const { return state_.x; }
281+
[[nodiscard]] AmiVector const& getState() const { return final_state_.x; }
278282

279283
/**
280284
* @brief Return state sensitivity at steady state
281285
* @return sx
282286
*/
283287
[[nodiscard]] AmiVectorArray const& getStateSensitivity() const {
284-
return state_.sx;
288+
return final_state_.sx;
285289
}
286290

287291
/**
@@ -310,7 +314,7 @@ class SteadystateProblem {
310314
* @brief Get model time at which steady state was found through simulation.
311315
* @return Time at which steady state was found (model time units).
312316
*/
313-
[[nodiscard]] realtype getSteadyStateTime() const { return state_.t; }
317+
[[nodiscard]] realtype getSteadyStateTime() const { return final_state_.t; }
314318

315319
/**
316320
* @brief Get the weighted root mean square of the residuals.
@@ -496,12 +500,10 @@ class SteadystateProblem {
496500
*/
497501
void updateRightHandSide(Model& model);
498502

503+
/** Workspace for forward simulation */
504+
FwdSimWorkspace* ws_;
499505
/** WRMS computer for x */
500506
WRMSComputer wrms_computer_x_;
501-
/** time derivative state vector */
502-
AmiVector xdot_;
503-
/** state differential sensitivities */
504-
AmiVectorArray sdx_;
505507
/** adjoint state vector */
506508
AmiVector xB_;
507509
/** integral over adjoint state vector */
@@ -511,7 +513,8 @@ class SteadystateProblem {
511513
/** weighted root-mean-square error */
512514
realtype wrms_{NAN};
513515

514-
SimulationState state_;
516+
/** The simulation state at the end of the forward problem. */
517+
SimulationState final_state_;
515518

516519
/** stores diagnostic information about employed number of steps */
517520
std::vector<int> numsteps_{std::vector<int>(3, 0)};

src/forwardproblem.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ void ForwardProblem::handlePreequilibration() {
173173

174174
ConditionContext cc2(model, edata, FixedParameterContext::preequilibration);
175175

176-
preeq_problem_.emplace(*solver, *model);
176+
preeq_problem_.emplace(&ws_, *solver, *model);
177177
auto t0 = std::isnan(model->t0Preeq()) ? model->t0() : model->t0Preeq();
178178
preeq_problem_->workSteadyStateProblem(*solver, *model, -1, t0);
179179

@@ -267,7 +267,7 @@ void ForwardProblem::handleMainSimulation() {
267267

268268
void ForwardProblem::handlePostequilibration() {
269269
if (getCurrentTimeIteration() < model->nt()) {
270-
posteq_problem_.emplace(*solver, *model);
270+
posteq_problem_.emplace(&ws_, *solver, *model);
271271
auto it = getCurrentTimeIteration();
272272
auto t0 = it < 1 ? model->t0() : model->getTimepoint(it - 1);
273273
posteq_problem_->workSteadyStateProblem(*solver, *model, it, t0);

src/steadystateproblem.cpp

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,20 @@ void computeQBfromQ(
115115
}
116116
}
117117

118-
SteadystateProblem::SteadystateProblem(Solver const& solver, Model& model)
119-
: wrms_computer_x_(
118+
SteadystateProblem::SteadystateProblem(
119+
FwdSimWorkspace* ws, Solver const& solver, Model& model
120+
)
121+
: ws_(ws)
122+
, wrms_computer_x_(
120123
model.nx_solver, solver.getSunContext(),
121124
solver.getAbsoluteToleranceSteadyState(),
122125
solver.getRelativeToleranceSteadyState(),
123126
AmiVector(model.get_steadystate_mask(), solver.getSunContext())
124127
)
125-
, xdot_(model.nx_solver, solver.getSunContext())
126-
, sdx_(model.nx_solver, model.nplist(), solver.getSunContext())
127128
, xB_(model.nJ * model.nx_solver, solver.getSunContext())
128129
, xQ_(model.nJ * model.nx_solver, solver.getSunContext())
129130
, xQB_(model.nplist(), solver.getSunContext())
130-
, state_(
131+
, final_state_(
131132
{.t = INFINITY,
132133
.x = AmiVector(model.nx_solver, solver.getSunContext()),
133134
.dx = AmiVector(model.nx_solver, solver.getSunContext()),
@@ -222,7 +223,7 @@ void SteadystateProblem::workSteadyStateProblem(
222223
// This might still fail, if the Jacobian is singular and
223224
// simulation did not find a steady state.
224225
newton_solver_.computeNewtonSensis(
225-
state_.sx, model, {state_.t, state_.x, state_.dx}
226+
ws_->sx, model, {ws_->t, ws_->x, ws_->dx}
226227
);
227228
} catch (NewtonFailure const&) {
228229
throw AmiException(
@@ -232,6 +233,11 @@ void SteadystateProblem::workSteadyStateProblem(
232233
}
233234
}
234235
cpu_time_ = cpu_timer.elapsed_milliseconds();
236+
final_state_.state = model.getModelState();
237+
final_state_.t = ws_->t;
238+
final_state_.x = ws_->x;
239+
final_state_.dx = ws_->dx;
240+
final_state_.sx = ws_->sx;
235241
}
236242

237243
void SteadystateProblem::workSteadyStateBackwardProblem(
@@ -328,7 +334,7 @@ void SteadystateProblem::findSteadyStateByNewtonsMethod(
328334
try {
329335
updateRightHandSide(model);
330336
newtons_method_.run(
331-
xdot_, {state_.t, state_.x, state_.dx}, wrms_computer_x_
337+
ws_->xdot, {ws_->t, ws_->x, ws_->dx}, wrms_computer_x_
332338
);
333339
steady_state_status_[stage] = SteadyStateStatus::success;
334340
} catch (NewtonFailure const& ex) {
@@ -381,7 +387,7 @@ SteadyStateStatus SteadystateProblem::findSteadyStateBySimulation(
381387
sim_solver->setSensitivityMethod(SensitivityMethod::none);
382388
sim_solver->setSensitivityOrder(SensitivityOrder::none);
383389
}
384-
sim_solver->setup(t0, &model, state_.x, state_.dx, state_.sx, sdx_);
390+
sim_solver->setup(t0, &model, ws_->x, ws_->dx, ws_->sx, ws_->sdx);
385391
runSteadystateSimulationFwd(*sim_solver, model);
386392
} else {
387393
// Postequilibration -> Solver was already created, use that one
@@ -437,17 +443,16 @@ void SteadystateProblem::initializeForwardProblem(
437443
// The solver was not run before, set up everything.
438444
auto roots_found = std::vector<int>(model.ne, 0);
439445
model.initialize(
440-
t0, state_.x, state_.dx, state_.sx, sdx_,
446+
t0, ws_->x, ws_->dx, ws_->sx, ws_->sdx,
441447
solver.getSensitivityOrder() >= SensitivityOrder::first, roots_found
442448
);
443-
solver.setup(t0, &model, state_.x, state_.dx, state_.sx, sdx_);
449+
solver.setup(t0, &model, ws_->x, ws_->dx, ws_->sx, ws_->sdx);
444450
} else {
445451
// The solver was run before, extract current state from solver.
446-
solver.writeSolution(state_.t, state_.x, state_.dx, state_.sx);
452+
solver.writeSolution(ws_->t, ws_->x, ws_->dx, ws_->sx);
447453
}
448454

449-
state_.t = t0;
450-
state_.state = model.getModelState();
455+
ws_->t = t0;
451456
flagUpdatedState();
452457
}
453458

@@ -502,11 +507,13 @@ void SteadystateProblem::getQuadratureByLinSolve(Model& model) {
502507
try {
503508
// compute integral over xB and write to xQ
504509
newton_solver_.prepareLinearSystemB(
505-
model, {state_.t, state_.x, state_.dx}
510+
model, {final_state_.t, final_state_.x, final_state_.dx}
506511
);
507512
newton_solver_.solveLinearSystem(xQ_);
508513
// Compute the quadrature as the inner product xQ * dxdotdp
509-
computeQBfromQ(model, xQ_, xQB_, {state_.t, state_.x, state_.dx});
514+
computeQBfromQ(
515+
model, xQ_, xQB_, {final_state_.t, final_state_.x, final_state_.dx}
516+
);
510517
hasQuadrature_ = true;
511518

512519
// Finalize by setting adjoint state to zero (its steady state)
@@ -524,17 +531,17 @@ void SteadystateProblem::getQuadratureBySimulation(
524531
// x is not time-dependent, no forward trajectory is needed.
525532

526533
// Set starting timepoint for the simulation solver
527-
state_.t = t0;
534+
final_state_.t = t0;
528535
// xQ was written in getQuadratureByLinSolve() -> set to zero
529536
xQ_.zero();
530537

531538
auto sim_solver = std::unique_ptr<Solver>(solver.clone());
532539
sim_solver->logger = solver.logger;
533540
sim_solver->setSensitivityMethod(SensitivityMethod::none);
534541
sim_solver->setSensitivityOrder(SensitivityOrder::none);
535-
sim_solver->setup(t0, &model, xB_, xB_, state_.sx, sdx_);
542+
sim_solver->setup(t0, &model, xB_, xB_, final_state_.sx, ws_->sdx);
536543
sim_solver->setupSteadystate(
537-
t0, &model, state_.x, state_.dx, xB_, xB_, xQ_
544+
t0, &model, final_state_.x, final_state_.dx, xB_, xB_, xQ_
538545
);
539546

540547
// perform integration and quadrature
@@ -570,11 +577,11 @@ realtype SteadystateProblem::getWrmsState(Model& model) {
570577
updateRightHandSide(model);
571578

572579
if (newton_step_conv_) {
573-
newtons_method_.compute_step(xdot_, {state_.t, state_.x, state_.dx});
574-
return wrms_computer_x_.wrms(newtons_method_.get_delta(), state_.x);
580+
newtons_method_.compute_step(ws_->xdot, {ws_->t, ws_->x, ws_->dx});
581+
return wrms_computer_x_.wrms(newtons_method_.get_delta(), ws_->x);
575582
}
576583

577-
return wrms_computer_x_.wrms(xdot_, state_.x);
584+
return wrms_computer_x_.wrms(ws_->xdot, ws_->x);
578585
}
579586

580587
realtype
@@ -589,12 +596,12 @@ SteadystateProblem::getWrmsFSA(Model& model, WRMSComputer& wrms_computer_sx) {
589596
xdot_updated_ = false;
590597
for (int ip = 0; ip < model.nplist(); ++ip) {
591598
model.fsxdot(
592-
state_.t, state_.x, state_.dx, ip, state_.sx[ip], state_.dx, xdot_
599+
ws_->t, ws_->x, ws_->dx, ip, ws_->sx[ip], ws_->dx, ws_->xdot
593600
);
594601
if (newton_step_conv_) {
595-
newton_solver_.solveLinearSystem(xdot_);
602+
newton_solver_.solveLinearSystem(ws_->xdot);
596603
}
597-
wrms = wrms_computer_sx.wrms(xdot_, state_.sx[ip]);
604+
wrms = wrms_computer_sx.wrms(ws_->xdot, ws_->sx[ip]);
598605
// ideally this function would report the maximum of all wrms over
599606
// all ip, but for practical purposes we can just report the wrms for
600607
// the first ip where we know that the convergence threshold is not
@@ -673,7 +680,7 @@ void SteadystateProblem::runSteadystateSimulationFwd(
673680

674681
// check for maxsteps
675682
if (sim_steps >= solver.getMaxSteps()) {
676-
throw IntegrationFailure(AMICI_TOO_MUCH_WORK, state_.t);
683+
throw IntegrationFailure(AMICI_TOO_MUCH_WORK, ws_->t);
677684
}
678685

679686
// increase counter
@@ -687,9 +694,9 @@ void SteadystateProblem::runSteadystateSimulationFwd(
687694
// ensure stable computation.
688695
// The value is not important for AMICI_ONE_STEP mode, only the
689696
// direction w.r.t. current t.
690-
solver.step(std::max(state_.t, 1.0) * 10);
697+
solver.step(std::max(ws_->t, 1.0) * 10);
691698

692-
solver.writeSolution(state_.t, state_.x, state_.dx, state_.sx);
699+
solver.writeSolution(ws_->t, ws_->x, ws_->dx, ws_->sx);
693700
flagUpdatedState();
694701
}
695702

@@ -738,8 +745,14 @@ void SteadystateProblem::runSteadystateSimulationBwd(
738745
// exact steadystate is less important, as xB = xQdot may even not
739746
// converge to zero at all. So we need xQBdot, hence compute xQB
740747
// first.
741-
computeQBfromQ(model, xQ_, xQB_, {state_.t, state_.x, state_.dx});
742-
computeQBfromQ(model, xB_, xQBdot, {state_.t, state_.x, state_.dx});
748+
computeQBfromQ(
749+
model, xQ_, xQB_,
750+
{final_state_.t, final_state_.x, final_state_.dx}
751+
);
752+
computeQBfromQ(
753+
model, xB_, xQBdot,
754+
{final_state_.t, final_state_.x, final_state_.dx}
755+
);
743756
wrms_ = wrms_computer_xQB_.wrms(xQBdot, xQB_);
744757
if (wrms_ < conv_thresh) {
745758
break; // converged
@@ -748,7 +761,7 @@ void SteadystateProblem::runSteadystateSimulationBwd(
748761

749762
// check for maxsteps
750763
if (sim_steps >= max_steps) {
751-
throw IntegrationFailureB(AMICI_TOO_MUCH_WORK, state_.t);
764+
throw IntegrationFailureB(AMICI_TOO_MUCH_WORK, final_state_.t);
752765
}
753766

754767
// increase counter
@@ -762,9 +775,11 @@ void SteadystateProblem::runSteadystateSimulationBwd(
762775
// ensure stable computation.
763776
// The value is not important for AMICI_ONE_STEP mode, only the
764777
// direction w.r.t. current t.
765-
solver.step(std::max(state_.t, 1.0) * 10);
778+
solver.step(std::max(final_state_.t, 1.0) * 10);
766779

767-
solver.writeSolution(&state_.t, xB_, state_.dx, state_.sx, xQ_);
780+
solver.writeSolution(
781+
&final_state_.t, xB_, final_state_.dx, final_state_.sx, xQ_
782+
);
768783
}
769784
}
770785

@@ -776,14 +791,14 @@ void SteadystateProblem::flagUpdatedState() {
776791
void SteadystateProblem::updateSensiSimulation(Solver const& solver) {
777792
if (sensis_updated_)
778793
return;
779-
state_.sx = solver.getStateSensitivity(state_.t);
794+
ws_->sx = solver.getStateSensitivity(ws_->t);
780795
sensis_updated_ = true;
781796
}
782797

783798
void SteadystateProblem::updateRightHandSide(Model& model) {
784799
if (xdot_updated_)
785800
return;
786-
model.fxdot(state_.t, state_.x, state_.dx, xdot_);
801+
model.fxdot(ws_->t, ws_->x, ws_->dx, ws_->xdot);
787802
xdot_updated_ = true;
788803
}
789804

0 commit comments

Comments
 (0)