Skip to content

Commit b04201b

Browse files
author
Sanggyu Lee
committed
retype.input_ids.py : support multiple subgraphs + It is pattern-based
1 parent e1cbbd0 commit b04201b

File tree

2 files changed

+30
-20
lines changed

2 files changed

+30
-20
lines changed

tools/o2o/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ Performs garbage collection by removing unreachable tensors and buffers, reducin
146146

147147
### `retype.input_ids.py`
148148

149-
Finds tensors named `input_ids` and changes their data type from int64 to int32. This filter is useful for models that need to be compatible with hardware or frameworks that expect input_ids to be 32-bit integers instead of 64-bit integers.
149+
Identifies `input_ids` tensors based on the graph structure (specifically, `INT64` tensors that are the indices input to `GATHER` operators and are also Subgraph Inputs) and changes their data type from `int64` to `int32`. This robust detection method works regardless of the tensor name. This filter is useful for models that need to be compatible with hardware or frameworks that expect input_ids to be 32-bit integers.
150150

151151
##
152152

tools/o2o/retype.input_ids.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,35 @@ def process_subgraph(subgraph):
1818
o2o.log(f"Processing subgraph with {len(subgraph.tensors)} tensors")
1919

2020
retyped_count = 0
21-
for tensor in subgraph.tensors:
22-
tensor_name = o2o.get_tensor_name(tensor)
23-
24-
# Check if this is the input_ids tensor
25-
if tensor_name == "tico::input_ids":
26-
# Check if current type is int64
27-
if tensor.type == circle.TensorType.INT64:
28-
old_type = "int64"
29-
new_type = "int32"
30-
31-
# Change type to int32
32-
tensor.type = circle.TensorType.INT32
33-
34-
o2o.log(f"Retyped tensor: {tensor_name} {old_type}{new_type}")
35-
retyped_count += 1
36-
else:
37-
o2o.log(
38-
f"Found input_ids tensor but type is not int64 (current type: {tensor.type})"
39-
)
21+
22+
# Collect subgraph inputs for quick lookup
23+
subgraph_inputs = set(subgraph.inputs)
24+
25+
for op_idx, op in enumerate(subgraph.operators):
26+
opcode = model.operatorCodes[op.opcodeIndex]
27+
28+
if opcode.builtinCode == circle.BuiltinOperator.GATHER:
29+
# GATHER input 1 is the indices tensor (params, indices)
30+
if op.inputs is not None and len(op.inputs) > 1:
31+
input_tensor_idx = op.inputs[1]
32+
33+
# Check if this input is a subgraph input
34+
if input_tensor_idx in subgraph_inputs:
35+
tensor = subgraph.tensors[input_tensor_idx]
36+
37+
# Check if type is INT64
38+
if tensor.type == circle.TensorType.INT64:
39+
tensor_name = o2o.get_tensor_name(tensor)
40+
old_type = "int64"
41+
new_type = "int32"
42+
43+
# Change type to int32
44+
tensor.type = circle.TensorType.INT32
45+
46+
o2o.log(
47+
f"Retyped tensor: {tensor_name} (Index: {input_tensor_idx}) {old_type}{new_type}"
48+
)
49+
retyped_count += 1
4050

4151
if retyped_count > 0:
4252
o2o.log(f"Retyped {retyped_count} input_ids tensors in this subgraph")

0 commit comments

Comments
 (0)