Skip to content

Commit b3c0d85

Browse files
committed
classification integration docs and example
1 parent 5614760 commit b3c0d85

File tree

2 files changed

+216
-0
lines changed

2 files changed

+216
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
Integrating PyReason with Machine Learning
2+
===========================
3+
4+
PyReason can be integrated with machine learning models by incorporating the predictions from the machine learning model
5+
as facts in the graph and reasoning over them with logical rules. This allows users to combine the strengths of machine learning
6+
models with the interpretability and reasoning capabilities of PyReason.
7+
8+
Classifier Integration Example
9+
-----------------------------
10+
.. note::
11+
Find the full, executable code `here <https://github.com/lab-v2/pyreason/blob/main/examples/classifier_integration_ex.py>`_
12+
13+
In this section, we will outline how to perform ML integration using a simple classification example. We assume
14+
that we have a fraud detection model that predicts whether a transaction is fraudulent or not. We will use the predictions
15+
to reason over a knowledge base of account information to identify potential fraudulent activities. For this example, we
16+
use an untrained linear model just to demonstrate, but in practice, this can be replaced by any PyTorch classification model.
17+
18+
We start by defining our classifier model.
19+
20+
.. code-block:: python
21+
22+
import torch
23+
import torch.nn as nn
24+
25+
26+
model = nn.Linear(5, 2)
27+
class_names = ["fraud", "legitimate"]
28+
29+
Next, we define how we want to incorporate the predictions from the model into the graph. Since the model outputs a probability
30+
over the classes, we can specify how we integrate this probability into the graph. There are a few options the user can define
31+
using a ``ModelInterfaceOptions`` object:
32+
33+
1. ``threshold`` **(float)**: The threshold beyond which the prediction is incorporated as a fact in the graph. If the probability
34+
of the class is lower than the threshold, no information is added to the graph. Defaults to 0.5.
35+
2. ``set_lower_bound`` **(bool)**: If True, the lower bound of the probability is set as the fact in the graph.
36+
if False the lower bound will be 0. Defaults to True.
37+
3. ``set_upper_bound`` **(bool)**: If True, the upper bound of the probability is set as the fact in the graph.
38+
if False, the upper bound will be 1. Defaults to True.
39+
4. ``snap_value`` **(float)**: If set, all the probabilities that crossed the threshold are snapped to this value. Defaults to 1.0.
40+
The upper/lower bounds are set to this value according to the ``set_lower_bound`` and ``set_upper_bound`` options.
41+
42+
In our binary classification model, we want predictions that cross the threshold of ``0.5`` to be added to the graph.
43+
For this example we will use a ``snap_value`` of ``1.0`` and set the lower bound of the probability as the fact in the graph.
44+
Therefore any prediction with a probability greater than ``0.5`` will be added to the graph as a fact with bounds of ``[1.0, 1.0]``.
45+
46+
.. code-block:: python
47+
48+
interface_options = pr.ModelInterfaceOptions(
49+
threshold=0.5, # Only process probabilities above 0.5
50+
set_lower_bound=True, # Modify the lower bound.
51+
set_upper_bound=False, # Keep the upper bound unchanged at 1.0.
52+
snap_value=1.0 # Use 1.0 as the snap value.
53+
)
54+
55+
56+
Next, we create a ``LogicIntegratedClassifier`` object that helps us integrate the predictions from the model into the graph.
57+
58+
.. code-block:: python
59+
60+
fraud_detector = pr.LogicIntegratedClassifier(
61+
model,
62+
class_names,
63+
model_name="fraud_detector",
64+
interface_options=interface_options
65+
)
66+
67+
68+
To run the model, we perform the same steps as we would with a regular PyTorch model. In this example we use a dummy input.
69+
This gives us a list of facts that can then be added to PyReason.
70+
71+
.. code-block:: python
72+
73+
transaction_features = torch.rand(1, 5)
74+
75+
# Get the prediction from the model
76+
logits, probabilities, classifier_facts = fraud_detector(transaction_features)
77+
78+
We now add the facts to PyReason as normal.
79+
80+
.. code-block:: python
81+
82+
# Add the classifier-generated facts.
83+
for fact in classifier_facts:
84+
pr.add_fact(fact)
85+
86+
87+
88+
Next, we define a knowledge graph that contains information about accounts and its relationships. we also define some context
89+
about the transaction and rules that we want to reason over with the classifier predictions.
90+
91+
.. code-block:: python
92+
# Create a networkx graph representing a network of accounts.
93+
G = nx.DiGraph()
94+
# Add account nodes.
95+
G.add_node("AccountA", account=1)
96+
G.add_node("AccountB", account=1)
97+
G.add_node("AccountC", account=1)
98+
99+
# Add edges with an attribute "associated".
100+
G.add_edge("AccountA", "AccountB", associated=1)
101+
G.add_edge("AccountB", "AccountC", associated=1)
102+
pr.load_graph(G)
103+
104+
# Add additional contextual information:
105+
# 1. A fact indicating the transaction comes from a suspicious location. This could come from a separate fraud detection system.
106+
pr.add_fact(pr.Fact("suspicious_location(AccountA)", "transaction_fact"))
107+
108+
# Define a rule: if the fraud detector flags a transaction as fraud and the transaction info is suspicious,
109+
# then mark the associated account (AccountA) as requiring investigation.
110+
pr.add_rule(pr.Rule("requires_investigation(acc) <- account(acc), fraud_detector(fraud), suspicious_location(acc)", "investigation_rule"))
111+
112+
# Define a propagation rule:
113+
# If an account requires investigation and is connected (via the "associated" relationship) to another account,
114+
# then the connected account is also flagged for investigation.
115+
pr.add_rule(pr.Rule("requires_investigation(y) <- requires_investigation(x), associated(x,y)", "propagation_rule"))
116+
117+
118+
Finally, we run the reasoning process and print the output.
119+
120+
.. code-block:: python
121+
122+
# Run the reasoning engine to allow the investigation flag to propagate through the network.
123+
pr.settings.allow_ground_rules = True # The ground rules allow us to use the classifier prediction facts
124+
pr.settings.atom_trace = True
125+
interpretation = pr.reason()
126+
127+
trace = pr.get_rule_trace(interpretation)
128+
print(f"RULE TRACE: \n\n{trace[0]}\n")
129+
130+
131+
132+
This simple example demonstrates the integration of a machine learning model with PyReason. In practice more complex models
133+
can be used, along with larger and more complex knowledge graphs.

