diff --git a/include/amici/abstract_model.h b/include/amici/abstract_model.h index b24b9a2e43..32841f2888 100644 --- a/include/amici/abstract_model.h +++ b/include/amici/abstract_model.h @@ -2,14 +2,18 @@ #define AMICI_ABSTRACT_MODEL_H #include "amici/defines.h" -#include "amici/splinefunctions.h" -#include "amici/sundials_matrix_wrapper.h" -#include "amici/vector.h" + +#include +#include #include +#include namespace amici { class Solver; +class HermiteSpline; +class SUNMatrixWrapper; +class AmiVector; /** * @brief Abstract base class of amici::Model defining functions that need to diff --git a/include/amici/model.h b/include/amici/model.h index 5a032d765f..6ee7ef2ab4 100644 --- a/include/amici/model.h +++ b/include/amici/model.h @@ -4,13 +4,11 @@ #include "amici/abstract_model.h" #include "amici/defines.h" #include "amici/event.h" -#include "amici/logging.h" #include "amici/model_dimensions.h" #include "amici/model_state.h" #include "amici/simulation_parameters.h" #include "amici/splinefunctions.h" #include "amici/sundials_matrix_wrapper.h" -#include "amici/vector.h" #include #include @@ -20,6 +18,9 @@ namespace amici { class ExpData; class Model; class Solver; +class Logger; +class AmiVector; +class AmiVectorArray; } // namespace amici diff --git a/python/sdist/amici/_codegen/cxx_functions.py b/python/sdist/amici/_codegen/cxx_functions.py index a90829090d..574093e60a 100644 --- a/python/sdist/amici/_codegen/cxx_functions.py +++ b/python/sdist/amici/_codegen/cxx_functions.py @@ -45,6 +45,20 @@ class _FunctionInfo: default_return_value: str = "" header: list[str] = field(default_factory=list) + def __post_init__(self): + common_header = [ + '#include "amici/symbolic_functions.h"', + '#include "amici/defines.h"', + "", + # std::{min,find} + "#include ", + ] + if self.sparse: + common_header.append("#include ") + common_header.append("#include ") + + self.header = common_header + self.header + def arguments(self, ode: bool = True) -> str: """Get the arguments for the ODE or DAE function""" if ode or not self.dae_arguments: @@ -324,6 +338,7 @@ def var_in_signature(self, varname: str, ode: bool = True) -> bool: "realtype *x0_fixedParameters, const realtype t, " "const realtype *p, const realtype *k, " "gsl::span reinitialization_state_idxs", + header=["#include "], ), "sx0": _FunctionInfo( "realtype *sx0, const realtype t, const realtype *x, " @@ -333,6 +348,7 @@ def var_in_signature(self, varname: str, ode: bool = True) -> bool: "realtype *sx0_fixedParameters, const realtype t, " "const realtype *x0, const realtype *p, const realtype *k, " "const int ip, gsl::span reinitialization_state_idxs", + header=["#include "], ), "xdot": _FunctionInfo( "realtype *xdot, const realtype t, const realtype *x, " diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 90ca718319..34c1cb0bbc 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -423,18 +423,7 @@ def _write_function_file(self, function: str) -> None: func_info = self.functions[function] # function header - lines.extend( - [ - '#include "amici/symbolic_functions.h"', - '#include "amici/defines.h"', - '#include "sundials/sundials_types.h"', - "", - "#include ", - "#include ", - "", - *func_info.header, - ] - ) + lines.extend(func_info.header) # extract symbols that need definitions from signature # don't add includes for files that won't be generated. diff --git a/src/abstract_model.cpp b/src/abstract_model.cpp index f7db90f2c8..48cf39dad9 100644 --- a/src/abstract_model.cpp +++ b/src/abstract_model.cpp @@ -1,4 +1,6 @@ #include "amici/abstract_model.h" +#include "amici/exception.h" +#include "amici/splinefunctions.h" namespace amici {