@@ -12,15 +12,17 @@ class LogicIntegratedClassifier(torch.nn.Module):
12
12
Class to integrate a PyTorch model with PyReason. The output of the model is returned to the
13
13
user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them.
14
14
"""
15
- def __init__ (self , model , class_names : List [str ], model_name : str = 'classifier' , interface_options : ModelInterfaceOptions = None ):
15
+ def __init__ (self , model , class_names : List [str ], identifier : str = 'classifier' , interface_options : ModelInterfaceOptions = None ):
16
16
"""
17
- :param model:
18
- :param class_names:
17
+ :param model: PyTorch model to be integrated.
18
+ :param class_names: List of class names for the model output.
19
+ :param identifier: Identifier for the model, used as the constant in the facts.
20
+ :param interface_options: Options for the model interface, including threshold and snapping behavior.
19
21
"""
20
22
super (LogicIntegratedClassifier , self ).__init__ ()
21
23
self .model = model
22
24
self .class_names = class_names
23
- self .model_name = model_name
25
+ self .identifier = identifier
24
26
self .interface_options = interface_options
25
27
26
28
def get_class_facts (self , t1 : int , t2 : int ) -> List [Fact ]:
@@ -33,7 +35,7 @@ def get_class_facts(self, t1: int, t2: int) -> List[Fact]:
33
35
"""
34
36
facts = []
35
37
for c in self .class_names :
36
- fact = Fact (f'{ self . model_name } ({ c } )' , name = f'{ self .model_name } -{ c } -fact' , start_time = t1 , end_time = t2 )
38
+ fact = Fact (f'{ c } ({ self . identifier } )' , name = f'{ self .identifier } -{ c } -fact' , start_time = t1 , end_time = t2 )
37
39
facts .append (fact )
38
40
return facts
39
41
@@ -82,8 +84,8 @@ def forward(self, x, t1: int = 0, t2: int = 0) -> Tuple[torch.Tensor, torch.Tens
82
84
facts = []
83
85
for class_name , bounds in zip (self .class_names , bounds_list ):
84
86
lower , upper = bounds
85
- fact_str = f'{ self . model_name } ({ class_name } ) : [{ lower :.3f} , { upper :.3f} ]'
86
- fact = Fact (fact_str , name = f'{ self .model_name } -{ class_name } -fact' , start_time = t1 , end_time = t2 )
87
+ fact_str = f'{ class_name } ({ self . identifier } ) : [{ lower :.3f} , { upper :.3f} ]'
88
+ fact = Fact (fact_str , name = f'{ self .identifier } -{ class_name } -fact' , start_time = t1 , end_time = t2 )
87
89
facts .append (fact )
88
90
return output , probabilities , facts
89
91
0 commit comments