Skip to content

Add methods to average quaternions #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ dependencies = [
[project.optional-dependencies]
mapping = [
"scipy>=1.7",
"scikit-learn"
]
tests = [
"pytest",
]

[project.urls]
Expand All @@ -44,11 +48,11 @@ ignore = [
"COM812", # Conflicts with ruff formatter
"B904", # We don't always want to raise from another exception
"PT011", # Checking exception messages is undesirable
"S101", # Useful in testing
]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["S101"]
"benchmarks/*.ipynb" = ["D103", "E501", "C400"]
2 changes: 2 additions & 0 deletions rowan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
log,
log10,
logb,
mean,
multiply,
norm,
normalize,
Expand Down Expand Up @@ -81,6 +82,7 @@
"log",
"logb",
"log10",
"mean",
"multiply",
"norm",
"normalize",
Expand Down
38 changes: 38 additions & 0 deletions rowan/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,44 @@ def power(q, n):
return powers


def mean(q, weights=None):
r"""Compute the mean of an array of quaternions.

This algorithm is based on :cite:`Markley 2007`, and computes the (weighted)
average quaternion of an input array via a maximum likelihood method. Intuitively,
this method computes the member of SO(3) that is "most likely" to represent the
inputs. For a more rigorous understanding, consult the original work.

Args:
q ((:, 4) :class:`numpy.ndarray`): Array of quaternions.
weights ((:,) :class:`numpy.ndarray`) | None:
Scalar weight for each quaternion. Default value = None, which assumes
all weights are 1.

Returns:
(4, ) :class:`numpy.ndarray`: Mean of ``q``.

Example::

>>> rowan.mean([[1, 0, 0, 0], [-1, 0, 0, 0]])
array([1, 0, 0, 0])
"""
# NOTE: Markley takes quaternions as columns [xyzw].T, so transposes are flipped
q = np.atleast_2d(q)
if len(q.shape) != 2:
raise ValueError("Mean must be taken along the 0th axis of an (N, 4) array.")

M = (q.T @ q) if weights is None else (q.T @ (q * weights[:, None]))
np.testing.assert_allclose(
M,
M.T,
atol=1e-12,
err_msg="Matrix is not symmetric! eigh is not valid for this calculation",
)
# TODO: should we update/refine the method with Weiszfelt?
return np.linalg.eigh(M)[1][:, -1] # eigh returns eigenvectors sorted by eigenvalue


def conjugate(q):
r"""Conjugates an array of quaternions.

Expand Down
3 changes: 1 addition & 2 deletions rowan/mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,7 @@ def icp( # noqa: C901
except ImportError:
raise ImportError(
"Running without unique_match requires "
"scikit-learn. Please install sklearn and try "
"again.",
"scikit-learn. Please install scikit-learn and try again.",
)

# Copy points so we have originals available.
Expand Down
30 changes: 30 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,33 @@ def test_isnan(self):
quats = np.random.random_sample(shape)
assert rowan.isnan(quats).shape == quats.shape[:-1]
assert not np.any(rowan.isnan(quats))

def test_mean(self, N=128):
"""Test mean taken between quaternions."""
rng = np.random.default_rng(seed=0)
qs = rowan.random.rand(N)
# Verify mean of one quaternion (or duplicates of the same quat) == q_input
for q in qs:
for n in [1, 2, 3]:
assert rowan.isclose(q, rowan.mean([q] * n)) or rowan.isclose(
q, -rowan.mean([q] * n)
)
assert rowan.isclose(
q, rowan.mean([q] * n, weights=np.ones(n))
) or rowan.isclose(q, -rowan.mean([q] * n, weights=np.ones(n)))

def mean_two_quats(q0, q1, w0=1, w1=1):
"""Compute the maximum-likelihood mean of two quaternions in closed form."""
z = np.sqrt(np.square(w0 - w1) + 4 * w0 * w1 * np.square(np.dot(q0, q1)))
s0 = np.sqrt((w0 * (w0 - w1 + z)) / (z * (w0 + w1 + z)))
s1 = np.sqrt((w1 * (w1 - w0 + z)) / (z * (w0 + w1 + z)))
return s0 * q0 + np.sign(np.dot(q0, q1)) * s1 * q1

# Split list of quaternions in half and zip into pairs
for w in [np.ones(2), rng.random(2)]:
for q0, q1 in zip(qs[: N // 2, :], qs[N // 2 :, :]):
assert rowan.isclose(
rowan.mean([q0, q1], weights=w), mean_two_quats(q0, q1, w[0], w[1])
) or rowan.isclose(
rowan.mean([q0, q1], weights=w), -mean_two_quats(q0, q1, w[0], w[1])
)