Skip to content

Commit 50a7753

Browse files
committed
classifier updates
1 parent 96ae50e commit 50a7753

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

examples/classifier_integration_ex.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
import torch
33
import torch.nn as nn
44
import networkx as nx
5+
import numpy as np
6+
import random
7+
8+
# seed_value = 41
9+
seed_value = 42
10+
random.seed(seed_value)
11+
np.random.seed(seed_value)
12+
torch.manual_seed(seed_value)
13+
514

615
# --- Part 1: Fraud Detector Model Integration ---
716

@@ -13,20 +22,20 @@
1322
# Create a dummy transaction feature vector.
1423
transaction_features = torch.rand(1, 5)
1524

16-
# Define integration options.
25+
# Define integration options
1726
# Only probabilities above 0.5 are considered for adjustment.
1827
interface_options = pr.ModelInterfaceOptions(
19-
threshold=0.5, # Only process probabilities above 0.5
28+
threshold=0.5, # Only process probabilities above 0.5
2029
set_lower_bound=True, # For high confidence, adjust the lower bound.
2130
set_upper_bound=False, # Keep the upper bound unchanged.
22-
snap_value=1.0 # Use 1.0 as the snap value.
31+
snap_value=1.0 # Use 1.0 as the snap value.
2332
)
2433

2534
# Wrap the model using LogicIntegratedClassifier
2635
fraud_detector = pr.LogicIntegratedClassifier(
2736
model,
2837
class_names,
29-
model_name="fraud_detector",
38+
identifier="fraud_detector",
3039
interface_options=interface_options
3140
)
3241

@@ -65,7 +74,7 @@
6574

6675
# Define a rule: if the fraud detector flags a transaction as fraud and the transaction info is suspicious,
6776
# then mark the associated account (AccountA) as requiring investigation.
68-
pr.add_rule(pr.Rule("requires_investigation(acc) <- account(acc), fraud_detector(fraud), suspicious_location(acc)", "investigation_rule"))
77+
pr.add_rule(pr.Rule("requires_investigation(acc) <- account(acc), fraud(fraud_detector), suspicious_location(acc)", "investigation_rule"))
6978

7079
# Define a propagation rule:
7180
# If an account requires investigation and is connected (via the "associated" relationship) to another account,
@@ -75,7 +84,7 @@
7584
# --- Part 4: Run the Reasoning Engine ---
7685

7786
# Run the reasoning engine to allow the investigation flag to propagate through the network.
78-
pr.settings.allow_ground_rules = True
87+
# pr.settings.allow_ground_rules = True
7988
pr.settings.atom_trace = True
8089
interpretation = pr.reason()
8190

pyreason/scripts/learning/classification/classifier.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@ class LogicIntegratedClassifier(torch.nn.Module):
1212
Class to integrate a PyTorch model with PyReason. The output of the model is returned to the
1313
user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them.
1414
"""
15-
def __init__(self, model, class_names: List[str], model_name: str = 'classifier', interface_options: ModelInterfaceOptions = None):
15+
def __init__(self, model, class_names: List[str], identifier: str = 'classifier', interface_options: ModelInterfaceOptions = None):
1616
"""
17-
:param model:
18-
:param class_names:
17+
:param model: PyTorch model to be integrated.
18+
:param class_names: List of class names for the model output.
19+
:param identifier: Identifier for the model, used as the constant in the facts.
20+
:param interface_options: Options for the model interface, including threshold and snapping behavior.
1921
"""
2022
super(LogicIntegratedClassifier, self).__init__()
2123
self.model = model
2224
self.class_names = class_names
23-
self.model_name = model_name
25+
self.identifier = identifier
2426
self.interface_options = interface_options
2527

2628
def get_class_facts(self, t1: int, t2: int) -> List[Fact]:
@@ -33,7 +35,7 @@ def get_class_facts(self, t1: int, t2: int) -> List[Fact]:
3335
"""
3436
facts = []
3537
for c in self.class_names:
36-
fact = Fact(f'{self.model_name}({c})', name=f'{self.model_name}-{c}-fact', start_time=t1, end_time=t2)
38+
fact = Fact(f'{c}({self.identifier})', name=f'{self.identifier}-{c}-fact', start_time=t1, end_time=t2)
3739
facts.append(fact)
3840
return facts
3941

@@ -82,8 +84,8 @@ def forward(self, x, t1: int = 0, t2: int = 0) -> Tuple[torch.Tensor, torch.Tens
8284
facts = []
8385
for class_name, bounds in zip(self.class_names, bounds_list):
8486
lower, upper = bounds
85-
fact_str = f'{self.model_name}({class_name}) : [{lower:.3f}, {upper:.3f}]'
86-
fact = Fact(fact_str, name=f'{self.model_name}-{class_name}-fact', start_time=t1, end_time=t2)
87+
fact_str = f'{class_name}({self.identifier}) : [{lower:.3f}, {upper:.3f}]'
88+
fact = Fact(fact_str, name=f'{self.identifier}-{class_name}-fact', start_time=t1, end_time=t2)
8789
facts.append(fact)
8890
return output, probabilities, facts
8991

0 commit comments

Comments
 (0)