@@ -195,6 +195,9 @@ def __init__(
195
195
self .encode_path_length = encode_path_length
196
196
self .edge_position_vars = {}
197
197
198
+ self .edges_set_to_zero = {}
199
+ self .edges_set_to_one = {}
200
+
198
201
self .solver_options = solver_options
199
202
if self .solver_options is None :
200
203
self .solver_options = {}
@@ -288,6 +291,8 @@ def create_solver_and_paths(self):
288
291
289
292
self ._encode_paths ()
290
293
294
+ self ._apply_safety_optimizations_fix_zero_edges ()
295
+
291
296
def _encode_paths (self ):
292
297
293
298
# Encodes the paths in the graph by creating variables for edges and subpaths.
@@ -447,49 +452,108 @@ def _encode_paths(self):
447
452
name = f"path_length_constr_i={ i } "
448
453
)
449
454
450
- ########################################
451
- # #
452
- # Fixing variables based on safe lists #
453
- # #
454
- ########################################
455
+ def _apply_safety_optimizations (self ):
455
456
456
457
if self .safe_lists is not None :
457
- paths_to_fix = self ._get_paths_to_fix_from_safe_lists ()
458
-
459
- if not self .optimize_with_safety_as_subpath_constraints :
460
- # iterating over safe lists
461
- for i in range (min (len (paths_to_fix ), self .k )):
462
- # print("Fixing variables for safe list #", i)
463
- # iterate over the edges in the safe list to fix variables to 1
464
- for u , v in paths_to_fix [i ]:
465
- self .solver .add_constraint (
466
- self .edge_vars [(u , v , i )] == 1 ,
467
- name = f"safe_list_u={ u } _v={ v } _i={ i } " ,
468
- )
458
+ self .paths_to_fix = self ._get_paths_to_fix_from_safe_lists ()
459
+
460
+ if not self .optimize_with_safety_as_subpath_constraints :
461
+ # iterating over safe lists
462
+ for i in range (min (len (self .paths_to_fix ), self .k )):
463
+ # print("Fixing variables for safe list #", i)
464
+ # iterate over the edges in the safe list to fix variables to 1
465
+ for u , v in self .paths_to_fix [i ]:
466
+ self .solver .add_constraint (
467
+ self .edge_vars [(u , v , i )] == 1 ,
468
+ name = f"safe_list_u={ u } _v={ v } _i={ i } " ,
469
+ )
470
+ self .edges_set_to_one [(u , v , i )] = True
471
+
472
+ self ._apply_safety_optimizations_fix_zero_edges ()
473
+
474
+ def _apply_safety_optimizations_fix_zero_edges (self ):
475
+ """
476
+ Prune layer-edge variables to zero using safe-walk reachability while
477
+ preserving edges that can be part of the walk or its connectors.
478
+
479
+ For each walk i in `walks_to_fix` we build a protection set of edges that
480
+ must not be fixed to 0 for layer i:
481
+ 1) Protect all edges that appear in the walk itself.
482
+ 2) Whole-walk reachability: let first_node be the first node of the walk
483
+ and last_node the last node. Protect any edge (u,v) such that
484
+ - u is reachable (forward) from last_node, OR
485
+ - v can reach (backward) first_node.
486
+ 3) Gap-bridging between consecutive edges: for every pair of consecutive
487
+ edges whose endpoints do not match (a gap), let
488
+ - current_last = end node of the first edge, and
489
+ - current_start = start node of the next edge.
490
+ Protect any edge (u,v) such that
491
+ - u is reachable (forward) from current_last, AND
492
+ - v can reach (backward) current_start.
493
+
494
+ All remaining edges (u,v) not in the protection set are fixed to 0 in
495
+ layer i.
496
+
497
+ Notes:
498
+ - Requires `self.paths_to_fix` already computed and `self.edge_vars` created.
499
+ """
500
+ if not hasattr (self , "paths_to_fix" ) or self .paths_to_fix is None :
501
+ return
502
+
503
+ fixed_zero_count = 0
504
+ # Ensure we don't go beyond k layers
505
+ for i in range (min (len (self .paths_to_fix ), self .k )):
506
+ path = self .paths_to_fix [i ]
507
+ if not path or len (path ) == 0 :
508
+ continue
509
+
510
+ # Build the set of edges that should NOT be fixed to 0 for this layer i
511
+ # Start by protecting all edges in the path itself
512
+ protected_edges = set ((u , v ) for (u , v ) in path if self .G .has_edge (u , v ))
513
+
514
+ # Also protect edges that are reachable from the last node of the path
515
+ # or that can reach the first node of the path
516
+ first_node = path [0 ][0 ]
517
+ last_node = path [- 1 ][1 ]
518
+ for (u , v ) in self .G .edges :
519
+ if (u in self .G .reachable_nodes_from [last_node ]) or (v in self .G .nodes_reaching (first_node )):
520
+ protected_edges .add ((u , v ))
521
+
522
+ # Collect pairs of non-contiguous consecutive edges (gaps)
523
+ gap_pairs = []
524
+ for idx in range (len (path ) - 1 ):
525
+ end_prev = path [idx ][1 ]
526
+ start_next = path [idx + 1 ][0 ]
527
+ # We consider all consecutive edges as gap pairs, because there could be a cycle
528
+ # formed between them (this is not the case in DAGs)
529
+ if end_prev != start_next :
530
+ gap_pairs .append ((end_prev , start_next ))
531
+
532
+ # For each gap, add edges that can lie on some path bridging the gap
533
+ for (current_last , current_start ) in gap_pairs :
534
+ for (u , v ) in self .G .edges :
535
+ if (u in self .G .nodes_reachable (current_last )) and (v in self .G .nodes_reaching (current_start )):
536
+ # if (u in reachable_from_last) and (v in can_reach_start):
537
+ protected_edges .add ((u , v ))
538
+
539
+ # Now fix every other edge to 0 for this layer i
540
+ for (u , v ) in self .G .edges :
541
+ if (u , v ) in protected_edges :
542
+ continue
543
+ # Queue zero-fix for batch bounds update
544
+ # self.solver.queue_fix_variable(self.edge_vars[(u, v, i)], int(0))
545
+ self .solver .add_constraint (
546
+ self .edge_vars [(u , v , i )] == 0 ,
547
+ name = f"i={ i } _u={ u } _v={ v } _fix0" ,
548
+ )
549
+ self .edges_set_to_zero [(u , v , i )] = True
550
+ fixed_zero_count += 1
551
+
552
+ if fixed_zero_count :
553
+ # Accumulate into solve statistics
554
+ self .solve_statistics ["edge_variables=0" ] = self .solve_statistics .get ("edge_variables=0" , 0 ) + fixed_zero_count
555
+ utils .logger .debug (f"{ __name__ } : Fixed { fixed_zero_count } edge variables to 0 via reachability pruning." )
469
556
470
- if self .optimize_with_safe_zero_edges :
471
- # get the endpoints of the longest safe path in the sequence
472
- first_node , last_node = (
473
- safetypathcovers .get_endpoints_of_longest_safe_path_in (paths_to_fix [i ])
474
- )
475
- # get the reachable nodes from the last node
476
- reachable_nodes = self .G .reachable_nodes_from [last_node ]
477
- # get the backwards reachable nodes from the first node
478
- reachable_nodes_reverse = self .G .reachable_nodes_rev_from [first_node ]
479
- # get the edges in the path
480
- path_edges = set ((u , v ) for (u , v ) in paths_to_fix [i ])
481
-
482
- for u , v in self .G .base_graph .edges ():
483
- if (
484
- (u , v ) not in path_edges
485
- and u not in reachable_nodes
486
- and v not in reachable_nodes_reverse
487
- ):
488
- # print(f"Adding zero constraint for edge ({u}, {v}) in path {i}")
489
- self .solver .add_constraint (
490
- self .edge_vars [(u , v , i )] == 0 ,
491
- name = f"safe_list_zero_edge_u={ u } _v={ v } _i={ i } " ,
492
- )
493
557
494
558
495
559
def _get_paths_to_fix_from_safe_lists (self ) -> list :
0 commit comments