Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/jabs/behavior_search/behavior_search_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _search_behaviors_gen(
)

case PredictionBehaviorSearchQuery() as pred_query:
proj_settings = project.settings_manager.project_settings
proj_settings = project.settings_manager.project_info
if pred_query.behavior_label is None:
behavior_dict = proj_settings.get("behavior", {})
behaviors = list(behavior_dict.keys())
Expand Down
392 changes: 341 additions & 51 deletions src/jabs/classifier/classifier.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/jabs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@
# some defaults for compressing hdf5 output
COMPRESSION = "gzip"
COMPRESSION_OPTS_DEFAULT = 6

DEFAULT_CALIBRATION_METHOD = "auto" # can be 'auto', 'isotonic', or 'sigmoid'
DEFAULT_CALIBRATION_CV = 3
2 changes: 1 addition & 1 deletion src/jabs/feature_extraction/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class IdentityFeatures:

def __init__(
self,
source_file: str,
source_file: str | Path,
identity: int,
directory: str | Path | None,
pose_est: PoseEstimation,
Expand Down
12 changes: 12 additions & 0 deletions src/jabs/project/export_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import jabs.feature_extraction
import jabs.version
from jabs.constants import DEFAULT_CALIBRATION_CV, DEFAULT_CALIBRATION_METHOD
from jabs.project.project_utils import to_safe_name
from jabs.utils import FINAL_TRAIN_SEED

Expand Down Expand Up @@ -64,6 +65,17 @@ def export_training_data(
write_project_settings(out_h5, project.settings_manager.get_behavior(behavior), "settings")
out_h5.attrs["classifier_type"] = classifier_type.value
out_h5.attrs["training_seed"] = training_seed
out_h5.attrs["calibrate_probabilities"] = project.settings_manager.jabs_settings.get(
"calibrate_probabilities", False
)
if out_h5.attrs["calibrate_probabilities"]:
out_h5.attrs["calibration_method"] = project.settings_manager.jabs_settings.get(
"calibration_method", DEFAULT_CALIBRATION_METHOD
)
out_h5.attrs["calibration_cv"] = project.settings_manager.jabs_settings.get(
"calibration_cv", DEFAULT_CALIBRATION_CV
)

feature_group = out_h5.create_group("features")
for feature, data in features["per_frame"].items():
feature_group.create_dataset(f"per_frame/{feature}", data=data)
Expand Down
4 changes: 1 addition & 3 deletions src/jabs/project/prediction_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ def load_predictions(self, video: str, behavior: str):
file_base = Path(video).with_suffix("").name + ".h5"
path = self._project.project_paths.prediction_dir / file_base

nident = self._project.settings_manager.project_settings["video_files"][video][
"identities"
]
nident = self._project.settings_manager.project_info["video_files"][video]["identities"]

try:
with h5py.File(path, "r") as h5:
Expand Down
9 changes: 2 additions & 7 deletions src/jabs/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Project:
"""

def __init__(
self, project_path, use_cache=True, enable_video_check=True, enable_session_tracker=True
self, project_path, use_cache=True, enable_video_check=True, enable_session_tracker=False
):
self._paths = ProjectPaths(Path(project_path), use_cache=use_cache)
self._paths.create_directories()
Expand All @@ -69,7 +69,7 @@ def __init__(
self._session_tracker = SessionTracker(self, tracking_enabled=enable_session_tracker)

# write out the defaults to the project file
if self._settings_manager.project_settings.get("defaults") != self.get_project_defaults():
if self._settings_manager.project_info.get("defaults") != self.get_project_defaults():
self._settings_manager.save_project_file({"defaults": self.get_project_defaults()})

# Start a session tracker for this project.
Expand Down Expand Up @@ -107,11 +107,6 @@ def classifier_dir(self):
"""get the classifier directory"""
return self._paths.classifier_dir

@property
def settings(self):
"""get the project metadata and preferences."""
return self._settings_manager.project_settings

@property
def settings_manager(self) -> SettingsManager:
"""get the project settings manager"""
Expand Down
19 changes: 18 additions & 1 deletion src/jabs/project/read_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import h5py
import pandas as pd

from jabs.constants import DEFAULT_CALIBRATION_CV, DEFAULT_CALIBRATION_METHOD
from jabs.types import ClassifierType, ProjectDistanceUnit


Expand Down Expand Up @@ -79,9 +80,10 @@ def load_training_data(training_file: Path):
with h5py.File(training_file, "r") as in_h5:
features["min_pose_version"] = in_h5.attrs["min_pose_version"]
features["behavior"] = in_h5.attrs["behavior"]
features["settings"] = read_project_settings(in_h5["settings"])
features["behavior_settings"] = read_project_settings(in_h5["settings"])
features["training_seed"] = in_h5.attrs["training_seed"]
features["classifier_type"] = ClassifierType(in_h5.attrs["classifier_type"])

# convert the string distance_unit attr to corresponding
# ProjectDistanceUnit enum
unit = in_h5.attrs.get("distance_unit")
Expand All @@ -92,6 +94,21 @@ def load_training_data(training_file: Path):
else:
features["distance_unit"] = ProjectDistanceUnit[unit]

features["jabs_settings"] = {}

# load other jabs settings that might or might not be present
calibrate_probabilities = in_h5.attrs.get("calibrate_probabilities", False)
if calibrate_probabilities:
features["jabs_settings"].update(
{
"calibrate_probabilities": calibrate_probabilities,
"calibration_method": in_h5.attrs.get(
"calibration_method", DEFAULT_CALIBRATION_METHOD
),
"calibration_cv": in_h5.attrs.get("calibration_cv", DEFAULT_CALIBRATION_CV),
}
)

features["labels"] = in_h5["label"][:]
features["groups"] = in_h5["group"][:]

Expand Down
11 changes: 10 additions & 1 deletion src/jabs/project/settings_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,23 @@ def save_project_file(self, data: dict | None = None):
tmp.replace(self._paths.project_file)

@property
def project_settings(self) -> dict:
def project_info(self) -> dict:
"""Get a copy of the current project properties and settings.

Returns:
dict
"""
return dict(self._project_info)

@property
def jabs_settings(self) -> dict:
"""Get a copy of general JABS settings from project file

Returns:
dict
"""
return dict(self._project_info.get("settings", {}))

@property
def behavior_names(self) -> list[str]:
"""Get a list of all behaviors defined in the project settings.
Expand Down
2 changes: 1 addition & 1 deletion src/jabs/project/video_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_video_identity_count(self, video_name: str) -> int:

def _load_video_metadata(self):
"""Load metadata for each video and calculate total identities."""
video_metadata = self._settings_manager.project_settings.get("video_files", {})
video_metadata = self._settings_manager.project_info.get("video_files", {})
flush = False
for video in self._videos:
vinfo = video_metadata.get(video, {})
Expand Down
17 changes: 9 additions & 8 deletions src/jabs/scripts/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def classify_pose(
prediction_labels = np.full((pose_est.num_identities, pose_est.num_frames), -1, dtype=np.int8)
prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32)

classifier_settings = classifier.project_settings
classifier_settings = classifier.behavior_settings

print(f"Classifying {input_pose_file}...")

Expand Down Expand Up @@ -137,13 +137,13 @@ def classify_pose(
data = Classifier.combine_data(per_frame_features, window_features)

if data.shape[0] > 0:
pred = classifier.predict(data)
pred_prob = classifier.predict_proba(data)
positive_proba = pred_prob[:, 1]

# Keep the probability for the predicted class only.
# The following code uses some
# numpy magic to use the pred array as column indexes
# for each row of the pred_prob array we just computed.
# Derive predicted labels by thresholding at 0.5
pred = (positive_proba >= classifier.TRUE_THRESHOLD).astype(int)

# Keep the probability of the predicted class
pred_prob = pred_prob[np.arange(len(pred_prob)), pred]

# Only copy out predictions where there was a valid pose
Expand Down Expand Up @@ -188,7 +188,7 @@ def train(training_file: Path) -> Classifier:
Classifier: The trained classifier instance.
"""
classifier = Classifier.from_training_file(training_file)
classifier_settings = classifier.project_settings
classifier_settings = classifier.behavior_settings

print("Training classifier for:", classifier.behavior_name)
print(f" Classifier Type: {__CLASSIFIER_CHOICES[classifier.classifier_type]}")
Expand All @@ -197,6 +197,7 @@ def train(training_file: Path) -> Classifier:
print(f" Balanced Labels: {classifier_settings['balance_labels']}")
print(f" Symmetric Behavior: {classifier_settings['symmetric_behavior']}")
print(f" CM Units: {bool(classifier_settings['cm_units'])}")
print(f" Calibrate Probabilities: {classifier.calibrate_probabilities}")

return classifier

Expand Down Expand Up @@ -315,7 +316,7 @@ def classify_main():
sys.exit(str(e))

behavior = classifier.behavior_name
classifier_settings = classifier.project_settings
classifier_settings = classifier.behavior_settings

print(f"Classifying using trained classifier: {args.classifier}")
try:
Expand Down
6 changes: 3 additions & 3 deletions src/jabs/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def export_training(ctx, directory: Path, behavior: str, classifier: str, outfil
jabs_project = Project(directory, enable_session_tracker=False)

# validate that the behavior exists in the project
if behavior not in jabs_project.settings["behavior"]:
if behavior not in jabs_project.settings_manager.project_info["behavior"]:
raise click.ClickException(f"Behavior '{behavior}' not found in project.")

console = Console()
Expand Down Expand Up @@ -143,11 +143,11 @@ def rename_behavior(ctx, directory: Path, old_name: str, new_name: str) -> None:
jabs_project = Project(directory, enable_session_tracker=False)

# validate that the old behavior exists in the project
if old_name not in jabs_project.settings["behavior"]:
if old_name not in jabs_project.settings_manager.project_info["behavior"]:
raise click.ClickException(f"Behavior '{old_name}' not found in project.")

# validate that the new behavior does not already exist in the project
if new_name in jabs_project.settings["behavior"]:
if new_name in jabs_project.settings_manager.project_info["behavior"]:
raise click.ClickException(f"Behavior '{new_name}' already exists in project.")

console = Console()
Expand Down
2 changes: 1 addition & 1 deletion src/jabs/scripts/initialize_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def validation_job_producer():

# save window sizes to project settings
deduped_window_sizes = set(
project.settings_manager.project_settings.get("window_sizes", []) + window_sizes
project.settings_manager.project_info.get("window_sizes", []) + window_sizes
)
project.settings_manager.save_project_file({"window_sizes": list(deduped_window_sizes)})

Expand Down
2 changes: 1 addition & 1 deletion src/jabs/scripts/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def main():
print(f"\nClassifier: {classifier.classifier_name}")
print(f"Behavior: {features['behavior']}")
unit = (
"cm" if classifier.project_settings["cm_units"] == ProjectDistanceUnit.CM else "pixel"
"cm" if classifier.behavior_settings["cm_units"] == ProjectDistanceUnit.CM else "pixel"
)
print(f"Feature Distance Unit: {unit}")
print("-" * 70)
Expand Down
4 changes: 2 additions & 2 deletions src/jabs/ui/behavior_search_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __init__(self, project: Project, parent: QtWidgets.QWidget | None = None):
self.setModal(True)
self.resize(500, 320)

proj_settings = project.settings
self._behavior_labels = sorted(proj_settings.get("behavior", {}).keys())
proj_info = project.settings_manager.project_info
self._behavior_labels = sorted(proj_info.get("behavior", {}).keys())

# === Main Layout ===
main_layout = QtWidgets.QVBoxLayout(self)
Expand Down
6 changes: 3 additions & 3 deletions src/jabs/ui/central_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def set_project(self, project: Project) -> None:
self._labels = None
self._loaded_video = None

self._controls.update_project_settings(project.settings)
self._controls.update_project_settings(project.settings_manager.project_info)
self._search_bar_widget.update_project(project)
self._update_timeline_search_results()

Expand Down Expand Up @@ -655,7 +655,7 @@ def _update_classifier_controls(self) -> None:
self._controls.set_classifier_selection(self._classifier.classifier_type)

# does the classifier match the current settings?
classifier_settings = self._classifier.project_settings
classifier_settings = self._classifier.behavior_settings
if (
classifier_settings is not None
and classifier_settings.get("window_size", None) == self.window_size
Expand All @@ -675,7 +675,7 @@ def _train_button_clicked(self) -> None:
# make sure video playback is stopped
self._player_widget.stop()

# setup training thread
# setup training thread, training thread will configure self._classifier with current settings
self._training_thread = TrainingThread(
self._classifier,
self._project,
Expand Down
27 changes: 11 additions & 16 deletions src/jabs/ui/classification_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,18 @@ def check_termination_requested() -> None:

check_termination_requested()
if data.shape[0] > 0:
# make predictions
# Note: this makes predictions for all frames in the video, even those without valid pose
# We will later filter these out when saving the predictions to disk
# consider changing this to only predict on frames with valid pose
predictions[video][identity] = self._classifier.predict(data)

# also get the probabilities
# get predicted probabilities for the positive class (class 1)
prob = self._classifier.predict_proba(data)
# Save the probability for the predicted class only.
# The following code uses some
# numpy magic to use the _predictions array as column indexes
# for each row of the 'prob' array we just computed.
probabilities[video][identity] = prob[
np.arange(len(prob)), predictions[video][identity]
]

# save the indexes for the predicted frames
positive_proba = prob[:, 1]

# derive binary predictions by thresholding probabilities
preds = (positive_proba >= self._classifier.TRUE_THRESHOLD).astype(int)
predictions[video][identity] = preds

# save probability of the predicted class for each frame
probabilities[video][identity] = prob[np.arange(len(prob)), preds]

# store the frame indexes corresponding to each prediction
frame_indexes[video][identity] = feature_values["frame_indexes"]
else:
predictions[video][identity] = np.array(0)
Expand Down
16 changes: 15 additions & 1 deletion src/jabs/ui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .progress_dialog import create_progress_dialog
from .project_loader_thread import ProjectLoaderThread
from .project_pruning_dialog import ProjectPruningDialog
from .settings_dialog import JabsSettingsDialog
from .stacked_timeline_widget import StackedTimelineWidget
from .user_guide_dialog import UserGuideDialog
from .util import send_file_to_recycle_bin
Expand Down Expand Up @@ -119,6 +120,13 @@ def __init__(self, app_name: str, app_name_long: str, *args, **kwargs) -> None:
self._clear_cache.triggered.connect(self._clear_cache_action)
app_menu.addAction(self._clear_cache)

# model calibration settings
self._settings_action = QtGui.QAction("JABS Settings", self)
self._settings_action.setStatusTip("Open settings dialog")
self._settings_action.setEnabled(False)
self._settings_action.triggered.connect(self._open_settings_dialog)
app_menu.addAction(self._settings_action)

# exit action
exit_action = QtGui.QAction(f" &Quit {self._app_name}", self)
exit_action.setShortcut(QtGui.QKeySequence("Ctrl+Q"))
Expand Down Expand Up @@ -524,7 +532,7 @@ def behavior_label_add_event(self, behaviors: list[str]) -> None:
"""handle project updates required when user adds new behavior labels"""
# check for new behaviors
for behavior in behaviors:
if behavior not in self._project.settings_manager.project_settings["behavior"]:
if behavior not in self._project.settings_manager.project_info["behavior"]:
# save new behavior with default settings
self._project.settings_manager.save_behavior(behavior, {})

Expand Down Expand Up @@ -725,6 +733,7 @@ def _project_loaded_callback(self) -> None:
self._project.feature_manager.can_use_segmentation_features
)
self._clear_cache.setEnabled(True)
self._settings_action.setEnabled(True)
available_objects = self._project.feature_manager.static_objects
for static_object, menu_item in self.enable_landmark_features.items():
if static_object in available_objects:
Expand Down Expand Up @@ -990,3 +999,8 @@ def _view_license(self) -> None:
"""View the license agreement (JABS->View License Agreement menu action)"""
dialog = LicenseAgreementDialog(self, view_only=True)
dialog.exec_()

def _open_settings_dialog(self) -> None:
"""Open the settings dialog (JABS->Settings menu action)"""
dialog = JabsSettingsDialog(parent=self, settings_manager=self._project.settings_manager)
dialog.exec_()
Loading