diff --git a/tests/test_api_model_hw_parametrization.py b/tests/test_api_model_hw_parametrization.py index 9d389710a..ab5d2f49e 100644 --- a/tests/test_api_model_hw_parametrization.py +++ b/tests/test_api_model_hw_parametrization.py @@ -118,6 +118,13 @@ def test_model_scaling_against_rod( assert jnp.allclose(scaled_metadata.L_H_G, pre_scaled_metadata.L_H_G, atol=1e-6) assert jnp.allclose(scaled_metadata.L_H_vis, pre_scaled_metadata.L_H_vis, atol=1e-6) + # Compare collidable points positions + assert jnp.allclose( + jaxsim_model_garpez_scaled.kin_dyn_parameters.contact_parameters.point, + updated_model.kin_dyn_parameters.contact_parameters.point, + atol=1e-6, + ) + def test_update_hw_parameters_vmap( jaxsim_model_garpez: js.model.JaxSimModel, @@ -209,94 +216,97 @@ def test_export_updated_model( ), density=jnp.ones(4), ) - - # Update the model with the scaling parameters - updated_model: js.model.JaxSimModel = js.model.update_hw_parameters( - model, scaling_parameters + identity_scaling = ScalingFactors( + dims=jnp.ones((model.number_of_links(), 3)), + density=jnp.ones(model.number_of_links()), ) - # Export the updated model - exported_model_urdf = updated_model.export_updated_model() - assert isinstance(exported_model_urdf, str), "Exported model URDF is not a string." - - # Convert the URDF string to a ROD model - exported_model_sdf = rod.Sdf.load(exported_model_urdf, is_urdf=True) - assert isinstance( - exported_model_sdf, rod.Sdf - ), "Failed to load exported model as ROD Sdf." - assert ( - len(exported_model_sdf.models()) == 1 - ), "Exported ROD model does not contain exactly one model." - exported_model_rod = exported_model_sdf.models()[0] - - # Get the pre-scaled ROD model - pre_scaled_model_rod = rod.Sdf.load(jaxsim_model_garpez_scaled.built_from).models()[ - 0 - ] - assert isinstance( - pre_scaled_model_rod, rod.Model - ), "Failed to load pre-scaled model as ROD Model." - - # Validate that the exported model matches the pre-scaled model - for link_idx, link_name in enumerate(model.link_names()): + def get_link_by_name(model, name): try: - exported_link = next( - link for link in exported_model_rod.links() if link.name == link_name - ) - except StopIteration: + return next(link for link in model.links() if link.name == name) + except StopIteration as err: raise ValueError( - f"Link '{link_name}' not found in exported model. " - f"Available links: {[link.name for link in exported_model_rod.links()]}" - ) from None + f"Link '{name}' not found. Available links: {[l.name for l in model.links()]}" + ) from err - pre_scaled_link = next( - link for link in pre_scaled_model_rod.links() if link.name == link_name - ) + def compare_geometries(exported_link, ref_link, label=""): + exported_geom = exported_link.visual.geometry.geometry() + ref_geom = ref_link.visual.geometry.geometry() - # Compare shape dimensions - exported_geometry = exported_link.visual.geometry.geometry() - pre_scaled_geometry = pre_scaled_link.visual.geometry.geometry() - - # Ensure both geometries have the same attributes for comparison - exported_values = jnp.array( - [ - getattr(exported_geometry, attr, 0) - for attr in vars(exported_geometry) - if hasattr(pre_scaled_geometry, attr) - ] - ) - pre_scaled_values = jnp.array( - [ - getattr(pre_scaled_geometry, attr, 0) - for attr in vars(pre_scaled_geometry) - if hasattr(exported_geometry, attr) - ] - ) - - assert jnp.allclose(exported_values, pre_scaled_values, atol=1e-6), ( - f"Mismatch in geometry dimensions for link {link_name}: " - f"expected {pre_scaled_values}, got {exported_values}" - ) + attrs = [attr for attr in vars(exported_geom) if hasattr(ref_geom, attr)] + exported_vals = jnp.array([getattr(exported_geom, attr) for attr in attrs]) + ref_vals = jnp.array([getattr(ref_geom, attr) for attr in attrs]) + assert jnp.allclose( + exported_vals, ref_vals, atol=1e-6 + ), f"Geometry mismatch in {label} model." - # Compare mass + def compare_mass_and_inertia(exported_link, ref_link, label=""): assert exported_link.inertial.mass == pytest.approx( - pre_scaled_link.inertial.mass, abs=1e-4 - ), ( - f"Mismatch in mass for link {link_name}: " - f"expected {pre_scaled_link.inertial.mass}, got {exported_link.inertial.mass}" - ) - - # Compare inertia tensors + ref_link.inertial.mass, abs=1e-4 + ), f"Mass mismatch in {label} model." assert jnp.allclose( exported_link.inertial.inertia.matrix(), - pre_scaled_link.inertial.inertia.matrix(), + ref_link.inertial.inertia.matrix(), atol=1e-4, - ), ( - f"Mismatch in inertia tensor for link {link_name}: " - f"expected {pre_scaled_link.inertial.inertia.matrix()}, " - f"got {exported_link.inertial.inertia.matrix()}" + ), f"Inertia matrix mismatch in {label} model." + + def compare_collisions(exported_link, ref_link, label=""): + geom_types = ["box", "sphere", "cylinder"] + for geom_type in geom_types: + exp_geom = getattr(exported_link.collision.geometry, geom_type) + ref_geom = getattr(ref_link.collision.geometry, geom_type) + if ref_geom is not None: + if geom_type == "box": + assert jnp.allclose( + jnp.array(exp_geom.size), jnp.array(ref_geom.size), atol=1e-6 + ) + elif geom_type == "sphere": + assert jnp.isclose(exp_geom.radius, ref_geom.radius, atol=1e-6) + elif geom_type == "cylinder": + assert jnp.isclose(exp_geom.radius, ref_geom.radius, atol=1e-6) + assert jnp.isclose(exp_geom.length, ref_geom.length, atol=1e-6) + return + pytest.skip( + f"Collision geometry type for link {exported_link.name} not supported." ) + def validate_model(updated_model, ref_model, label): + + urdf = updated_model.export_updated_model() + assert isinstance(urdf, str), f"{label}: Exported URDF is not a string." + + exported_sdf = rod.Sdf.load(urdf, is_urdf=True) + assert ( + len(exported_sdf.models()) == 1 + ), f"{label}: Exported model does not contain exactly one ROD model." + + exported_model = exported_sdf.models()[0] + + for link_name in model.link_names(): + + exported_link = get_link_by_name(exported_model, link_name) + ref_link = get_link_by_name(ref_model, link_name) + + compare_geometries(exported_link, ref_link, label=label) + + compare_mass_and_inertia(exported_link, ref_link, label=label) + + compare_collisions(exported_link, ref_link, label=label) + + # Test both scaled and identity-scaled updates + for scaling, label in [ + (scaling_parameters, "SCALED"), + (identity_scaling, "IDENTITY SCALED"), + ]: + # Load reference ROD model + if label == "IDENTITY SCALED": + ref_model = rod.Sdf.load(jaxsim_model_garpez.built_from).models()[0] + else: + ref_model = rod.Sdf.load(jaxsim_model_garpez_scaled.built_from).models()[0] + + updated_model = js.model.update_hw_parameters(model, scaling) + validate_model(updated_model, ref_model, label) + def test_hw_parameters_optimization(jaxsim_model_garpez: js.model.JaxSimModel): """