diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 804b91705306e..713bc2f8daef0 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -104,6 +104,8 @@ #include #include #include +#include +#include #include #include #include @@ -133,6 +135,7 @@ #include #include #include +#include namespace at { @@ -316,6 +319,87 @@ TORCH_META_FUNC(baddbmm)(const Tensor& self, const Tensor& batch1, const Tensor& common_checks_baddbmm_bmm(*this, batch1, batch2, beta, alpha, false, *self_); } +/* + batch1 and batch2 are >= 2 always + */ +template +void checks_ndmm(Meta& meta, const Tensor& batch1, const Tensor& batch2) { + const auto batch1_sizes = batch1.sizes().vec(); + const auto batch2_sizes = batch2.sizes().vec(); + + const auto batch1_dim_len = batch1_sizes.size(); + const auto batch2_dim_len = batch2_sizes.size(); + + const bool is_batch1_vector = (batch1_dim_len == 1); + const bool is_batch2_vector = (batch2_dim_len == 1); + + // Make sure inner dimension matches + int64_t batch1_innerdim = batch1_sizes.end()[-1]; + int64_t batch2_innerdim = is_batch2_vector? batch2_sizes.end()[-1] : batch2_sizes.end()[-2]; // If batch2 is a vector, it only has an inner dim + TORCH_CHECK(batch1_innerdim == batch2_innerdim, + "Expected inner dimension of both tensors to match, but got: ", batch1_innerdim, " and ", batch2_innerdim, "."); + + // Result row and col dims + int64_t res_rows = is_batch1_vector? 1: batch1_sizes.end()[-2]; // If batch1 is a vector, result will have 1 row + int64_t res_cols = is_batch2_vector? 1: batch2_sizes.end()[-1]; // If batch2 is a vector, result will have 1 col + + // The vector will store reverse dims + std::vector output_size; + + if(!is_batch2_vector) + output_size.push_back(res_cols); + if(!is_batch1_vector) + output_size.push_back(res_rows); + + // Result batch dims + // Batch dims are equal or batch dims should broadcast + if(batch1_dim_len >= 3 && batch2_dim_len >= 3) { + auto common_batch_dims = std::min(batch1_dim_len, batch2_dim_len) - 2; + for(int i = 0; i < common_batch_dims; ++i) { + int64_t batch1_curr_batchdim = batch1_sizes.end()[-2-i-1]; + int64_t batch2_curr_batchdim = batch2_sizes.end()[-2-i-1]; + TORCH_CHECK((batch2_curr_batchdim == batch1_curr_batchdim) || (batch1_curr_batchdim == 1) || (batch2_curr_batchdim == 1), + "Expected batch dimension of both tensors to match or be broadcastable, but got: ", batch1_curr_batchdim, " and ", batch2_curr_batchdim, "."); + output_size.push_back(std::max(batch1_curr_batchdim, batch2_curr_batchdim)); + } + // Append extra batch dims to result + auto extra_batch_tensor = (batch1_dim_len >= batch2_dim_len)? batch1_sizes : batch2_sizes; + for(int i = common_batch_dims+2; i < std::max(batch1_dim_len, batch2_dim_len); ++i) { + output_size.push_back(extra_batch_tensor.end()[-i-1]); + } + } + // Use batch dims of tensor1 + else if(batch1_dim_len >= 3 && batch2_dim_len < 3) { + for(int i = 0; i < batch1_dim_len - 2; ++i) { + output_size.push_back(batch1_sizes.end()[-2-i-1]); + } + } + // Use batch dims of tensor2 + else if(batch1_dim_len < 3 && batch2_dim_len >= 3) { + for(int i = 0; i < batch2_dim_len - 2; ++i) { + output_size.push_back(batch2_sizes.end()[-2-i-1]); + } + } + std::reverse(output_size.begin(), output_size.end()); + auto& result = meta.maybe_get_output(0); + // 'set_output' does not resize for in-place calls + meta.set_output_raw_strided(0, output_size, {}, batch2.options()); + const auto result_sizes = result.sizes(); + // Error is raised if called from in-place overload with incorrect shape + TORCH_CHECK(result_sizes == output_size, + "Expected an output tensor with shape [", output_size, "] but got shape ", result_sizes); + std::vector outnames = {}; + outnames = namedinference::compute_bmm_outnames(result, batch1, batch2); + namedinference::propagate_names_if_nonempty( + result, + outnames + ); +} + +TORCH_META_FUNC(ndmm)(const Tensor& self, const Tensor& mat2) { + checks_ndmm(*this, self, mat2); +} + } // namespace meta namespace native { @@ -1885,6 +1969,10 @@ Tensor _matmul_impl( const bool has_out = out.defined(); + if (tensor1.is_mps() && tensor2.is_mps()) { + return has_out? at::ndmm_out(out, tensor1, tensor2) : tensor1.ndmm(tensor2); + } + if (dim_tensor1 == 1 && dim_tensor2 == 1) { return has_out ? at::dot_out(out, tensor1, tensor2) : tensor1.dot(tensor2); } else if (dim_tensor1 == 2 && dim_tensor2 == 1) { diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index b2f873636c0ea..4b92d8ad5895b 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -502,6 +502,97 @@ Tensor addr_mps(const Tensor& self, return output; } +Tensor& ndmm_out_mps_impl( + const Tensor & batch1, + const Tensor & batch2, + Tensor & result) { + + using namespace mps; + if (batch1.numel() == 0 || batch2.numel() == 0) { + return result; + } + MPSStream* stream = getCurrentMPSStream(); + + struct CachedGraph : public mps::MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *batch1Tensor_ = nil; + MPSGraphTensor *batch2Tensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; + }; + + mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + + bool expandBatch2 = (batch1.dim() == 4) && (batch2.dim() == 3); + bool expandBatch1 = (batch2.dim() == 4) && (batch1.dim() == 3); + + @autoreleasepool { + string key = "ndmm_out_mps_impl" + getTensorsStringKey({batch1, batch2}); + + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + + mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { + CachedGraph *newCachedGraph = nil; + + @autoreleasepool{ + MPSGraph *mpsGraph = mps::make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor *batch1Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch1); + MPSGraphTensor *batch2Tensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, batch2); + + // If left or right tensors are 1D vector, perform expand to make them 2D + MPSGraphTensor *batch1InputTensor = batch1Tensor; + if(batch1.dim() == 1 || expandBatch1) { + batch1InputTensor = [mpsGraph expandDimsOfTensor:batch1Tensor axis:0 name:nil]; + } + MPSGraphTensor *batch2InputTensor = batch2Tensor; + if(batch2.dim() == 1) { + batch2InputTensor = [mpsGraph expandDimsOfTensor:batch2Tensor axis:1 name:nil]; + } + if(expandBatch2) { + batch2InputTensor = [mpsGraph expandDimsOfTensor:batch2Tensor axis:0 name:nil]; + } + + MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1InputTensor + secondaryTensor:batch2InputTensor + name:@"MM/(batch1@batch2)"]; + + if(batch1.dim() == 1) { + productTensor = [mpsGraph squeezeTensor:productTensor axis:-2 name:nil]; + } + if(batch2.dim() == 1) { + productTensor = [mpsGraph squeezeTensor:productTensor axis:-1 name:nil]; + } + + newCachedGraph->batch1Tensor_ = batch1Tensor; + newCachedGraph->batch2Tensor_ = batch2Tensor; + newCachedGraph->outputTensor_ = productTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + Placeholder batch1Placeholder = Placeholder(cachedGraph->batch1Tensor_, batch1); + Placeholder batch2Placeholder = Placeholder(cachedGraph->batch2Tensor_, batch2); + Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); + + NSDictionary* feeds = @{ + batch1Placeholder.getMPSGraphTensor() : batch1Placeholder.getMPSGraphTensorData(), + batch2Placeholder.getMPSGraphTensor() : batch2Placeholder.getMPSGraphTensorData(), + }; + + NSDictionary* results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + + return result; +} Tensor& bmm_out_mps_impl( const Tensor & batch1, @@ -710,6 +801,220 @@ Tensor addr_mps(const Tensor& self, return result; } +bool ndmm_gradient_reduction(NSMutableArray * reduction_axes, MPSShape* input1_shape, MPSShape* input2_shape, MPSShape* grad_output_shape, bool is_grad_input1) { + + auto input1_dims = [input1_shape count]; + auto input2_dims = [input2_shape count]; + auto grad_output_dims = [grad_output_shape count]; + + auto input_tensor_matmul_dims = is_grad_input1? input2_dims: input1_dims; + auto output_tensor_matmul_dims = is_grad_input1? input1_dims: input2_dims; + + // Batch dimensions of the product + auto product_tensor_batch_shape = (input_tensor_matmul_dims > grad_output_dims)? (is_grad_input1? input2_shape: input1_shape): grad_output_shape; + // Batch dimensions of the final gradient + auto output_tensor_batch_shape = is_grad_input1? input1_shape: input2_shape; + // All batch dims will be present in product through bcast + auto product_batch_dims = std::max(input_tensor_matmul_dims - 2, grad_output_dims - 2); + // Number of batch dims in grad_input tensor + auto final_batch_dims = output_tensor_matmul_dims - 2; + + // If there are fewer batch dims in final result than in product, we must reduce them + if(product_batch_dims > final_batch_dims) { + auto num_dims_to_reduce = product_batch_dims - final_batch_dims; + for(int i = 0; i < num_dims_to_reduce; ++i) + [reduction_axes addObject:[NSNumber numberWithInteger:i]]; + } + // In the equal dims, check for bcast dims. They must be reduced + auto product_batch_start = product_batch_dims - std::min(product_batch_dims, final_batch_dims); + for(int i = 0; i < product_batch_dims; ++i) { + auto product_batch_iter = product_batch_start + i; + if(output_tensor_batch_shape[i].intValue == 1) + [reduction_axes addObject:[NSNumber numberWithInteger:product_batch_iter]]; + } + auto num_reduction_axes = [reduction_axes count]; + return (num_reduction_axes > 0); + +} + + +Tensor ndmm_backward_out_mps( + const Tensor & tensor1, + const Tensor & tensor2, + const Tensor & grad_output, + bool is_grad_input1) { + using namespace mps; + + TORCH_CHECK(tensor1.is_mps()); + TORCH_CHECK(tensor2.is_mps()); + TORCH_CHECK(grad_output.is_mps()); + + MPSShape* input1_final_shape = nil; + MPSShape* input2_final_shape = nil; + MPSShape* grad_output_shape = getMPSShape(grad_output); + MPSShape* grad_output_final_shape = nil; + NSMutableArray* grad_output_mutable = [grad_output_shape mutableCopy]; + MPSShape* grad_input_final_shape = nil; + + if(tensor1.dim() == 1 || tensor2.dim() == 1) { + NSUInteger indexToUpdate; + // Is input1 a vector, convert it to a row matrix, and expand grad_output + if(tensor1.dim() == 1) { + input1_final_shape = @[[NSNumber numberWithInteger:1], [NSNumber numberWithInteger:tensor1.size(0)]]; + indexToUpdate = [grad_output_mutable count] - 1; + } + + // Is input2 a vector, convert it to a column matrix, and expand grad_output + if(tensor2.dim() == 1) { + input2_final_shape = @[[NSNumber numberWithInteger:tensor2.size(0)], [NSNumber numberWithInteger:1]]; + indexToUpdate = [grad_output_mutable count]; + } + + [grad_output_mutable insertObject:[NSNumber numberWithInteger:1] atIndex:indexToUpdate]; + grad_output_final_shape = [NSArray arrayWithArray:grad_output_mutable]; + } + if(tensor1.dim() == 3 && tensor2.dim() == 4) { + input1_final_shape = @[[NSNumber numberWithInteger:1], [NSNumber numberWithInteger:tensor1.size(0)], [NSNumber numberWithInteger:tensor1.size(1)], [NSNumber numberWithInteger:tensor1.size(2)]]; + } + if(tensor2.dim() == 3 && tensor1.dim() == 4) { + input2_final_shape = @[[NSNumber numberWithInteger:1], [NSNumber numberWithInteger:tensor2.size(0)], [NSNumber numberWithInteger:tensor2.size(1)], [NSNumber numberWithInteger:tensor2.size(2)]]; + } + + NSMutableArray *reduction_axes = [[NSMutableArray alloc] init]; + MPSShape* tensor1_dims = input1_final_shape? input1_final_shape: getMPSShape(tensor1); + MPSShape* tensor2_dims = input2_final_shape? input2_final_shape: getMPSShape(tensor2); + MPSShape* grad_output_dims = grad_output_final_shape? grad_output_final_shape: getMPSShape(grad_output); + bool needs_reduction = ndmm_gradient_reduction(reduction_axes, tensor1_dims, tensor2_dims, grad_output_dims, is_grad_input1); + + MPSStream* stream = getCurrentMPSStream(); + + struct CachedGraph : public mps::MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *fwdInputTensor_ = nil; + MPSGraphTensor *gradOutputTensor_ = nil; + MPSGraphTensor *gradInputTensor_ = nil; + }; + + IntArrayRef grad_input_size = is_grad_input1? tensor1.sizes() : tensor2.sizes(); + Tensor grad_input = at::native::empty_mps(grad_input_size, + grad_output.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + grad_output.suggest_memory_format()); + MPSShape* grad_input_shape = getMPSShape(grad_input); + TORCH_CHECK(grad_input.is_mps()); + + mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance(); + + @autoreleasepool { + std::string key = "ndmm_backward_out_mps" + getTensorsStringKey({tensor1, tensor2, grad_output}) + std::to_string(is_grad_input1); + + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if(!cachedGraph) { + mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () { + CachedGraph *newCachedGraph = nil; + + /* + Case 1: No bcast in batch dim + Forward: A_pmk x B_pkn = C_pmn + dA_pmk = dC_pmn x B_pkn^T + dB_pkn = A_pmk^T x dC_pmn + + Case 2: With bcast in batch dim + Forward: A_pmk x B_kn = C_pmn + dA_pmk = dC_pmn x B_kn^T + dB_kn = ReduceOverP(A_pmk^T x dC_pmn) + */ + @autoreleasepool{ + MPSGraph *mpsGraph = mps::make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + MPSGraphTensor *fwdInputTensor = is_grad_input1? mps::mpsGraphRankedPlaceHolder(mpsGraph, tensor2): mps::mpsGraphRankedPlaceHolder(mpsGraph, tensor1); + MPSGraphTensor *gradOutputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_output); + MPSGraphTensor *gradInputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, grad_input); + MPSGraphTensor* fwdInputTensorExpanded = fwdInputTensor; + MPSGraphTensor* gradOutputTensorExpanded = gradOutputTensor; + + if(is_grad_input1) { + if(input2_final_shape) { + fwdInputTensorExpanded = [mpsGraph reshapeTensor:fwdInputTensorExpanded + withShape:input2_final_shape + name:nil]; + } + if(grad_output_final_shape) { + gradOutputTensorExpanded = [mpsGraph reshapeTensor:gradOutputTensorExpanded + withShape:grad_output_final_shape + name:nil]; + } + MPSGraphTensor* input2Transpose = [mpsGraph transposeTensor: fwdInputTensorExpanded + dimension: -1 + withDimension: -2 + name: nil]; + gradInputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:gradOutputTensorExpanded + secondaryTensor:input2Transpose + name:nil]; + } + else { + + if(input1_final_shape) { + fwdInputTensorExpanded = [mpsGraph reshapeTensor:fwdInputTensorExpanded + withShape:input1_final_shape + name:nil]; + } + if(grad_output_final_shape) { + gradOutputTensorExpanded = [mpsGraph reshapeTensor:gradOutputTensorExpanded + withShape:grad_output_final_shape + name:nil]; + } + MPSGraphTensor* input1Transpose = [mpsGraph transposeTensor: fwdInputTensorExpanded + dimension: -1 + withDimension: -2 + name: nil]; + gradInputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:input1Transpose + secondaryTensor:gradOutputTensorExpanded + name:nil]; + } + if(needs_reduction) { + gradInputTensor = [mpsGraph reductionSumWithTensor: gradInputTensor + axes: reduction_axes + name: nil]; + } + + gradInputTensor = [mpsGraph reshapeTensor:gradInputTensor + withShape:grad_input_shape + name:nil]; + + newCachedGraph->fwdInputTensor_ = fwdInputTensor; + newCachedGraph->gradOutputTensor_ = gradOutputTensor; + newCachedGraph->gradInputTensor_ = gradInputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + Tensor fwd_input = is_grad_input1? tensor2 : tensor1; + Placeholder fwdInputPlaceholder = Placeholder(cachedGraph->fwdInputTensor_, fwd_input); + Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); + Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); + + NSDictionary* feeds = @{ + fwdInputPlaceholder.getMPSGraphTensor() : fwdInputPlaceholder.getMPSGraphTensorData(), + gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData() + }; + + NSDictionary* results = @{ + gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData(), + }; + + mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + } + + return grad_input; +} + + TORCH_IMPL_FUNC(mm_out_mps)(const Tensor& self, const Tensor& mat2, const Tensor& result) { mm_out_mps_impl(self, mat2, const_cast(result)); } @@ -726,6 +1031,10 @@ Tensor addr_mps(const Tensor& self, addbmm_or_baddbmm_out_mps_impl(self, batch1, batch2, beta, alpha, const_cast(result), BADDBMM_OP_TYPE); } +TORCH_IMPL_FUNC(ndmm_out_mps) (const Tensor & batch1, const Tensor & batch2, const Tensor & result) { + ndmm_out_mps_impl(batch1, batch2, const_cast(result)); +} + Tensor& addbmm_out_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, Tensor& result) { auto b_self = expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm_out"); @@ -742,6 +1051,17 @@ Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2 return addbmm_out_mps(self, batch1, batch2, beta, alpha, self); } +std::tuple ndmm_backward_mps( + const Tensor& input1, const Tensor& input2, + const Tensor& grad, std::array grad_mask) { + Tensor grad_input1, grad_input2; + if(grad_mask[0]) + grad_input1 = ndmm_backward_out_mps(input1, input2, grad, true); + if(grad_mask[1]) + grad_input2 = ndmm_backward_out_mps(input1, input2, grad, false); + return std::tuple(grad_input1, grad_input2); +} + Tensor& linalg_solve_triangular_mps_impl( const Tensor& A, const Tensor& B, bool upper, bool transpose, bool left, bool unitriangular, Tensor& out) { using namespace mps; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b1b06cba26f9b..aa1c09a62f152 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1260,6 +1260,21 @@ SparseCUDA: bmm_out_sparse_cuda SparseCsrCUDA: bmm_out_sparse_csr_cuda +- func: ndmm(Tensor self, Tensor mat2) -> Tensor + structured_delegate: ndmm.out + variants: function, method + tags: core + +- func: ndmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + dispatch: + MPS: ndmm_out_mps + +- func: ndmm_backward(Tensor self, Tensor mat2, Tensor grad, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + MPS: ndmm_backward_mps + - func: broadcast_tensors(Tensor[] tensors) -> Tensor[] device_check: NoCheck device_guard: False diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index d9f85cba366b6..6506f6da79760 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -338,6 +338,9 @@ batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad), alpha.conj()) result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p), alpha) + maybe_multiply(batch1_p.bmm(batch2_t), alpha) +- name: ndmm(Tensor self, Tensor mat2) -> Tensor + self, mat2: ndmm_backward(self, mat2, grad, grad_input_mask) + - name: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor self: zeros_like(grad) result: auto_element_wise