Skip to content

Commit 9f19281

Browse files
committed
made classifier.py return probability tensor as well
1 parent c8467b0 commit 9f19281

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

pyreason/scripts/learning/classification/classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,5 @@ def forward(self, x, t1: int = 0, t2: int = 0):
8585
fact_str = f'{self.model_name}({class_name}) : [{lower:.3f}, {upper:.3f}]'
8686
fact = Fact(fact_str, name=f'{self.model_name}-{class_name}-fact', start_time=t1, end_time=t2)
8787
facts.append(fact)
88-
return output, facts
88+
return output, probabilities, facts
8989

tests/test_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_classifier_integration():
3838
t2 = 0
3939

4040
# Run the forward pass to get the model output and the corresponding PyReason facts.
41-
output, facts = logic_classifier(input_tensor, t1, t2)
41+
output, probabilities, facts = logic_classifier(input_tensor, t1, t2)
4242

4343
# Assert that the output is a tensor.
4444
assert isinstance(output, torch.Tensor), "The model output should be a torch.Tensor"

0 commit comments

Comments
 (0)