Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 85 additions & 75 deletions tests/test_api_model_hw_parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ def test_model_scaling_against_rod(
assert jnp.allclose(scaled_metadata.L_H_G, pre_scaled_metadata.L_H_G, atol=1e-6)
assert jnp.allclose(scaled_metadata.L_H_vis, pre_scaled_metadata.L_H_vis, atol=1e-6)

# Compare collidable points positions
assert jnp.allclose(
jaxsim_model_garpez_scaled.kin_dyn_parameters.contact_parameters.point,
updated_model.kin_dyn_parameters.contact_parameters.point,
atol=1e-6,
)


def test_update_hw_parameters_vmap(
jaxsim_model_garpez: js.model.JaxSimModel,
Expand Down Expand Up @@ -209,94 +216,97 @@ def test_export_updated_model(
),
density=jnp.ones(4),
)

# Update the model with the scaling parameters
updated_model: js.model.JaxSimModel = js.model.update_hw_parameters(
model, scaling_parameters
identity_scaling = ScalingFactors(
dims=jnp.ones((model.number_of_links(), 3)),
density=jnp.ones(model.number_of_links()),
)

# Export the updated model
exported_model_urdf = updated_model.export_updated_model()
assert isinstance(exported_model_urdf, str), "Exported model URDF is not a string."

# Convert the URDF string to a ROD model
exported_model_sdf = rod.Sdf.load(exported_model_urdf, is_urdf=True)
assert isinstance(
exported_model_sdf, rod.Sdf
), "Failed to load exported model as ROD Sdf."
assert (
len(exported_model_sdf.models()) == 1
), "Exported ROD model does not contain exactly one model."
exported_model_rod = exported_model_sdf.models()[0]

# Get the pre-scaled ROD model
pre_scaled_model_rod = rod.Sdf.load(jaxsim_model_garpez_scaled.built_from).models()[
0
]
assert isinstance(
pre_scaled_model_rod, rod.Model
), "Failed to load pre-scaled model as ROD Model."

# Validate that the exported model matches the pre-scaled model
for link_idx, link_name in enumerate(model.link_names()):
def get_link_by_name(model, name):
try:
exported_link = next(
link for link in exported_model_rod.links() if link.name == link_name
)
except StopIteration:
return next(link for link in model.links() if link.name == name)
except StopIteration as err:
raise ValueError(
f"Link '{link_name}' not found in exported model. "
f"Available links: {[link.name for link in exported_model_rod.links()]}"
) from None
f"Link '{name}' not found. Available links: {[l.name for l in model.links()]}"
) from err

pre_scaled_link = next(
link for link in pre_scaled_model_rod.links() if link.name == link_name
)
def compare_geometries(exported_link, ref_link, label=""):
exported_geom = exported_link.visual.geometry.geometry()
ref_geom = ref_link.visual.geometry.geometry()

# Compare shape dimensions
exported_geometry = exported_link.visual.geometry.geometry()
pre_scaled_geometry = pre_scaled_link.visual.geometry.geometry()

# Ensure both geometries have the same attributes for comparison
exported_values = jnp.array(
[
getattr(exported_geometry, attr, 0)
for attr in vars(exported_geometry)
if hasattr(pre_scaled_geometry, attr)
]
)
pre_scaled_values = jnp.array(
[
getattr(pre_scaled_geometry, attr, 0)
for attr in vars(pre_scaled_geometry)
if hasattr(exported_geometry, attr)
]
)

assert jnp.allclose(exported_values, pre_scaled_values, atol=1e-6), (
f"Mismatch in geometry dimensions for link {link_name}: "
f"expected {pre_scaled_values}, got {exported_values}"
)
attrs = [attr for attr in vars(exported_geom) if hasattr(ref_geom, attr)]
exported_vals = jnp.array([getattr(exported_geom, attr) for attr in attrs])
ref_vals = jnp.array([getattr(ref_geom, attr) for attr in attrs])
assert jnp.allclose(
exported_vals, ref_vals, atol=1e-6
), f"Geometry mismatch in {label} model."

# Compare mass
def compare_mass_and_inertia(exported_link, ref_link, label=""):
assert exported_link.inertial.mass == pytest.approx(
pre_scaled_link.inertial.mass, abs=1e-4
), (
f"Mismatch in mass for link {link_name}: "
f"expected {pre_scaled_link.inertial.mass}, got {exported_link.inertial.mass}"
)

# Compare inertia tensors
ref_link.inertial.mass, abs=1e-4
), f"Mass mismatch in {label} model."
assert jnp.allclose(
exported_link.inertial.inertia.matrix(),
pre_scaled_link.inertial.inertia.matrix(),
ref_link.inertial.inertia.matrix(),
atol=1e-4,
), (
f"Mismatch in inertia tensor for link {link_name}: "
f"expected {pre_scaled_link.inertial.inertia.matrix()}, "
f"got {exported_link.inertial.inertia.matrix()}"
), f"Inertia matrix mismatch in {label} model."

def compare_collisions(exported_link, ref_link, label=""):
geom_types = ["box", "sphere", "cylinder"]
for geom_type in geom_types:
exp_geom = getattr(exported_link.collision.geometry, geom_type)
ref_geom = getattr(ref_link.collision.geometry, geom_type)
if ref_geom is not None:
if geom_type == "box":
assert jnp.allclose(
jnp.array(exp_geom.size), jnp.array(ref_geom.size), atol=1e-6
)
elif geom_type == "sphere":
assert jnp.isclose(exp_geom.radius, ref_geom.radius, atol=1e-6)
elif geom_type == "cylinder":
assert jnp.isclose(exp_geom.radius, ref_geom.radius, atol=1e-6)
assert jnp.isclose(exp_geom.length, ref_geom.length, atol=1e-6)
return
pytest.skip(
f"Collision geometry type for link {exported_link.name} not supported."
)

def validate_model(updated_model, ref_model, label):

urdf = updated_model.export_updated_model()
assert isinstance(urdf, str), f"{label}: Exported URDF is not a string."

exported_sdf = rod.Sdf.load(urdf, is_urdf=True)
assert (
len(exported_sdf.models()) == 1
), f"{label}: Exported model does not contain exactly one ROD model."

exported_model = exported_sdf.models()[0]

for link_name in model.link_names():

exported_link = get_link_by_name(exported_model, link_name)
ref_link = get_link_by_name(ref_model, link_name)

compare_geometries(exported_link, ref_link, label=label)

compare_mass_and_inertia(exported_link, ref_link, label=label)

compare_collisions(exported_link, ref_link, label=label)

# Test both scaled and identity-scaled updates
for scaling, label in [
(scaling_parameters, "SCALED"),
(identity_scaling, "IDENTITY SCALED"),
]:
# Load reference ROD model
if label == "IDENTITY SCALED":
ref_model = rod.Sdf.load(jaxsim_model_garpez.built_from).models()[0]
else:
ref_model = rod.Sdf.load(jaxsim_model_garpez_scaled.built_from).models()[0]

updated_model = js.model.update_hw_parameters(model, scaling)
validate_model(updated_model, ref_model, label)


def test_hw_parameters_optimization(jaxsim_model_garpez: js.model.JaxSimModel):
"""
Expand Down
Loading