From 6faeca9be08c5c036bc2d0120467f48215bac3b7 Mon Sep 17 00:00:00 2001 From: Abhishikth Mallampalli Date: Thu, 14 Aug 2025 00:41:00 -0500 Subject: [PATCH] fixing the error when reshape present next to input layer --- hls4ml/model/optimizer/__init__.py | 1 + .../model/optimizer/passes/absorb_reshape.py | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 hls4ml/model/optimizer/passes/absorb_reshape.py diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 8fc6876942..0a0b5617f6 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -73,6 +73,7 @@ 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'enforce_proxy_model_embedded_config', + 'absorb_reshape_into_input', 'bit_exact', 'fuse_fixed_point_quantizer', 'fix_input_precision', diff --git a/hls4ml/model/optimizer/passes/absorb_reshape.py b/hls4ml/model/optimizer/passes/absorb_reshape.py new file mode 100644 index 0000000000..a1fb4353c4 --- /dev/null +++ b/hls4ml/model/optimizer/passes/absorb_reshape.py @@ -0,0 +1,28 @@ +from hls4ml.model.layers import Input, Reshape +from hls4ml.model.optimizer import OptimizerPass + + +class AbsorbReshapeIntoInput(OptimizerPass): + def match(self, node): + # Looking for a Reshape layer whose input is an Input layer + if isinstance(node, Reshape): + inp_node = node.get_input_node() + if isinstance(inp_node, Input): + return True + return False + + def transform(self, model, node): + # node == Reshape layer that matched + inp_node = node.get_input_node() + out_nodes = node.get_output_nodes() + if len(out_nodes) > 1: + raise Exception('Reshape node has multiple outputs') + + target_shape = node.get_attr('target_shape') + + input_output_var = inp_node.get_output_variable() + input_output_var.shape = target_shape + + model.remove_node(node) + + return True # because we modified the graph