Skip to content

Commit c2b79b8

Browse files
committed
added torch model integration
1 parent 6cadb6d commit c2b79b8

File tree

9 files changed

+170
-0
lines changed

9 files changed

+170
-0
lines changed

.github/workflows/python-package-version-test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ jobs:
2828
run: |
2929
python -m pip install --upgrade pip
3030
python -m pip install flake8 pytest
31+
pip install torch==2.6.0
3132
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
3233
- name: Lint with flake8
3334
run: |

pyreason/pyreason.py

+9
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@
2525
import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge
2626
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
2727
from pyreason.scripts.utils.reorder_clauses import reorder_clauses
28+
try:
29+
import torch
30+
except ImportError:
31+
LogicIntegratedClassifier = None
32+
ModelInterfaceOptions = None
33+
print('torch is not installed, model integration is disabled')
34+
else:
35+
from pyreason.scripts.learning.classification.classifier import LogicIntegratedClassifier
36+
from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions
2837

2938

3039
# USER VARIABLES

pyreason/scripts/facts/fact.py

+8
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,11 @@ def __init__(self, fact_text: str, name: str = None, start_time: int = 0, end_ti
2626
self.component = component
2727
self.bound = bound
2828
self.type = fact_type
29+
30+
def __str__(self):
31+
s = f'{self.pred}({self.component}) : {self.bound}'
32+
if self.static:
33+
s += ' | static'
34+
else:
35+
s += f' | start: {self.start_time} -> end: {self.end_time}'
36+
return s

pyreason/scripts/learning/__init__.py

Whitespace-only changes.

pyreason/scripts/learning/classification/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from typing import List
2+
3+
import torch.nn
4+
import torch.nn.functional as F
5+
6+
from pyreason.scripts.facts.fact import Fact
7+
from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions
8+
9+
10+
class LogicIntegratedClassifier(torch.nn.Module):
11+
"""
12+
Class to integrate a PyTorch model with PyReason. The output of the model is returned to the
13+
user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them.
14+
"""
15+
def __init__(self, model, class_names: List[str], model_name: str = 'classifier', interface_modes: ModelInterfaceOptions = None):
16+
"""
17+
:param model:
18+
:param class_names:
19+
"""
20+
super(LogicIntegratedClassifier, self).__init__()
21+
self.model = model
22+
self.class_names = class_names
23+
self.model_name = model_name
24+
self.interface_modes = interface_modes
25+
26+
def get_class_facts(self, t1: int, t2: int) -> List[Fact]:
27+
"""
28+
Return PyReason facts to create nodes for each class. Each class node will have bounds `[1,1]` with the
29+
predicate corresponding to the model name.
30+
:param t1: Start time for the facts
31+
:param t2: End time for the facts
32+
:return: List of PyReason facts
33+
"""
34+
facts = []
35+
for c in self.class_names:
36+
fact = Fact(f'{self.model_name}({c})', name=f'{self.model_name}-{c}-fact', start_time=t1, end_time=t2)
37+
facts.append(fact)
38+
return facts
39+
40+
def forward(self, x, t1: int, t2: int):
41+
"""
42+
Forward pass of the model
43+
:param x: Input tensor
44+
:param t1: Start time for the facts
45+
:param t2: End time for the facts
46+
:return: Output tensor
47+
"""
48+
output = self.model(x)
49+
50+
# Convert logits to probabilities assuming a multi-class classification.
51+
probabilities = F.softmax(output, dim=1).squeeze()
52+
opts = self.integration_options
53+
54+
# Prepare threshold tensor.
55+
threshold = torch.tensor(opts.threshold, dtype=probabilities.dtype, device=probabilities.device)
56+
condition = probabilities > threshold
57+
58+
if opts.snap_value is not None:
59+
snap_value = torch.tensor(opts.snap_value, dtype=probabilities.dtype, device=probabilities.device)
60+
# For values that pass the threshold:
61+
lower_val = snap_value if opts.set_lower_bound else torch.tensor(0.0, dtype=probabilities.dtype,
62+
device=probabilities.device)
63+
upper_val = snap_value if opts.set_upper_bound else torch.tensor(1.0, dtype=probabilities.dtype,
64+
device=probabilities.device)
65+
else:
66+
# If no snap_value is provided, keep original probabilities for those passing threshold.
67+
lower_val = probabilities
68+
upper_val = probabilities
69+
70+
# For probabilities that pass the threshold, apply the above; else, bounds are fixed to [0,1].
71+
lower_bounds = torch.where(condition, lower_val, torch.zeros_like(probabilities))
72+
upper_bounds = torch.where(condition, upper_val, torch.ones_like(probabilities))
73+
74+
# Convert bounds to Python floats for fact creation.
75+
bounds_list = []
76+
for i in range(len(self.class_names)):
77+
lower = lower_bounds[i].item()
78+
upper = upper_bounds[i].item()
79+
bounds_list.append([lower, upper])
80+
81+
# Define time bounds for the facts.
82+
facts = []
83+
for class_name, bounds in zip(self.class_names, bounds_list):
84+
lower, upper = bounds
85+
fact_str = f'{self.model_name}({class_name}) : [{lower:.3f}, {upper:.3f}])'
86+
fact = Fact(fact_str, name=f'{self.model_name}-{class_name}-fact', start_time=t1, end_time=t2)
87+
facts.append(fact)
88+
return output, facts
89+

pyreason/scripts/learning/utils/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
5+
@dataclass
6+
class ModelInterfaceOptions:
7+
threshold: float = 0.5
8+
set_lower_bound: bool = True
9+
set_upper_bound: bool = True
10+
snap_value: Optional[float] = 1.0

tests/test_classifier.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Test cases for classifier integration with pyreason
2+
import pyreason as pr
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
def test_classifier_integration():
8+
# Reset PyReason
9+
pr.reset()
10+
pr.reset_rules()
11+
pr.reset_settings()
12+
13+
# Create a dummy PyTorch model: input size 10, output 3 classes.
14+
model = nn.Linear(10, 3)
15+
16+
# Define class names for the output classes.
17+
class_names = ["cat", "dog", "rabbit"]
18+
19+
# Create integration options.
20+
# Only probabilities exceeding 0.6 will be considered.
21+
# For those, if set_lower_bound is True, lower bound becomes 0.95; if set_upper_bound is False, upper bound is forced to 1.
22+
dummy_options = pr.ModelInterfaceOptions(
23+
threshold=0.6,
24+
set_lower_bound=True,
25+
set_upper_bound=False,
26+
snap_value=0.95
27+
)
28+
29+
# Create an instance of LogicIntegratedClassifier.
30+
logic_classifier = pr.LogicIntegratedClassifier(model, class_names, model_name="classifier",
31+
interface_modes=dummy_options)
32+
33+
# Create a dummy input tensor with 10 features.
34+
input_tensor = torch.rand(1, 10)
35+
36+
# Set time bounds for the facts.
37+
t1 = 0
38+
t2 = 0
39+
40+
# Run the forward pass to get the model output and the corresponding PyReason facts.
41+
output, facts = logic_classifier(input_tensor, t1, t2)
42+
43+
# Assert that the output is a tensor.
44+
assert isinstance(output, torch.Tensor), "The model output should be a torch.Tensor"
45+
# Assert that we have one fact per class.
46+
assert len(facts) == len(class_names), "Expected one fact per class"
47+
48+
# Print results for visual inspection.
49+
print("Model output (logits):")
50+
print(output)
51+
print("\nGenerated PyReason Facts:")
52+
for fact in facts:
53+
print(fact)

0 commit comments

Comments
 (0)