Skip to content

Commit c8467b0

Browse files
committed
minor changes to names and parameters
1 parent 01ff01b commit c8467b0

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

pyreason/scripts/learning/classification/classifier.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class LogicIntegratedClassifier(torch.nn.Module):
1212
Class to integrate a PyTorch model with PyReason. The output of the model is returned to the
1313
user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them.
1414
"""
15-
def __init__(self, model, class_names: List[str], model_name: str = 'classifier', interface_modes: ModelInterfaceOptions = None):
15+
def __init__(self, model, class_names: List[str], model_name: str = 'classifier', interface_options: ModelInterfaceOptions = None):
1616
"""
1717
:param model:
1818
:param class_names:
@@ -21,7 +21,7 @@ def __init__(self, model, class_names: List[str], model_name: str = 'classifier'
2121
self.model = model
2222
self.class_names = class_names
2323
self.model_name = model_name
24-
self.interface_modes = interface_modes
24+
self.interface_options = interface_options
2525

2626
def get_class_facts(self, t1: int, t2: int) -> List[Fact]:
2727
"""
@@ -49,7 +49,7 @@ def forward(self, x, t1: int = 0, t2: int = 0):
4949

5050
# Convert logits to probabilities assuming a multi-class classification.
5151
probabilities = F.softmax(output, dim=1).squeeze()
52-
opts = self.interface_modes
52+
opts = self.interface_options
5353

5454
# Prepare threshold tensor.
5555
threshold = torch.tensor(opts.threshold, dtype=probabilities.dtype, device=probabilities.device)

tests/test_classifier.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@ def test_classifier_integration():
1919
# Create integration options.
2020
# Only probabilities exceeding 0.6 will be considered.
2121
# For those, if set_lower_bound is True, lower bound becomes 0.95; if set_upper_bound is False, upper bound is forced to 1.
22-
dummy_options = pr.ModelInterfaceOptions(
23-
threshold=0.6,
22+
interface_options = pr.ModelInterfaceOptions(
23+
threshold=0.4,
2424
set_lower_bound=True,
2525
set_upper_bound=False,
2626
snap_value=0.95
2727
)
2828

2929
# Create an instance of LogicIntegratedClassifier.
3030
logic_classifier = pr.LogicIntegratedClassifier(model, class_names, model_name="classifier",
31-
interface_modes=dummy_options)
31+
interface_options=interface_options)
3232

3333
# Create a dummy input tensor with 10 features.
3434
input_tensor = torch.rand(1, 10)

0 commit comments

Comments
 (0)