Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions yateto/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def __init__(self, kernel: Function, arguments: list[TinytcKernelArgument | Tiny
self.wrapper_args.append(f'{wrapper_type} {arg.name}')
self.wrapper_call_args.append(arg.name)
self.call_args.append(f'const_cast<{wrapper_type}>({arg.call_expr})')
if arg.temporary:
if not arg.constant:
self.wrapper_call_args.append(BatchedOperationsAux.NUM_ELEMENTS_NAME)
elif not arg.constant:
if not arg.temporary and not arg.constant:
offset_name = f'{BatchedOperationsAux.EXTRA_OFFSET_NAME}_{arg.name}'
self.wrapper_args.append(f'long {offset_name}')
self.wrapper_call_args.append(offset_name)
Expand All @@ -219,16 +219,23 @@ def definition(self):
"""
make_kernel += self.source
make_kernel += """)tinytc\";
auto source_ctx = tinytc::source_context{};
auto err_log = std::string{};
try {
source_ctx = tinytc::make_source_context();
auto program = tinytc::parse_string(source, source_ctx);
auto bundle = tinytc::make_kernel_bundle(queue.get_context(), queue.get_device(), std::move(program), 0, source_ctx);"""
make_kernel += f' auto kernel = tinytc::make_kernel(bundle, "{self.kernel_name}");\n'
auto ctx = tinytc::create_compiler_context();
tinytc::set_error_reporter(ctx.get(), [](char const *what, const tinytc_location_t *, void *log) {
*static_cast<std::string*>(log) += what;
}, &err_log);
auto program = tinytc::parse_string(source, ctx.get());
auto bundle = tinytc::create_kernel_bundle(queue.get_context(), queue.get_device(), program.get(), 0);"""
make_kernel += f' auto kernel = tinytc::create_kernel(bundle, "{self.kernel_name}");\n'
make_kernel += """ auto group_size = tinytc::get_group_size(kernel);
return {std::move(kernel), std::move(group_size)};
} catch (tinytc::status const& st) {
throw std::runtime_error(source_ctx.get_error_log());
if (!err_log.empty()) {
throw std::runtime_error(err_log);
} else {
throw std::runtime_error(tinytc::to_string(st));
}
}
}""";
make_kernel += f'(*static_cast<::sycl::queue*>({BatchedOperationsAux.STREAM_PTR_NAME}));\n'
Expand All @@ -237,7 +244,7 @@ def definition(self):
wrapper += make_kernel
wrapper += f' static_cast<::sycl::queue*>({BatchedOperationsAux.STREAM_PTR_NAME})->submit([&](::sycl::handler &h) {{\n';
wrapper += f' h.set_args({", ".join(self.wrapper_call_args)});\n'
wrapper += f' h.parallel_for(::sycl::nd_range{{tinytc::get_global_size({BatchedOperationsAux.NUM_ELEMENTS_NAME}, k.group_size), k.group_size}}, k.kernel);\n'
wrapper += f' h.parallel_for(::sycl::nd_range{{tinytc::get_global_size({{1,1,static_cast<std::size_t>({BatchedOperationsAux.NUM_ELEMENTS_NAME})}}, k.group_size), k.group_size}}, k.kernel);\n'
wrapper += ' });\n'
wrapper += '}\n\n'

Expand All @@ -249,13 +256,13 @@ def call(self):
def prototype(self):
return f'void {self.name}({", ".join(self.wrapper_args)});'

def makeMemrefType(scalarTy, memoryLayout, needsBatchMode: bool):
def makeMemrefType(scalarTy, memoryLayout, needsBatchMode: bool, local: bool=False):
shape = tuple(r.size() for r in memoryLayout.bbox())
stride = memoryLayout.stride()
if needsBatchMode:
shape = shape + (DYNAMIC, )
stride = stride + (memoryLayout.requiredReals(), )
return MemrefType(scalarTy, shape, stride)
return MemrefType(scalarTy, shape, stride, local)

def makeBatchType(scalarTy, memoryLayout, isComputeConstant: bool, isTemporary: bool):
if isComputeConstant:
Expand Down
7 changes: 5 additions & 2 deletions yateto/codegen/copyscaleadd/tinytc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..cache import TinytcWriter
from ..tiny_tensor_language import *

import hashlib

class CopyScaleAddTinytc:

Expand All @@ -21,7 +22,6 @@ def generate(self, cpp, routineCache):

# Order can be 1 or 2
def MakeLoopOverAxpby(d, order, transpose, A, B):
beta = FloatImmValue(self._ty, d.beta)
A_offset_list = [None] * len(d.term.indices)
A_size_list = [None] * len(d.term.indices)
B_offset_list = [None] * len(d.result.indices)
Expand Down Expand Up @@ -49,6 +49,7 @@ def MakeLoopOverAxpby(d, order, transpose, A, B):
csa_bb = RegionBuilder()
a = csa_bb.add(SubviewInst(A, A_offset_list, A_size_list))
b = csa_bb.add(SubviewInst(B, B_offset_list, B_size_list))
beta = csa_bb.add(ConstantInst(FloatImmValue(self._ty, d.beta)))
csa_bb.add(AxpbyInst(trans, alpha, a, beta, b))
csa_region = csa_bb.get_product()

Expand All @@ -70,7 +71,7 @@ def MakeLoopOverAxpby(d, order, transpose, A, B):
makeBatchType(self._ty, d.result.memoryLayout,
d.result.is_compute_constant, d.result.is_temporary),
'B')
kernel = Function('copyscaleadd', [alpha, Abatch, Bbatch], None)
kernel = Function(None, [alpha, Abatch, Bbatch], None)

bb = RegionBuilder()
gid = bb.add(GroupIdInst())
Expand All @@ -93,6 +94,8 @@ def MakeLoopOverAxpby(d, order, transpose, A, B):

kernel.body = bb.get_product()
AssignIdentifiers().visit(kernel)
hash_ = hashlib.sha256(Dump().visit(kernel.body).encode()).hexdigest()
kernel.name = f'copyscaleadd_{hash_}'

args = [
TinytcScalarKernelArgument('alpha', str(d.alpha)),
Expand Down
10 changes: 6 additions & 4 deletions yateto/codegen/fused_gemms/tinytc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def addVal(var, node):
if res.is_temporary:
res_val = bb.add(
AllocaInst(
makeMemrefType(self._ty, res.memoryLayout(), False)))
makeMemrefType(self._ty, res.memoryLayout(), False, True)))
vals[res] = res_val
else:
modified.add(res)
Expand Down Expand Up @@ -87,6 +87,7 @@ def offsetSizeLists(ml, range0, range1):
return ([IntImmValue(IntegerType.index, o) for o in offsets],
[IntImmValue(IntegerType.index, s) for s in sizes])

