diff --git a/difflogic/compiled_model.py b/difflogic/compiled_model.py index 307e578..e64e765 100644 --- a/difflogic/compiled_model.py +++ b/difflogic/compiled_model.py @@ -139,18 +139,16 @@ def get_gate_code(self, var1, var2, gate_op): def get_layer_code(self, layer_a, layer_b, layer_op, layer_id, prefix_sums): code = [] for var_id, (gate_a, gate_b, gate_op) in enumerate(zip(layer_a, layer_b, layer_op)): - if self.device == 'cpu' and layer_id == len(prefix_sums) - 1: + if layer_id == 0: + a = f"inp[{gate_a}]" + b = f"inp[{gate_b}]" + else: a = f"v{prefix_sums[layer_id - 1] + gate_a}" b = f"v{prefix_sums[layer_id - 1] + gate_b}" + if self.device == 'cpu' and layer_id == len(prefix_sums) - 1: code.append(f"\tout[{var_id}] = {self.get_gate_code(a, b, gate_op)};") else: assert not (self.device == 'cpu' and layer_id >= len(prefix_sums) - 1), (layer_id, len(prefix_sums)) - if layer_id == 0: - a = f"inp[{gate_a}]" - b = f"inp[{gate_b}]" - else: - a = f"v{prefix_sums[layer_id - 1] + gate_a}" - b = f"v{prefix_sums[layer_id - 1] + gate_b}" code.append( f"\tconst {BITS_TO_DTYPE[self.num_bits]} v{prefix_sums[layer_id] + var_id} = {self.get_gate_code(a, b, gate_op)};" )