Skip to content

Commit 94136bb

Browse files
committed
Merge branch 'main' of https://github.com/lab-v2/pyreason
2 parents d3eb7a6 + cad7ad6 commit 94136bb

13 files changed

+299
-13
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,9 @@ cython_debug/
160160
# and can be added to the global gitignore or merged into this file. For a more nuclear
161161
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162162
.idea/
163+
164+
*env/
165+
166+
# Sphinx Documentation
163167
/docs/source/_static/css/fonts/
168+

docs/group-chat-example.md

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Custom Thresholds Example
2+
3+
Here is an example that utilizes custom thresholds.
4+
5+
The following graph represents a network of People and a Text Message in their group chat.
6+
<img src="../media/group_chat_graph.png"/>
7+
8+
In this case, we want to know when a text message has been viewed by all members of the group chat.
9+
10+
## Graph
11+
First, lets create the group chat.
12+
13+
```python
14+
import networkx as nx
15+
16+
# Create an empty graph
17+
G = nx.Graph()
18+
19+
# Add nodes
20+
nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"]
21+
G.add_nodes_from(nodes)
22+
23+
# Add edges with attribute 'HaveAccess'
24+
edges = [
25+
("Zach", "TextMessage", {"HaveAccess": 1}),
26+
("Justin", "TextMessage", {"HaveAccess": 1}),
27+
("Michelle", "TextMessage", {"HaveAccess": 1}),
28+
("Amy", "TextMessage", {"HaveAccess": 1})
29+
]
30+
G.add_edges_from(edges)
31+
32+
```
33+
34+
## Rules and Custom Thresholds
35+
Considering that we only want a text message to be considered viewed by all if it has been viewed by everyone that can view it, we define the rule as follows:
36+
37+
```text
38+
ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)
39+
```
40+
41+
The `head` of the rule is `ViewedByAll(x)` and the body is `HaveAccess(x,y), Viewed(y)`. The head and body are separated by an arrow which means the rule will start evaluating from
42+
timestep 0.
43+
44+
We add the rule into pyreason with:
45+
46+
```python
47+
import pyreason as pr
48+
from pyreason import Threshold
49+
50+
user_defined_thresholds = [
51+
Threshold("greater_equal", ("number", "total"), 1),
52+
Threshold("greater_equal", ("percent", "total"), 100),
53+
]
54+
55+
pr.add_rule(pr.Rule('ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)', 'viewed_by_all_rule', user_defined_thresholds))
56+
```
57+
Where `viewed_by_all_rule` is the name of the rule. This helps to understand which rule/s are fired during reasoning later on.
58+
59+
The `user_defined_thresholds` are a list of custom thresholds of the format: (quantifier, quantifier_type, thresh) where:
60+
- quantifier can be greater_equal, greater, less_equal, less, equal
61+
- quantifier_type is a tuple where the first element can be either number or percent and the second element can be either total or available
62+
- thresh represents the numerical threshold value to compare against
63+
64+
The custom thresholds are created corresponding to the two clauses (HaveAccess(x,y) and Viewed(y)) as below:
65+
- ('greater_equal', ('number', 'total'), 1) (there needs to be at least one person who has access to TextMessage for the first clause to be satisfied)
66+
- ('greater_equal', ('percent', 'total'), 100) (100% of people who have access to TextMessage need to view the message for second clause to be satisfied)
67+
68+
## Facts
69+
The facts determine the initial conditions of elements in the graph. They can be specified from the graph attributes but in that
70+
case they will be immutable later on. Adding PyReason facts gives us more flexibility.
71+
72+
In our case we want one person to view the TextMessage in a particular interval of timestep.
73+
For example, we create facts stating:
74+
- Zach and Justin view the TextMessage from at timestep 0
75+
- Michelle views the TextMessage at timestep 1
76+
- Amy views the TextMessage at timestep 2
77+
78+
We add the facts in PyReason as below:
79+
```python
80+
import pyreason as pr
81+
82+
pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 0, static=True))
83+
pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 0, static=True))
84+
pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 1, static=True))
85+
pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 2, static=True))
86+
```
87+
88+
This allows us to specify the component that has an initial condition, the initial condition itself in the form of bounds
89+
as well as the start and end time of this condition.
90+
91+
## Running PyReason
92+
Find the full code for this example [here](../tests/test_custom_thresholds.py)
93+
94+
The main line that runs the reasoning in that file is:
95+
```python
96+
interpretation = pr.reason(timesteps=3)
97+
```
98+
This specifies how many timesteps to run for.
99+
100+
## Expected Output
101+
After running the python file, the expected output is:
102+
103+
```
104+
TIMESTEP - 0
105+
Empty DataFrame
106+
Columns: [component, ViewedByAll]
107+
Index: []
108+
109+
TIMESTEP - 1
110+
Empty DataFrame
111+
Columns: [component, ViewedByAll]
112+
Index: []
113+
114+
TIMESTEP - 2
115+
component ViewedByAll
116+
0 TextMessage [1.0, 1.0]
117+
118+
TIMESTEP - 3
119+
component ViewedByAll
120+
0 TextMessage [1.0, 1.0]
121+
122+
```
123+
124+
1. For timestep 0, we set `Zach -> Viewed: [1,1]` and `Justin -> Viewed: [1,1]` in the facts
125+
2. For timestep 1, Michelle views the TextMessage as stated in facts `Michelle -> Viewed: [1,1]`
126+
3. For timestep 2, since Amy has just viewed the TextMessage, therefore `Amy -> Viewed: [1,1]`. As per the rule,
127+
since all the people have viewed the TextMessage, the message is marked as ViewedByAll.
128+
129+
130+
We also output two CSV files detailing all the events that took place during reasoning (one for nodes, one for edges)

