Skip to content

Commit 383f361

Browse files
committed
updated the way we handle reasoning again with new facts. Made it modular with same format of pyreason facts. Updated docs
1 parent b7cc033 commit 383f361

File tree

7 files changed

+105
-31
lines changed

7 files changed

+105
-31
lines changed

docs/source/user_guide/8_advanced_usage.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,10 @@ Reasoning Multiple Times
1818
-------------------------
1919
PyReason allows you to reason over the graph multiple times. This can be useful when you want to reason over the graph iteratively
2020
and add facts that were not available before. To reason over the graph multiple times, you can set ``again=True`` in ``pr.reason(again=True)``.
21-
To specify additional facts, use the ``node_facts`` or ``edge_facts`` parameters in ``pr.reason(...)``. These parameters allow you to add additional facts to the graph before reasoning again.
21+
To specify additional facts, use the ``facts`` parameter in ``pr.reason(...)``. These parameters allow you to add additional
22+
facts to the graph before reasoning again. The facts are specified as a list of PyReason facts.
23+
24+
.. note::
25+
When reasoning multiple times, the time continues to increment. Therefore any facts that are added should take this into account.
26+
The timestep parameter specifies how many additional timesteps to reason. For example, if the initial reasoning converges at
27+
timestep 5, and you want to reason for 3 more timesteps, you can set ``timestep=3`` in ``pr.reason(timestep=3, again=True, facts=[some_new_fact])``.

pyreason/.cache_status.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
initialized: true
1+
initialized: false

pyreason/pyreason.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77
import memory_profiler as mp
88
import warnings
9-
from typing import List, Type, Callable, Tuple
9+
from typing import List, Type, Callable, Tuple, Optional
1010

1111
from pyreason.scripts.utils.output import Output
1212
from pyreason.scripts.utils.filter import Filter
@@ -423,24 +423,24 @@ def allow_ground_rules(self, value: bool) -> None:
423423

424424

425425
# VARIABLES
426-
__graph = None
427-
__rules = None
428-
__clause_maps = None
429-
__node_facts = None
430-
__edge_facts = None
431-
__ipl = None
432-
__specific_node_labels = None
433-
__specific_edge_labels = None
434-
435-
__non_fluent_graph_facts_node = None
436-
__non_fluent_graph_facts_edge = None
437-
__specific_graph_node_labels = None
438-
__specific_graph_edge_labels = None
426+
__graph: Optional[nx.DiGraph] = None
427+
__rules: Optional[numba.typed.List] = None
428+
__clause_maps: Optional[dict] = None
429+
__node_facts: Optional[numba.typed.List] = None
430+
__edge_facts: Optional[numba.typed.List] = None
431+
__ipl: Optional[numba.typed.List] = None
432+
__specific_node_labels: Optional[numba.typed.List] = None
433+
__specific_edge_labels: Optional[numba.typed.List] = None
434+
435+
__non_fluent_graph_facts_node: Optional[numba.typed.List] = None
436+
__non_fluent_graph_facts_edge: Optional[numba.typed.List] = None
437+
__specific_graph_node_labels: Optional[numba.typed.List] = None
438+
__specific_graph_edge_labels: Optional[numba.typed.List] = None
439439

440440
__annotation_functions = []
441441

442442
__timestamp = ''
443-
__program = None
443+
__program: Optional[Program] = None
444444

445445
__graphml_parser = GraphmlParser()
446446
settings = _Settings()
@@ -624,16 +624,15 @@ def add_annotation_function(function: Callable) -> None:
624624
__annotation_functions.append(function)
625625

626626

