diff --git a/pyproject.toml b/pyproject.toml index d633dd1..0a5ef8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ ignore = [ "F811", "PLR0911", # Too many return statements "PLR0912", # Too many branches + "PLR0915", # Too many statements ] select = [ "E", # pycodestyle diff --git a/src/irx/builders/llvmliteir.py b/src/irx/builders/llvmliteir.py index 6b3863f..281bdc2 100644 --- a/src/irx/builders/llvmliteir.py +++ b/src/irx/builders/llvmliteir.py @@ -38,11 +38,14 @@ def safe_pop(lst: list[ir.Value | ir.Function]) -> ir.Value | ir.Function: class VariablesLLVM: """Store all the LLVM variables that is used for the code generation.""" + BOOLEAN_TYPE: ir.types.Type FLOAT_TYPE: ir.types.Type DOUBLE_TYPE: ir.types.Type INT8_TYPE: ir.types.Type INT32_TYPE: ir.types.Type VOID_TYPE: ir.types.Type + STRING_TYPE: ir.types.Type + INT64_TYPE: ir.types.Type context: ir.context.Context module: ir.module.Module @@ -61,7 +64,7 @@ def get_data_type(self, type_name: str) -> ir.types.Type: ------- ir.Type: The LLVM data type. """ - if type_name == "float": + if type_name == "float32": return self.FLOAT_TYPE elif type_name == "double": return self.DOUBLE_TYPE @@ -73,8 +76,14 @@ def get_data_type(self, type_name: str) -> ir.types.Type: return self.INT8_TYPE elif type_name == "void": return self.VOID_TYPE + elif type_name == "bool": + return self.BOOLEAN_TYPE + elif type_name == "string": + return self.STRING_TYPE + elif type_name == "int64": + return self.INT64_TYPE - raise Exception("[EE]: type_name not valid.") + raise Exception(f"[EE]: type_name : {type_name} not valid. ") class LLVMLiteIRVisitor(BuilderVisitor): @@ -130,6 +139,7 @@ def initialize(self) -> None: self._llvm.INT8_TYPE = ir.IntType(8) self._llvm.INT32_TYPE = ir.IntType(32) self._llvm.VOID_TYPE = ir.VoidType() + self._llvm.BOOLEAN_TYPE = ir.IntType(1) def _add_builtins(self) -> None: # The C++ tutorial adds putchard() simply by defining it in the host @@ -220,10 +230,8 @@ def visit(self, expr: astx.BinaryOp) -> None: # If you build LLVM with RTTI, this can be changed to a # dynamic_cast for automatic error checking. var_lhs = expr.lhs - if not isinstance(var_lhs, astx.VariableExprAST): raise Exception("destination of '=' must be a variable") - # Codegen the rhs. self.visit(expr.rhs) llvm_rhs = safe_pop(self.result_stack) @@ -254,43 +262,103 @@ def visit(self, expr: astx.BinaryOp) -> None: if expr.op_code == "+": # note: it should be according the datatype, # e.g. for float it should be fadd - result = self._llvm.ir_builder.add(llvm_lhs, llvm_rhs, "addtmp") + + # handle float datatype + + if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type): + result = self._llvm.ir_builder.fadd( + llvm_lhs, llvm_rhs, "addtmp" + ) + + else: + # there's more conditions to be handled + result = self._llvm.ir_builder.add( + llvm_lhs, llvm_rhs, "addtmp" + ) + self.result_stack.append(result) return + elif expr.op_code == "-": # note: it should be according the datatype, # e.g. for float it should be fsub - result = self._llvm.ir_builder.sub(llvm_lhs, llvm_rhs, "subtmp") + + # handle the float datatype + if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type): + result = self._llvm.ir_builder.fsub( + llvm_lhs, llvm_rhs, "subtmp" + ) + else: + # note: be careful you should handle this as INT32 + result = self._llvm.ir_builder.sub( + llvm_lhs, llvm_rhs, "subtmp" + ) + self.result_stack.append(result) return + elif expr.op_code == "*": # note: it should be according the datatype, # e.g. for float it should be fmul - result = self._llvm.ir_builder.mul(llvm_lhs, llvm_rhs, "multmp") + + # handle float datatype + if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type): + result = self._llvm.ir_builder.fmul( + llvm_lhs, llvm_rhs, "multmp" + ) + else: + # note: be careful you should handle this + result = self._llvm.ir_builder.mul( + llvm_lhs, llvm_rhs, "multmp" + ) + self.result_stack.append(result) return + elif expr.op_code == "<": # note: it should be according the datatype, # e.g. for float it should be fcmp - cmp_result = self._llvm.ir_builder.cmp_unordered( - "<", llvm_lhs, llvm_rhs, "lttmp" - ) - result = self._llvm.ir_builder.uitofp( - cmp_result, self._llvm.INT32_TYPE, "booltmp" - ) + + # handle float type + if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type): + cmp_result = self._llvm.ir_builder.fcmp_ordered( + "<", llvm_lhs, llvm_rhs, "lttmp" + ) + result = self._llvm.ir_builder.uitofp( + cmp_result, self._llvm.FLOAT_TYPE, "booltmp" + ) + else: + # handle it depend on datatype + cmp_result = self._llvm.ir_builder.cmp_unordered( + "<", llvm_lhs, llvm_rhs, "lttmp" + ) + result = self._llvm.ir_builder.uitofp( + cmp_result, self._llvm.INT32_TYPE, "booltmp" + ) self.result_stack.append(result) return + elif expr.op_code == ">": # note: it should be according the datatype, # e.g. for float it should be fcmp - cmp_result = self._llvm.ir_builder.cmp_unordered( - ">", llvm_lhs, llvm_rhs, "gttmp" - ) - result = self._llvm.ir_builder.uitofp( - cmp_result, self._llvm.INT32_TYPE, "booltmp" - ) + if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type): + cmp_result = self._llvm.ir_builder.fcmp_ordered( + ">", llvm_lhs, llvm_rhs, "gttmp" + ) + result = self._llvm.ir_builder.uitofp( + cmp_result, self._llvm.FLOAT_TYPE, "booltmp" + ) + else: + # be careful we havn't handled all the conditions + cmp_result = self._llvm.ir_builder.cmp_unordered( + ">", llvm_lhs, llvm_rhs, "gttmp" + ) + result = self._llvm.ir_builder.uitofp( + cmp_result, self._llvm.INT32_TYPE, "booltmp" + ) self.result_stack.append(result) return + elif expr.op_code == "/": # Check the datatype to decide between floating-point and integer # division @@ -604,6 +672,12 @@ def visit(self, expr: astx.LiteralInt32) -> None: result = ir.Constant(self._llvm.INT32_TYPE, expr.value) self.result_stack.append(result) + @dispatch # type: ignore[no-redef] + def visit(self, expr: astx.LiteralFloat32) -> None: + """Translate ASTx LiteralFloat32 to LLVM-IR.""" + result = ir.Constant(self._llvm.FLOAT_TYPE, expr.value) + self.result_stack.append(result) + @dispatch # type: ignore[no-redef] def visit(self, expr: astx.FunctionCall) -> None: """Translate Function FunctionCall.""" @@ -658,13 +732,28 @@ def visit(self, expr: astx.Function) -> None: @dispatch # type: ignore[no-redef] def visit(self, expr: astx.FunctionPrototype) -> None: """Translate ASTx Function Prototype to LLVM-IR.""" - args_type = [self._llvm.INT32_TYPE] * len(expr.args.nodes) - # note: it should be dynamic - return_type = self._llvm.get_data_type("int32") + args_type = [] + for arg in expr.args.nodes: + if isinstance(arg.type_, astx.Float32): + args_type.append(self._llvm.FLOAT_TYPE) + elif isinstance(arg.type_, astx.Int32): + args_type.append(self._llvm.INT32_TYPE) + else: + raise Exception("Unsupported data type") + if isinstance(expr.return_type, astx.Float32): + return_type = self._llvm.FLOAT_TYPE + elif isinstance(expr.return_type, astx.Float64): + return_type = self._llvm.DOUBLE_TYPE + elif isinstance(expr.return_type, astx.Int32): + return_type = self._llvm.INT32_TYPE + elif isinstance(expr.return_type, astx.Int64): + return_type = self._llvm.INT64_TYPE + elif isinstance(expr.return_type, astx.Void): + return_type = self._llvm.VOID_TYPE + else: + raise Exception(f"Unsupported return type: {expr.return_type}") fn_type = ir.FunctionType(return_type, args_type, False) - fn = ir.Function(self._llvm.module, fn_type, expr.name) - # Set names for all arguments. for idx, arg in enumerate(fn.args): arg.name = expr.args[idx].name @@ -689,6 +778,7 @@ def visit(self, expr: astx.FunctionReturn) -> None: @dispatch # type: ignore[no-redef] def visit(self, expr: astx.InlineVariableDeclaration) -> None: """Translate an ASTx InlineVariableDeclaration expression.""" + type = expr.type_ if self.named_values.get(expr.name): raise Exception(f"Variable already declared: {expr.name}") @@ -698,11 +788,24 @@ def visit(self, expr: astx.InlineVariableDeclaration) -> None: init_val = self.result_stack.pop() if init_val is None: raise Exception("Initializer code generation failed.") - else: + # If not specified, use 0 as the initializer. + # note: it should create something according to the defined type + elif isinstance(type, astx.Int32): init_val = ir.Constant(self._llvm.get_data_type("int32"), 0) + elif isinstance(type, astx.Float32): + init_val = ir.Constant(self._llvm.get_data_type("float32"), 0.0) + else: + raise Exception("Unsupported type") - alloca = self.create_entry_block_alloca(expr.name, "int32") + if isinstance(type, astx.Int32): + alloca = self.create_entry_block_alloca(expr.name, "int32") + elif isinstance(type, astx.Float32): + alloca = self.create_entry_block_alloca(expr.name, "float32") + else: + raise Exception("Unsupported type") + # Store the initial value. self._llvm.ir_builder.store(init_val, alloca) + # Remember this binding. self.named_values[expr.name] = alloca self.result_stack.append(init_val) @@ -721,6 +824,7 @@ def visit(self, expr: astx.Variable) -> None: @dispatch # type: ignore[no-redef] def visit(self, expr: astx.VariableDeclaration) -> None: """Translate ASTx Variable to LLVM-IR.""" + type = expr.type_ if self.named_values.get(expr.name): raise Exception(f"Variable already declared: {expr.name}") @@ -730,15 +834,22 @@ def visit(self, expr: astx.VariableDeclaration) -> None: init_val = self.result_stack.pop() if init_val is None: raise Exception("Initializer code generation failed.") - else: - # If not specified, use 0 as the initializer. - # note: it should create something according to the defined type + # If not specified, use 0 as the initializer. + # note: it should create something according to the defined type + elif isinstance(type, astx.Int32): init_val = ir.Constant(self._llvm.get_data_type("int32"), 0) - + elif isinstance(type, astx.Float32): + init_val = ir.Constant(self._llvm.get_data_type("float32"), 0.0) + else: + raise Exception("Unsupported type") # Create an alloca in the entry block. # note: it should create the type according to the defined type - alloca = self.create_entry_block_alloca(expr.name, "int32") - + if isinstance(type, astx.Int32): + alloca = self.create_entry_block_alloca(expr.name, "int32") + elif isinstance(type, astx.Float32): + alloca = self.create_entry_block_alloca(expr.name, "float32") + else: + raise Exception("Unsupported type") # Store the initial value. self._llvm.ir_builder.store(init_val, alloca) diff --git a/tests/test_binary_op.py b/tests/test_binary_op.py index a010ff1..1e450d5 100644 --- a/tests/test_binary_op.py +++ b/tests/test_binary_op.py @@ -61,3 +61,54 @@ def test_binary_op_basic( module.block.append(main_fn) check_result(action, builder, module, expected_file) + + +@pytest.mark.parametrize( + "action,expected_file", + [ + ("build", ""), + ], +) +@pytest.mark.parametrize( + "builder_class", + [ + LLVMLiteIR, + ], +) +def test_binary_op_float( + action: str, expected_file: str, builder_class: Type[Builder] +) -> None: + """Test ASTx Module with float operations.""" + builder = builder_class() + module = builder.module() + + decl_x = astx.VariableDeclaration( + name="x", type_=astx.Float32(), value=astx.LiteralFloat32(1.5) + ) + decl_y = astx.VariableDeclaration( + name="y", type_=astx.Float32(), value=astx.LiteralFloat32(2.5) + ) + decl_z = astx.VariableDeclaration( + name="z", type_=astx.Float32(), value=astx.LiteralFloat32(4.0) + ) + + x = astx.Variable("x") + y = astx.Variable("y") + z = astx.Variable("z") + + lit_2 = astx.LiteralFloat32(1.0) + + basic_op = lit_2 + y - z * z / x + (y - x + z / x) + + main_proto = astx.FunctionPrototype( + name="main", args=astx.Arguments(), return_type=astx.Float32() + ) + main_block = astx.Block() + main_block.append(decl_x) + main_block.append(decl_y) + main_block.append(decl_z) + main_block.append(astx.FunctionReturn(basic_op)) + main_fn = astx.Function(prototype=main_proto, body=main_block) + + module.block.append(main_fn) + check_result(action, builder, module, expected_file)