@@ -78,11 +78,18 @@ class TensorStorage:
7878 tensor_id : str
7979 above_loop_index : int
8080 backer_id : str
81+ # NOTE: Tile size is not included in hash or equality functions. This is
82+ # because inter-Einsum comparisons care about the loops and locations of
83+ # backing storages, and the tile sizes are derived from these. We don't want
84+ # rounding errors in the tile size to effect our inter-Einsum comparisons.
8185 tile_size : int
8286 # n_repititions: int = 1
8387
8488 def __tuple__ (self ):
8589 return (self .tensor_id , self .backer_id , self .above_loop_index , self .tile_size )
90+
91+ def __hash__ (self ):
92+ return hash ((self .tensor_id , self .backer_id , self .above_loop_index ))
8693
8794 @property
8895 def ts (self ):
@@ -127,7 +134,7 @@ def get_backing_stores(all_tensors: set["TensorStorage"]) -> list["TensorStorage
127134 def __eq__ (self , value ):
128135 if not isinstance (value , TensorStorage ):
129136 return False
130- for to_check in ["tensor_id" , "backer_id" , "above_loop_index" , "tile_size" ]:
137+ for to_check in ["tensor_id" , "backer_id" , "above_loop_index" ]: #$ , "tile_size"]:
131138 a , b = getattr (self , to_check ), getattr (value , to_check )
132139 if a != "*" and b != "*" and a != b :
133140 return False
@@ -369,7 +376,13 @@ def _group(
369376
370377 @staticmethod
371378 def combine_combineable (sims : list ["SIM" ], live_tensors : set [str ], allow_different_tilings : bool = False ) -> list ["SIM" ]:
372- return parallel (delayed (SIM .concat )(s , allow_different_tilings ) for s in SIM ._group (sims , live_tensors ).values ())
379+ groups = list (SIM ._group (sims , live_tensors ).values ())
380+ groups_with_one = [g [0 ] for g in groups if len (g ) == 1 ]
381+ others = parallel (
382+ [delayed (SIM .concat )(g , allow_different_tilings ) for g in groups if len (g ) > 1 ],
383+ pbar = "Combining SIMs"
384+ )
385+ return groups_with_one + others
373386
374387 @staticmethod
375388 def filter_by_tensor_storages (
0 commit comments