627-
def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False, node_facts: List[Type[fact_node.Fact]] = None, edge_facts: List[Type[fact_edge.Fact]] = None):
627+
def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False, facts: List[Fact] = None):
628628
"""Function to start the main reasoning process. Graph and rules must already be loaded.
629629
630630
:param timesteps: Max number of timesteps to run. -1 specifies run till convergence. If reasoning again, this is the number of timesteps to reason for extra (no zero timestep), defaults to -1
631631
:param convergence_threshold: Maximum number of interpretations that have changed between timesteps or fixed point operations until considered convergent. Program will end at convergence. -1 => no changes, perfect convergence, defaults to -1
632632
:param convergence_bound_threshold: Maximum change in any interpretation (bounds) between timesteps or fixed point operations until considered convergent, defaults to -1
633633
:param queries: A list of PyReason query objects that can be used to filter the ruleset based on the query. Default is None
634634
:param again: Whether to reason again on an existing interpretation, defaults to False
635-
:param node_facts: New node facts to use during the next reasoning process. Other facts from file will be discarded, defaults to None
636-
:param edge_facts: New edge facts to use during the next reasoning process. Other facts from file will be discarded, defaults to None
635+
:param facts: New facts to use during the next reasoning process when reasoning again. Other facts from file will be discarded, defaults to None
637636
:return: The final interpretation after reasoning.
638637
"""
639638
global settings, __timestamp
@@ -654,10 +653,10 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou
654653
else:
655654
if settings.memory_profile:
656655
start_mem = mp.memory_usage(max_usage=True)
657-
mem_usage, interp = mp.memory_usage((_reason_again, [timesteps, convergence_threshold, convergence_bound_threshold, node_facts, edge_facts]), max_usage=True, retval=True)
656+
mem_usage, interp = mp.memory_usage((_reason_again, [timesteps, convergence_threshold, convergence_bound_threshold, facts]), max_usage=True, retval=True)
658657
print(f"\nProgram used {mem_usage-start_mem} MB of memory")
659658
else:
660-
interp = _reason_again(timesteps, convergence_threshold, convergence_bound_threshold, node_facts, edge_facts)
659+
interp = _reason_again(timesteps, convergence_threshold, convergence_bound_threshold, facts)
661660

662661
return interp
663662

@@ -746,20 +745,31 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
746745
return interpretation
747746

748747

749-
def _reason_again(timesteps, convergence_threshold, convergence_bound_threshold, node_facts, edge_facts):
748+
def _reason_again(timesteps, convergence_threshold, convergence_bound_threshold, facts):
750749
# Globals
751750
global __graph, __rules, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
752751
global settings, __timestamp, __program
753752

754753
assert __program is not None, 'To run `reason_again` you need to have reasoned once before'
755754

756-
# Extend current set of facts with the new facts supplied
757-
all_edge_facts = numba.typed.List.empty_list(fact_edge.fact_type)
755+
# Parse new facts and Extend current set of facts with the new facts supplied
758756
all_node_facts = numba.typed.List.empty_list(fact_node.fact_type)
759-
if node_facts is not None:
760-
all_node_facts.extend(numba.typed.List(node_facts))
761-
if edge_facts is not None:
762-
all_edge_facts.extend(numba.typed.List(edge_facts))
757+
all_edge_facts = numba.typed.List.empty_list(fact_edge.fact_type)
758+
fact_cnt = 1
759+
for fact in facts:
760+
if fact.type == 'node':
761+
print(fact.name)
762+
if fact.name is None:
763+
fact.name = f'fact_{len(__node_facts)+len(__edge_facts)+fact_cnt}'
764+
f = fact_node.Fact(fact.name, fact.component, fact.pred, fact.bound, fact.start_time, fact.end_time, fact.static)
765+
all_node_facts.append(f)
766+
fact_cnt += 1
767+
else:
768+
if fact.name is None:
769+
fact.name = f'fact_{len(__node_facts)+len(__edge_facts)+fact_cnt}'
770+
f = fact_edge.Fact(fact.name, fact.component, fact.pred, fact.bound, fact.start_time, fact.end_time, fact.static)
771+
all_edge_facts.append(f)
772+
fact_cnt += 1
763773

764774
# Run Program and get final interpretation
765775
interpretation = __program.reason_again(timesteps, convergence_threshold, convergence_bound_threshold, all_node_facts, all_edge_facts, settings.verbose)

