Skip to content

Add support for float32 datatype #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ ignore = [
"F811",
"PLR0911", # Too many return statements
"PLR0912", # Too many branches
"PLR0915", # Too many statements
]
select = [
"E", # pycodestyle
Expand Down
175 changes: 143 additions & 32 deletions src/irx/builders/llvmliteir.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ def safe_pop(lst: list[ir.Value | ir.Function]) -> ir.Value | ir.Function:
class VariablesLLVM:
"""Store all the LLVM variables that is used for the code generation."""

BOOLEAN_TYPE: ir.types.Type
FLOAT_TYPE: ir.types.Type
DOUBLE_TYPE: ir.types.Type
INT8_TYPE: ir.types.Type
INT32_TYPE: ir.types.Type
VOID_TYPE: ir.types.Type
STRING_TYPE: ir.types.Type
INT64_TYPE: ir.types.Type

context: ir.context.Context
module: ir.module.Module
Expand All @@ -61,7 +64,7 @@ def get_data_type(self, type_name: str) -> ir.types.Type:
-------
ir.Type: The LLVM data type.
"""
if type_name == "float":
if type_name == "float32":
return self.FLOAT_TYPE
elif type_name == "double":
return self.DOUBLE_TYPE
Expand All @@ -73,8 +76,14 @@ def get_data_type(self, type_name: str) -> ir.types.Type:
return self.INT8_TYPE
elif type_name == "void":
return self.VOID_TYPE
elif type_name == "bool":
return self.BOOLEAN_TYPE
elif type_name == "string":
return self.STRING_TYPE
elif type_name == "int64":
return self.INT64_TYPE

raise Exception("[EE]: type_name not valid.")
raise Exception(f"[EE]: type_name : {type_name} not valid. ")


class LLVMLiteIRVisitor(BuilderVisitor):
Expand Down Expand Up @@ -130,6 +139,7 @@ def initialize(self) -> None:
self._llvm.INT8_TYPE = ir.IntType(8)
self._llvm.INT32_TYPE = ir.IntType(32)
self._llvm.VOID_TYPE = ir.VoidType()
self._llvm.BOOLEAN_TYPE = ir.IntType(1)

def _add_builtins(self) -> None:
# The C++ tutorial adds putchard() simply by defining it in the host
Expand Down Expand Up @@ -220,10 +230,8 @@ def visit(self, expr: astx.BinaryOp) -> None:
# If you build LLVM with RTTI, this can be changed to a
# dynamic_cast for automatic error checking.
var_lhs = expr.lhs

if not isinstance(var_lhs, astx.VariableExprAST):
raise Exception("destination of '=' must be a variable")

# Codegen the rhs.
self.visit(expr.rhs)
llvm_rhs = safe_pop(self.result_stack)
Expand Down Expand Up @@ -254,43 +262,103 @@ def visit(self, expr: astx.BinaryOp) -> None:
if expr.op_code == "+":
# note: it should be according the datatype,
# e.g. for float it should be fadd
result = self._llvm.ir_builder.add(llvm_lhs, llvm_rhs, "addtmp")

# handle float datatype

if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type):
result = self._llvm.ir_builder.fadd(
llvm_lhs, llvm_rhs, "addtmp"
)

else:
# there's more conditions to be handled
result = self._llvm.ir_builder.add(
llvm_lhs, llvm_rhs, "addtmp"
)

self.result_stack.append(result)
return

elif expr.op_code == "-":
# note: it should be according the datatype,
# e.g. for float it should be fsub
result = self._llvm.ir_builder.sub(llvm_lhs, llvm_rhs, "subtmp")

# handle the float datatype
if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type):
result = self._llvm.ir_builder.fsub(
llvm_lhs, llvm_rhs, "subtmp"
)
else:
# note: be careful you should handle this as INT32
result = self._llvm.ir_builder.sub(
llvm_lhs, llvm_rhs, "subtmp"
)

self.result_stack.append(result)
return

elif expr.op_code == "*":
# note: it should be according the datatype,
# e.g. for float it should be fmul
result = self._llvm.ir_builder.mul(llvm_lhs, llvm_rhs, "multmp")

# handle float datatype
if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type):
result = self._llvm.ir_builder.fmul(
llvm_lhs, llvm_rhs, "multmp"
)
else:
# note: be careful you should handle this
result = self._llvm.ir_builder.mul(
llvm_lhs, llvm_rhs, "multmp"
)

self.result_stack.append(result)
return

elif expr.op_code == "<":
# note: it should be according the datatype,
# e.g. for float it should be fcmp
cmp_result = self._llvm.ir_builder.cmp_unordered(
"<", llvm_lhs, llvm_rhs, "lttmp"
)
result = self._llvm.ir_builder.uitofp(
cmp_result, self._llvm.INT32_TYPE, "booltmp"
)

# handle float type
if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type):
cmp_result = self._llvm.ir_builder.fcmp_ordered(
"<", llvm_lhs, llvm_rhs, "lttmp"
)
result = self._llvm.ir_builder.uitofp(
cmp_result, self._llvm.FLOAT_TYPE, "booltmp"
)
else:
# handle it depend on datatype
cmp_result = self._llvm.ir_builder.cmp_unordered(
"<", llvm_lhs, llvm_rhs, "lttmp"
)
result = self._llvm.ir_builder.uitofp(
cmp_result, self._llvm.INT32_TYPE, "booltmp"
)
self.result_stack.append(result)
return

elif expr.op_code == ">":
# note: it should be according the datatype,
# e.g. for float it should be fcmp
cmp_result = self._llvm.ir_builder.cmp_unordered(
">", llvm_lhs, llvm_rhs, "gttmp"
)
result = self._llvm.ir_builder.uitofp(
cmp_result, self._llvm.INT32_TYPE, "booltmp"
)
if self._llvm.FLOAT_TYPE in (llvm_lhs.type, llvm_rhs.type):
cmp_result = self._llvm.ir_builder.fcmp_ordered(
">", llvm_lhs, llvm_rhs, "gttmp"
)
result = self._llvm.ir_builder.uitofp(
cmp_result, self._llvm.FLOAT_TYPE, "booltmp"
)
else:
# be careful we havn't handled all the conditions
cmp_result = self._llvm.ir_builder.cmp_unordered(
">", llvm_lhs, llvm_rhs, "gttmp"
)
result = self._llvm.ir_builder.uitofp(
cmp_result, self._llvm.INT32_TYPE, "booltmp"
)
self.result_stack.append(result)
return

elif expr.op_code == "/":
# Check the datatype to decide between floating-point and integer
# division
Expand Down Expand Up @@ -604,6 +672,12 @@ def visit(self, expr: astx.LiteralInt32) -> None:
result = ir.Constant(self._llvm.INT32_TYPE, expr.value)
self.result_stack.append(result)

@dispatch # type: ignore[no-redef]
def visit(self, expr: astx.LiteralFloat32) -> None:
"""Translate ASTx LiteralFloat32 to LLVM-IR."""
result = ir.Constant(self._llvm.FLOAT_TYPE, expr.value)
self.result_stack.append(result)

@dispatch # type: ignore[no-redef]
def visit(self, expr: astx.FunctionCall) -> None:
"""Translate Function FunctionCall."""
Expand Down Expand Up @@ -658,13 +732,28 @@ def visit(self, expr: astx.Function) -> None:
@dispatch # type: ignore[no-redef]
def visit(self, expr: astx.FunctionPrototype) -> None:
"""Translate ASTx Function Prototype to LLVM-IR."""
args_type = [self._llvm.INT32_TYPE] * len(expr.args.nodes)
# note: it should be dynamic
return_type = self._llvm.get_data_type("int32")
args_type = []
for arg in expr.args.nodes:
if isinstance(arg.type_, astx.Float32):
args_type.append(self._llvm.FLOAT_TYPE)
elif isinstance(arg.type_, astx.Int32):
args_type.append(self._llvm.INT32_TYPE)
else:
raise Exception("Unsupported data type")
if isinstance(expr.return_type, astx.Float32):
return_type = self._llvm.FLOAT_TYPE
elif isinstance(expr.return_type, astx.Float64):
return_type = self._llvm.DOUBLE_TYPE
elif isinstance(expr.return_type, astx.Int32):
return_type = self._llvm.INT32_TYPE
elif isinstance(expr.return_type, astx.Int64):
return_type = self._llvm.INT64_TYPE
elif isinstance(expr.return_type, astx.Void):
return_type = self._llvm.VOID_TYPE
else:
raise Exception(f"Unsupported return type: {expr.return_type}")
fn_type = ir.FunctionType(return_type, args_type, False)

fn = ir.Function(self._llvm.module, fn_type, expr.name)

# Set names for all arguments.
for idx, arg in enumerate(fn.args):
arg.name = expr.args[idx].name
Expand All @@ -689,6 +778,7 @@ def visit(self, expr: astx.FunctionReturn) -> None:
@dispatch # type: ignore[no-redef]
def visit(self, expr: astx.InlineVariableDeclaration) -> None:
"""Translate an ASTx InlineVariableDeclaration expression."""
type = expr.type_
if self.named_values.get(expr.name):
raise Exception(f"Variable already declared: {expr.name}")

Expand All @@ -698,11 +788,24 @@ def visit(self, expr: astx.InlineVariableDeclaration) -> None:
init_val = self.result_stack.pop()
if init_val is None:
raise Exception("Initializer code generation failed.")
else:
# If not specified, use 0 as the initializer.
# note: it should create something according to the defined type
elif isinstance(type, astx.Int32):
init_val = ir.Constant(self._llvm.get_data_type("int32"), 0)
elif isinstance(type, astx.Float32):
init_val = ir.Constant(self._llvm.get_data_type("float32"), 0.0)
else:
raise Exception("Unsupported type")

alloca = self.create_entry_block_alloca(expr.name, "int32")
if isinstance(type, astx.Int32):
alloca = self.create_entry_block_alloca(expr.name, "int32")
elif isinstance(type, astx.Float32):
alloca = self.create_entry_block_alloca(expr.name, "float32")
else:
raise Exception("Unsupported type")
# Store the initial value.
self._llvm.ir_builder.store(init_val, alloca)
# Remember this binding.
self.named_values[expr.name] = alloca

self.result_stack.append(init_val)
Expand All @@ -721,6 +824,7 @@ def visit(self, expr: astx.Variable) -> None:
@dispatch # type: ignore[no-redef]
def visit(self, expr: astx.VariableDeclaration) -> None:
"""Translate ASTx Variable to LLVM-IR."""
type = expr.type_
if self.named_values.get(expr.name):
raise Exception(f"Variable already declared: {expr.name}")

Expand All @@ -730,15 +834,22 @@ def visit(self, expr: astx.VariableDeclaration) -> None:
init_val = self.result_stack.pop()
if init_val is None:
raise Exception("Initializer code generation failed.")
else:
# If not specified, use 0 as the initializer.
# note: it should create something according to the defined type
# If not specified, use 0 as the initializer.
# note: it should create something according to the defined type
elif isinstance(type, astx.Int32):
init_val = ir.Constant(self._llvm.get_data_type("int32"), 0)

elif isinstance(type, astx.Float32):
init_val = ir.Constant(self._llvm.get_data_type("float32"), 0.0)
else:
raise Exception("Unsupported type")
# Create an alloca in the entry block.
# note: it should create the type according to the defined type
alloca = self.create_entry_block_alloca(expr.name, "int32")

if isinstance(type, astx.Int32):
alloca = self.create_entry_block_alloca(expr.name, "int32")
elif isinstance(type, astx.Float32):
alloca = self.create_entry_block_alloca(expr.name, "float32")
else:
raise Exception("Unsupported type")
# Store the initial value.
self._llvm.ir_builder.store(init_val, alloca)

Expand Down
51 changes: 51 additions & 0 deletions tests/test_binary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,54 @@ def test_binary_op_basic(

module.block.append(main_fn)
check_result(action, builder, module, expected_file)


@pytest.mark.parametrize(
"action,expected_file",
[
("build", ""),
],
)
@pytest.mark.parametrize(
"builder_class",
[
LLVMLiteIR,
],
)
def test_binary_op_float(
action: str, expected_file: str, builder_class: Type[Builder]
) -> None:
"""Test ASTx Module with float operations."""
builder = builder_class()
module = builder.module()

decl_x = astx.VariableDeclaration(
name="x", type_=astx.Float32(), value=astx.LiteralFloat32(1.5)
)
decl_y = astx.VariableDeclaration(
name="y", type_=astx.Float32(), value=astx.LiteralFloat32(2.5)
)
decl_z = astx.VariableDeclaration(
name="z", type_=astx.Float32(), value=astx.LiteralFloat32(4.0)
)

x = astx.Variable("x")
y = astx.Variable("y")
z = astx.Variable("z")

lit_2 = astx.LiteralFloat32(1.0)

basic_op = lit_2 + y - z * z / x + (y - x + z / x)

main_proto = astx.FunctionPrototype(
name="main", args=astx.Arguments(), return_type=astx.Float32()
)
main_block = astx.Block()
main_block.append(decl_x)
main_block.append(decl_y)
main_block.append(decl_z)
main_block.append(astx.FunctionReturn(basic_op))
main_fn = astx.Function(prototype=main_proto, body=main_block)

module.block.append(main_fn)
check_result(action, builder, module, expected_file)
Loading