Skip to content

Commit 49aa83a

Browse files
committed
updated the way we pass facts to reason again to be same as reasoning for the first time
1 parent 956ef89 commit 49aa83a

File tree

5 files changed

+50
-23
lines changed

5 files changed

+50
-23
lines changed

pyreason/pyreason.py

+11-20
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def add_annotation_function(function: Callable) -> None:
624624
__annotation_functions.append(function)
625625

626626

627-
def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False, facts: List[Fact] = None):
627+
def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False):
628628
"""Function to start the main reasoning process. Graph and rules must already be loaded.
629629
630630
:param timesteps: Max number of timesteps to run. -1 specifies run till convergence. If reasoning again, this is the number of timesteps to reason for extra (no zero timestep), defaults to -1
@@ -653,10 +653,10 @@ def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bou
653653
else:
654654
if settings.memory_profile:
655655
start_mem = mp.memory_usage(max_usage=True)
656-
mem_usage, interp = mp.memory_usage((_reason_again, [timesteps, convergence_threshold, convergence_bound_threshold, facts]), max_usage=True, retval=True)
656+
mem_usage, interp = mp.memory_usage((_reason_again, [timesteps, convergence_threshold, convergence_bound_threshold]), max_usage=True, retval=True)
657657
print(f"\nProgram used {mem_usage-start_mem} MB of memory")
658658
else:
659-
interp = _reason_again(timesteps, convergence_threshold, convergence_bound_threshold, facts)
659+
interp = _reason_again(timesteps, convergence_threshold, convergence_bound_threshold)
660660

661661
return interp
662662

@@ -742,34 +742,25 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold, queri
742742
# Run Program and get final interpretation
743743
interpretation = __program.reason(timesteps, convergence_threshold, convergence_bound_threshold, settings.verbose)
744744

745+
# Clear facts after reasoning, so that reasoning again is possible with any added facts
746+
__node_facts = None
747+
__edge_facts = None
748+
745749
return interpretation
746750

747751

748-
def _reason_again(timesteps, convergence_threshold, convergence_bound_threshold, facts):
752+
def _reason_again(timesteps, convergence_threshold, convergence_bound_threshold):
749753
# Globals
750754
global __graph, __rules, __node_facts, __edge_facts, __ipl, __specific_node_labels, __specific_edge_labels, __graphml_parser
751755
global settings, __timestamp, __program
752756

753757
assert __program is not None, 'To run `reason_again` you need to have reasoned once before'
754758

755-
# Parse new facts and Extend current set of facts with the new facts supplied
759+
# Extend facts
756760
all_node_facts = numba.typed.List.empty_list(fact_node.fact_type)
757761
all_edge_facts = numba.typed.List.empty_list(fact_edge.fact_type)
758-
fact_cnt = 1
759-
for fact in facts:
760-
if fact.type == 'node':
761-
print(fact.name)
762-
if fact.name is None:
763-
fact.name = f'fact_{len(__node_facts)+len(__edge_facts)+fact_cnt}'
764-
f = fact_node.Fact(fact.name, fact.component, fact.pred, fact.bound, fact.start_time, fact.end_time, fact.static)
765-
all_node_facts.append(f)
766-
fact_cnt += 1
767-
else:
768-
if fact.name is None:
769-
fact.name = f'fact_{len(__node_facts)+len(__edge_facts)+fact_cnt}'
770-
f = fact_edge.Fact(fact.name, fact.component, fact.pred, fact.bound, fact.start_time, fact.end_time, fact.static)
771-
all_edge_facts.append(f)
772-
fact_cnt += 1
762+
all_node_facts.extend(numba.typed.List(__node_facts))
763+
all_edge_facts.extend(numba.typed.List(__edge_facts))
773764

774765
# Run Program and get final interpretation
775766
interpretation = __program.reason_again(timesteps, convergence_threshold, convergence_bound_threshold, all_node_facts, all_edge_facts, settings.verbose)

pyreason/scripts/interpretation/interpretation.py

+18
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
228228
timestep_loop = True
229229
facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type)
230230
facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
231+
facts_to_be_applied_node_trace_new = numba.typed.List.empty_list(numba.types.string)
232+
facts_to_be_applied_edge_trace_new = numba.typed.List.empty_list(numba.types.string)
231233
rules_to_remove_idx = set()
232234
rules_to_remove_idx.add(-1)
233235
while timestep_loop:
@@ -260,6 +262,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
260262
# Start by applying facts
261263
# Nodes
262264
facts_to_be_applied_node_new.clear()
265+
facts_to_be_applied_node_trace_new.clear()
263266
nodes_set = set(nodes)
264267
for i in range(len(facts_to_be_applied_node)):
265268
if facts_to_be_applied_node[i][0] == t:
@@ -317,17 +320,25 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
317320

318321
if static:
319322
facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute))
323+
if atom_trace:
324+
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])
320325

321326
# If time doesn't match, fact to be applied later
322327
else:
323328
facts_to_be_applied_node_new.append(facts_to_be_applied_node[i])
329+
if atom_trace:
330+
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])
324331

325332
# Update list of facts with ones that have not been applied yet (delete applied facts)
326333
facts_to_be_applied_node[:] = facts_to_be_applied_node_new.copy()
334+
if atom_trace:
335+
facts_to_be_applied_node_trace[:] = facts_to_be_applied_node_trace_new.copy()
327336
facts_to_be_applied_node_new.clear()
337+
facts_to_be_applied_node_trace_new.clear()
328338

329339
# Edges
330340
facts_to_be_applied_edge_new.clear()
341+
facts_to_be_applied_edge_trace_new.clear()
331342
edges_set = set(edges)
332343
for i in range(len(facts_to_be_applied_edge)):
333344
if facts_to_be_applied_edge[i][0]==t:
@@ -383,14 +394,21 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
383394

384395
if static:
385396
facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute))
397+
if atom_trace:
398+
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])
386399

387400
# Time doesn't match, fact to be applied later
388401
else:
389402
facts_to_be_applied_edge_new.append(facts_to_be_applied_edge[i])
403+
if atom_trace:
404+
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])
390405

391406
# Update list of facts with ones that have not been applied yet (delete applied facts)
392407
facts_to_be_applied_edge[:] = facts_to_be_applied_edge_new.copy()
408+
if atom_trace:
409+
facts_to_be_applied_edge_trace[:] = facts_to_be_applied_edge_trace_new.copy()
393410
facts_to_be_applied_edge_new.clear()
411+
facts_to_be_applied_edge_trace_new.clear()
394412

395413
in_loop = True
396414
while in_loop:

pyreason/scripts/interpretation/interpretation_parallel.py

+18
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
228228
timestep_loop = True
229229
facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type)
230230
facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type)
231+
facts_to_be_applied_node_trace_new = numba.typed.List.empty_list(numba.types.string)
232+
facts_to_be_applied_edge_trace_new = numba.typed.List.empty_list(numba.types.string)
231233
rules_to_remove_idx = set()
232234
rules_to_remove_idx.add(-1)
233235
while timestep_loop:
@@ -260,6 +262,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
260262
# Start by applying facts
261263
# Nodes
262264
facts_to_be_applied_node_new.clear()
265+
facts_to_be_applied_node_trace_new.clear()
263266
nodes_set = set(nodes)
264267
for i in range(len(facts_to_be_applied_node)):
265268
if facts_to_be_applied_node[i][0] == t:
@@ -317,17 +320,25 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
317320

318321
if static:
319322
facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute))
323+
if atom_trace:
324+
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])
320325

321326
# If time doesn't match, fact to be applied later
322327
else:
323328
facts_to_be_applied_node_new.append(facts_to_be_applied_node[i])
329+
if atom_trace:
330+
facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i])
324331

325332
# Update list of facts with ones that have not been applied yet (delete applied facts)
326333
facts_to_be_applied_node[:] = facts_to_be_applied_node_new.copy()
334+
if atom_trace:
335+
facts_to_be_applied_node_trace[:] = facts_to_be_applied_node_trace_new.copy()
327336
facts_to_be_applied_node_new.clear()
337+
facts_to_be_applied_node_trace_new.clear()
328338

329339
# Edges
330340
facts_to_be_applied_edge_new.clear()
341+
facts_to_be_applied_edge_trace_new.clear()
331342
edges_set = set(edges)
332343
for i in range(len(facts_to_be_applied_edge)):
333344
if facts_to_be_applied_edge[i][0]==t:
@@ -383,14 +394,21 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi
383394

384395
if static:
385396
facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute))
397+
if atom_trace:
398+
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])
386399

387400
# Time doesn't match, fact to be applied later
388401
else:
389402
facts_to_be_applied_edge_new.append(facts_to_be_applied_edge[i])
403+
if atom_trace:
404+
facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i])
390405

391406
# Update list of facts with ones that have not been applied yet (delete applied facts)
392407
facts_to_be_applied_edge[:] = facts_to_be_applied_edge_new.copy()
408+
if atom_trace:
409+
facts_to_be_applied_edge_trace[:] = facts_to_be_applied_edge_trace_new.copy()
393410
facts_to_be_applied_edge_new.clear()
411+
facts_to_be_applied_edge_trace_new.clear()
394412

395413
in_loop = True
396414
while in_loop:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setup(
1010
name='pyreason',
11-
version='3.0.2',
11+
version='3.0.3',
1212
author='Dyuman Aditya',
1313
author_email='[email protected]',
1414
description='An explainable inference software supporting annotated, real valued, graph based and temporal logic',

tests/test_reason_again.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def test_reason_again():
2828

2929
# Now reason again
3030
new_fact = pr.Fact('popular(Mary)', 'popular_fact2', 2, 4)
31-
interpretation = pr.reason(timesteps=3, again=True, facts=[new_fact])
32-
pr.save_rule_trace(interpretation)
31+
pr.add_fact(new_fact)
32+
interpretation = pr.reason(timesteps=3, again=True)
3333

3434
# Display the changes in the interpretation for each timestep
3535
dataframes = pr.filter_and_sort_nodes(interpretation, ['popular'])

0 commit comments

Comments
 (0)