diff --git a/src/jabs/behavior_search/behavior_search_util.py b/src/jabs/behavior_search/behavior_search_util.py index 5334b5ac..324780eb 100644 --- a/src/jabs/behavior_search/behavior_search_util.py +++ b/src/jabs/behavior_search/behavior_search_util.py @@ -114,7 +114,7 @@ def _search_behaviors_gen( ) case PredictionBehaviorSearchQuery() as pred_query: - proj_settings = project.settings_manager.project_settings + proj_settings = project.settings_manager.project_info if pred_query.behavior_label is None: behavior_dict = proj_settings.get("behavior", {}) behaviors = list(behavior_dict.keys()) diff --git a/src/jabs/classifier/classifier.py b/src/jabs/classifier/classifier.py index a38bfed3..5317b61c 100644 --- a/src/jabs/classifier/classifier.py +++ b/src/jabs/classifier/classifier.py @@ -6,22 +6,30 @@ from pathlib import Path import joblib +import matplotlib +import matplotlib.pyplot as plt import numpy as np import pandas as pd +from sklearn.calibration import CalibratedClassifierCV, calibration_curve from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier from sklearn.exceptions import InconsistentVersionWarning from sklearn.metrics import ( accuracy_score, + brier_score_loss, confusion_matrix, precision_recall_fscore_support, ) from sklearn.model_selection import LeaveOneGroupOut, train_test_split +from jabs.constants import DEFAULT_CALIBRATION_CV, DEFAULT_CALIBRATION_METHOD from jabs.project import Project, TrackLabels, load_training_data from jabs.types import ClassifierType from jabs.utils import hash_file -_VERSION = 9 +matplotlib.use("Agg") # use non-GUI backend to avoid thread warnings + + +_VERSION = 10 _classifier_choices = [ClassifierType.RANDOM_FOREST, ClassifierType.GRADIENT_BOOSTING] @@ -30,11 +38,11 @@ # we were able to import xgboost, make it available as an option: _classifier_choices.append(ClassifierType.XGBOOST) except Exception: - # we were unable to import the xgboost module. It's either not - # installed (it should be if the user used our requirements-old.txt) - # or it may have been unable to be imported due to a missing - # libomp. Either way, we won't add it to the available choices and - # we can otherwise ignore this exception + # we were unable to import the xgboost module -- possibly due to a missing + # libomp (which is not available by default on macOS). Mac users should + # install libomp via Homebrew (brew install libomp) to enable XGBoost support (this is + # detailed in the installation instructions). + # we won't add it to the available choices and we can otherwise ignore this exception _xgboost = None @@ -50,6 +58,8 @@ class Classifier: """ LABEL_THRESHOLD = 20 + TRUE_THRESHOLD = 0.5 + CALIBRATION_METHODS: typing.ClassVar[list[str]] = ["auto", "isotonic", "sigmoid"] _CLASSIFIER_NAMES: typing.ClassVar[dict] = { ClassifierType.RANDOM_FOREST: "Random Forest", @@ -60,7 +70,8 @@ class Classifier: def __init__(self, classifier=ClassifierType.RANDOM_FOREST, n_jobs=1): self._classifier_type = classifier self._classifier = None - self._project_settings = None + self._behavior_settings = None + self._jabs_settings = None self._behavior = None self._feature_names = None self._n_jobs = n_jobs @@ -91,7 +102,9 @@ def from_training_file(cls, path: Path): classifier = cls() classifier.behavior_name = behavior - classifier.set_dict_settings(loaded_training_data["settings"]) + classifier.set_behavior_settings(loaded_training_data["behavior_settings"]) + classifier._jabs_settings = loaded_training_data["jabs_settings"] + classifier_type = ClassifierType(loaded_training_data["classifier_type"]) if classifier_type in classifier.classifier_choices(): classifier.set_classifier(classifier_type) @@ -99,6 +112,7 @@ def from_training_file(cls, path: Path): print( f"Specified classifier type {classifier_type.name} is unavailable, using default: {classifier.classifier_type.name}" ) + training_features = classifier.combine_data( loaded_training_data["per_frame"], loaded_training_data["window"] ) @@ -141,10 +155,10 @@ def classifier_hash(self) -> str: return "NO HASH" @property - def project_settings(self) -> dict: - """return a copy of dictionary of project settings for this classifier""" - if self._project_settings is not None: - return dict(self._project_settings) + def behavior_settings(self) -> dict: + """return a copy of dictionary of behavior-specific settings for this classifier""" + if self._behavior_settings is not None: + return dict(self._behavior_settings) return {} @property @@ -167,6 +181,45 @@ def feature_names(self) -> list: """returns the list of feature names used when training this classifier""" return self._feature_names + @property + def calibrate_probabilities(self) -> bool: + """return whether the classifier is set to calibrate probabilities""" + if self._jabs_settings is not None: + return self._jabs_settings.get("calibrate_probabilities", False) + return False + + @staticmethod + def _choose_auto_calibration_method( + labels: np.ndarray, calibration_cv: int + ) -> tuple[str, dict]: + """Choose 'isotonic' or 'sigmoid' based on data size per calibration fold. + + Heuristic: + - Compute class counts on the *training set labels* passed in. + - Estimate per-fold calibration set size as min(pos, neg) / calibration_cv + (because CalibratedClassifierCV uses 1/cv of the train split for calibration). + - If per-fold per-class counts >= 500 ➜ 'isotonic', else 'sigmoid'. + + Returns: + (method, info_dict) where info_dict contains counts used for logging. + """ + # count positive and negative labels + pos = int(np.sum(labels == TrackLabels.Label.BEHAVIOR)) + neg = int(np.sum(labels == TrackLabels.Label.NOT_BEHAVIOR)) + min_per_class = min(pos, neg) + per_fold_per_class = max(0, min_per_class // calibration_cv) + + # Threshold for isotonic safety + threshold = 500 + method = "isotonic" if per_fold_per_class >= threshold else "sigmoid" + return method, { + "pos_total": pos, + "neg_total": neg, + "cv": calibration_cv, + "per_fold_per_class": per_fold_per_class, + "threshold": threshold, + } + @staticmethod def train_test_split(per_frame_features, window_features, label_data): """split features and labels into training and test datasets @@ -361,19 +414,22 @@ def set_project_settings(self, project: Project): if no behavior is currently set will use project defaults """ if self._behavior is None: - self._project_settings = project.get_project_defaults() + self._behavior_settings = project.get_project_defaults() else: - self._project_settings = project.settings_manager.get_behavior(self._behavior) + self._behavior_settings = project.settings_manager.get_behavior(self._behavior) - def set_dict_settings(self, settings: dict): - """assign project settings via a dict to the classifier + # grab other JABS settings from settings manager, some might be used by the classifier + self._jabs_settings = project.settings_manager.jabs_settings + + def set_behavior_settings(self, settings: dict): + """assign behavior-specific settings via a dict to the classifier Args: settings: dict of project settings. Must be same structure as project.settings_manager.get_behavior TODO: Add checks to enforce conformity to project settings """ - self._project_settings = dict(settings) + self._behavior_settings = dict(settings) def classifier_choices(self): """get the available classifier types @@ -390,7 +446,7 @@ def classifier_choices(self): """ return {d: self._CLASSIFIER_NAMES[d] for d in _classifier_choices} - def train(self, data, random_seed: int | None = None): + def train(self, data: dict, random_seed: int | None = None) -> None: """train the classifier Args: @@ -403,7 +459,7 @@ def train(self, data, random_seed: int | None = None): raises ValueError for having either unset project settings or an unset classifier """ - if self._project_settings is None: + if self._behavior_settings is None: raise ValueError("Project settings for classifier unset, cannot train classifier.") # Assume that feature names is provided, otherwise extract it from the dataframe @@ -416,32 +472,116 @@ def train(self, data, random_seed: int | None = None): features = data["training_data"] labels = data["training_labels"] # Symmetric augmentation should occur before balancing so that the class with more labels can sample from the whole set - if self._project_settings.get("symmetric_behavior", False): + if self._behavior_settings.get("symmetric_behavior", False): features, labels = self.augment_symmetric(features, labels) - if self._project_settings.get("balance_labels", False): + if self._behavior_settings.get("balance_labels", False): features, labels = self.downsample_balance(features, labels, random_seed) - if self._classifier_type == ClassifierType.RANDOM_FOREST: - self._classifier = self._fit_random_forest(features, labels, random_seed=random_seed) - elif self._classifier_type == ClassifierType.GRADIENT_BOOSTING: - self._classifier = self._fit_gradient_boost(features, labels, random_seed=random_seed) - elif _xgboost is not None and self._classifier_type == ClassifierType.XGBOOST: + # Optional probability calibration + if self.calibrate_probabilities: + # get and validate calibration settings + calibration_method = self._jabs_settings.get( + "calibration_method", DEFAULT_CALIBRATION_METHOD + ) + if calibration_method.lower() not in self.CALIBRATION_METHODS: + raise ValueError( + f"Invalid calibration method: {calibration_method}. Must be one of {self.CALIBRATION_METHODS}" + ) + calibration_cv = self._jabs_settings.get("calibration_cv", DEFAULT_CALIBRATION_CV) + + # Auto-select method if requested, always figure out what the auto method would be because some of the + # selection info is still useful for warnings/logging purposes if the user specified a method explicitly + auto_method, auto_method_info = self._choose_auto_calibration_method( + labels, calibration_cv + ) + if calibration_method.lower() == "auto": + calibration_method = auto_method + else: + # Optional safety warning: isotonic with small per-fold sets can overfit + if ( + str(calibration_method).lower() == "isotonic" + and auto_method_info["per_fold_per_class"] < auto_method_info["threshold"] + ): + warnings.warn( + ( + "Isotonic calibration selected but per-fold per-class count appears small " + f"(~{auto_method_info['per_fold_per_class']}). Consider 'sigmoid' or lowering calibration_cv." + ), + RuntimeWarning, + stacklevel=2, + ) + + # Build an unfitted base estimator + if self._classifier_type == ClassifierType.RANDOM_FOREST: + base_estimator = self._make_random_forest(random_seed=random_seed) + elif self._classifier_type == ClassifierType.GRADIENT_BOOSTING: + base_estimator = self._make_gradient_boost(random_seed=random_seed) + elif _xgboost is not None and self._classifier_type == ClassifierType.XGBOOST: + base_estimator = self._make_xgboost(random_seed=random_seed) + else: + raise ValueError("Unsupported classifier") + + # Wrap with calibrated classifier and fit + self._classifier = CalibratedClassifierCV( + estimator=base_estimator, method=calibration_method, cv=calibration_cv + ) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=FutureWarning) - self._classifier = self._fit_xgboost(features, labels, random_seed=random_seed) + self._classifier.fit(self._clean_features_for_training(features), labels) else: - raise ValueError("Unsupported classifier") + # Fit without calibration (original behavior) + if self._classifier_type == ClassifierType.RANDOM_FOREST: + self._classifier = self._fit_random_forest( + features, labels, random_seed=random_seed + ) + elif self._classifier_type == ClassifierType.GRADIENT_BOOSTING: + self._classifier = self._fit_gradient_boost( + features, labels, random_seed=random_seed + ) + elif _xgboost is not None and self._classifier_type == ClassifierType.XGBOOST: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + self._classifier = self._fit_xgboost(features, labels, random_seed=random_seed) + else: + raise ValueError("Unsupported classifier") # Classifier may have been re-used from a prior training, blank the logging attributes self._classifier_file = None self._classifier_hash = None self._classifier_source = None + def _clean_features_for_training(self, features: pd.DataFrame): + """Clean feature matrix prior to fitting based on classifier type. + + For XGBoost, only replace +/- inf with 0 (XGBoost can handle NaN). + For sklearn tree models, also fill NaNs with 0. + """ + if self._classifier_type == ClassifierType.XGBOOST: + return features.replace([np.inf, -np.inf], 0) + return features.replace([np.inf, -np.inf], 0).fillna(0) + + def _make_random_forest(self, random_seed: int | None = None): + if random_seed is not None: + return RandomForestClassifier(n_jobs=self._n_jobs, random_state=random_seed) + return RandomForestClassifier(n_jobs=self._n_jobs) + + def _make_gradient_boost(self, random_seed: int | None = None): + if random_seed is not None: + return GradientBoostingClassifier(random_state=random_seed) + return GradientBoostingClassifier() + + def _make_xgboost(self, random_seed: int | None = None): + if random_seed is not None: + return _xgboost.XGBClassifier(n_jobs=self._n_jobs, random_state=random_seed) + return _xgboost.XGBClassifier(n_jobs=self._n_jobs) + def sort_features_to_classify(self, features): """sorts features to match the current classifier""" - if self._classifier_type == ClassifierType.XGBOOST: + if isinstance(self._classifier, CalibratedClassifierCV): + # Use the training-time feature order we stored + classifier_columns = self._feature_names + elif self._classifier_type == ClassifierType.XGBOOST: classifier_columns = self._classifier.get_booster().feature_names - # sklearn places feature names in the same spot else: classifier_columns = self._classifier.feature_names_in_ features_sorted = features[classifier_columns] @@ -518,7 +658,8 @@ def load(self, path: Path): self._classifier = c._classifier self._behavior = c._behavior - self._project_settings = c._project_settings + self._behavior_settings = c._behavior_settings + self._jabs_settings = c._jabs_settings self._classifier_type = c._classifier_type if c._classifier_file is not None: self._classifier_file = c._classifier_file @@ -554,6 +695,100 @@ def confusion_matrix(truth, predictions): """return the confusion matrix using sklearn's confusion_matrix function""" return confusion_matrix(truth, predictions) + @staticmethod + def brier_score(truth: np.ndarray, probabilities: np.ndarray) -> float: + """Return the Brier score (lower is better). + + Args: + truth (ndarray): array of true binary labels (0/1). + probabilities (ndarray): array of predicted probabilities for the positive class; can be shape (n_samples,) + or a (n_samples, 2) array from `predict_proba`. + + Returns: + float Brier score. + """ + if probabilities.ndim == 2: + # assume columns [P(neg), P(pos)] as returned by predict_proba + probabilities = probabilities[:, 1] + return brier_score_loss(truth, probabilities) + + @staticmethod + def plot_reliability( + truth: np.ndarray, + probabilities: np.ndarray, + out_path: Path | str, + n_bins: int = 10, + strategy: str = "uniform", + title: str | None = None, + show_hist: bool = True, + ) -> dict: + """Create and save a reliability (calibration) plot. + + Args: + truth: Binary ground truth labels (0 or 1). + probabilities: Predicted probabilities (2D array where second column is positive class). + out_path: File path to save the reliability plot. + n_bins: Number of bins for calibration curve. + strategy: Binning strategy ('uniform' or 'quantile'). + title: Optional plot title. + show_hist: If True, adds a histogram of predicted probabilities below the curve. + + Returns: + Dict with calibration data: 'bins', 'mean_pred', 'frac_pos', and 'counts'. + """ + prob = probabilities[:, 1] + y = np.asarray(truth).astype(int) + + pos = int(np.sum(y == 1)) + neg = int(np.sum(y == 0)) + if pos == 0 or neg == 0: + warnings.warn( + "plot_reliability: need both positive and negative labels.", stacklevel=2 + ) + + # Compute calibration curve + frac_pos, mean_pred = calibration_curve(y, prob, n_bins=n_bins, strategy=strategy) + + # Bin edges and counts + if strategy == "uniform": + bins = np.linspace(0.0, 1.0, n_bins + 1) + counts, _ = np.histogram(prob, bins=bins) + else: + q = np.linspace(0.0, 1.0, n_bins + 1) + bins = np.quantile(prob, q) + bins = np.unique(bins) + counts, _ = np.histogram(prob, bins=bins) + + # Plot + fig, ax = plt.subplots(figsize=(6.5, 4.5)) + ax.plot([0, 1], [0, 1], "--", color="gray", label="Perfect calibration") + ax.plot(mean_pred, frac_pos, marker="o", color="C0", label="Model") + ax.set_xlabel("Predicted probability") + ax.set_ylabel("Empirical frequency") + if title: + ax.set_title(title) + ax.legend(loc="best") + + if show_hist: + ax_hist = ax.twinx() + ax_hist.set_ylim(0, max(counts) * 1.2 if counts.size else 1) + display_bins = bins if strategy == "uniform" else 10 + ax_hist.hist(prob, bins=display_bins, alpha=0.25, color="C1") + ax_hist.set_yticks([]) + + fig.tight_layout() + out_path = Path(out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, dpi=150) + plt.close(fig) + + return { + "bins": bins, + "mean_pred": mean_pred, + "frac_pos": frac_pos, + "counts": counts, + } + @staticmethod def combine_data(per_frame, window): """combine feature sets together @@ -568,27 +803,67 @@ def combine_data(per_frame, window): return pd.concat([per_frame, window], axis=1) def _fit_random_forest(self, features, labels, random_seed: int | None = None): - if random_seed is not None: - classifier = RandomForestClassifier(n_jobs=self._n_jobs, random_state=random_seed) - else: - classifier = RandomForestClassifier(n_jobs=self._n_jobs) + classifier = self._make_random_forest(random_seed=random_seed) return classifier.fit(features.replace([np.inf, -np.inf], 0).fillna(0), labels) def _fit_gradient_boost(self, features, labels, random_seed: int | None = None): - if random_seed is not None: - classifier = GradientBoostingClassifier(random_state=random_seed) - else: - classifier = GradientBoostingClassifier() + classifier = self._make_gradient_boost(random_seed=random_seed) return classifier.fit(features.replace([np.inf, -np.inf], 0).fillna(0), labels) def _fit_xgboost(self, features, labels, random_seed: int | None = None): - if random_seed is not None: - classifier = _xgboost.XGBClassifier(n_jobs=self._n_jobs, random_state=random_seed) - else: - classifier = _xgboost.XGBClassifier(n_jobs=self._n_jobs) + classifier = self._make_xgboost(random_seed=random_seed) classifier.fit(features.replace([np.inf, -np.inf]), labels) return classifier + def _get_estimator_with_feature_importances(self): + """Return the underlying estimator that exposes `feature_importances_`, if available. + + Handles calibrated classifiers by retrieving the estimator from the first + calibrated fold. Returns None if no estimator with `feature_importances_` is found. + """ + est = self._classifier + # If wrapped by CalibratedClassifierCV, peel off the estimator + if isinstance(est, CalibratedClassifierCV): + try: + cc0 = est.calibrated_classifiers_[0] + est = cc0.estimator + except Exception: + return None + # Some sklearn/xgboost estimators expose feature_importances_ + return est if hasattr(est, "feature_importances_") else None + + def get_calibrated_feature_importances(self): + """Return averaged feature importances across calibrated folds. + + For CalibratedClassifierCV with tree-based base estimators (RF/GBT/XGBoost), + this computes the mean and std of `feature_importances_` across + `calibrated_classifiers_` estimators and returns a list of tuples: + [(feature_name, mean_importance, std_importance), ...] sorted by mean desc. + + Returns None if unavailable (e.g., non-tree base estimators). + """ + if not isinstance(self._classifier, CalibratedClassifierCV): + return None + try: + base_ests = [cc.estimator for cc in self._classifier.calibrated_classifiers_] + except Exception: + return None + + # get the base estimators that have feature_importances_ + base_ests = [be for be in base_ests if hasattr(be, "feature_importances_")] + if not base_ests: + return None + + # get the mean and standard deviation of feature importances from the base estimators + importances = np.vstack([be.feature_importances_ for be in base_ests]) + mean_imp = importances.mean(axis=0) + std_imp = importances.std(axis=0) + + # combine with feature names and sort by mean importance + items = list(zip(self._feature_names, mean_imp, std_imp, strict=True)) + items.sort(key=lambda t: t[1], reverse=True) + return items + def print_feature_importance(self, feature_list, limit=20): """print the most important features and their importance @@ -596,20 +871,35 @@ def print_feature_importance(self, feature_list, limit=20): feature_list: list of feature names used in the classifier limit: maximum number of features to print, defaults to 20 """ - # Get numerical feature importance - importances = list(self._classifier.feature_importances_) - # List of tuples with variable and importance + # Prefer calibrated importances if available + if isinstance(self._classifier, CalibratedClassifierCV): + items = self.get_calibrated_feature_importances() + if items is not None: + print(f"{'Feature Name':100} Mean Importance Std") + print("-" * 120) + for name, mean_imp, std_imp in items[:limit]: + print(f"{name:100} {mean_imp:0.4f} {std_imp:0.4f}") + return + # fall through to base-estimator single-source path if calibrated but no importances + + # Fallback: single estimator feature_importances_ + est = self._get_estimator_with_feature_importances() + if est is None: + print("Feature importances are unavailable for the current classifier.") + return + importances = list(est.feature_importances_) + names = feature_list if feature_list is not None else (self._feature_names or []) + if len(importances) != len(names): + names = [f"feature_{i}" for i in range(len(importances))] feature_importance = [ - (feature, round(importance, 2)) - for feature, importance in zip(feature_list, importances, strict=True) + (feature, round(importance, 4)) + for feature, importance in zip(names, importances, strict=False) ] - # Sort the feature importance by most important first feature_importance = sorted(feature_importance, key=lambda x: x[1], reverse=True) - # Print out the feature and importance print(f"{'Feature Name':100} Importance") print("-" * 120) for feature, importance in feature_importance[:limit]: - print(f"{feature:100} {importance:0.2f}") + print(f"{feature:100} {importance:0.4f}") @staticmethod def count_label_threshold(all_counts: dict): diff --git a/src/jabs/constants.py b/src/jabs/constants.py index a4a585d0..fa40016e 100644 --- a/src/jabs/constants.py +++ b/src/jabs/constants.py @@ -8,3 +8,6 @@ # some defaults for compressing hdf5 output COMPRESSION = "gzip" COMPRESSION_OPTS_DEFAULT = 6 + +DEFAULT_CALIBRATION_METHOD = "auto" # can be 'auto', 'isotonic', or 'sigmoid' +DEFAULT_CALIBRATION_CV = 3 diff --git a/src/jabs/feature_extraction/features.py b/src/jabs/feature_extraction/features.py index 5991a7f2..98f46da4 100644 --- a/src/jabs/feature_extraction/features.py +++ b/src/jabs/feature_extraction/features.py @@ -85,7 +85,7 @@ class IdentityFeatures: def __init__( self, - source_file: str, + source_file: str | Path, identity: int, directory: str | Path | None, pose_est: PoseEstimation, diff --git a/src/jabs/project/export_training.py b/src/jabs/project/export_training.py index 3565454d..1c4458fd 100644 --- a/src/jabs/project/export_training.py +++ b/src/jabs/project/export_training.py @@ -7,6 +7,7 @@ import jabs.feature_extraction import jabs.version +from jabs.constants import DEFAULT_CALIBRATION_CV, DEFAULT_CALIBRATION_METHOD from jabs.project.project_utils import to_safe_name from jabs.utils import FINAL_TRAIN_SEED @@ -64,6 +65,17 @@ def export_training_data( write_project_settings(out_h5, project.settings_manager.get_behavior(behavior), "settings") out_h5.attrs["classifier_type"] = classifier_type.value out_h5.attrs["training_seed"] = training_seed + out_h5.attrs["calibrate_probabilities"] = project.settings_manager.jabs_settings.get( + "calibrate_probabilities", False + ) + if out_h5.attrs["calibrate_probabilities"]: + out_h5.attrs["calibration_method"] = project.settings_manager.jabs_settings.get( + "calibration_method", DEFAULT_CALIBRATION_METHOD + ) + out_h5.attrs["calibration_cv"] = project.settings_manager.jabs_settings.get( + "calibration_cv", DEFAULT_CALIBRATION_CV + ) + feature_group = out_h5.create_group("features") for feature, data in features["per_frame"].items(): feature_group.create_dataset(f"per_frame/{feature}", data=data) diff --git a/src/jabs/project/prediction_manager.py b/src/jabs/project/prediction_manager.py index af3407d3..2dc1d776 100644 --- a/src/jabs/project/prediction_manager.py +++ b/src/jabs/project/prediction_manager.py @@ -126,9 +126,7 @@ def load_predictions(self, video: str, behavior: str): file_base = Path(video).with_suffix("").name + ".h5" path = self._project.project_paths.prediction_dir / file_base - nident = self._project.settings_manager.project_settings["video_files"][video][ - "identities" - ] + nident = self._project.settings_manager.project_info["video_files"][video]["identities"] try: with h5py.File(path, "r") as h5: diff --git a/src/jabs/project/project.py b/src/jabs/project/project.py index 03abd2ec..19eaa8af 100644 --- a/src/jabs/project/project.py +++ b/src/jabs/project/project.py @@ -55,7 +55,7 @@ class Project: """ def __init__( - self, project_path, use_cache=True, enable_video_check=True, enable_session_tracker=True + self, project_path, use_cache=True, enable_video_check=True, enable_session_tracker=False ): self._paths = ProjectPaths(Path(project_path), use_cache=use_cache) self._paths.create_directories() @@ -69,7 +69,7 @@ def __init__( self._session_tracker = SessionTracker(self, tracking_enabled=enable_session_tracker) # write out the defaults to the project file - if self._settings_manager.project_settings.get("defaults") != self.get_project_defaults(): + if self._settings_manager.project_info.get("defaults") != self.get_project_defaults(): self._settings_manager.save_project_file({"defaults": self.get_project_defaults()}) # Start a session tracker for this project. @@ -107,11 +107,6 @@ def classifier_dir(self): """get the classifier directory""" return self._paths.classifier_dir - @property - def settings(self): - """get the project metadata and preferences.""" - return self._settings_manager.project_settings - @property def settings_manager(self) -> SettingsManager: """get the project settings manager""" diff --git a/src/jabs/project/read_training.py b/src/jabs/project/read_training.py index 5e5ed492..0cdbf303 100644 --- a/src/jabs/project/read_training.py +++ b/src/jabs/project/read_training.py @@ -4,6 +4,7 @@ import h5py import pandas as pd +from jabs.constants import DEFAULT_CALIBRATION_CV, DEFAULT_CALIBRATION_METHOD from jabs.types import ClassifierType, ProjectDistanceUnit @@ -79,9 +80,10 @@ def load_training_data(training_file: Path): with h5py.File(training_file, "r") as in_h5: features["min_pose_version"] = in_h5.attrs["min_pose_version"] features["behavior"] = in_h5.attrs["behavior"] - features["settings"] = read_project_settings(in_h5["settings"]) + features["behavior_settings"] = read_project_settings(in_h5["settings"]) features["training_seed"] = in_h5.attrs["training_seed"] features["classifier_type"] = ClassifierType(in_h5.attrs["classifier_type"]) + # convert the string distance_unit attr to corresponding # ProjectDistanceUnit enum unit = in_h5.attrs.get("distance_unit") @@ -92,6 +94,21 @@ def load_training_data(training_file: Path): else: features["distance_unit"] = ProjectDistanceUnit[unit] + features["jabs_settings"] = {} + + # load other jabs settings that might or might not be present + calibrate_probabilities = in_h5.attrs.get("calibrate_probabilities", False) + if calibrate_probabilities: + features["jabs_settings"].update( + { + "calibrate_probabilities": calibrate_probabilities, + "calibration_method": in_h5.attrs.get( + "calibration_method", DEFAULT_CALIBRATION_METHOD + ), + "calibration_cv": in_h5.attrs.get("calibration_cv", DEFAULT_CALIBRATION_CV), + } + ) + features["labels"] = in_h5["label"][:] features["groups"] = in_h5["group"][:] diff --git a/src/jabs/project/settings_manager.py b/src/jabs/project/settings_manager.py index d6aaeaf6..0924d3f2 100644 --- a/src/jabs/project/settings_manager.py +++ b/src/jabs/project/settings_manager.py @@ -57,7 +57,7 @@ def save_project_file(self, data: dict | None = None): tmp.replace(self._paths.project_file) @property - def project_settings(self) -> dict: + def project_info(self) -> dict: """Get a copy of the current project properties and settings. Returns: @@ -65,6 +65,15 @@ def project_settings(self) -> dict: """ return dict(self._project_info) + @property + def jabs_settings(self) -> dict: + """Get a copy of general JABS settings from project file + + Returns: + dict + """ + return dict(self._project_info.get("settings", {})) + @property def behavior_names(self) -> list[str]: """Get a list of all behaviors defined in the project settings. diff --git a/src/jabs/project/video_manager.py b/src/jabs/project/video_manager.py index e02ef73d..89f00cd2 100644 --- a/src/jabs/project/video_manager.py +++ b/src/jabs/project/video_manager.py @@ -142,7 +142,7 @@ def get_video_identity_count(self, video_name: str) -> int: def _load_video_metadata(self): """Load metadata for each video and calculate total identities.""" - video_metadata = self._settings_manager.project_settings.get("video_files", {}) + video_metadata = self._settings_manager.project_info.get("video_files", {}) flush = False for video in self._videos: vinfo = video_metadata.get(video, {}) diff --git a/src/jabs/scripts/classify.py b/src/jabs/scripts/classify.py index cfd38c1e..d8e4a94c 100755 --- a/src/jabs/scripts/classify.py +++ b/src/jabs/scripts/classify.py @@ -105,7 +105,7 @@ def classify_pose( prediction_labels = np.full((pose_est.num_identities, pose_est.num_frames), -1, dtype=np.int8) prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32) - classifier_settings = classifier.project_settings + classifier_settings = classifier.behavior_settings print(f"Classifying {input_pose_file}...") @@ -137,13 +137,13 @@ def classify_pose( data = Classifier.combine_data(per_frame_features, window_features) if data.shape[0] > 0: - pred = classifier.predict(data) pred_prob = classifier.predict_proba(data) + positive_proba = pred_prob[:, 1] - # Keep the probability for the predicted class only. - # The following code uses some - # numpy magic to use the pred array as column indexes - # for each row of the pred_prob array we just computed. + # Derive predicted labels by thresholding at 0.5 + pred = (positive_proba >= classifier.TRUE_THRESHOLD).astype(int) + + # Keep the probability of the predicted class pred_prob = pred_prob[np.arange(len(pred_prob)), pred] # Only copy out predictions where there was a valid pose @@ -188,7 +188,7 @@ def train(training_file: Path) -> Classifier: Classifier: The trained classifier instance. """ classifier = Classifier.from_training_file(training_file) - classifier_settings = classifier.project_settings + classifier_settings = classifier.behavior_settings print("Training classifier for:", classifier.behavior_name) print(f" Classifier Type: {__CLASSIFIER_CHOICES[classifier.classifier_type]}") @@ -197,6 +197,7 @@ def train(training_file: Path) -> Classifier: print(f" Balanced Labels: {classifier_settings['balance_labels']}") print(f" Symmetric Behavior: {classifier_settings['symmetric_behavior']}") print(f" CM Units: {bool(classifier_settings['cm_units'])}") + print(f" Calibrate Probabilities: {classifier.calibrate_probabilities}") return classifier @@ -315,7 +316,7 @@ def classify_main(): sys.exit(str(e)) behavior = classifier.behavior_name - classifier_settings = classifier.project_settings + classifier_settings = classifier.behavior_settings print(f"Classifying using trained classifier: {args.classifier}") try: diff --git a/src/jabs/scripts/cli.py b/src/jabs/scripts/cli.py index b9fabb34..f053aa09 100644 --- a/src/jabs/scripts/cli.py +++ b/src/jabs/scripts/cli.py @@ -83,7 +83,7 @@ def export_training(ctx, directory: Path, behavior: str, classifier: str, outfil jabs_project = Project(directory, enable_session_tracker=False) # validate that the behavior exists in the project - if behavior not in jabs_project.settings["behavior"]: + if behavior not in jabs_project.settings_manager.project_info["behavior"]: raise click.ClickException(f"Behavior '{behavior}' not found in project.") console = Console() @@ -143,11 +143,11 @@ def rename_behavior(ctx, directory: Path, old_name: str, new_name: str) -> None: jabs_project = Project(directory, enable_session_tracker=False) # validate that the old behavior exists in the project - if old_name not in jabs_project.settings["behavior"]: + if old_name not in jabs_project.settings_manager.project_info["behavior"]: raise click.ClickException(f"Behavior '{old_name}' not found in project.") # validate that the new behavior does not already exist in the project - if new_name in jabs_project.settings["behavior"]: + if new_name in jabs_project.settings_manager.project_info["behavior"]: raise click.ClickException(f"Behavior '{new_name}' already exists in project.") console = Console() diff --git a/src/jabs/scripts/initialize_project.py b/src/jabs/scripts/initialize_project.py index d67cd9b5..3b5aca34 100755 --- a/src/jabs/scripts/initialize_project.py +++ b/src/jabs/scripts/initialize_project.py @@ -295,7 +295,7 @@ def validation_job_producer(): # save window sizes to project settings deduped_window_sizes = set( - project.settings_manager.project_settings.get("window_sizes", []) + window_sizes + project.settings_manager.project_info.get("window_sizes", []) + window_sizes ) project.settings_manager.save_project_file({"window_sizes": list(deduped_window_sizes)}) diff --git a/src/jabs/scripts/stats.py b/src/jabs/scripts/stats.py index 2eda38ba..6d522e99 100644 --- a/src/jabs/scripts/stats.py +++ b/src/jabs/scripts/stats.py @@ -124,7 +124,7 @@ def main(): print(f"\nClassifier: {classifier.classifier_name}") print(f"Behavior: {features['behavior']}") unit = ( - "cm" if classifier.project_settings["cm_units"] == ProjectDistanceUnit.CM else "pixel" + "cm" if classifier.behavior_settings["cm_units"] == ProjectDistanceUnit.CM else "pixel" ) print(f"Feature Distance Unit: {unit}") print("-" * 70) diff --git a/src/jabs/ui/behavior_search_dialog.py b/src/jabs/ui/behavior_search_dialog.py index d6f5bd35..4cbf4305 100644 --- a/src/jabs/ui/behavior_search_dialog.py +++ b/src/jabs/ui/behavior_search_dialog.py @@ -58,8 +58,8 @@ def __init__(self, project: Project, parent: QtWidgets.QWidget | None = None): self.setModal(True) self.resize(500, 320) - proj_settings = project.settings - self._behavior_labels = sorted(proj_settings.get("behavior", {}).keys()) + proj_info = project.settings_manager.project_info + self._behavior_labels = sorted(proj_info.get("behavior", {}).keys()) # === Main Layout === main_layout = QtWidgets.QVBoxLayout(self) diff --git a/src/jabs/ui/central_widget.py b/src/jabs/ui/central_widget.py index f1f0fc64..8493b00b 100644 --- a/src/jabs/ui/central_widget.py +++ b/src/jabs/ui/central_widget.py @@ -277,7 +277,7 @@ def set_project(self, project: Project) -> None: self._labels = None self._loaded_video = None - self._controls.update_project_settings(project.settings) + self._controls.update_project_settings(project.settings_manager.project_info) self._search_bar_widget.update_project(project) self._update_timeline_search_results() @@ -655,7 +655,7 @@ def _update_classifier_controls(self) -> None: self._controls.set_classifier_selection(self._classifier.classifier_type) # does the classifier match the current settings? - classifier_settings = self._classifier.project_settings + classifier_settings = self._classifier.behavior_settings if ( classifier_settings is not None and classifier_settings.get("window_size", None) == self.window_size @@ -675,7 +675,7 @@ def _train_button_clicked(self) -> None: # make sure video playback is stopped self._player_widget.stop() - # setup training thread + # setup training thread, training thread will configure self._classifier with current settings self._training_thread = TrainingThread( self._classifier, self._project, diff --git a/src/jabs/ui/classification_thread.py b/src/jabs/ui/classification_thread.py index 972b067d..4fc8c2c3 100644 --- a/src/jabs/ui/classification_thread.py +++ b/src/jabs/ui/classification_thread.py @@ -130,23 +130,18 @@ def check_termination_requested() -> None: check_termination_requested() if data.shape[0] > 0: - # make predictions - # Note: this makes predictions for all frames in the video, even those without valid pose - # We will later filter these out when saving the predictions to disk - # consider changing this to only predict on frames with valid pose - predictions[video][identity] = self._classifier.predict(data) - - # also get the probabilities + # get predicted probabilities for the positive class (class 1) prob = self._classifier.predict_proba(data) - # Save the probability for the predicted class only. - # The following code uses some - # numpy magic to use the _predictions array as column indexes - # for each row of the 'prob' array we just computed. - probabilities[video][identity] = prob[ - np.arange(len(prob)), predictions[video][identity] - ] - - # save the indexes for the predicted frames + positive_proba = prob[:, 1] + + # derive binary predictions by thresholding probabilities + preds = (positive_proba >= self._classifier.TRUE_THRESHOLD).astype(int) + predictions[video][identity] = preds + + # save probability of the predicted class for each frame + probabilities[video][identity] = prob[np.arange(len(prob)), preds] + + # store the frame indexes corresponding to each prediction frame_indexes[video][identity] = feature_values["frame_indexes"] else: predictions[video][identity] = np.array(0) diff --git a/src/jabs/ui/main_window.py b/src/jabs/ui/main_window.py index e4863f8d..3277fad9 100644 --- a/src/jabs/ui/main_window.py +++ b/src/jabs/ui/main_window.py @@ -20,6 +20,7 @@ from .progress_dialog import create_progress_dialog from .project_loader_thread import ProjectLoaderThread from .project_pruning_dialog import ProjectPruningDialog +from .settings_dialog import JabsSettingsDialog from .stacked_timeline_widget import StackedTimelineWidget from .user_guide_dialog import UserGuideDialog from .util import send_file_to_recycle_bin @@ -119,6 +120,13 @@ def __init__(self, app_name: str, app_name_long: str, *args, **kwargs) -> None: self._clear_cache.triggered.connect(self._clear_cache_action) app_menu.addAction(self._clear_cache) + # model calibration settings + self._settings_action = QtGui.QAction("JABS Settings", self) + self._settings_action.setStatusTip("Open settings dialog") + self._settings_action.setEnabled(False) + self._settings_action.triggered.connect(self._open_settings_dialog) + app_menu.addAction(self._settings_action) + # exit action exit_action = QtGui.QAction(f" &Quit {self._app_name}", self) exit_action.setShortcut(QtGui.QKeySequence("Ctrl+Q")) @@ -524,7 +532,7 @@ def behavior_label_add_event(self, behaviors: list[str]) -> None: """handle project updates required when user adds new behavior labels""" # check for new behaviors for behavior in behaviors: - if behavior not in self._project.settings_manager.project_settings["behavior"]: + if behavior not in self._project.settings_manager.project_info["behavior"]: # save new behavior with default settings self._project.settings_manager.save_behavior(behavior, {}) @@ -725,6 +733,7 @@ def _project_loaded_callback(self) -> None: self._project.feature_manager.can_use_segmentation_features ) self._clear_cache.setEnabled(True) + self._settings_action.setEnabled(True) available_objects = self._project.feature_manager.static_objects for static_object, menu_item in self.enable_landmark_features.items(): if static_object in available_objects: @@ -990,3 +999,8 @@ def _view_license(self) -> None: """View the license agreement (JABS->View License Agreement menu action)""" dialog = LicenseAgreementDialog(self, view_only=True) dialog.exec_() + + def _open_settings_dialog(self) -> None: + """Open the settings dialog (JABS->Settings menu action)""" + dialog = JabsSettingsDialog(parent=self, settings_manager=self._project.settings_manager) + dialog.exec_() diff --git a/src/jabs/ui/settings_dialog.py b/src/jabs/ui/settings_dialog.py new file mode 100644 index 00000000..0038249d --- /dev/null +++ b/src/jabs/ui/settings_dialog.py @@ -0,0 +1,340 @@ +from PySide6.QtCore import Qt, QTimer, Signal +from PySide6.QtGui import QResizeEvent, QShowEvent +from PySide6.QtWidgets import ( + QAbstractScrollArea, + QCheckBox, + QComboBox, + QDialog, + QDialogButtonBox, + QFrame, + QGridLayout, + QGroupBox, + QLabel, + QLayout, + QScrollArea, + QSizePolicy, + QSpacerItem, + QSpinBox, + QToolButton, + QVBoxLayout, + QWidget, +) + +from jabs.classifier import Classifier +from jabs.constants import DEFAULT_CALIBRATION_CV, DEFAULT_CALIBRATION_METHOD +from jabs.project.settings_manager import SettingsManager + + +class CollapsibleSection(QWidget): + """A simple collapsible section with a header ToolButton and a content area.""" + + sizeChanged = Signal() + + def __init__(self, title: str, content: QWidget, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._content = content + self._toggle_btn = QToolButton(self) + self._toggle_btn.setStyleSheet("QToolButton { border: none; }") + self._toggle_btn.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextBesideIcon) + self._toggle_btn.setArrowType(Qt.ArrowType.RightArrow) + self._toggle_btn.setText(title) + self._toggle_btn.setCheckable(True) + self._toggle_btn.setChecked(False) + self._toggle_btn.toggled.connect(self._on_toggled) + + line = QFrame(self) + line.setFrameShape(QFrame.Shape.HLine) + line.setFrameShadow(QFrame.Shadow.Sunken) + + self._content.setVisible(False) + # Ensure the collapsible widget and its content expand to fit content + self.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + self._content.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + + lay = QVBoxLayout(self) + lay.setContentsMargins(0, 0, 0, 0) + lay.addWidget(self._toggle_btn) + lay.addWidget(line) + lay.addWidget(self._content) + + def _on_toggled(self, checked: bool) -> None: + self._toggle_btn.setArrowType( + Qt.ArrowType.DownArrow if checked else Qt.ArrowType.RightArrow + ) + self._content.setVisible(checked) + self._content.updateGeometry() + + # Ask ancestors to recompute layout so the page grows inside the scroll area + parent = self.parentWidget() + if parent is not None and parent.layout() is not None: + parent.layout().activate() + + if self.layout() is not None: + self.layout().activate() + + # Let ancestors recompute size hints and notify listeners + if parent is not None: + parent.updateGeometry() + self.updateGeometry() + self.sizeChanged.emit() + + +class JabsSettingsDialog(QDialog): + """ + Dialog for changing project settings. + + Args: + settings_manager (SettingsManager): Project settings manager used to load and save settings. + parent (QWidget | None, optional): Parent widget for this dialog. Defaults to None. + """ + + def __init__(self, settings_manager: SettingsManager, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setWindowTitle("Project Settings") + self._settings_manager = settings_manager + + # Allow resizing and show scrollbars if content overflows + self.setSizeGripEnabled(True) + + # Widgets + self._calibrate_checkbox = QCheckBox( + "Enable probability calibration (calibrate_probabilities)" + ) + self._method_selection = QComboBox() + self._method_selection.addItems(Classifier.CALIBRATION_METHODS) + self._cv_selection = QSpinBox() + self._cv_selection.setRange(2, 10) + self._cv_selection.setAccelerated(True) + self._cv_selection.setToolTip("Number of CV folds used inside the calibrator") + self._save_reliability_checkbox = QCheckBox("Save reliability plots") + self._save_reliability_checkbox.setToolTip( + "If enabled, save reliability (calibration) plots after training/validation." + ) + + # Load current values from project settings + current_settings = settings_manager.jabs_settings + calibrate = current_settings.get("calibrate_probabilities", False) + method = current_settings.get("calibration_method", DEFAULT_CALIBRATION_METHOD) + cv = current_settings.get("calibration_cv", DEFAULT_CALIBRATION_CV) + save_reliability = current_settings.get("save_reliability_plots", False) + + self._calibrate_checkbox.setChecked(calibrate) + self._method_selection.setCurrentIndex(max(0, self._method_selection.findText(method))) + self._cv_selection.setValue(cv) + self._save_reliability_checkbox.setChecked(save_reliability) + + # Layout for form + form = QWidget(self) + grid = QGridLayout(form) + grid.setContentsMargins(0, 0, 0, 0) + grid.setHorizontalSpacing(12) + grid.setVerticalSpacing(8) + grid.setColumnStretch(0, 0) # labels column: natural size + grid.setColumnStretch(1, 0) # inputs column: natural size + grid.setColumnStretch( + 2, 1 + ) # consume extra width on the right (keeps content left-aligned) + + # Keep inputs compact; whitespace grows in column 2 + self._method_selection.setSizeAdjustPolicy(QComboBox.SizeAdjustPolicy.AdjustToContents) + self._method_selection.setFixedWidth(self._method_selection.sizeHint().width() + 24) + self._method_selection.setSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + + self._cv_selection.setFixedWidth(90) + self._cv_selection.setSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + + self._calibrate_checkbox.setSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed) + self._save_reliability_checkbox.setSizePolicy( + QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed + ) + + grid.addWidget(QLabel("Calibrate probabilities:"), 0, 0, Qt.AlignmentFlag.AlignRight) + grid.addWidget(self._calibrate_checkbox, 0, 1) + + grid.addWidget(QLabel("Calibration method:"), 1, 0, Qt.AlignmentFlag.AlignRight) + grid.addWidget(self._method_selection, 1, 1) + + grid.addWidget(QLabel("calibration cv (folds):"), 2, 0, Qt.AlignmentFlag.AlignRight) + grid.addWidget(self._cv_selection, 2, 1) + + grid.addWidget(QLabel("Save reliability plots:"), 3, 0, Qt.AlignmentFlag.AlignRight) + grid.addWidget(self._save_reliability_checkbox, 3, 1) + grid.addItem( + QSpacerItem(0, 0, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum), 0, 2, 4, 1 + ) + + # Help / inline docs (rich text) + help_label = QLabel(self) + help_label.setTextFormat(Qt.TextFormat.RichText) + help_label.setWordWrap(True) + help_label.setText( + """ +
Calibrate probabilities remaps raw model scores to better probabilities using + cross-validation inside training. This improves log-loss, Brier score, and makes thresholding + (e.g., show if p ≥ 0.7) more reliable.
+ +calibration_cv setting) increase the data required for selecting
+ isotonic.Guidance: If your dataset is large (thousands of labeled frames and roughly balanced),
+ auto will select isotonic. If it selects sigmoid, you can collect more labels or reduce
+ calibration_cv to allow isotonic to activate.
Tip: Most users should leave calibration_method = auto.
Saving reliability plots: If Save reliability plots is enabled, JABS will write reliability
+ figures after training/validation to <project dir>/plots/<timestamp>/.
+ Each run creates a new timestamped folder so results are easy to compare.