diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 71840e03e..494d85c1b 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -52,7 +52,6 @@ def semi_implicit_euler_integration( W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=W_ω_WB, - omega_in_body_fixed=False, ).squeeze() W_p_B = data.base_position + dt * W_ṗ_B diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 8c8a949b6..1969933f7 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -128,7 +128,6 @@ def system_position_dynamics( W_Q̇_B = Quaternion.derivative( quaternion=W_Q_B, omega=W_ω_WB, - omega_in_body_fixed=False, K=baumgarte_quaternion_regularization, ).squeeze() diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index d5a58869c..3c88ebfa5 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -68,7 +68,6 @@ def from_dcm(dcm: jtp.Matrix) -> jtp.Vector: def derivative( quaternion: jtp.Vector, omega: jtp.Vector, - omega_in_body_fixed: bool = False, K: float = 0.1, ) -> jtp.Vector: """ @@ -77,59 +76,48 @@ def derivative( Args: quaternion: Quaternion in XYZW representation. omega: Angular velocity vector. - omega_in_body_fixed (bool): Whether the angular velocity is in the body-fixed frame. K (float): A scaling factor. Returns: The derivative of the quaternion. """ ω = omega.squeeze() - quaternion = quaternion.squeeze() - - def Q_body(q: jtp.Vector) -> jtp.Matrix: - qw, qx, qy, qz = q - - return jnp.array( - [ - [qw, -qx, -qy, -qz], - [qx, qw, -qz, qy], - [qy, qz, qw, -qx], - [qz, -qy, qx, qw], - ] - ) - - def Q_inertial(q: jtp.Vector) -> jtp.Matrix: - qw, qx, qy, qz = q - - return jnp.array( - [ - [qw, -qx, -qy, -qz], - [qx, qw, qz, -qy], - [qy, -qz, qw, qx], - [qz, qy, -qx, qw], - ] - ) - - Q = jax.lax.cond( - pred=omega_in_body_fixed, - true_fun=Q_body, - false_fun=Q_inertial, - operand=quaternion, + q = quaternion.squeeze() + + # Construct pure quaternion: (scalar damping term, angular velocity components) + ω_quat = jnp.hstack([K * safe_norm(ω) * (1 - safe_norm(quaternion)), ω]) + + # Quaternion multiplication using index tables. + # This approach avoids using the explicit quaternion multiplication formula + # by encoding the necessary element-wise products and signs via indexed operations. + # Given two quaternions q and w, their Hamilton product q ⊗ w can be written + # as a combination of q[i] * w[j] terms with appropriate signs. + # i_idx and j_idx define which elements of the outer product q ⊗ w to select. + # For example, i_idx[1][2] = 2 and j_idx[1][2] = 3 means: take q[2] * w[3] for this term. + # sign_matrix[i][j] gives the sign (+1 or -1) to apply to each q[i] * w[j] term, + # depending on quaternion multiplication rules. + # This indexed summation reproduces the Hamilton product of quaternions in a + # vectorized way, and is suitable for use with JAX. + i_idx = jnp.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]]) + j_idx = jnp.array([[0, 1, 2, 3], [1, 0, 3, 2], [2, 0, 1, 3], [3, 0, 2, 1]]) + sign_matrix = jnp.array( + [ + [1, -1, -1, -1], + [1, 1, 1, -1], + [1, 1, 1, -1], + [1, 1, 1, -1], + ] ) - norm_ω = safe_norm(ω) + # Compute quaternion derivative via Einstein summation + q_outer = jnp.outer(q, ω_quat) - qd = 0.5 * ( - Q - @ jnp.hstack( - [ - K * norm_ω * (1 - safe_norm(quaternion)), - ω, - ] - ) + Qd = jnp.sum( + sign_matrix * q_outer[..., i_idx, j_idx], + axis=-1, ) - return jnp.vstack(qd) + return 0.5 * Qd @staticmethod def integration( diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index e97030524..2c480598f 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -267,7 +267,6 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: W_Q̇_B = Quaternion.derivative( quaternion=data.base_orientation, omega=B_ω_WB, - omega_in_body_fixed=True, K=0.0, ).squeeze() diff --git a/tests/test_api_link.py b/tests/test_api_link.py index ee7b3c965..0f653abd2 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -357,7 +357,6 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=B_ω_WB, - omega_in_body_fixed=True, K=0.0, ).squeeze() diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 27a0775d5..41993c64c 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -458,7 +458,6 @@ def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=B_ω_WB, - omega_in_body_fixed=True, K=0.0, ).squeeze()