alpha = bb.add(ConstantInst(FloatImmValue(self._ty, scalar)))
op1_sub = bb.add(
SubviewInst(
op1_val,
Expand All @@ -95,20 +96,21 @@ def offsetSizeLists(ml, range0, range1):
SubviewInst(
op2_val,
*offsetSizeLists(node.rightTerm().memoryLayout(), k, n)))
beta = bb.add(ConstantInst(FloatImmValue(self._ty, 1.0 if add else 0.0)))
res_sub = bb.add(
SubviewInst(res_val,
*offsetSizeLists(node.memoryLayout(), m, n)))

trans = lambda t: Transpose.t if t else Transpose.n
alpha = FloatImmValue(self._ty, scalar)
beta = FloatImmValue(self._ty, 1.0 if add else 0.0)
bb.add(
GemmInst(trans(node.transA()), trans(node.transB()), alpha,
op1_sub, op2_sub, beta, res_sub))

flops += 2 * m.size() * n.size() * k.size()

kernel = Function('fused_gemm', args.values(), bb.get_product())
ast = bb.get_product()
hash_ = hashlib.sha256(Dump().visit(ast).encode()).hexdigest()
kernel = Function(f'fused_gemm_{hash_}', args.values(), ast)
AssignIdentifiers().visit(kernel)

wrapper_args = []
Expand Down
3 changes: 2 additions & 1 deletion yateto/codegen/gemm/gemmgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ def load_inst(op, batch, gid):
A = bb.add(load_inst(opA, Abatch, gid))
B = bb.add(load_inst(opB, Bbatch, gid))
C = bb.add(load_inst(opC, Cbatch, gid))
bb.add(GemmInst(tA, tB, alpha, A, B, FloatImmValue(scalar_ty, beta), C))
beta = bb.add(ConstantInst(FloatImmValue(scalar_ty, beta)))
bb.add(GemmInst(tA, tB, alpha, A, B, beta, C))
kernel.body = bb.get_product()
AssignIdentifiers().visit(kernel)

Expand Down
47 changes: 32 additions & 15 deletions yateto/codegen/tiny_tensor_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ def __init__(self, ty: Enum):

class MemrefType(DataType):

def __init__(self, ty: ScalarType, shape: tuple[int], stride: tuple[int]):
def __init__(self, ty: ScalarType, shape: tuple[int], stride: tuple[int], local: bool=False):
self.ty = ty
self.shape = shape
self.stride = stride
self.local = local

def order(self):
return len(self.shape)
Expand Down Expand Up @@ -175,6 +176,16 @@ def value(self):
return self.result


class ConstantInst(Inst):

def __init__(self, a: Value):
self.a = a
self.result = LocalValue(a.type())

def value(self):
return self.result


class GemmInst(Inst):

def __init__(self,
Expand Down Expand Up @@ -273,7 +284,7 @@ def __init__(self, operand: Value, offset_list: list[Value],
self.operand = operand
self.offset_list = offset_list
self.size_list = size_list
self.result = LocalValue(MemrefType(operand.type().ty, shape, stride))
self.result = LocalValue(MemrefType(operand.type().ty, shape, stride, operand.type().local))

def value(self):
return self.result
Expand Down Expand Up @@ -333,6 +344,10 @@ def visit_ArithInst(self, node):
self.visit(node.a)
self.visit(node.b)

def visit_ConstantInst(self, node):
self.visit(node.result)
self.visit(node.a)

def visit_LocalValue(self, node):
self.visit(node.type())

Expand Down Expand Up @@ -413,10 +428,11 @@ def visit_MemrefType(self, node):
for s in node.shape:
shape_str += f'x{format_mode(s)}'
stride_str = ','.join(format_mode(s) for s in node.stride)
return f'memref<{self.visit(node.ty)}{shape_str}, strided<{stride_str}>>'
local = ', local' if node.local else ''
return f'memref<{self.visit(node.ty)}{shape_str}, strided<{stride_str}>{local}>'

def visit_GroupType(self, node):
return f'group<{self.visit(node.ty)}, offset: {format_mode(node.offset)}>'
return f'group<{self.visit(node.ty)}x?, offset: {format_mode(node.offset)}>'

def visit_IntImmValue(self, node):
return f'{format_mode(node.value)}'
Expand All @@ -428,43 +444,44 @@ def visit_LocalValue(self, node):
return f'%{node.name}'

def visit_AllocaInst(self, node):
return f'{self.visit(node.value())} = alloca -> {self.visit(node.value().type())}'
return f'{self.visit(node.value())} = alloca : {self.visit(node.value().type())}'

def visit_AxpbyInst(self, node):
opcode = f'axpby.{node.trans.name}'
if node.atomic:
opcode += '.atomic'
args = (node.alpha, node.a, node.beta, node.b)
args_str = ', '.join(self.visit(arg) for arg in args)
type_str = ', '.join(self.visit(arg.type()) for arg in args)
return f'{opcode} {args_str} : {type_str}'
return f'{opcode} {args_str}'

def visit_ArithInst(self, node):
return f'{self.visit(node.value())} = arith.{node.operation_type.name} {self.visit(node.a)}, {self.visit(node.b)} : {self.visit(node.a.type())}'
return f'{self.visit(node.value())} = arith.{node.operation_type.name} {self.visit(node.a)}, {self.visit(node.b)} : {self.visit(node.value().type())}'

def visit_ConstantInst(self, node):
return f'{self.visit(node.value())} = constant {self.visit(node.a)} : {self.visit(node.value().type())}'

def visit_GemmInst(self, node):
opcode = f'gemm.{node.transA.name}.{node.transB.name}'
if node.atomic:
opcode += '.atomic'
args = (node.alpha, node.a, node.b, node.beta, node.c)
args_str = ', '.join(self.visit(arg) for arg in args)
type_str = ', '.join(self.visit(arg.type()) for arg in args)
return f'{opcode} {args_str} : {type_str}'
return f'{opcode} {args_str}'

def visit_GroupIdInst(self, node):
return f'{self.visit(node.value())} = group_id'
return f'{self.visit(node.value())} = group_id.x : index'

def visit_LoadInst(self, node):
indices = ','.join(self.visit(index) for index in node.index_list)
return f'{self.visit(node.value())} = load {self.visit(node.operand)}[{indices}] : {self.visit(node.operand.type())}'
return f'{self.visit(node.value())} = load {self.visit(node.operand)}[{indices}] : {self.visit(node.value().type())}'

def visit_ForInst(self, node):
loop_range = f'{self.visit(node.loop_var)}={self.visit(node.start)},{self.visit(node.stop)}'
return f'for {loop_range} : {self.visit(node.loop_var.type())} {self.visit(node.body)}'
return f'for {loop_range} {self.visit(node.body)}'

def visit_StoreInst(self, node):
indices = ','.join(self.visit(index) for index in node.index_list)
return f'store {self.visit(node.data)}, {self.visit(node.operand)}[{indices}] : {self.visit(node.operand.type())}'
return f'store {self.visit(node.data)}, {self.visit(node.operand)}[{indices}]'

def visit_SubviewInst(self, node):
slice_list = []
Expand All @@ -473,7 +490,7 @@ def visit_SubviewInst(self, node):
slice_list.append(f'{self.visit(offset)}:{self.visit(size)}')
else:
slice_list.append(f'{self.visit(offset)}')
return f'{self.visit(node.value())} = subview {self.visit(node.operand)}[{",".join(slice_list)}] : {self.visit(node.operand.type())}'
return f'{self.visit(node.value())} = subview {self.visit(node.operand)}[{",".join(slice_list)}] : {self.visit(node.value().type())}'

def visit_Region(self, node):
self.level += 1
Expand Down