pyreason/scripts/interpretation/interpretation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_ap
209209
return max_time
210210

211211
def _start_fp(self, rules, max_facts_time, verbose, again):
212+
if again:
213+
self.num_ga.append(self.num_ga[-1])
212214
fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.persistent, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, self.num_ga, verbose, again)
213215
self.time = t - 1
214216
# If we need to reason again, store the next timestep to start from

pyreason/scripts/interpretation/interpretation_parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_ap
209209
return max_time
210210

211211
def _start_fp(self, rules, max_facts_time, verbose, again):
212+
if again:
213+
self.num_ga.append(self.num_ga[-1])
212214
fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.persistent, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, self.num_ga, verbose, again)
213215
self.time = t - 1
214216
# If we need to reason again, store the next timestep to start from
@@ -218,7 +220,7 @@ def _start_fp(self, rules, max_facts_time, verbose, again):
218220
print('Fixed Point iterations:', fp_cnt)
219221

220222
@staticmethod
221-
@numba.njit(cache=True, parallel=False)
223+
@numba.njit(cache=True, parallel=True)
222224
def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, num_ga, verbose, again):
223225
t = prev_reasoning_data[0]
224226
fp_cnt = prev_reasoning_data[1]

tests/test_hello_world.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,3 @@ def test_hello_world():
4848
# John should be popular in timestep 3
4949
assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps'
5050

51-
test_hello_world()

tests/test_reason_again.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Test if the simple hello world program works
2+
import pyreason as pr
3+
import faulthandler
4+
5+
6+
def test_reason_again():
7+
# Reset PyReason
8+
pr.reset()
9+
pr.reset_rules()
10+
pr.reset_settings()
11+
12+
# Modify the paths based on where you've stored the files we made above
13+
graph_path = './tests/friends_graph.graphml'
14+
15+
# Modify pyreason settings to make verbose
16+
pr.settings.verbose = True # Print info to screen
17+
pr.settings.atom_trace = True # Save atom trace
18+
# pr.settings.optimize_rules = False # Disable rule optimization for debugging
19+
20+
# Load all the files into pyreason
21+
pr.load_graphml(graph_path)
22+
pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule'))
23+
pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 1))
24+
25+
# Run the program for two timesteps to see the diffusion take place
26+
faulthandler.enable()
27+
interpretation = pr.reason(timesteps=1)
28+
29+
# Now reason again
30+
new_fact = pr.Fact('popular(Mary)', 'popular_fact2', 2, 4)
31+
interpretation = pr.reason(timesteps=3, again=True, facts=[new_fact])
32+
pr.save_rule_trace(interpretation)
33+
34+
# Display the changes in the interpretation for each timestep
35+
dataframes = pr.filter_and_sort_nodes(interpretation, ['popular'])
36+
for t, df in enumerate(dataframes):
37+
print(f'TIMESTEP - {t}')
38+
print(df)
39+
print()
40+
41+
assert len(dataframes[2]) == 1, 'At t=0 there should be one popular person'
42+
assert len(dataframes[3]) == 2, 'At t=1 there should be two popular people'
43+
assert len(dataframes[4]) == 3, 'At t=2 there should be three popular people'
44+
45+
# Mary should be popular in all three timesteps
46+
assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps'
47+
assert 'Mary' in dataframes[3]['component'].values and dataframes[3].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps'
48+
assert 'Mary' in dataframes[4]['component'].values and dataframes[4].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps'
49+
50+
# Justin should be popular in timesteps 1, 2
51+
assert 'Justin' in dataframes[3]['component'].values and dataframes[3].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps'
52+
assert 'Justin' in dataframes[4]['component'].values and dataframes[4].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps'
53+
54+
# John should be popular in timestep 3
55+
assert 'John' in dataframes[4]['component'].values and dataframes[4].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps'

0 commit comments

Comments
 (0)