Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@
#include <ATen/ops/max.h>
#include <ATen/ops/mm.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/ndmm.h>
#include <ATen/ops/ndmm_native.h>
#include <ATen/ops/movedim.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/mv.h>
Expand Down Expand Up @@ -133,6 +135,7 @@
#include <string>
#include <tuple>
#include <utility>
#include <iostream>

namespace at {

Expand Down Expand Up @@ -316,6 +319,87 @@ TORCH_META_FUNC(baddbmm)(const Tensor& self, const Tensor& batch1, const Tensor&
common_checks_baddbmm_bmm(*this, batch1, batch2, beta, alpha, false, *self_);
}

/*
batch1 and batch2 are >= 2 always
*/
template <typename Meta>
void checks_ndmm(Meta& meta, const Tensor& batch1, const Tensor& batch2) {
const auto batch1_sizes = batch1.sizes().vec();
const auto batch2_sizes = batch2.sizes().vec();

const auto batch1_dim_len = batch1_sizes.size();
const auto batch2_dim_len = batch2_sizes.size();

const bool is_batch1_vector = (batch1_dim_len == 1);
const bool is_batch2_vector = (batch2_dim_len == 1);

// Make sure inner dimension matches
int64_t batch1_innerdim = batch1_sizes.end()[-1];
int64_t batch2_innerdim = is_batch2_vector? batch2_sizes.end()[-1] : batch2_sizes.end()[-2]; // If batch2 is a vector, it only has an inner dim
TORCH_CHECK(batch1_innerdim == batch2_innerdim,
"Expected inner dimension of both tensors to match, but got: ", batch1_innerdim, " and ", batch2_innerdim, ".");

// Result row and col dims
int64_t res_rows = is_batch1_vector? 1: batch1_sizes.end()[-2]; // If batch1 is a vector, result will have 1 row
int64_t res_cols = is_batch2_vector? 1: batch2_sizes.end()[-1]; // If batch2 is a vector, result will have 1 col

// The vector will store reverse dims
std::vector<int64_t> output_size;

if(!is_batch2_vector)
output_size.push_back(res_cols);
if(!is_batch1_vector)
output_size.push_back(res_rows);

// Result batch dims
// Batch dims are equal or batch dims should broadcast
if(batch1_dim_len >= 3 && batch2_dim_len >= 3) {
auto common_batch_dims = std::min(batch1_dim_len, batch2_dim_len) - 2;
for(int i = 0; i < common_batch_dims; ++i) {
int64_t batch1_curr_batchdim = batch1_sizes.end()[-2-i-1];
int64_t batch2_curr_batchdim = batch2_sizes.end()[-2-i-1];
TORCH_CHECK((batch2_curr_batchdim == batch1_curr_batchdim) || (batch1_curr_batchdim == 1) || (batch2_curr_batchdim == 1),
"Expected batch dimension of both tensors to match or be broadcastable, but got: ", batch1_curr_batchdim, " and ", batch2_curr_batchdim, ".");
output_size.push_back(std::max(batch1_curr_batchdim, batch2_curr_batchdim));
}
// Append extra batch dims to result
auto extra_batch_tensor = (batch1_dim_len >= batch2_dim_len)? batch1_sizes : batch2_sizes;
for(int i = common_batch_dims+2; i < std::max(batch1_dim_len, batch2_dim_len); ++i) {
output_size.push_back(extra_batch_tensor.end()[-i-1]);
}
}
// Use batch dims of tensor1
else if(batch1_dim_len >= 3 && batch2_dim_len < 3) {
for(int i = 0; i < batch1_dim_len - 2; ++i) {
output_size.push_back(batch1_sizes.end()[-2-i-1]);
}
}
// Use batch dims of tensor2
else if(batch1_dim_len < 3 && batch2_dim_len >= 3) {
for(int i = 0; i < batch2_dim_len - 2; ++i) {
output_size.push_back(batch2_sizes.end()[-2-i-1]);
}
}
std::reverse(output_size.begin(), output_size.end());
auto& result = meta.maybe_get_output(0);
// 'set_output' does not resize for in-place calls
meta.set_output_raw_strided(0, output_size, {}, batch2.options());
const auto result_sizes = result.sizes();
// Error is raised if called from in-place overload with incorrect shape
TORCH_CHECK(result_sizes == output_size,
"Expected an output tensor with shape [", output_size, "] but got shape ", result_sizes);
std::vector<Dimname> outnames = {};
outnames = namedinference::compute_bmm_outnames(result, batch1, batch2);
namedinference::propagate_names_if_nonempty(
result,
outnames
);
}

TORCH_META_FUNC(ndmm)(const Tensor& self, const Tensor& mat2) {
checks_ndmm(*this, self, mat2);
}

} // namespace meta
namespace native {

Expand Down Expand Up @@ -1885,6 +1969,10 @@ Tensor _matmul_impl(

const bool has_out = out.defined();

if (tensor1.is_mps() && tensor2.is_mps()) {
return has_out? at::ndmm_out(out, tensor1, tensor2) : tensor1.ndmm(tensor2);
}

if (dim_tensor1 == 1 && dim_tensor2 == 1) {
return has_out ? at::dot_out(out, tensor1, tensor2) : tensor1.dot(tensor2);
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
Expand Down
Loading