Skip to content

Conversation

BSnelling
Copy link
Collaborator

Implements support for petab sciml test suite: https://github.com/sebapersson/petab_sciml_testsuite/tree/main

Includes support for all test_net cases and a subset of test_ude. Test cases with frozen layers and networks in the observable formulae are not yet implemented.

@BSnelling BSnelling requested a review from a team as a code owner September 2, 2025 13:33
@BSnelling BSnelling marked this pull request as draft September 2, 2025 13:39
Copy link

codecov bot commented Sep 2, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 17.87%. Comparing base (f1ece15) to head (e630100).

❗ There is a different number of reports uploaded between BASE (f1ece15) and HEAD (e630100). Click for more details.

HEAD has 4 uploads less than BASE
Flag BASE (f1ece15) HEAD (e630100)
python 3 0
cpp_python 1 0
Additional details and impacted files

Impacted file tree graph

@@              Coverage Diff               @@
##           jax_sciml    #2947       +/-   ##
==============================================
- Coverage      41.44%   17.87%   -23.57%     
==============================================
  Files            303      104      -199     
  Lines          19945    16071     -3874     
  Branches        1501     1412       -89     
==============================================
- Hits            8266     2873     -5393     
- Misses         11654    13198     +1544     
+ Partials          25        0       -25     
Flag Coverage Δ
cpp_python ?
petab 15.36% <99.00%> (?)
python ?
sbmlsuite-jax 33.91% <15.00%> (+1.56%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
python/sdist/amici/de_model.py 55.65% <100.00%> (-27.57%) ⬇️
python/sdist/amici/jax/jaxcodeprinter.py 90.32% <100.00%> (+11.01%) ⬆️
python/sdist/amici/jax/model.py 68.00% <100.00%> (+3.46%) ⬆️
python/sdist/amici/jax/nn.py 97.61% <100.00%> (+83.92%) ⬆️
python/sdist/amici/jax/ode_export.py 88.33% <ø> (+5.00%) ⬆️
python/sdist/amici/jax/petab.py 73.95% <100.00%> (+55.49%) ⬆️
python/sdist/amici/petab/parameter_mapping.py 44.81% <100.00%> (-19.23%) ⬇️
python/sdist/amici/petab/petab_import.py 92.30% <100.00%> (-7.70%) ⬇️

... and 268 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

][:],
dtype=jnp.float64,
),
) # ?? hardcoded dtype not ideal ?? could infer from env somehow ??
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it really necessary to set dtype here? Usually jax infers float precision from https://docs.jax.dev/en/latest/config_options.html. Might be necessary to cast this as numpy array first if conversion from hdf5 is the problem

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried casting to a numpy array but the array persisted as float32s. I've defined the dtype based on the current jax config settings which I think is better than hard coding.

petab.NOMINAL_VALUE,
],
)
if "input"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check seems a bit too unspecific. I think you want to construct a sequence of petab id's that are mapped to $nnId.inputs{[$inputArgumentIndex]{[$inputIndex]}}?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this. The complication is that the values here could be pulled from a nominal value in the parameter table or from a value in the conditions table, depending on whether the id appears in the parameters table. That's my understanding anyway.

dfs.append(df_sc)
return pd.concat(dfs).sort_index()

def apply_grad_filter(problem: JAXProblem,):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great solution! Only thing that I am a bit worried about is that in the current implementation stop_gradient is only applied when calling problem methods in the context of run_simulations, which may lead to confusion when trying to compute gradients outside of that context. My interpretation of the petab problem definition is that setting estimate=0 means that gradient computation is permanently disabled and we should apply apply_grad_filter during JaxProblem instantiation.

@FFroehlich
Copy link
Member

just checking test failures:

  • Notebook tests, also fails on the base branch albeit for a different reasons, but this is not related to changes here
  • mac os, this looks like an issue with PRs from a fork
  • doc tests: also problem in base branch, h5py is missing from doc requirements
  • sbml jax: unrelated, also failing in base branch

@FFroehlich
Copy link
Member

  • mac os, this looks like an issue with PRs from a fork

this is probably not related to failures from forks, but rather CMAKE: #2949 (review)

@FFroehlich
Copy link
Member

just updated the base branch (hopefully without messing up any of the merge conflicts), this should hopefully fix the failing mac tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants