diff --git a/model_tools/brain_transformation/behavior.py b/model_tools/brain_transformation/behavior.py index bc0869e..ef6d802 100644 --- a/model_tools/brain_transformation/behavior.py +++ b/model_tools/brain_transformation/behavior.py @@ -57,8 +57,11 @@ def __init__(self, identifier, activations_model, layer): def start_task(self, task: BrainModel.Task, fitting_stimuli): assert task in [BrainModel.Task.passive, BrainModel.Task.probabilities] self.current_task = task - - fitting_features = self.activations_model(fitting_stimuli, layers=[self.layer]) + if type(self.layer) is not list: + layers = [self.layer] + else: + layers = self.layer + fitting_features = self.activations_model(fitting_stimuli, layers=layers) fitting_features = fitting_features.transpose('presentation', 'neuroid') assert all(fitting_features['image_id'].values == fitting_stimuli['image_id'].values), \ "image_id ordering is incorrect" @@ -67,7 +70,11 @@ def start_task(self, task: BrainModel.Task, fitting_stimuli): def look_at(self, stimuli): if self.current_task is BrainModel.Task.passive: return - features = self.activations_model(stimuli, layers=[self.layer]) + if type(self.layer) is not list: + layers = [self.layer] + else: + layers = self.layer + features = self.activations_model(stimuli, layers=layers) features = features.transpose('presentation', 'neuroid') prediction = self.classifier.predict_proba(features) return prediction