diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 94c3b3e59af9..892c8e115c4e 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -16,12 +16,18 @@ Strided_Tensor }; +struct BinaryKernelOpInfo { + std::string op_name; + std::string metal_operator; + bool canVectorize; +}; + static char* BINARY_OP_TEMPLATE_TENSOR = R"METAL_BINARY( kernel void {3}_kernel(uint tid [[thread_position_in_grid]], const device {1} * input [[buffer(0)]], const device {2} * other [[buffer(1)]], device {0} * output [[buffer(2)]]) {{ - output[tid] = ({5})input[tid] {4} ({5})other[tid]; + output[tid] = ({5})input[tid] {4} ({6})other[tid]; }} )METAL_BINARY"; @@ -40,7 +46,7 @@ const device {1}* input = (const device {1}*)((const device uint8_t*)input_ + offsets.y); const device {2}* other = (const device {2}*)((const device uint8_t*)other_ + offsets.z); - *output = ({5})*input {4} ({5})*other; + *output = ({5})*input {4} ({6})*other; }} )METAL_BINARY"; @@ -49,7 +55,7 @@ const device {1} & input [[buffer(0)]], const device {2} * other [[buffer(1)]], device {0} * output [[buffer(2)]]) {{ - output[tid] = ({5})input {4} ({5})other[tid]; + output[tid] = ({5})input {4} ({6})other[tid]; }} )METAL_BINARY"; @@ -58,7 +64,7 @@ const device {1} * input [[buffer(0)]], const device {2} & other [[buffer(1)]], device {0} * output [[buffer(2)]]) {{ - output[tid] = ({5})input[tid] {4} ({5})other; + output[tid] = ({5})input[tid] {4} ({6})other; }} )METAL_BINARY"; @@ -67,7 +73,7 @@ const device {1} & input [[buffer(0)]], const device {2} & other [[buffer(1)]], device {0} & output [[buffer(2)]]) {{ - output = ({5})input {4} ({5})other; + output = ({5})input {4} ({6})other; }} )METAL_BINARY"; @@ -85,7 +91,7 @@ device {0}* output = (device {0}*)((device uint8_t*)output_ + offsets.x); const device {1}* input = (const device {1}*)((const device uint8_t*)input_ + offsets.y); - *output = ({5})*input {4} ({5})other; + *output = ({5})*input {4} ({6})other; }} )METAL_BINARY"; @@ -103,19 +109,28 @@ device {0}* output = (device {0}*)((device uint8_t*)output_ + offsets.x); const device {2}* other = (const device {2}*)((const device uint8_t*)other_ + offsets.z); - *output = ({5})input {4} ({5})*other; + *output = ({5})input {4} ({6})*other; }} )METAL_BINARY"; +static uint8_t getVectorType(int64_t input_numel) { + if (input_numel % 4 == 0) return 4; + if (input_numel % 3 == 0) return 3; + if (input_numel % 2 == 0) return 2; + + return 0; +} + static id compileBinaryOpsLibrary(id device, const std::string& t1, const std::string& t2, const std::string& t3, - const std::string& common_dtype, + const std::string& cast_dtype_input, + const std::string& cast_dtype_other, const std::string& op, const std::string& kernel_operator, BinaryKernelType binaryKernelType) { - auto key = op + t1 + t2 + t3 + common_dtype + std::to_string(int(binaryKernelType)); + auto key = op + t1 + t2 + t3 + cast_dtype_input + std::to_string(int(binaryKernelType)); static std::unordered_map> libMap; auto it = libMap.find(key); if (it != libMap.end()) { @@ -158,7 +173,7 @@ TORCH_CHECK(false, "Unknown binary template"); } - auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(str, t1, t2, t3, op, kernel_operator, common_dtype).c_str()] + auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(str, t1, t2, t3, op, kernel_operator, cast_dtype_input, cast_dtype_other).c_str()] options:options error:&error]; TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]); @@ -170,19 +185,20 @@ const std::string& t1, const std::string& t2, const std::string& t3, - const std::string& common_dtype, + const std::string& cast_dtype_input, + const std::string& cast_dtype_other, const std::string& fname, const std::string& op, const std::string& kernel_operator, BinaryKernelType binaryKernelType) { - auto key = t1 + t2 + t3 + common_dtype + fname; + auto key = t1 + t2 + t3 + cast_dtype_input + fname; static std::unordered_map> cplMap; auto it = cplMap.find(key); if (it != cplMap.end()) { return it->second; } NSError *error = nil; - auto library = compileBinaryOpsLibrary(device, t1, t2, t3, common_dtype, op, kernel_operator, binaryKernelType); + auto library = compileBinaryOpsLibrary(device, t1, t2, t3, cast_dtype_input, cast_dtype_other, op, kernel_operator, binaryKernelType); id func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]]; TORCH_CHECK(func != nil, "Can't get function ", fname); auto rc = [device newComputePipelineStateWithFunction:func error:&error]; @@ -192,11 +208,14 @@ } static -void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const std::string& op, const std::string& kernel_operator) { +void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const BinaryKernelOpInfo& binaryKernelOpInfo) { Tensor inputTensor; Tensor otherTensor; BinaryKernelType type; + const std::string& op = binaryKernelOpInfo.op_name; + const std::string& kernel_operator = binaryKernelOpInfo.metal_operator; + int scalar_pos = 0; bool all_scalar = false; const Tensor& outputTensor = iter.tensor(0); @@ -264,44 +283,80 @@ void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const std::string& op MPSStream* mpsStream = getCurrentMPSStream(); id device = MPSDevice::getInstance()->device(); + std::string outputStringType = getMetalScalarType(outputDataType); + std::string inputStringType = getMetalScalarType(inputDataType); + std::string otherStringType = getMetalScalarType(otherDataType); + std::string inputCastType = getMetalScalarType(common_dtype); + std::string otherCastType = getMetalScalarType(common_dtype); + id inputBuffer = mps::getMTLBufferStorage(inputTensor); id otherBuffer = mps::getMTLBufferStorage(otherTensor); id outputBuffer = mps::getMTLBufferStorage(outputTensor); uint32_t inputTensorStorage = inputTensor.storage_offset() * inputTensor.element_size(); uint32_t otherTensorStorage = otherTensor.storage_offset() * otherTensor.element_size(); mps::MPSScalar scalar; + uint32_t numThreads = iter.numel(); + if (all_scalar) { type = BinaryKernelType::Scalar; - if (iter.is_cpu_scalar(1)) { - scalar = mps::getMPSScalar(inputTensor.item(), inputTensor.scalar_type()); - inputBuffer = (id)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get(); - inputTensorStorage = 0; - } - if (iter.is_cpu_scalar(2)) { - scalar = mps::getMPSScalar(otherTensor.item(), otherTensor.scalar_type()); - otherBuffer = (id)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get(); - otherTensorStorage = 0; - } } else if (scalar_pos) { if (allContiguous) { - type = scalar_pos == 1 ? BinaryKernelType::LHS_Scalar : BinaryKernelType::RHS_Scalar; - } else { - type = scalar_pos == 1 ? BinaryKernelType::Strided_LHS_Scalar : BinaryKernelType::Strided_RHS_Scalar; - } - - if (iter.is_cpu_scalar(scalar_pos)) { + uint8_t vecType = 0; + string vecStringType = ""; if (scalar_pos == 1) { - scalar = mps::getMPSScalar(inputTensor.item(), inputTensor.scalar_type()); - inputBuffer = (id)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get(); - inputTensorStorage = 0; + type = BinaryKernelType::LHS_Scalar; + if (binaryKernelOpInfo.canVectorize) { + vecType = getVectorType(otherTensor.numel()); + vecStringType = vecType >= 2 ? std::to_string(vecType) : ""; + otherStringType += vecStringType; + otherCastType += vecStringType; + } } else { - scalar = mps::getMPSScalar(otherTensor.item(), otherTensor.scalar_type()); - otherBuffer = (id)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get(); - otherTensorStorage = 0; + type = BinaryKernelType::RHS_Scalar; + if (binaryKernelOpInfo.canVectorize) { + vecType = getVectorType(inputTensor.numel()); + vecStringType = vecType >= 2 ? std::to_string(vecType) : ""; + inputStringType += vecStringType; + inputCastType += vecStringType; + } } + if (vecType >= 2) { + numThreads /= vecType; + outputStringType += vecStringType; + } + } else { + type = scalar_pos == 1 ? BinaryKernelType::Strided_LHS_Scalar : BinaryKernelType::Strided_RHS_Scalar; } } else { - type = allContiguous ? BinaryKernelType::Tensor : BinaryKernelType::Strided_Tensor; + if (allContiguous) { + type = BinaryKernelType::Tensor; + if (binaryKernelOpInfo.canVectorize) { + uint8_t inputVecType = getVectorType(inputTensor.numel()); + uint8_t otherVecType = getVectorType(otherTensor.numel()); + if (inputVecType >= 2 && inputVecType == otherVecType) { + std::string vecType = std::to_string(inputVecType); + inputStringType += vecType; + inputCastType += vecType; + otherStringType += vecType; + otherCastType += vecType; + outputStringType += vecType; + numThreads /= inputVecType; + } + } + } else { + type = BinaryKernelType::Strided_Tensor; + } + } + + if (iter.is_cpu_scalar(1)) { + scalar = mps::getMPSScalar(inputTensor.item(), inputTensor.scalar_type()); + inputBuffer = (id)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get(); + inputTensorStorage = 0; + } + if (iter.is_cpu_scalar(2)) { + scalar = mps::getMPSScalar(otherTensor.item(), otherTensor.scalar_type()); + otherBuffer = (id)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get(); + otherTensorStorage = 0; } const uint32_t nDim = iter.ndim(); @@ -309,7 +364,6 @@ void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const std::string& op dispatch_sync(mpsStream->queue(), ^(){ @autoreleasepool { - uint32_t numThreads = iter.numel(); id computeEncoder = mpsStream->commandEncoder(); MTLSize gridSize = MTLSizeMake(numThreads, 1, 1); const IntArrayRef& iterShape = iter.shape(); @@ -347,14 +401,15 @@ void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const std::string& op } id binaryPSO = mps::getBinaryPSO(device, - getMetalScalarType(outputDataType), - getMetalScalarType(inputDataType), - getMetalScalarType(otherDataType), - getMetalScalarType(common_dtype), - kernel, - op, - kernel_operator, - type); + outputStringType, + inputStringType, + otherStringType, + inputCastType, + otherCastType, + kernel, + op, + kernel_operator, + type); getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {inputTensor, otherTensor, outputTensor}); [computeEncoder setComputePipelineState:binaryPSO]; [computeEncoder setBuffer:inputBuffer offset:inputTensorStorage atIndex:0]; @@ -380,38 +435,39 @@ void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const std::string& op } static -void dispatch_binary_kernel_mps(const Tensor& self, const Tensor& other, const Tensor& output, const std::string& op, const std::string& kernel_operator) { +void dispatch_binary_kernel_mps(const Tensor& self, const Tensor& other, const Tensor& output, const BinaryKernelOpInfo& binaryKernelOpInfo) { TensorIterator iter; + const std::string& op = binaryKernelOpInfo.op_name; if (op == "lt" || op == "le" || op == "gt" || op == "ge" || op == "ne" || op == "logical_or" || op == "logical_and" || op == "eq") { iter = TensorIterator::comparison_op(const_cast(output), self, other); } else { iter = TensorIterator::borrowing_binary_op(output, self, other); } - dispatch_binary_kernel_mps_(iter, op, kernel_operator); + dispatch_binary_kernel_mps_(iter, binaryKernelOpInfo); } -bool getBinaryKernelOperator(const std::string& op_name, std::pair& kernel_operator) { +bool getBinaryKernelOpInfo(const std::string& op_name, BinaryKernelOpInfo& binaryKernelOpInfo) { static bool macOS13_0_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS); if (!macOS13_0_plus) { return false; } - static std::unordered_map> opToKernelOperator = { - {"multiplication", {"mul", "*" }}, - {"div_out_mps:", {"div", "/" }}, - {"add_out_mps:", {"add", "+" }}, - {"sub_out_mps:", {"sub", "-" }}, + static std::unordered_map opToKernelOperator = { + {"multiplication", {"mul", "*", true}}, + {"div_out_mps:", {"div", "/", true}}, + {"add_out_mps:", {"add", "+", true}}, + {"sub_out_mps:", {"sub", "-", true}}, // comparison ops - {"lessThan", {"lt", "<" }}, - {"lessThanOrEqualTo", {"le", "<="}}, - {"greaterThan", {"gt", ">" }}, - {"greaterThanOrEqualTo", {"ge", ">="}}, - {"notEqual", {"ne", "!="}}, - {"logicalOR", {"logical_or", "||"}}, - {"logicalAND", {"logical_and", "&&"}}, - {"equal", {"eq", "=="}}, + {"lessThan", {"lt", "<" , true}}, + {"lessThanOrEqualTo", {"le", "<=", true}}, + {"greaterThan", {"gt", ">" , true}}, + {"greaterThanOrEqualTo", {"ge", ">=", true}}, + {"notEqual", {"ne", "!=", true}}, + {"logicalOR", {"logical_or", "||", false}}, + {"logicalAND", {"logical_and", "&&", false}}, + {"equal", {"eq", "==", true}}, }; auto it = opToKernelOperator.find(op_name); @@ -419,7 +475,7 @@ bool getBinaryKernelOperator(const std::string& op_name, std::pairsecond; + binaryKernelOpInfo = it->second; return true; } @@ -430,8 +486,9 @@ bool dispatchNativeBinaryKernel(const Tensor& self, const std::string& op_name) { if (alpha.toFloat() == 1.0) { std::pair kernel_operator; - if (getBinaryKernelOperator(op_name, kernel_operator)) { - dispatch_binary_kernel_mps(self, other, output, kernel_operator.first, kernel_operator.second); + BinaryKernelOpInfo binaryKernelOpInfo; + if (getBinaryKernelOpInfo(op_name, binaryKernelOpInfo)) { + dispatch_binary_kernel_mps(self, other, output, binaryKernelOpInfo); return true; } }