examples/classifier_integration_ex.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import pyreason as pr
2+
import torch
3+
import torch.nn as nn
4+
import networkx as nx
5+
6+
# --- Part 1: Fraud Detector Model Integration ---
7+
8+
# Create a dummy PyTorch model for transaction fraud detection.
9+
# Each transaction is represented by 5 features and is classified into "fraud" or "legitimate".
10+
model = nn.Linear(5, 2)
11+
class_names = ["fraud", "legitimate"]
12+
13+
# Create a dummy transaction feature vector.
14+
transaction_features = torch.rand(1, 5)
15+
16+
# Define integration options.
17+
# Only probabilities above 0.4 are considered for adjustment.
18+
interface_options = pr.ModelInterfaceOptions(
19+
threshold=0.5, # Only process probabilities above 0.6
20+
set_lower_bound=True, # For high confidence, adjust the lower bound.
21+
set_upper_bound=False, # Keep the upper bound unchanged.
22+
snap_value=1.0 # Use 1.0 as the snap value.
23+
)
24+
25+
# Wrap the model using LogicIntegratedClassifier
26+
fraud_detector = pr.LogicIntegratedClassifier(
27+
model,
28+
class_names,
29+
model_name="fraud_detector",
30+
interface_options=interface_options
31+
)
32+
33+
# Run the model to obtain logits, probabilities, and generated PyReason facts.
34+
logits, probabilities, classifier_facts = fraud_detector(transaction_features)
35+
36+
print("=== Fraud Detector Output ===")
37+
print("Logits:", logits)
38+
print("Probabilities:", probabilities)
39+
print("\nGenerated Classifier Facts:")
40+
for fact in classifier_facts:
41+
print(fact)
42+
43+
# Add the classifier-generated facts.
44+
for fact in classifier_facts:
45+
pr.add_fact(fact)
46+
47+
# --- Part 2: Create and Load a Networkx Graph representing an account knowledge base ---
48+
49+
# Create a networkx graph representing a network of accounts.
50+
G = nx.DiGraph()
51+
# Add account nodes.
52+
G.add_node("AccountA", account=1)
53+
G.add_node("AccountB", account=1)
54+
G.add_node("AccountC", account=1)
55+
# Add edges with an attribute "relationship" set to "associated".
56+
G.add_edge("AccountA", "AccountB", associated=1)
57+
G.add_edge("AccountB", "AccountC", associated=1)
58+
pr.load_graph(G)
59+
60+
# --- Part 3: Set Up Context and Reasoning Environment ---
61+
62+
# Add additional contextual information:
63+
# 1. A fact indicating the transaction comes from a suspicious location. This could come from a separate fraud detection system.
64+
pr.add_fact(pr.Fact("suspicious_location(AccountA)", "transaction_fact"))
65+
66+
# Define a rule: if the fraud detector flags a transaction as fraud and the transaction info is suspicious,
67+
# then mark the associated account (AccountA) as requiring investigation.
68+
pr.add_rule(pr.Rule("requires_investigation(acc) <- account(acc), fraud_detector(fraud), suspicious_location(acc)", "investigation_rule"))
69+
70+
# Define a propagation rule:
71+
# If an account requires investigation and is connected (via the "associated" relationship) to another account,
72+
# then the connected account is also flagged for investigation.
73+
pr.add_rule(pr.Rule("requires_investigation(y) <- requires_investigation(x), associated(x,y)", "propagation_rule"))
74+
75+
# --- Part 4: Run the Reasoning Engine ---
76+
77+
# Run the reasoning engine to allow the investigation flag to propagate through the network.
78+
pr.settings.allow_ground_rules = True
79+
pr.settings.atom_trace = True
80+
interpretation = pr.reason()
81+
82+
trace = pr.get_rule_trace(interpretation)
83+
print(f"RULE TRACE: \n\n{trace[0]}\n")

0 commit comments

Comments
 (0)