Skip to content

Commit 830a109

Browse files
Some refactors, tile size bugfix
1 parent 8e8aa56 commit 830a109

File tree

5 files changed

+27
-17
lines changed

5 files changed

+27
-17
lines changed

pytimeloop/fastfusion/mapper/per_einsum_mapper_snowcat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def per_einsum_mapper_snowcat(
215215
)
216216
data = {einsum_id: defaultdict(list) for einsum_id in einsums_to_explore}
217217

218-
for einsum_id, result in parallel(jobs, return_as="generator_unordered", pbar="Generating data for Einsums"):
218+
for einsum_id, result in parallel(jobs, return_as="generator_unordered", pbar="Generating Single-Einsum Mappings"):
219219
d = data[einsum_id]
220220
for k, v in result.items():
221221
d[k[0]] += v

pytimeloop/fastfusion/mapper/simexplore.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,23 @@ def fuse_sims(
144144
shared_tensors=shared_tensors,
145145
)
146146

147-
# right = consolidate(right, left=False, **args)
148-
# left = consolidate(left, left=True, **args)
149-
150147
left = SIM.combine_combineable(left, live_tensors | right_tensors)
151148
right = SIM.combine_combineable(right, live_tensors | left_tensors)
152149

150+
print_time("Combining")
151+
153152
left = sorted(left, key=lambda x: len(x.mapping.data), reverse=True)
154153
right = sorted(right, key=lambda x: len(x.mapping.data), reverse=True)
155154

156155
left = parallel([delayed(lambda l: l.left_consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in left], pbar="Left consolidate")
157156
right = parallel([delayed(lambda l: l.consolidate(live_tensors, resource2capacity, shared_tensors))(l) for l in right], pbar="Right consolidate")
158157

158+
print_time("Consolidating")
159+
159160
# Group left and right into buckets
160161
right = SIM.group_right(right, left_tensors)
161162
left = SIM.group_left(left, right_tensors)
163+
162164
print_time("Grouping")
163165

164166
for v in list(left.values()) + list(right.values()):
@@ -167,11 +169,6 @@ def fuse_sims(
167169
if t not in live_tensors:
168170
del s.tensors[t]
169171

170-
# left = {k: SIM.combine_combineable(v, live_tensors | right_tensors) for k, v in left.items()}
171-
# right = {k: SIM.combine_combineable(v, live_tensors | left_tensors) for k, v in right.items()}
172-
173-
print_time("Consolidating")
174-
175172
DO_PRINT = False
176173
DELAY_MERGE = not debugger_active()
177174

pytimeloop/fastfusion/pareto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,10 +370,10 @@ def einsum_ids(self):
370370
return fzs(self.data[LOGSTRING].iloc[0].keys())
371371

372372
@staticmethod
373-
def concat(paretos: list["Pareto"]) -> "Pareto":
373+
def concat(paretos: list["Pareto"], skip_pareto: bool=False) -> "Pareto":
374374
return Pareto(
375375
pd.concat([p.data for p in paretos]).fillna(0),
376-
skip_pareto=len(paretos) == 1,
376+
skip_pareto=len(paretos) == 1 or skip_pareto,
377377
)
378378

379379
def merge_next(

pytimeloop/fastfusion/plot/looptree.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], skip_ba
8484
prev_tilings = []
8585
root = Node()
8686
einsum_ids = list(mappings.keys())
87-
88-
assert set(einsum_ids) == set(stats.keys())
87+
if stats is not None:
88+
assert set(einsum_ids) == set(stats.keys())
8989

9090

9191

@@ -133,8 +133,8 @@ def tilings2looptree(mappings: dict[str, Tiling], stats: dict[str, Any], skip_ba
133133
# TODO if tensor not in n.this_level or tensor not in backers:
134134
if tensor not in n.this_level or tensor not in backers:
135135
n.this_level.append(tensor)
136-
137-
root.add_stats(stats[einsum_id])
136+
if stats is not None:
137+
root.add_stats(stats[einsum_id])
138138
# for k, v in partial_stats[einsum_id].items():
139139
# last_level.append(f"_PARTIAL {k}: {expfmt(v)}")
140140
prev_tilings.append(tiling)

pytimeloop/fastfusion/sim.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,18 @@ class TensorStorage:
7878
tensor_id: str
7979
above_loop_index: int
8080
backer_id: str
81+
# NOTE: Tile size is not included in hash or equality functions. This is
82+
# because inter-Einsum comparisons care about the loops and locations of
83+
# backing storages, and the tile sizes are derived from these. We don't want
84+
# rounding errors in the tile size to effect our inter-Einsum comparisons.
8185
tile_size: int
8286
# n_repititions: int = 1
8387

8488
def __tuple__(self):
8589
return (self.tensor_id, self.backer_id, self.above_loop_index, self.tile_size)
90+
91+
def __hash__(self):
92+
return hash((self.tensor_id, self.backer_id, self.above_loop_index))
8693

8794
@property
8895
def ts(self):
@@ -127,7 +134,7 @@ def get_backing_stores(all_tensors: set["TensorStorage"]) -> list["TensorStorage
127134
def __eq__(self, value):
128135
if not isinstance(value, TensorStorage):
129136
return False
130-
for to_check in ["tensor_id", "backer_id", "above_loop_index", "tile_size"]:
137+
for to_check in ["tensor_id", "backer_id", "above_loop_index"]:#$, "tile_size"]:
131138
a, b = getattr(self, to_check), getattr(value, to_check)
132139
if a != "*" and b != "*" and a != b:
133140
return False
@@ -369,7 +376,13 @@ def _group(
369376

370377
@staticmethod
371378
def combine_combineable(sims: list["SIM"], live_tensors: set[str], allow_different_tilings: bool=False) -> list["SIM"]:
372-
return parallel(delayed(SIM.concat)(s, allow_different_tilings) for s in SIM._group(sims, live_tensors).values())
379+
groups = list(SIM._group(sims, live_tensors).values())
380+
groups_with_one = [g[0] for g in groups if len(g) == 1]
381+
others = parallel(
382+
[delayed(SIM.concat)(g, allow_different_tilings) for g in groups if len(g) > 1],
383+
pbar="Combining SIMs"
384+
)
385+
return groups_with_one + others
373386

374387
@staticmethod
375388
def filter_by_tensor_storages(

0 commit comments

Comments
 (0)