diff --git a/.gitignore b/.gitignore index 22c8ff685b..8fb87927ce 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ docs/autodoc/* hls4mlprj_* *~ *.ipynb_checkpoints/ +*.bak diff --git a/example-models b/example-models index c6bb3c0686..e7a9dee394 160000 --- a/example-models +++ b/example-models @@ -1 +1 @@ -Subproject commit c6bb3c0686d52439d8c53d7407903bf78e852562 +Subproject commit e7a9dee394b6c1f6e0eb23178d34e55f077297fe diff --git a/hls4ml/converters/onnx/core.py b/hls4ml/converters/onnx/core.py index 8ad851426d..6cb11fe20e 100644 --- a/hls4ml/converters/onnx/core.py +++ b/hls4ml/converters/onnx/core.py @@ -105,7 +105,7 @@ def parse_batchnorm_layer(node, input_names, input_shapes, graph): return layer -@onnx_handler('Quant') +@onnx_handler('Quant', 'IntQuant') def parse_quant_layer(node, input_names, input_shapes, graph): layer = {} @@ -120,3 +120,14 @@ def parse_quant_layer(node, input_names, input_shapes, graph): layer['signed'] = bool(get_onnx_attribute(node, 'signed')) return layer + + +@onnx_handler('BipolarQuant') +def parse_bipolar_quant_layer(node, input_names, input_shapes, graph): + layer = {} + + layer['class_name'] = 'BipolarQuant' + layer['name'] = node.name + layer['inputs'] = input_names + layer['outputs'] = list(node.output) + return layer diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index e3c293dd46..d7610c011a 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -693,6 +693,11 @@ def replace_node(self, old_node, new_node): repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node.outputs)} repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node.inputs)}) + for old_output in old_node.outputs: + if old_output in self.outputs: + new_output = repl[old_output] + self.outputs = [new_output if name == old_output else name for name in self.outputs] + for node in self.graph.values(): for i, n in enumerate(node.inputs): if n in repl: @@ -703,11 +708,6 @@ def replace_node(self, old_node, new_node): self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items()) - old_name = old_node.name - if old_name in self.outputs: - new_name = new_node.name - self.outputs = [new_name if name == old_name else name for name in self.outputs] - def split_node(self, old_node, new_node1, new_node2): """Replace an existing node in the graph with two nodes in sequence. @@ -728,6 +728,11 @@ def split_node(self, old_node, new_node1, new_node2): repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node2.outputs)} repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node1.inputs)}) + for old_output in old_node.outputs: + if old_output in self.outputs: + new_output = repl[old_output] + self.outputs = [new_output if name == old_output else name for name in self.outputs] + for node in self.graph.values(): for i, n in enumerate(node.inputs): if n in repl: @@ -745,9 +750,6 @@ def split_node(self, old_node, new_node1, new_node2): new_graph[key] = value self.graph = new_graph - if old_node.name in self.outputs: - self.outputs = [new_node2.name if name == old_node.name else name for name in self.outputs] - def next_layer(self): self.index += 1 return self.index diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index f72d3595ed..a47dca285a 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -432,6 +432,20 @@ def initialize(self): self.add_output_variable(shape) +class BipolarQuant(Layer): # The QONNX quantization layer + """ + This is a QONNX quantization layer. Optimizations should convert it + before HLS is produced. + """ + + _expected_attributes = [] + + def initialize(self): + inp = self.get_input_variable(self.inputs[0]) + shape = inp.shape + self.add_output_variable(shape) + + class Reshape(Layer): _expected_attributes = [ Attribute('target_shape', value_type=typing.Sequence), @@ -945,7 +959,7 @@ class Activation(Layer): def initialize(self): inp = self.get_input_variable() shape = inp.shape - self.add_output_variable(shape) + self.add_output_variable(shape, precision=self.get_attr('quantizer_precision')) # for xor precision if 'n_in' not in self.attributes: self.set_attr('n_in', self.get_input_variable().size()) @@ -1826,6 +1840,8 @@ def initialize(self): 'GarNet': GarNet, 'GarNetStack': GarNetStack, 'Quant': Quant, + 'IntQuant': Quant, + 'BipolarQuant': BipolarQuant, 'ApplyAlpha': ApplyAlpha, 'BatchNormOnnx': BatchNormOnnx, 'LayerGroup': LayerGroup, diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 8fc6876942..c0a99a66c5 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -36,10 +36,15 @@ 'reshape_constant', 'resize_remove_constants', 'quant_constant_parameters', - 'quant_to_activation', + 'bipolar_quant_constant_parameters', 'fuse_quant_with_constant', + 'fuse_bipolar_quant_with_constant', + 'quant_to_activation', + 'bipolar_quant_to_activation', 'const_quant_to_const_alpha', + 'const_bipolar_quant_to_const_alpha', 'quant_to_alpha_activation_alpha', + 'bipolar_quant_to_alpha_activation_alpha', 'batch_norm_onnx_constant_parameters', 'constant_batch_norm_fusion', 'merge_two_constants', diff --git a/hls4ml/model/optimizer/passes/bipolar_quant_opt.py b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py new file mode 100644 index 0000000000..59a78f947b --- /dev/null +++ b/hls4ml/model/optimizer/passes/bipolar_quant_opt.py @@ -0,0 +1,263 @@ +""" +This file includes optimizations related to BipolarQuant nodes. + +""" + +import copy + +import numpy as np + +from hls4ml.model.layers import Activation, ApplyAlpha, BipolarQuant, Constant +from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.quantizers import BinaryQuantizer +from hls4ml.model.types import XnorPrecisionType + + +class BipolarQuantConstantParameters(OptimizerPass): + """Remove Constant from the BipolarQaunt node parameters (but not input[0])""" + + def match(self, node): + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 2 + and (node.get_input_node(node.inputs[1]) and isinstance(node.get_input_node(node.inputs[1]), Constant)) + ) + + return is_match + + def transform(self, model, node): + """ + Remove Constant from the BipolarQuant node parameters (but not input[0]) + """ + if node.get_input_node(node.inputs[1]): + scale_node = node.get_input_node(node.inputs[1]) + if isinstance(scale_node, Constant): + node.set_attr('scale', scale_node.get_attr('value')) + node.inputs[1] = '' + model.remove_node(scale_node) + + node.inputs = [inp for inp in node.inputs if inp] + if len(node.inputs) != 1: + raise RuntimeError("hls4ml only supports constant scale") + + return True + + +class BipolarQuantToActivation(OptimizerPass): + """ + This is for the case when scale is 1. It is a a 1:1 transformation of a BipolarQuant to an Activation. + This is not called when the input is constant. + """ + + def match(self, node): + # only matches after the other inputs are already folded + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 1 + and not isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + # Only match if the scale is 1 + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + is_match = (scale == 1.0).all() + + return is_match + + def transform(self, model, node): + """ + Change BipolarQuant node to Activation + """ + precision = XnorPrecisionType() + quantizer = BinaryQuantizer(bits=1) + + attributes = {'activation': 'binary_tanh', 'quantizer': quantizer, 'quantizer_precision': precision} + + # update the configuration (not setting the precision since can't specify xnor type) + config = model.config.get_layer_config(node) + new_name = f'{node.name}_act' + model.config.set_name_config(new_name, config) + model.config.parse_name_config(new_name, config) + + new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], list(node.outputs)) + model.replace_node(node, new_node) + return True + + +class FuseBipolarQuantWithConstant(OptimizerPass): + """ + This is for the case when scale is 1 and the input is a constant + """ + + def match(self, node): + + # only matches after the other inputs are already folded + # and scale is unit + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 1 + and isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + # Only match if the scale is 1 + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + is_match = (scale == 1.0).all() + + return is_match + + def transform(self, model, node): + """ + Fuse BipolarQuant with Constant. + """ + precision = XnorPrecisionType() + quantizer = BinaryQuantizer(bits=1) + + const_node = node.get_input_node(node.inputs[0]) + const_node.set_attr('quantizer', quantizer) + const_node.get_output_variable().type.precision = precision + + # remove the Quant node + model.remove_node(node) + return True + + +class BipolarQuantToAlphaActivationAlpha(OptimizerPass): + """ + This is for the case when scale is not 1. It is a a 1:3 transformation of + a BipolarQuant to an ApplyAlpha (to scale), Activation, ApplyAlpho (to rescale). + + NOTE: It needs to be scheduled after BipolarQuantToActivation (or we need to make the match criteria stricter) + """ + + def match(self, node): + # only matches after the other inputs are already folded + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 1 + and not isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + return is_match + + def transform(self, model, node): + """ + Change quant node to ApplyAlhpa, Activation, ApplyAlpha + """ + + # Do the Activation as in the simple case + + precision = XnorPrecisionType() + quantizer = BinaryQuantizer(bits=1) + + activation_attributes = {'activation': 'binary_tanh', 'quantizer': quantizer, 'quantizer_precision': precision} + + # update the configuration (not setting the precision since can't specify xnor type) + config = model.config.get_layer_config(node) + act_config = copy.deepcopy(config) + act_name = f'{node.name}_act' + model.config.set_name_config(act_name, act_config) + model.config.parse_name_config(act_name, act_config) + + new_node = model.make_node(Activation, act_name, activation_attributes, [node.inputs[0]], [x for x in node.outputs]) + model.replace_node(node, new_node) + + # but now add the ApplyAlhpas before and after + + inshape = node.get_input_variable().shape + + scale = node.get_attr('scale') + bias = np.array(0) + + attributes_scale = {'n_filt': -1} + attributes_rescale = {'n_filt': -1} + + scale_config = copy.deepcopy(config) + scale_name = f'{node.name}_scale' + model.config.set_name_config(scale_name, scale_config) + model.config.parse_name_config(scale_name, scale_config) + + rescale_config = config # no need to deep copy the last + rescale_name = f'{node.name}_rescale' + model.config.set_name_config(rescale_name, rescale_config) + model.config.parse_name_config(rescale_name, rescale_config) + + firstscale = 1 / scale + firstbias = bias + attributes_scale['scale_data'] = np.broadcast_to(firstscale, inshape) + attributes_scale['bias_data'] = np.broadcast_to(firstbias, inshape) + + scale_node = model.make_node(ApplyAlpha, scale_name, attributes_scale, [node.inputs[0]]) + model.insert_node(scale_node) + + rescale = scale + rebias = -bias * scale + attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape) + attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape) + + rescale_node = model.make_node(ApplyAlpha, rescale_name, attributes_rescale, [new_node.outputs[0]]) + model.insert_node(rescale_node) + + return True + + +class ConstBipolarQuantToConstAlpha(OptimizerPass): + """ + This is for the case when scale is not 1. It is a a 1:3 transformation of + a BipolarQuant to an ApplyAlpha (to scale), Activation, ApplyAlpho (to unscale), but an input + consts allows for optimization, so the ApplyAlpha (to scale), Activation are + optimized away right away. + """ + + def match(self, node): + # only matches after the other inputs are already folded + is_match = ( + isinstance(node, BipolarQuant) + and len(node.inputs) == 1 + and isinstance(node.get_input_node(node.inputs[0]), Constant) + ) + + if is_match: # to make sure this is a quant node with inputs + scale = node.get_attr('scale') + is_match = is_match and ((scale != np.ones_like(scale)).any()) + return is_match + + def transform(self, model, node): + """ + Change Constant + Quant node to Constant, ApplyAlpha + """ + + precision = XnorPrecisionType() + quantizer = BinaryQuantizer(bits=1) + + const_node = node.get_input_node(node.inputs[0]) + + scale = node.get_attr('scale') + bias = np.array(0) # zeropt not defined for bipolar quants + + # caclucate the new value + new_val = const_node.get_attr('value') / scale + bias + const_node.set_attr('value', new_val) + const_node.set_attr('quantizer', quantizer) + + const_node.get_output_variable().type.precision = precision + + inshape = node.get_input_variable().shape + + attributes_rescale = {'n_filt': -1} + + rescale_config = copy.deepcopy(model.config.get_layer_config(node)) + rescale_name = f'{node.name}_rescale' + model.config.set_name_config(rescale_name, rescale_config) + model.config.parse_name_config(rescale_name, rescale_config) + + rescale = scale + rebias = -bias * scale + attributes_rescale['scale_data'] = np.broadcast_to(rescale, inshape) + attributes_rescale['bias_data'] = np.broadcast_to(rebias, inshape) + + rescale_node = model.make_node( + ApplyAlpha, rescale_name, attributes_rescale, [x for x in node.inputs], [x for x in node.outputs] + ) + model.replace_node(node, rescale_node) + + return True diff --git a/hls4ml/model/optimizer/passes/bn_fuse.py b/hls4ml/model/optimizer/passes/bn_fuse.py index 0702033deb..c73e06d66e 100644 --- a/hls4ml/model/optimizer/passes/bn_fuse.py +++ b/hls4ml/model/optimizer/passes/bn_fuse.py @@ -2,7 +2,7 @@ from hls4ml.model.layers import BatchNormalization, Conv1D, Conv2D, Dense from hls4ml.model.optimizer import OptimizerPass -from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, UnspecifiedPrecisionType +from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, UnspecifiedPrecisionType, XnorPrecisionType class FuseBatchNormalization(OptimizerPass): @@ -25,6 +25,11 @@ def match(self, node): and isinstance(prev_node.get_output_variable().type.precision, UnspecifiedPrecisionType) ) if basic_match: + # don't merge if binary weights. + # (Can technically merge if only one of weight or input is not binary, be more conservative here) + weight_t = prev_node.get_attr('weight_t') + if weight_t and isinstance(weight_t.precision, XnorPrecisionType): + return False s0 = prev_node.weights['weight'].data_unquantized b0 = prev_node.weights['bias'].data_unquantized s1 = node.weights['scale'].data_unquantized diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index 6c9badd832..e8114e68a2 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -83,7 +83,7 @@ class QuantToActivation(OptimizerPass): This is for the case when scale is a (positive) power of 2 and zeropt is 0. It is a a 1:1 transformation of a Quant to an Activation. - As an optimization, this is not called when the input is constant. + This is not called when the input is constant. """ def match(self, node): @@ -148,7 +148,7 @@ def transform(self, model, node): class FuseQuantWithConstant(OptimizerPass): """ - This is for the case when scale is a positive power of 2 and zeropt is 0. + This is for the case when scale is a positive power of 2 and zeropt is 0, and when the input is a constant. """ def match(self, node): @@ -207,7 +207,7 @@ def transform(self, model, node): class QuantToAlphaActivationAlpha(OptimizerPass): """ This is for the case when scale is not power-of-2 or zeropt is not 0. It is a a 1:3 transformation of - a Quant to an ApplyAlpha (to scale), Activatio, ApplyAlpho (to rescale). + a Quant to an ApplyAlpha (to scale), Activation, ApplyAlpho (to rescale). NOTE: It needs to be scheduled after QuantToActivation (or we need to make the match criteria stricter) """ @@ -294,6 +294,8 @@ class ConstQuantToConstAlpha(OptimizerPass): a Quant to an ApplyAlpha (to scale), Activation, ApplyAlpho (to unscale), but an input consts allows for optimization, so the ApplyAlpha (to scale), Activation are optimized away right away. + + NOTE: It needs to be scheduled after FuseQuantWithConstant (or we need to make the match criteria stricter) """ def match(self, node): @@ -302,10 +304,6 @@ def match(self, node): isinstance(node, Quant) and len(node.inputs) == 1 and isinstance(node.get_input_node(node.inputs[0]), Constant) ) - if is_match: # to make sure this is a quant node with inputs - scale = node.get_attr('scale') - bias = node.get_attr('zeropt') - is_match = is_match and ((scale != np.ones_like(scale)).any() or (bias != np.zeros_like(bias)).any()) return is_match def transform(self, model, node): diff --git a/hls4ml/model/quantizers.py b/hls4ml/model/quantizers.py index cedee0c984..833d27a0d2 100644 --- a/hls4ml/model/quantizers.py +++ b/hls4ml/model/quantizers.py @@ -67,9 +67,9 @@ def __call__(self, data): ones = np.ones_like(data) quant_data = data if self.bits == 1: - quant_data = np.where(data > 0, ones, zeros).astype('int') + quant_data = np.where(data >= 0, ones, zeros).astype('int') if self.bits == 2: - quant_data = np.where(data > 0, ones, -ones) + quant_data = np.where(data >= 0, ones, -ones) return quant_data def serialize_state(self): diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_activation.h b/hls4ml/templates/catapult/nnet_utils/nnet_activation.h index fb72460b96..2f7c4cc8c9 100644 --- a/hls4ml/templates/catapult/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_activation.h @@ -1059,23 +1059,35 @@ void prelu(data_T data[CONFIG_T::n_in], param_T alpha[CONFIG_T::n_in], res_T res } } +template +inline typename std::enable_if<(!std::is_same>::value), res_T>::type binary_cast(data_T data) { + return static_cast(data); +} + +// should choose this via function overloading +template +inline typename std::enable_if<(std::is_same>::value), res_T>::type binary_cast(data_T data) { + return (data > 0) ? static_cast(data) : static_cast(0); +} + // ************************************************* // Binary TanH Activation // ************************************************* template void binary_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { //#pragma HLS PIPELINE + using cache_T = ac_int<2, true>; data_T datareg; - res_T cache; + cache_T cache; for (int ii = 0; ii < CONFIG_T::n_in; ii++) { datareg = data[ii]; - if (datareg > 0) + if (datareg >= 0) cache = 1; else cache = -1; - res[ii] = (res_T)cache; + res[ii] = binary_cast(cache); } } diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h index 82570dbe51..5778c0629c 100644 --- a/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_activation_stream.h @@ -871,6 +871,7 @@ void prelu(ac_channel &data, const param_T alpha[CONFIG_T::n_in], ac_cha // Binary TanH Activation // ************************************************* template void binary_tanh(ac_channel &data, ac_channel &res) { + using cache_T = ac_int<2, true>; PReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { //#pragma HLS PIPELINE @@ -881,11 +882,14 @@ template void binary_tanh(ac_chan PReLUPackLoop: for (int j = 0; j < res_T::size; j++) { - //#pragma HLS UNROLL - if (in_data[j] > 0) - out_data[j] = (typename res_T::value_type)1; + cache_T cache; + + if (in_data[j] >= 0) + cache = 1; else - out_data[j] = (typename res_T::value_type) - 1; + cache = -1; + + out_data[j] = binary_cast(cache); } res.write(out_data); } diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm_stream.h b/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm_stream.h index 48085f82dc..0d8c26e42f 100644 --- a/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm_stream.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_batchnorm_stream.h @@ -68,7 +68,7 @@ void normalize_binary_tanh(ac_channel &data, ac_channel threshold[i * data_T::size + j]) ? 1 : 0; + out_data[j] = (in_data[j] >= threshold[i * data_T::size + j]) ? 1 : 0; } res.write(out_data); diff --git a/hls4ml/templates/catapult/nnet_utils/nnet_mult.h b/hls4ml/templates/catapult/nnet_utils/nnet_mult.h index 7379eec489..f73d9c91f1 100755 --- a/hls4ml/templates/catapult/nnet_utils/nnet_mult.h +++ b/hls4ml/templates/catapult/nnet_utils/nnet_mult.h @@ -105,7 +105,7 @@ inline typename std::enable_if>::value && std::is_same>::value, ac_int>::type cast(typename CONFIG_T::accum_t x) { - return (ac_int)(x - CONFIG_T::n_in / 2) * 2; + return static_cast>((x * 2 - CONFIG_T::n_in).to_ac_int()); } template diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation.h index ab1874ec10..f118ecb05c 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation.h @@ -458,20 +458,33 @@ void prelu(const data_T &data, const typename CONFIG_T::param_t &alpha, res_T &r } } +template +inline typename std::enable_if<(!std::is_same>::value), res_T>::type binary_cast(data_T data) { + return static_cast(data); +} + +// should choose this via function overloading +template +inline typename std::enable_if<(std::is_same>::value), res_T>::type binary_cast(data_T data) { + return (data > 0) ? static_cast(data) : static_cast(0); +} + // ************************************************* // Binary TanH Activation // ************************************************* template void binary_tanh(const data_T &data, res_T &res) { + using cache_T = ac_int<2, true>; + #pragma unroll for (int ii = 0; ii < CONFIG_T::n_in; ii++) { auto datareg = data[ii]; - typename res_T::value_type cache; - if (datareg > 0) + cache_T cache; + if (datareg >= 0) cache = 1; else cache = -1; - res[ii] = cache; + res[ii] = binary_cast(cache); } } diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation_stream.h index 13de5ab3bb..e860c38988 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_activation_stream.h @@ -1,6 +1,7 @@ #ifndef NNET_ACTIVATION_STREAM_H_ #define NNET_ACTIVATION_STREAM_H_ +#include "nnet_activation.h" #include "nnet_common.h" #include "nnet_types.h" @@ -661,6 +662,7 @@ template void hard_tanh_str // Binary TanH Activation // ************************************************* template void binary_tanh_stream() { + using cache_T = ac_int<2, true>; BinaryTanHActLoop: [[intel::initiation_interval( 1)]] for (int i = 0; i < CONFIG_T::n_in / std::tuple_size::value_type>{}; i++) { @@ -671,10 +673,14 @@ template void binary_tanh_s BinaryTanHPackLoop: #pragma unroll for (int j = 0; j < std::tuple_size::value_type>{}; j++) { - if (in_data[j] > 0) - out_data[j] = static_cast::value_type::value_type>(1); + cache_T cache; + + if (in_data[j] >= 0) + cache = 1; else - out_data[j] = static_cast::value_type::value_type>(-1); + cache = -1; + + out_data[j] = binary_cast::value_type::value_type>(cache); } res_pipe::write(out_data); diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_mult.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_mult.h index c7dfc2d7c5..ff3f34ebb6 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_mult.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_mult.h @@ -88,15 +88,15 @@ template class weight_exponential : public Product { // TO-DO: These may need extra variants if ac_int types are used in more places template inline typename std::enable_if>::value && - std::is_same>::value, - ac_int>::type + std::is_same>::value, + ac_int::val + 2, true>>::type cast(typename CONFIG_T::accum_t x) { - return static_cast>(((x - CONFIG_T::n_in / 2) * 2).to_ac_int()); + return static_cast::val + 2, true>>((x * 2 - CONFIG_T::n_in).to_ac_int()); } template inline typename std::enable_if>::value && - !std::is_same>::value, + !std::is_same>::value, res_T>::type cast(typename CONFIG_T::accum_t x) { return static_cast(x); diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_printf.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_printf.h index 5fec90d1aa..570534bdc6 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_printf.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_printf.h @@ -1,6 +1,8 @@ #ifndef NNET_PRINTF_H_ #define NNET_PRINTF_H_ +namespace nnet { + #ifdef __SYCL_DEVICE_ONLY__ #define CL_CONSTANT __attribute__((opencl_constant)) #else @@ -15,4 +17,5 @@ using namespace sycl; ext::oneapi::experimental::printf(_format, ##__VA_ARGS__); \ } +} // namespace nnet #endif diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h index 1edf9e6641..ac85e0b2cc 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h @@ -785,23 +785,34 @@ void prelu(data_T data[CONFIG_T::n_in], param_T alpha[CONFIG_T::n_in], res_T res } } +template +inline typename std::enable_if<(!std::is_same>::value), res_T>::type binary_cast(data_T data) { + return static_cast(data); +} + +// should choose this via function overloading +template +inline typename std::enable_if<(std::is_same>::value), res_T>::type binary_cast(data_T data) { + return (data > 0) ? static_cast(data) : static_cast(0); +} + // ************************************************* // Binary TanH Activation // ************************************************* template void binary_tanh(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { #pragma HLS PIPELINE - + using cache_T = ap_int<2>; data_T datareg; - res_T cache; + cache_T cache; for (int ii = 0; ii < CONFIG_T::n_in; ii++) { datareg = data[ii]; - if (datareg > 0) + if (datareg >= 0) cache = 1; else cache = -1; - res[ii] = (res_T)cache; + res[ii] = binary_cast(cache); } } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h index c5a9ae609b..50c6c4068c 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h @@ -750,21 +750,25 @@ void prelu(hls::stream &data, const param_T alpha[CONFIG_T::n_in], hls:: // ************************************************* template void binary_tanh(hls::stream &data, hls::stream &res) { + using cache_T = ap_int<2>; PReLUActLoop: for (int i = 0; i < CONFIG_T::n_in / res_T::size; i++) { #pragma HLS PIPELINE data_T in_data = data.read(); + cache_T cache; res_T out_data; PRAGMA_DATA_PACK(out_data) PReLUPackLoop: for (int j = 0; j < res_T::size; j++) { #pragma HLS UNROLL - if (in_data[j] > 0) - out_data[j] = (typename res_T::value_type)1; + if (in_data[j] >= 0) + cache = 1; else - out_data[j] = (typename res_T::value_type) - 1; + cache = -1; + + out_data[j] = binary_cast(cache); } res.write(out_data); } diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h index 00d1c6d12b..10ada63b2b 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h @@ -96,7 +96,7 @@ inline typename std::enable_if>::value && std::is_same>::value, ap_int>::type cast(typename CONFIG_T::accum_t x) { - return (ap_int)(x - CONFIG_T::n_in / 2) * 2; + return static_cast>(x * 2 - CONFIG_T::n_in); } template diff --git a/test/pytest/test_qonnx.py b/test/pytest/test_qonnx.py index bfa6e0a49c..7e675e1d94 100644 --- a/test/pytest/test_qonnx.py +++ b/test/pytest/test_qonnx.py @@ -1,8 +1,10 @@ +import copy import os import urllib from pathlib import Path import numpy as np +import onnx import pytest import qonnx.core.onnx_exec as oxe import qonnx.util.cleanup @@ -12,6 +14,7 @@ from qonnx.core.modelwrapper import ModelWrapper from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean from qonnx.transformation.gemm_to_matmul import GemmToMatMul +from qonnx.util.cleanup import cleanup_model import hls4ml @@ -231,6 +234,69 @@ def conv2d_small_mp_keras_model(): return model +@pytest.fixture(scope='module') +def bnn_fc_small_qonnx_model(): + """ + Load a small binarized model of a single fully connected layer. + """ + dl_file = str(example_model_path / "onnx/bnn_model_fc_1layer.onnx") + assert os.path.isfile(dl_file) + + model = ModelWrapper(dl_file) + model = cleanup_model(model) + model = model.transform(GemmToMatMul()) # ishape = (1, 3) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def bnn_fc_small_qonnx_model_scale_nonunit(bnn_fc_small_qonnx_model): + """ + Use scale factors of 0.5 to see if that works. + This is done by modifying the bnn_fc_small_qonnx_model, which has unit scale factors. + """ + + model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary? + new_iscale = onnx.helper.make_tensor("BipolarQuant_0_param0", 1, [1], [0.5]) + new_wscale = onnx.helper.make_tensor("BipolarQuant_1_param1", 1, [1], [0.5]) + old_iscale = old_wscale = None + for init in model.graph.initializer: + if init.name == "BipolarQuant_0_param0": + old_iscale = init + elif init.name == "BipolarQuant_1_param1": + old_wscale = init + model.graph.initializer.remove(old_iscale) + model.graph.initializer.remove(old_wscale) + model.graph.initializer.append(new_iscale) + model.graph.initializer.append(new_wscale) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + +@pytest.fixture(scope='module') +def bnn_fc_small_qonnx_model_scale_nonunit2(bnn_fc_small_qonnx_model): + """ + Use po2 scale factors to see if that works. + This is done by modifying the bnn_fc_small_qonnx_model, which has unit scale factors. + """ + + model = copy.deepcopy(bnn_fc_small_qonnx_model) # is copying neccessary? + new_iscale = onnx.helper.make_tensor("BipolarQuant_0_param0", 1, [1], [2]) + new_wscale = onnx.helper.make_tensor("BipolarQuant_1_param1", 1, [1], [4]) + old_iscale = old_wscale = None + for init in model.graph.initializer: + if init.name == "BipolarQuant_0_param0": + old_iscale = init + elif init.name == "BipolarQuant_1_param1": + old_wscale = init + model.graph.initializer.remove(old_iscale) + model.graph.initializer.remove(old_wscale) + model.graph.initializer.append(new_iscale) + model.graph.initializer.append(new_wscale) + model = qonnx.util.cleanup.cleanup_model(model) + return model + + # The actual tests @@ -428,3 +494,61 @@ def test_simple_model(model_name, io_type, backend, request): y_hls4ml = hls_model.predict(X) np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1) + + +# @pytest.mark.parametrize( +# 'model_name', +# ['bnn_fc_small_qonnx_model', 'bnn_fc_small_qonnx_model_scale_nonunit', 'bnn_fc_small_qonnx_model_scale_nonunit2'], +# ) +@pytest.mark.parametrize( + 'model_name', + ['bnn_fc_small_qonnx_model'], +) +@pytest.mark.parametrize( + 'backend,strategy', + [ + ('Catapult', 'Resource'), + ('Catapult', 'Latency'), + ('Vitis', 'Resource'), + ('Vitis', 'Latency'), + ('oneAPI', 'Resource'), + ], +) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_bnn(model_name, io_type, backend, strategy, request): + "Checks if a basic binarized model works correctly." + qonnx_model = request.getfixturevalue(model_name) + + config = hls4ml.utils.config.config_from_onnx_model( + qonnx_model, granularity='name', backend=backend, default_precision='fixed<16,6>' + ) + config['Model']['Strategy'] = strategy + hls_model = hls4ml.converters.convert_from_onnx_model( + qonnx_model, + output_dir=str(test_root_path / f'hls4mlprj_onnx_{model_name}_{io_type}_{backend}_{strategy}'), + io_type=io_type, + backend=backend, + hls_config=config, + ) + hls_model.compile() + + data_x = np.array( + [ + [[+1, +1, +1]], + [[+1, +1, -1]], + [[+1, -1, +1]], + [[-1, -1, -1]], + [[-1, +1, +1]], + [[-1, +1, -1]], + [[-1, -1, +1]], + [[-1, -1, -1]], + ], + dtype=np.float32, + ) + for x in data_x: + idict = {qonnx_model.graph.input[0].name: x} + y_qonnx = oxe.execute_onnx(qonnx_model, idict)[qonnx_model.graph.output[0].name] + y_hls4ml = hls_model.predict(x[0]) + # note, y_hls4ml returns xnor type, so let's interpret it + y_hls4ml_logical = 2 * y_hls4ml - 1 + np.testing.assert_array_equal(y_qonnx.ravel(), y_hls4ml_logical.ravel())