Skip to content

Commit 4704d76

Browse files
committed
added new PyReason Rule object
1 parent 5d02ac5 commit 4704d76

File tree

7 files changed

+77
-68
lines changed

7 files changed

+77
-68
lines changed

pyreason/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
settings.verbose = False
2121
load_graphml(graph_path)
22-
add_rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')
22+
add_rule(Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule'))
2323
add_fact(Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2))
2424
reason(timesteps=2)
2525

pyreason/pyreason.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
1717
import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule
1818
from pyreason.scripts.facts.fact import Fact
19+
from pyreason.scripts.rules.rule import Rule
1920
import pyreason.scripts.numba_wrapper.numba_types.fact_node_type as fact_node
2021
import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge
2122
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
@@ -449,29 +450,15 @@ def load_inconsistent_predicate_list(path: str) -> None:
449450
__ipl = yaml_parser.parse_ipl(path)
450451

451452

452-
def add_rule(rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> None:
453+
def add_rule(pr_rule: Rule) -> None:
453454
"""Add a rule to pyreason from text format. This format is not as modular as the YAML format.
454-
1. It is not possible to specify thresholds. Threshold is greater than or equal to 1 by default
455-
2. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0
456-
TODO: Add threshold class where we can pass this as a parameter
457-
TODO: Add weights as a parameter
458-
459-
Example:
460-
`'pred1(x,y) : [0.2, 1] <- pred2(a, b) : [1,1], pred3(b, c)'`
461-
462-
:param rule_text: The rule in text format
463-
:param name: The name of the rule. This will appear in the rule trace
464-
:param infer_edges: Whether to infer new edges after edge rule fires
465-
:param set_static: Whether to set the atom in the head as static if the rule fires. The bounds will no longer change
466-
:param immediate_rule: Whether the rule is immediate. Immediate rules check for more applicable rules immediately after being applied
467455
"""
468456
global __rules
469457

470-
r = rule_parser.parse_rule(rule_text, name, infer_edges, set_static, immediate_rule)
471458
# Add to collection of rules
472459
if __rules is None:
473460
__rules = numba.typed.List.empty_list(rule.rule_type)
474-
__rules.append(r)
461+
__rules.append(pr_rule.rule)
475462

476463

477464
def add_rules_from_file(file_path: str, infer_edges: bool = False) -> None:
@@ -488,7 +475,7 @@ def add_rules_from_file(file_path: str, infer_edges: bool = False) -> None:
488475

489476
rule_offset = 0 if __rules is None else len(__rules)
490477
for i, r in enumerate(rules):
491-
add_rule(r, f'rule_{i+rule_offset}', infer_edges)
478+
add_rule(Rule(r, f'rule_{i+rule_offset}', infer_edges))
492479

493480

494481
def add_fact(pyreason_fact: Fact) -> None:
@@ -578,7 +565,7 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold):
578565
if __graph is None:
579566
raise Exception('Graph not loaded. Use `load_graph` to load the graphml file')
580567
if __rules is None:
581-
raise Exception('Rules not loaded. Use `load_rules` to load the rules yaml file')
568+
raise Exception('There are no rules, use `add_rule` or `add_rules_from_file`')
582569

583570
# Check variables that are highly recommended. Warnings
584571
if __node_labels is None and __edge_labels is None:

pyreason/scripts/numba_wrapper/numba_types/rule_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
22
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
3-
from pyreason.scripts.rules.rule import Rule
3+
from pyreason.scripts.rules.rule_internal import Rule
44

55
from numba import types
66
from numba.extending import typeof_impl

