Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hls4ml/converters/keras_v3/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def handle(
match cls_name:
case 'Concatenate':
rank = len(output_shape)
class_name = f'Concatenate{rank}d'
class_name = 'Concatenate'
op = f'Concatenate{rank}d'
config['axis'] = layer.axis
case 'Dot':
msg = (
Expand Down
66 changes: 50 additions & 16 deletions hls4ml/model/optimizer/passes/bit_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,16 @@ def _produce_kif(layer: Layer) -> KIF_t:

@_produce_kif.register
def _(layer: Input):
k = np.ones(get_output_shape(layer), dtype=np.int16)
i = f = np.full(get_output_shape(layer), 126, dtype=np.int16)
shape = get_output_shape(layer)
if layer.attributes.get('trusted', False):
precision: FixedPrecisionType = layer.get_output_variable().type.precision
k, i, f = precision.signed, precision.integer - precision.signed, precision.fractional
k = np.full(shape, k, dtype=np.int16)
i = np.full(shape, i, dtype=np.int16)
f = np.full(shape, f, dtype=np.int16)
else:
k = np.ones(shape, dtype=np.int16)
i = f = np.full(shape, 126, dtype=np.int16)
return k, i, f


Expand Down Expand Up @@ -630,8 +638,8 @@ def kif_arrs_to_ints(arr: tuple[np.ndarray, np.ndarray, np.ndarray]):
return tuple(int(np.max(a)) for a in arr)


def produce_kif(layer: Layer) -> KIF_t:
if layer.attributes.get('_produce_kif'):
def produce_kif(layer: Layer, force_reset=False) -> KIF_t:
if layer.attributes.get('_produce_kif') and not force_reset:
return layer.attributes['_produce_kif']
kif = _produce_kif(layer)
layer.attributes['_produce_kif'] = kif
Expand Down Expand Up @@ -885,7 +893,9 @@ def transform(self, model: 'ModelGraph'):
for node in model.graph.values():
if node.attributes.get('bit_exact_transformed'):
continue
produce_kif(node) # Shrink FixedPointQuantizer bits when possible to be used in backward flow (requested_kif).
produce_kif(
node, force_reset=True
) # Shrink FixedPointQuantizer bits when possible to be used in backward flow (requested_kif).

for node in model.graph.values():
if node.attributes.get('bit_exact_transformed'):
Expand All @@ -894,14 +904,31 @@ def transform(self, model: 'ModelGraph'):
node.attributes['bit_exact_transformed'] = True

for node in model.graph.values():
if node.attributes.get('_produce_kif'):
if '_produce_kif' in node.attributes:
del node.attributes['_produce_kif']
if node.attributes.get('_request_kif'):
if '_request_kif' in node.attributes:
del node.attributes['_request_kif']

return True


def get_output_layers_and_quantizers(
node: Layer, layers: list | None = None, quantizers: list | None = None
) -> tuple[list[Layer], list[FixedPointQuantizer]]:

layers = layers if layers is not None else []
quantizers = quantizers if quantizers is not None else []
for _node in get_output_layers(node):
if isinstance(_node, FixedPointQuantizer):
quantizers.append(_node)
elif isinstance(_node, (Reshape, Transpose, Concatenate)):
layers.append(_node)
get_output_layers_and_quantizers(_node, layers, quantizers)
else:
raise ValueError(f'Layer {node.name} ({node.class_name}) unexpected input layer chain.')
return layers, quantizers


class FixInputPrecision(OptimizerPass):
def match(self, node: Layer):
if not isinstance(node, Input):
Expand All @@ -911,21 +938,17 @@ def match(self, node: Layer):
return node.get_output_variable().type.precision.width > 100

def transform(self, model, node: Layer):
out_layers: list[FixedPointQuantizer] = get_output_layers(node) # type: ignore
for layer in out_layers:
assert isinstance(
layer, FixedPointQuantizer
), f'Input {node.name} connected to non-quantizer {layer.name} with non-trivial configuration'
layers, out_quantizers = get_output_layers_and_quantizers(node)

if len(out_layers) == 0: # Input connected to nothing
if len(out_quantizers) == 0: # Input connected to nothing
new_type = to_hls4ml_fixed(0, 0, 1, f'{node.name}_t')
node.get_output_variable().type = new_type
node.model.config.layer_name_precision[node.name] = str(new_type)
return False

sat_modes = [l.SAT for l in out_layers]
sat_modes = [l.SAT for l in out_quantizers]
sat_modes_set = set(sat_modes)
rnd_modes = [l.RND for l in out_layers]
rnd_modes = [l.RND for l in out_quantizers]
rnd_modes_set = set(rnd_modes)
illegal_sat_modes = sat_modes_set - {'WRAP', 'SAT', 'SAT_SYM'}
illegal_rnd_modes = rnd_modes_set - {'TRN', 'RND'}
Expand All @@ -936,7 +959,7 @@ def transform(self, model, node: Layer):
if illegal_rnd_modes:
warn(f'Saturation mode {illegal_rnd_modes} may compromise bit-exactness. Forcing at maximum 24 fractional bits.')

kifs = [_produce_kif(l) for l in out_layers]
kifs = [_produce_kif(l) for l in out_quantizers]
i = np.max([np.max(i) for _, i, _ in kifs])
k = np.max([np.max(k) for k, _, _ in kifs])
if illegal_rnd_modes:
Expand All @@ -951,4 +974,15 @@ def transform(self, model, node: Layer):
new_type.precision.saturation_mode = 'SAT'
node.get_output_variable().type = new_type
node.model.config.layer_name_precision[node.name] = str(new_type)
node.attributes['trusted'] = True

for layer in layers:
produce_kif(layer, force_reset=True)
for layer in layers:
register_precision(layer)
for layer in layers:
if '_produce_kif' in layer.attributes:
del layer.attributes['_produce_kif']
if '_request_kif' in layer.attributes:
del layer.attributes['_request_kif']
return False
4 changes: 2 additions & 2 deletions hls4ml/model/optimizer/passes/hgq_proxy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from hls4ml.model.attributes import Attribute, TypeAttribute, WeightAttribute
from hls4ml.model.layers import Activation, Layer, Reshape, register_layer
from hls4ml.model.layers import Activation, Layer, Reshape, Transpose, register_layer
from hls4ml.model.optimizer import OptimizerPass, register_pass
from hls4ml.model.types import FixedPrecisionType, UnspecifiedPrecisionType

Expand Down Expand Up @@ -100,7 +100,7 @@ def propagate(self, node: Layer, precision: FixedPrecisionType):
node.attributes['result_t'].precision = precision
node.attributes['_result_t_propagated'] = True

if not isinstance(node, Reshape):
if not isinstance(node, (Reshape, Transpose)):
return node

inp_layer = get_input_layers(node)[0]
Expand Down
Loading