Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions afqinsight/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import pandas as pd
from scipy.optimize import linear_sum_assignment

__all__ = ["mahalonobis_dist_match"]
__all__ = ["mahalanobis_dist_match"]


def mahalonobis_dist_match(
def mahalanobis_dist_match(
data=None, test=None, ctrl=None, status_col=None, feature_cols=None, threshold=0.2
):
"""
Expand Down Expand Up @@ -77,12 +77,12 @@ def mahalonobis_dist_match(
(
"There are NaNs in test or ctrl data. "
"Please replace these NaNs using interpolation or by removing "
"the subjects with NaNs before calling mahalonobis_dist_match. "
"the subjects with NaNs before calling mahalanobis_dist_match. "
)
)

# calculate Mahalonobis distance between test and control
nbrs = _mahalonobis_dist(test, ctrl)
# calculate Mahalanobis distance between test and control
nbrs = _mahalanobis_dist(test, ctrl)

# assign neighbors using Munkres algorithm
row_ind, col_ind = linear_sum_assignment(nbrs)
Expand Down Expand Up @@ -118,9 +118,9 @@ def mahalonobis_dist_match(
return data.iloc[all_idx]


def _mahalonobis_dist(arr1, arr2):
def _mahalanobis_dist(arr1, arr2):
"""
Calculate the Mahalonobis distance between two 2d arrays along the first axis.
Calculate the Mahalanobis distance between two 2d arrays along the first axis.

Parameters
----------
Expand All @@ -132,7 +132,7 @@ def _mahalonobis_dist(arr1, arr2):
Returns
-------
nbrs : array-like of shape (n_samples1, n_samples2)
Mahalonobis distance between each sample in the
Mahalanobis distance between each sample in the
two input arrays.
"""
v_inv = np.linalg.inv(np.cov(np.concatenate((arr1, arr2), axis=0).T, ddof=0))
Expand Down
26 changes: 13 additions & 13 deletions afqinsight/tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
data[f"feature_{ii}"] = [*test_features[:, ii], *ctrl_features[:, ii]]


def test_mahalonobis_dist():
nbrs = aim._mahalonobis_dist(test_features, ctrl_features)
def test_mahalanobis_dist():
nbrs = aim._mahalanobis_dist(test_features, ctrl_features)

v_inv = np.linalg.inv(
np.cov(np.concatenate((test_features, ctrl_features), axis=0).T, ddof=0)
Expand All @@ -48,8 +48,8 @@ def test_mahalonobis_dist():
assert_array_almost_equal(nbrs, nbrs_scipy)


def test_mahalonobis_dist_match_df():
matched_df = aim.mahalonobis_dist_match(
def test_mahalanobis_dist_match_df():
matched_df = aim.mahalanobis_dist_match(
data=data,
status_col="status",
feature_cols=["feature_0", "feature_1", "feature_2", "feature_3"],
Expand All @@ -64,16 +64,16 @@ def test_mahalonobis_dist_match_df():
)


def test_mahalonobis_dist_match_df_err():
def test_mahalanobis_dist_match_df_err():
with pytest.raises(ValueError): # no status column
aim.mahalonobis_dist_match(data=data)
aim.mahalanobis_dist_match(data=data)
with pytest.raises(ValueError): # status has more than two unique values
aim.mahalonobis_dist_match(data=data, status_col="feature_0")
aim.mahalanobis_dist_match(data=data, status_col="feature_0")


def test_mahalonobis_dist_match_df_feauture_none():
def test_mahalanobis_dist_match_df_feauture_none():
data_wo_extras = data.drop(columns=["eid", "age"])
matched_df = aim.mahalonobis_dist_match(
matched_df = aim.mahalanobis_dist_match(
data=data_wo_extras,
status_col="status",
threshold=1,
Expand All @@ -87,8 +87,8 @@ def test_mahalonobis_dist_match_df_feauture_none():
)


def test_mahalonobis_dist_match():
matched_df = aim.mahalonobis_dist_match(
def test_mahalanobis_dist_match():
matched_df = aim.mahalanobis_dist_match(
test=test_features, ctrl=ctrl_features, threshold=1
)

Expand All @@ -103,10 +103,10 @@ def test_mahalonobis_dist_match():
)


def test_mahalonobis_dist_match_err():
def test_mahalanobis_dist_match_err():
with pytest.raises(ValueError): # test if nan error is raised
test_features_with_nans = test_features.copy()
test_features_with_nans[0, 2] = np.nan
aim.mahalonobis_dist_match(
aim.mahalanobis_dist_match(
test=test_features_with_nans, ctrl=ctrl_features, threshold=1
)
Loading