media/group_chat_graph.png

24.5 KB
Loading

pyreason/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
add_fact(Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2))
2424
reason(timesteps=2)
2525

26+
reset()
27+
reset_rules()
28+
2629
# Update cache status
2730
cache_status['initialized'] = True
2831
with open(cache_status_path, 'w') as file:

pyreason/pyreason.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule
1818
from pyreason.scripts.facts.fact import Fact
1919
from pyreason.scripts.rules.rule import Rule
20+
from pyreason.scripts.threshold.threshold import Threshold
2021
import pyreason.scripts.numba_wrapper.numba_types.fact_node_type as fact_node
2122
import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge
2223
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval

pyreason/scripts/rules/rule.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@ class Rule:
66
Example text:
77
`'pred1(x,y) : [0.2, 1] <- pred2(a, b) : [1,1], pred3(b, c)'`
88
9-
1. It is not possible to specify thresholds. Threshold is greater than or equal to 1 by default
10-
2. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0
11-
TODO: Add threshold class where we can pass this as a parameter
9+
1. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0
1210
TODO: Add weights as a parameter
1311
"""
14-
def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False):
12+
def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False, custom_thresholds=None):
1513
"""
1614
:param rule_text: The rule in text format
1715
:param name: The name of the rule. This will appear in the rule trace
1816
:param infer_edges: Whether to infer new edges after edge rule fires
1917
:param set_static: Whether to set the atom in the head as static if the rule fires. The bounds will no longer change
2018
:param immediate_rule: Whether the rule is immediate. Immediate rules check for more applicable rules immediately after being applied
2119
"""
22-
self.rule = rule_parser.parse_rule(rule_text, name, infer_edges, set_static, immediate_rule)
20+
if custom_thresholds is None:
21+
custom_thresholds = []
22+
self.rule = rule_parser.parse_rule(rule_text, name, custom_thresholds, infer_edges, set_static, immediate_rule)

pyreason/scripts/threshold/__init__.py

Whitespace-only changes.
+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
class Threshold:
2+
"""
3+
A class representing a threshold for a clause in a rule.
4+
5+
Attributes:
6+
quantifier (str): The comparison operator, e.g., 'greater_equal', 'less', etc.
7+
quantifier_type (tuple): A tuple indicating the type of quantifier, e.g., ('number', 'total').
8+
thresh (int): The numerical threshold value to compare against.
9+
10+
Methods:
11+
to_tuple(): Converts the Threshold instance into a tuple compatible with numba types.
12+
"""
13+
14+
def __init__(self, quantifier, quantifier_type, thresh):
15+
"""
16+
Initializes a Threshold instance.
17+
18+
Args:
19+
quantifier (str): The comparison operator for the threshold.
20+
quantifier_type (tuple): The type of quantifier ('number' or 'percent', 'total' or 'available').
21+
thresh (int): The numerical value for the threshold.
22+
"""
23+
24+
if quantifier not in ("greater_equal", "greater", "less_equal", "less", "equal"):
25+
raise ValueError("Invalid quantifier")
26+
27+
if quantifier_type[0] not in ("number", "percent") or quantifier_type[1] not in ("total", "available"):
28+
raise ValueError("Invalid quantifier type")
29+
30+
self.quantifier = quantifier
31+
self.quantifier_type = quantifier_type
32+
self.thresh = thresh
33+
34+
def to_tuple(self):
35+
"""
36+
Converts the Threshold instance into a tuple compatible with numba types.
37+
38+
Returns:
39+
tuple: A tuple representation of the Threshold instance.
40+
"""
41+
return (self.quantifier, self.quantifier_type, self.thresh)

pyreason/scripts/utils/rule_parser.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
77

88

9-
def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule:
9+
def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule:
1010
# First remove all spaces from line
1111
r = rule_text.replace(' ', '')
1212

@@ -152,7 +152,23 @@ def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static:
152152
# Array to store clauses for nodes: node/edge, [subset]/[subset1, subset2], label, interval, operator
153153
clauses = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string), interval.interval_type, numba.types.string)))
154154

155-
# Loop though clauses
155+
# gather count of clauses for threshold validation
156+
num_clauses = len(body_clauses)
157+
158+
if custom_thresholds and (len(custom_thresholds) != num_clauses):
159+
raise Exception('The length of custom thresholds {} is not equal to number of clauses {}'
160+
.format(len(custom_thresholds), num_clauses))
161+
162+
# If no custom thresholds provided, use defaults
163+
# otherwise loop through user-defined thresholds and convert to numba compatible format
164+
if not custom_thresholds:
165+
for _ in range(num_clauses):
166+
thresholds.append(('greater_equal', ('number', 'total'), 1.0))
167+
else:
168+
for threshold in custom_thresholds:
169+
thresholds.append(threshold.to_tuple())
170+
171+
# # Loop though clauses
156172
for body_clause, predicate, variables, bounds in zip(body_clauses, body_predicates, body_variables, body_bounds):
157173
# Neigh criteria
158174
clause_type = 'node' if len(variables) == 1 else 'edge'
@@ -165,12 +181,6 @@ def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static:
165181
bnd = interval.closed(bounds[0], bounds[1])
166182
clauses.append((clause_type, l, subset, bnd, op))
167183

168-
# Threshold.
169-
quantifier = 'greater_equal'
170-
quantifier_type = ('number', 'total')
171-
thresh = 1
172-
thresholds.append((quantifier, quantifier_type, thresh))
173-
174184
# Assert that there are two variables in the head of the rule if we infer edges
175185
# Add edges between head variables if necessary
176186
if infer_edges:

tests/group_chat_graph.graphml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
<?xml version='1.0' encoding='utf-8'?>
2+
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">
3+
4+
<key id="HaveAccess" for="edge" attr.name="HaveAccess" attr.type="long" />
5+
<graph edgedefault="undirected">
6+
<node id="TextMessage" />
7+
<node id="Zach" />
8+
<node id="Justin" />
9+
<node id="Michelle" />
10+
<node id="Amy" />
11+
12+
<edge source="Zach" target="TextMessage">
13+
<data key="HaveAccess">1</data>
14+
</edge>
15+
<edge source="Justin" target="TextMessage">
16+
<data key="HaveAccess">1</data>
17+
</edge>
18+
19+
<edge source="Amy" target="TextMessage">
20+
<data key="HaveAccess">1</data>
21+
</edge>
22+
<edge source="Michelle" target="TextMessage">
23+
<data key="HaveAccess">1</data>
24+
</edge>
25+
</graph>
26+
</graphml>

tests/test_custom_thresholds.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Test if the simple program works with thresholds defined
2+
import pyreason as pr
3+
from pyreason import Threshold
4+
5+
6+
def test_custom_thresholds():
7+
# Reset PyReason
8+
pr.reset()
9+
pr.reset_rules()
10+
11+
# Modify the paths based on where you've stored the files we made above
12+
graph_path = "./tests/group_chat_graph.graphml"
13+
14+
# Modify pyreason settings to make verbose and to save the rule trace to a file
15+
pr.settings.verbose = True # Print info to screen
16+
17+
# Load all the files into pyreason
18+
pr.load_graphml(graph_path)
19+
20+
# add custom thresholds
21+
user_defined_thresholds = [
22+
Threshold("greater_equal", ("number", "total"), 1),
23+
Threshold("greater_equal", ("percent", "total"), 100),
24+
]
25+
26+
pr.add_rule(
27+
pr.Rule(
28+
"ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)",
29+
"viewed_by_all_rule",
30+
custom_thresholds=user_defined_thresholds,
31+
)
32+
)
33+
34+
pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 3))
35+
pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 3))
36+
pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 3))
37+
pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 3))
38+
39+
# Run the program for three timesteps to see the diffusion take place
40+
interpretation = pr.reason(timesteps=3)
41+
42+
# Display the changes in the interpretation for each timestep
43+
dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"])
44+
for t, df in enumerate(dataframes):
45+
print(f"TIMESTEP - {t}")
46+
print(df)
47+
print()
48+
49+
assert (
50+
len(dataframes[0]) == 0
51+
), "At t=0 the TextMessage should not have been ViewedByAll"
52+
assert (
53+
len(dataframes[2]) == 1
54+
), "At t=2 the TextMessage should have been ViewedByAll"
55+
56+
# TextMessage should be ViewedByAll in t=2
57+
assert "TextMessage" in dataframes[2]["component"].values and dataframes[2].iloc[
58+
0
59+
].ViewedByAll == [
60+
1,
61+
1,
62+
], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps"

tests/test_hello_world.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44

55
def test_hello_world():
6+
# Reset PyReason
7+
pr.reset()
8+
pr.reset_rules()
9+
610
# Modify the paths based on where you've stored the files we made above
711
graph_path = './tests/friends_graph.graphml'
812

tests/test_hello_world_parallel.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44

55
def test_hello_world_parallel():
6+
# Reset PyReason
7+
pr.reset()
8+
pr.reset_rules()
9+
610
# Modify the paths based on where you've stored the files we made above
711
graph_path = './tests/friends_graph.graphml'
812

0 commit comments

Comments
 (0)