Skip to content

Commit ac3334a

Browse files
committed
initial add
1 parent 0e9bc82 commit ac3334a

File tree

3 files changed

+184
-0
lines changed

3 files changed

+184
-0
lines changed

src/rtanalysis/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Analysis of response data to estimate accuracy from response time (RT)."""

src/rtanalysis/generate_testdata.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Utility module for handling the generation of test data."""
2+
3+
import numpy as np
4+
import pandas as pd
5+
import scipy.stats
6+
7+
8+
def generate_test_df(mean_rt, sd_rt, mean_accuracy, n=100):
9+
"""Generate simulated RT data for testing.
10+
11+
Parameters
12+
----------
13+
mean_rt : float
14+
Mean response time for correct trials
15+
sd_rt : float
16+
Standard deviation of the response time in correct trials
17+
mean_accuracy : float
18+
Mean accuracy across trials (between 0 and 1)
19+
n : int, optional
20+
Number of observations to generate, by default 100
21+
22+
Returns
23+
-------
24+
pd.DataFrame
25+
Generated mock data
26+
"""
27+
rt = pd.Series(scipy.stats.weibull_min.rvs(2, loc=1, size=n))
28+
29+
# get random accuracy values and threshold for intended proportion
30+
accuracy_continuous = np.random.rand(n)
31+
accuracy = pd.Series(
32+
accuracy_continuous
33+
< scipy.stats.scoreatpercentile(accuracy_continuous, 100 * mean_accuracy)
34+
)
35+
36+
# scale the correct RTs only
37+
rt_correct = rt.mask(~accuracy)
38+
rt_scaled = scale_values(rt_correct, mean_rt, sd_rt)
39+
40+
# NB: .where() replaces values where the condition is False
41+
rt_scaled_with_inaccurate_rts = rt_scaled.where(accuracy, rt)
42+
43+
return pd.DataFrame({"rt": rt_scaled_with_inaccurate_rts, "accuracy": accuracy})
44+
45+
46+
def scale_values(values, mean, sd):
47+
"""Scale values by given mean/SD.
48+
49+
Parameters
50+
----------
51+
values : array-like
52+
Values to be scaled
53+
mean : float
54+
Target mean
55+
sd : float
56+
Target standard deviation
57+
58+
Returns
59+
-------
60+
array-like
61+
Scaled values
62+
"""
63+
values = values * (sd / np.std(values))
64+
values = (values - np.mean(values)) + mean
65+
return values

src/rtanalysis/rtanalysis.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Example class to analyze reaction times.
2+
3+
Given a data frame with RT and accuracy, compute mean RT for correct trials and
4+
mean accuracy.
5+
"""
6+
7+
import pandas as pd
8+
9+
10+
class RTAnalysis:
11+
"""Response time (RT) analysis."""
12+
13+
def __init__(self, outlier_cutoff_sd=None):
14+
"""Initialize a new RTAnalysis instance.
15+
16+
Parameters
17+
----------
18+
outlier_cutoff_sd : float, optional
19+
Standard deviation cutoff for long RT outliers, by default None
20+
"""
21+
self.outlier_cutoff_sd = outlier_cutoff_sd
22+
self.mean_rt_ = None
23+
self.mean_accuracy_ = None
24+
25+
def fit(self, rt, accuracy, verbose=True):
26+
"""Fit response time to accuracy.
27+
28+
Parameters
29+
----------
30+
rt : pd.Series
31+
Response time per trial
32+
accuracy : pd.Series
33+
Accuracy per trial
34+
verbose : bool, optional
35+
Whether to print verbose output or not, by default True
36+
37+
Raises
38+
------
39+
ValueError
40+
RT/accuracy length mismatch
41+
ValueError
42+
Accuracy is 0
43+
"""
44+
rt = self._ensure_series_type(rt)
45+
accuracy = self._ensure_series_type(accuracy)
46+
47+
self._validate_length(rt, accuracy)
48+
49+
# Ensure that accuracy values are boolean.
50+
assert accuracy.dtype == bool
51+
52+
rt = self.reject_outlier_rt(rt, verbose=verbose)
53+
54+
self.mean_accuracy_ = accuracy.mean()
55+
try:
56+
assert self.mean_accuracy_ > 0
57+
except AssertionError as e:
58+
raise ValueError("Accuracy is zero!") from e
59+
60+
rt = rt.mask(~accuracy)
61+
self.mean_rt_ = rt.mean()
62+
63+
try:
64+
assert rt.min() > 0
65+
except:
66+
raise ValueError("negative response times found")
67+
if verbose:
68+
print(f"mean RT: {self.mean_rt_}")
69+
print(f"mean accuracy: {self.mean_accuracy_}")
70+
71+
@staticmethod
72+
def _validate_length(rt, accuracy):
73+
"""Validate response time and accuracy series lengths.
74+
75+
Parameters
76+
----------
77+
rt : pd.Series
78+
Response time values
79+
accuracy : _type_
80+
Accuracy values
81+
82+
Raises
83+
------
84+
ValueError
85+
Length mismatch
86+
"""
87+
same_length = rt.shape[0] == accuracy.shape[0]
88+
try:
89+
assert same_length
90+
except AssertionError as e:
91+
raise ValueError("RT and accuracy must be the same length!") from e
92+
93+
@staticmethod
94+
def _ensure_series_type(var):
95+
"""Return variable as a pandas Series.
96+
97+
Parameters
98+
----------
99+
var : Iterable
100+
Variable to be converted
101+
102+
Returns
103+
-------
104+
pd.Series
105+
Variable values as a pandas Series
106+
"""
107+
if not isinstance(var, pd.Series):
108+
var = pd.Series(var)
109+
return var
110+
111+
def reject_outlier_rt(self, rt, verbose=True):
112+
if self.outlier_cutoff_sd is None:
113+
return rt
114+
cutoff = rt.std() * self.outlier_cutoff_sd
115+
if verbose:
116+
n_excluded = (rt > cutoff).sum()
117+
print(f"Outlier rejection excluded {n_excluded} trials.")
118+
return rt.mask(rt > cutoff)

0 commit comments

Comments
 (0)