Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def build(
The optional name of the model overriding the physics model name.
contact_model:
The contact model to consider.
If not specified, a relaxed-constraints rigid contacts model is used.
If not specified, a soft contact model is used.
contact_params: The parameters of the contact model.
actuation_params: The parameters of the actuation model.
integrator: The integrator to use for the simulation.
Expand Down Expand Up @@ -282,7 +282,7 @@ def build(
contact_model = (
contact_model
if contact_model is not None
else jaxsim.rbda.contacts.RelaxedRigidContacts.build()
else jaxsim.rbda.contacts.SoftContacts.build()
)

if contact_params is None:
Expand Down Expand Up @@ -2344,9 +2344,6 @@ def update_hw_parameters(

Returns:
The updated JaxSimModel object with modified hardware parameters.

Note:
This function can be used only with models using Relax-Rigid contact model.
"""

kin_dyn_params: KinDynParameters = model.kin_dyn_parameters
Expand Down
4 changes: 0 additions & 4 deletions src/jaxsim/rbda/contacts/soft.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,12 @@ class SoftContacts(common.ContactModel):
@classmethod
def build(
cls: type[Self],
model: js.model.JaxSimModel | None = None,
**kwargs,
) -> Self:
"""
Create a `SoftContacts` instance with specified parameters.

Args:
model:
The robot model considered by the contact model.
If passed, it is used to estimate good default parameters.
**kwargs: Additional parameters to pass to the contact model.

Returns:
Expand Down
21 changes: 17 additions & 4 deletions tests/test_api_model_hw_parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jaxsim.api as js
from jaxsim.api.kin_dyn_parameters import HwLinkMetadata, ScalingFactors
from jaxsim.rbda.contacts import SoftContactsParams


def test_update_hw_link_parameters(jaxsim_model_garpez: js.model.JaxSimModel):
Expand Down Expand Up @@ -401,6 +402,18 @@ def test_hw_parameters_collision_scaling(
# Define the scaling factor for the model
scaling_factor = 5.0

# Recompute K and D, since the mass is scaled by scaling_factor^3
# and the expected static compression of the terrain is approximately
# proportional to mass/K and divided by the 4 contact points.
K = model.contact_params.K * (scaling_factor**2)

# Strongly overdamped, to avoid oscillations due to the high mass
# and the low penetration allowed.
D = 8 * jnp.sqrt(K)

with model.editable(validate=False) as model:
model.contact_params = SoftContactsParams(K=K, D=D)

# Define the nominal radius of the sphere
nominal_height = model.kin_dyn_parameters.hw_link_metadata.geometry[0, 2]

Expand All @@ -413,6 +426,9 @@ def test_hw_parameters_collision_scaling(
# Update the model with the scaling parameters
updated_model = js.model.update_hw_parameters(model, scaling_parameters)

# Compute the expected height (nominal radius * scaling factor)
expected_height = nominal_height * scaling_factor / 2

# Simulate the box falling under gravity
data = js.data.JaxSimModelData.build(
model=updated_model,
Expand All @@ -424,7 +440,7 @@ def test_hw_parameters_collision_scaling(
base_position=jnp.array(
[
*jax.random.uniform(subkey, shape=(2,)),
nominal_height * scaling_factor + 0.01,
expected_height + 0.05,
]
),
)
Expand All @@ -440,9 +456,6 @@ def test_hw_parameters_collision_scaling(
# Get the final height of the box's base
updated_base_height = data.base_position[2]

# Compute the expected height (nominal radius * scaling factor)
expected_height = nominal_height * scaling_factor / 2

# Assert that the box settles at the expected height
assert jnp.isclose(
updated_base_height, expected_height, atol=1e-3
Expand Down
2 changes: 1 addition & 1 deletion tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_ad_soft_contacts(
model = jaxsim_models_types

with model.editable(validate=False) as model:
model.contact_model = jaxsim.rbda.contacts.SoftContacts.build(model=model)
model.contact_model = jaxsim.rbda.contacts.SoftContacts.build()

_, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4)
p = jax.random.uniform(subkey1, shape=(3,), minval=-1)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ def test_simulation_with_kinematic_constraints_4_bar_linkage(
data_t0 = js.data.JaxSimModelData.build(
model=model,
velocity_representation=VelRepr.Inertial,
base_position=jnp.array([0.0, 0.0, 0.10]),
)

# ====
Expand Down Expand Up @@ -607,7 +608,7 @@ def test_simulation_with_kinematic_constraints_4_bar_linkage(
pos1 = H_frame1[:3, 3]
pos2 = H_frame2[:3, 3]
assert pos1 == pytest.approx(
pos2, abs=1e-6
pos2, abs=1e-5
), f"Frame position mismatch. pos1={pos1}, pos2={pos2}, diff={pos1 - pos2}"

# Orientation check
Expand Down