Skip to content

Commit b98c2ea

Browse files
committed
Reduce includes in models
Remove unnecessary includes in the generated code. Reduces serial compilation time by a second or two.
1 parent f67cbcc commit b98c2ea

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

python/sdist/amici/_codegen/cxx_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ class _FunctionInfo:
4545
default_return_value: str = ""
4646
header: list[str] = field(default_factory=list)
4747

48+
def __post_init__(self):
49+
common_header = [
50+
'#include "amici/symbolic_functions.h"',
51+
'#include "amici/defines.h"',
52+
"",
53+
# std::{min,find}
54+
"#include <algorithm>",
55+
]
56+
if self.sparse:
57+
common_header.append("#include <sundials/sundials_types.h>")
58+
common_header.append("#include <gsl/gsl-lite.hpp>")
59+
60+
self.header = common_header + self.header
61+
4862
def arguments(self, ode: bool = True) -> str:
4963
"""Get the arguments for the ODE or DAE function"""
5064
if ode or not self.dae_arguments:
@@ -324,6 +338,7 @@ def var_in_signature(self, varname: str, ode: bool = True) -> bool:
324338
"realtype *x0_fixedParameters, const realtype t, "
325339
"const realtype *p, const realtype *k, "
326340
"gsl::span<const int> reinitialization_state_idxs",
341+
header=["#include <gsl/gsl-lite.hpp>"],
327342
),
328343
"sx0": _FunctionInfo(
329344
"realtype *sx0, const realtype t, const realtype *x, "
@@ -333,6 +348,7 @@ def var_in_signature(self, varname: str, ode: bool = True) -> bool:
333348
"realtype *sx0_fixedParameters, const realtype t, "
334349
"const realtype *x0, const realtype *p, const realtype *k, "
335350
"const int ip, gsl::span<const int> reinitialization_state_idxs",
351+
header=["#include <gsl/gsl-lite.hpp>"],
336352
),
337353
"xdot": _FunctionInfo(
338354
"realtype *xdot, const realtype t, const realtype *x, "

python/sdist/amici/de_export.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -459,18 +459,7 @@ def _write_function_file(self, function: str) -> None:
459459
func_info = self.functions[function]
460460

461461
# function header
462-
lines.extend(
463-
[
464-
'#include "amici/symbolic_functions.h"',
465-
'#include "amici/defines.h"',
466-
'#include "sundials/sundials_types.h"',
467-
"",
468-
"#include <gsl/gsl-lite.hpp>",
469-
"#include <algorithm>",
470-
"",
471-
*func_info.header,
472-
]
473-
)
462+
lines.extend(func_info.header)
474463

475464
# extract symbols that need definitions from signature
476465
# don't add includes for files that won't be generated.

0 commit comments

Comments
 (0)