Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def semi_implicit_euler_integration(
W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation,
omega=W_ω_WB,
omega_in_body_fixed=False,
).squeeze()

W_p_B = data.base_position + dt * W_ṗ_B
Expand Down
1 change: 0 additions & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def system_position_dynamics(
W_Q̇_B = Quaternion.derivative(
quaternion=W_Q_B,
omega=W_ω_WB,
omega_in_body_fixed=False,
K=baumgarte_quaternion_regularization,
).squeeze()

Expand Down
74 changes: 31 additions & 43 deletions src/jaxsim/math/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
def derivative(
quaternion: jtp.Vector,
omega: jtp.Vector,
omega_in_body_fixed: bool = False,
K: float = 0.1,
) -> jtp.Vector:
"""
Expand All @@ -77,59 +76,48 @@ def derivative(
Args:
quaternion: Quaternion in XYZW representation.
omega: Angular velocity vector.
omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame.
K (float): A scaling factor.

Returns:
The derivative of the quaternion.
"""
ω = omega.squeeze()
quaternion = quaternion.squeeze()

def Q_body(q: jtp.Vector) -> jtp.Matrix:
qw, qx, qy, qz = q

return jnp.array(
[
[qw, -qx, -qy, -qz],
[qx, qw, -qz, qy],
[qy, qz, qw, -qx],
[qz, -qy, qx, qw],
]
)

def Q_inertial(q: jtp.Vector) -> jtp.Matrix:
qw, qx, qy, qz = q

return jnp.array(
[
[qw, -qx, -qy, -qz],
[qx, qw, qz, -qy],
[qy, -qz, qw, qx],
[qz, qy, -qx, qw],
]
)

Q = jax.lax.cond(
pred=omega_in_body_fixed,
true_fun=Q_body,
false_fun=Q_inertial,
operand=quaternion,
q = quaternion.squeeze()

# Construct pure quaternion: (scalar damping term, angular velocity components)
ω_quat = jnp.hstack([K * safe_norm(ω) * (1 - safe_norm(quaternion)), ω])

# Quaternion multiplication using index tables.
# This approach avoids using the explicit quaternion multiplication formula
# by encoding the necessary element-wise products and signs via indexed operations.
# Given two quaternions q and w, their Hamilton product q ⊗ w can be written
# as a combination of q[i] * w[j] terms with appropriate signs.
# i_idx and j_idx define which elements of the outer product q ⊗ w to select.
# For example, i_idx[1][2] = 2 and j_idx[1][2] = 3 means: take q[2] * w[3] for this term.
# sign_matrix[i][j] gives the sign (+1 or -1) to apply to each q[i] * w[j] term,
# depending on quaternion multiplication rules.
# This indexed summation reproduces the Hamilton product of quaternions in a
# vectorized way, and is suitable for use with JAX.
i_idx = jnp.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]])
j_idx = jnp.array([[0, 1, 2, 3], [1, 0, 3, 2], [2, 0, 1, 3], [3, 0, 2, 1]])
sign_matrix = jnp.array(
[
[1, -1, -1, -1],
[1, 1, 1, -1],
[1, 1, 1, -1],
[1, 1, 1, -1],
]
)

norm_ω = safe_norm(ω)
# Compute quaternion derivative via Einstein summation
q_outer = jnp.outer(q, ω_quat)

qd = 0.5 * (
Q
@ jnp.hstack(
[
K * norm_ω * (1 - safe_norm(quaternion)),
ω,
]
)
Qd = jnp.sum(
sign_matrix * q_outer[..., i_idx, j_idx],
axis=-1,
)

return jnp.vstack(qd)
return 0.5 * Qd

@staticmethod
def integration(
Expand Down
1 change: 0 additions & 1 deletion tests/test_api_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:
W_Q̇_B = Quaternion.derivative(
quaternion=data.base_orientation,
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()

Expand Down
1 change: 0 additions & 1 deletion tests/test_api_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:
W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation,
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()

Expand Down
1 change: 0 additions & 1 deletion tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array:
W_Q̇_B = jaxsim.math.Quaternion.derivative(
quaternion=data.base_orientation,
omega=B_ω_WB,
omega_in_body_fixed=True,
K=0.0,
).squeeze()

Expand Down