pyreason/scripts/rules/rule.py

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,22 @@
1-
class Rule:
2-
3-
def __init__(self, rule_name, rule_type, target, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule):
4-
self._rule_name = rule_name
5-
self._type = rule_type
6-
self._target = target
7-
self._delta = delta
8-
self._clauses = clauses
9-
self._bnd = bnd
10-
self._thresholds = thresholds
11-
self._ann_fn = ann_fn
12-
self._weights = weights
13-
self._edges = edges
14-
self._static = static
15-
self._immediate_rule = immediate_rule
16-
17-
def get_rule_name(self):
18-
return self._rule_name
19-
20-
def get_rule_type(self):
21-
return self._type
22-
23-
def get_target(self):
24-
return self._target
1+
import pyreason.scripts.utils.rule_parser as rule_parser
252

26-
def get_delta(self):
27-
return self._delta
283

29-
def get_neigh_criteria(self):
30-
return self._clauses
31-
32-
def get_bnd(self):
33-
return self._bnd
34-
35-
def get_thresholds(self):
36-
return self._thresholds
37-
38-
def get_annotation_function(self):
39-
return self._ann_fn
40-
41-
def get_edges(self):
42-
return self._edges
43-
44-
def is_static(self):
45-
return self._static
46-
47-
def is_immediate_rule(self):
48-
return self._immediate_rule
4+
class Rule:
5+
"""
6+
Example text:
7+
`'pred1(x,y) : [0.2, 1] <- pred2(a, b) : [1,1], pred3(b, c)'`
8+
9+
1. It is not possible to specify thresholds. Threshold is greater than or equal to 1 by default
10+
2. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0
11+
TODO: Add threshold class where we can pass this as a parameter
12+
TODO: Add weights as a parameter
13+
"""
14+
def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False):
15+
"""
16+
:param rule_text: The rule in text format
17+
:param name: The name of the rule. This will appear in the rule trace
18+
:param infer_edges: Whether to infer new edges after edge rule fires
19+
:param set_static: Whether to set the atom in the head as static if the rule fires. The bounds will no longer change
20+
:param immediate_rule: Whether the rule is immediate. Immediate rules check for more applicable rules immediately after being applied
21+
"""
22+
self.rule = rule_parser.parse_rule(rule_text, name, infer_edges, set_static, immediate_rule)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
class Rule:
2+
3+
def __init__(self, rule_name, rule_type, target, delta, clauses, bnd, thresholds, ann_fn, weights, edges, static, immediate_rule):
4+
self._rule_name = rule_name
5+
self._type = rule_type
6+
self._target = target
7+
self._delta = delta
8+
self._clauses = clauses
9+
self._bnd = bnd
10+
self._thresholds = thresholds
11+
self._ann_fn = ann_fn
12+
self._weights = weights
13+
self._edges = edges
14+
self._static = static
15+
self._immediate_rule = immediate_rule
16+
17+
def get_rule_name(self):
18+
return self._rule_name
19+
20+
def get_rule_type(self):
21+
return self._type
22+
23+
def get_target(self):
24+
return self._target
25+
26+
def get_delta(self):
27+
return self._delta
28+
29+
def get_neigh_criteria(self):
30+
return self._clauses
31+
32+
def get_bnd(self):
33+
return self._bnd
34+
35+
def get_thresholds(self):
36+
return self._thresholds
37+
38+
def get_annotation_function(self):
39+
return self._ann_fn
40+
41+
def get_edges(self):
42+
return self._edges
43+
44+
def is_static(self):
45+
return self._static
46+
47+
def is_immediate_rule(self):
48+
return self._immediate_rule

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setup(
1010
name='pyreason',
11-
version='2.0.4',
11+
version='2.1.0',
1212
author='Dyuman Aditya',
1313
author_email='[email protected]',
1414
description='An explainable inference software supporting annotated, real valued, graph based and temporal logic',

tests/test_hello_world.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_hello_world():
1111

1212
# Load all the files into pyreason
1313
pr.load_graphml(graph_path)
14-
pr.add_rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')
14+
pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule'))
1515
pr.add_fact(pr.Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2))
1616

1717
# Run the program for two timesteps to see the diffusion take place

0 commit comments

Comments
 (0)