Skip to content

Commit dd0f4f1

Browse files
committed
Add gc.py
1 parent fe70bcc commit dd0f4f1

File tree

2 files changed

+131
-3
lines changed

2 files changed

+131
-3
lines changed

tools/circle2circle/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ Selectively removes operators from a Circle model based on their index range. Th
121121

122122
##
123123

124-
### `remove.unused_tensors.py`
124+
### `gc.py`
125125

126-
Identifies and removes unused tensors from all subgraphs within a Circle model. A tensor is considered "unused" if it is not an input to any operator and not an output of its containing subgraph. This helps in cleaning up the model and potentially reducing its size. The script can either list unused tensors or modify the model to remove them.
126+
Performs garbage collection by removing unreachable tensors and buffers, reducing model size and memory consumption.
127127

128128
##
129129

tools/circle2circle/remove.unused_tensors.py renamed to tools/circle2circle/gc.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,63 @@ def find_unused_tensors_in_subgraph(subgraph):
5656
return unused_indices
5757

5858

59+
def find_unused_buffers(model):
60+
"""
61+
Finds and returns the indices of unused buffers in the model.
62+
This function works with both Native API (read-only) and Object API (mutable) model objects.
63+
64+
Args:
65+
model: The Circle model object (read-only or mutable).
66+
67+
Returns:
68+
list: A list of integer indices representing unused buffers.
69+
"""
70+
# Handle both Native API and Object API
71+
if hasattr(model, 'BuffersLength'):
72+
# Native API
73+
if not model.BuffersLength():
74+
return []
75+
76+
used_buffer_indices = set()
77+
78+
# Collect buffer indices from all tensors in all subgraphs
79+
for i in range(model.SubgraphsLength()):
80+
subgraph = model.Subgraphs(i)
81+
if subgraph:
82+
for j in range(subgraph.TensorsLength()):
83+
tensor = subgraph.Tensors(j)
84+
if tensor and tensor.Buffer() != -1: # -1 indicates no buffer
85+
used_buffer_indices.add(tensor.Buffer())
86+
87+
# A buffer is unused if it's not referenced by any tensor
88+
unused_indices = []
89+
for i in range(model.BuffersLength()):
90+
if i not in used_buffer_indices:
91+
unused_indices.append(i)
92+
93+
return unused_indices
94+
else:
95+
# Object API
96+
if not model.buffers:
97+
return []
98+
99+
used_buffer_indices = set()
100+
101+
# Collect buffer indices from all tensors in all subgraphs
102+
for subgraph in model.subgraphs:
103+
for tensor in subgraph.tensors:
104+
if tensor.buffer != -1: # -1 indicates no buffer
105+
used_buffer_indices.add(tensor.buffer)
106+
107+
# A buffer is unused if it's not referenced by any tensor
108+
unused_indices = []
109+
for i in range(len(model.buffers)):
110+
if i not in used_buffer_indices:
111+
unused_indices.append(i)
112+
113+
return unused_indices
114+
115+
59116
def remove_tensors_and_update_model(model, subgraph_index_to_modify,
60117
tensor_indices_to_remove):
61118
"""
@@ -158,6 +215,61 @@ def remove_tensors_and_update_model(model, subgraph_index_to_modify,
158215
return sorted(removed_indices)
159216

160217

218+
def remove_buffers_and_update_model(model, buffer_indices_to_remove):
219+
"""
220+
Removes specified buffers from the model and updates all tensor references.
221+
This function uses the Object API for mutable model objects.
222+
223+
Args:
224+
model: The mutable Circle model object (ModelT).
225+
buffer_indices_to_remove (list): A list of buffer indices to remove.
226+
Must be sorted in descending order.
227+
228+
Returns:
229+
list: The list of buffer indices that were actually removed.
230+
"""
231+
if not model.buffers:
232+
o2o.log("Model has no buffers to remove.")
233+
return []
234+
235+
removed_indices = []
236+
237+
# Sort in descending order to avoid index shifting issues during removal
238+
for buffer_idx in sorted(buffer_indices_to_remove, reverse=True):
239+
if 0 <= buffer_idx < len(model.buffers):
240+
o2o.log(f" Removing buffer at index {buffer_idx}")
241+
del model.buffers[buffer_idx]
242+
removed_indices.append(buffer_idx)
243+
else:
244+
o2o.log(f" Warning: Buffer index {buffer_idx} out of bounds, skipping.")
245+
246+
if not removed_indices:
247+
return []
248+
249+
# Create a map for old index to new index after removal
250+
new_indices_map = {}
251+
current_new_idx = 0
252+
# Iterate over original buffer count
253+
original_buffer_count = len(model.buffers) + len(removed_indices)
254+
for old_idx in range(original_buffer_count):
255+
if old_idx not in buffer_indices_to_remove:
256+
new_indices_map[old_idx] = current_new_idx
257+
current_new_idx += 1
258+
259+
# Update tensor buffer references in all subgraphs
260+
for subgraph_idx, subgraph in enumerate(model.subgraphs):
261+
for tensor_idx, tensor in enumerate(subgraph.tensors):
262+
if tensor.buffer != -1: # -1 indicates no buffer
263+
if tensor.buffer in new_indices_map:
264+
old_buffer_idx = tensor.buffer
265+
tensor.buffer = new_indices_map[old_buffer_idx]
266+
# If tensor.buffer was removed, set to -1 (no buffer)
267+
elif tensor.buffer in buffer_indices_to_remove:
268+
tensor.buffer = -1
269+
270+
return sorted(removed_indices)
271+
272+
161273
def main():
162274
# Read the entire model from stdin
163275
data = sys.stdin.buffer.read()
@@ -203,11 +315,27 @@ def main():
203315
f"\nTotal unused tensors found across all subgraphs: {total_unused_tensors_count}"
204316
)
205317

318+
# After removing tensors, now process unused buffers
319+
# Use the mutable model directly since find_unused_buffers now supports both APIs
320+
unused_buffers = find_unused_buffers(model)
321+
if unused_buffers:
322+
o2o.log(
323+
f"Found {len(unused_buffers)} unused buffer(s): {', '.join(map(str, sorted(unused_buffers)))}"
324+
)
325+
actually_removed_buffers = remove_buffers_and_update_model(model, unused_buffers)
326+
if actually_removed_buffers:
327+
o2o.log(f"Removed {len(actually_removed_buffers)} buffer(s).")
328+
model_changed = True
329+
else:
330+
o2o.log("No buffers were actually removed during the process.")
331+
else:
332+
o2o.log("No unused buffers found.")
333+
206334
if model_changed:
207335
o2o.log("\nSaving modified model to stdout...")
208336
else:
209337
o2o.log(
210-
"\nNo tensors were actually removed from any subgraph. Saving original model to stdout."
338+
"\nNo tensors or buffers were actually removed. Saving original model to stdout."
211339
)
212340
o2o.save_model_to_stdout(model)
213341

0 commit comments

Comments
 (0)