@@ -12,7 +12,7 @@ class LogicIntegratedClassifier(torch.nn.Module):
12
12
Class to integrate a PyTorch model with PyReason. The output of the model is returned to the
13
13
user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them.
14
14
"""
15
- def __init__ (self , model , class_names : List [str ], model_name : str = 'classifier' , interface_modes : ModelInterfaceOptions = None ):
15
+ def __init__ (self , model , class_names : List [str ], model_name : str = 'classifier' , interface_options : ModelInterfaceOptions = None ):
16
16
"""
17
17
:param model:
18
18
:param class_names:
@@ -21,7 +21,7 @@ def __init__(self, model, class_names: List[str], model_name: str = 'classifier'
21
21
self .model = model
22
22
self .class_names = class_names
23
23
self .model_name = model_name
24
- self .interface_modes = interface_modes
24
+ self .interface_options = interface_options
25
25
26
26
def get_class_facts (self , t1 : int , t2 : int ) -> List [Fact ]:
27
27
"""
@@ -49,7 +49,7 @@ def forward(self, x, t1: int = 0, t2: int = 0):
49
49
50
50
# Convert logits to probabilities assuming a multi-class classification.
51
51
probabilities = F .softmax (output , dim = 1 ).squeeze ()
52
- opts = self .interface_modes
52
+ opts = self .interface_options
53
53
54
54
# Prepare threshold tensor.
55
55
threshold = torch .tensor (opts .threshold , dtype = probabilities .dtype , device = probabilities .device )
0 commit comments