Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/ATen/native/xpu/Nonzero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
#include <ATen/xpu/EmptyTensor.h>

#include <ATen/native/xpu/sycl/NonzeroKernel.h>
#include <ATen/native/xpu/sycl/OffsetCalculator.h>
#include <comm/TensorInfo.h>

namespace at {
namespace native {
Tensor& nonzero_out_xpu(const Tensor& self, Tensor& out) {
TORCH_CHECK(
self.numel() < std::numeric_limits<int>::max(),
"nonzero is not supported for tensors with more than INT_MAX elements, \
See https://github.com/pytorch/pytorch/issues/51871");
See https://github.com/pytorch/pytorch/issues/51871");
TORCH_CHECK(
out.dtype() == at::kLong,
"Expected object of scalar type ",
Expand All @@ -24,11 +24,15 @@ Tensor& nonzero_out_xpu(const Tensor& self, Tensor& out) {
" and self on ",
self.device());
TORCH_CHECK(
self.dim() <= MAX_DIMS,
self.dim() <= XPU_MAX_TENSORINFO_DIMS,
"nonzero is not supported for tensor with more than ",
MAX_DIMS,
XPU_MAX_TENSORINFO_DIMS,
" dimensions");

if (self.numel() == 0) {
out = at::detail::empty_xpu({0, self.dim()}, out.options());
return out;
}
xpu::nonzero_kernel(self, out);
return out;
}
Expand All @@ -39,4 +43,4 @@ Tensor nonzero_xpu(const Tensor& self) {
return out;
}
} // namespace native
} // namespace at
} // namespace at
129 changes: 60 additions & 69 deletions src/ATen/native/xpu/sycl/NonzeroKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@ struct FlattenIdxtoRealIdxKernelFunctor {
if (global_id < N_) {
auto dim = global_id / num_nonzeros_;
auto index = global_id % num_nonzeros_;
tensor_begin_[global_id] =
out_begin_[global_id] =
idx_flat_begin_[index] / divisor_[dim] % sizes_[dim];
}
}
FlattenIdxtoRealIdxKernelFunctor(
int64_t N,
const int64_t num_dim,
const int64_t num_nonzeros,
int64_t* tensor_begin,
int64_t* out_begin,
int64_t* idx_flat_begin,
int64_t* divisor,
int64_t* sizes)
: N_(N),
num_dim_(num_dim),
num_nonzeros_(num_nonzeros),
tensor_begin_(tensor_begin),
out_begin_(out_begin),
idx_flat_begin_(idx_flat_begin) {
for (auto dim = num_dim - 1; dim >= 0; dim--) {
sizes_[dim] = sizes[dim];
Expand All @@ -44,7 +44,7 @@ struct FlattenIdxtoRealIdxKernelFunctor {
int64_t N_;
const int64_t num_dim_;
const int64_t num_nonzeros_;
int64_t* tensor_begin_;
int64_t* out_begin_;
int64_t* idx_flat_begin_;
int64_t divisor_[XPU_MAX_TENSORINFO_DIMS];
int64_t sizes_[XPU_MAX_TENSORINFO_DIMS];
Expand Down Expand Up @@ -79,77 +79,68 @@ struct CopyIfFunc<bool> {
};

template <typename scalar_t>
void nonzero_template(const Tensor& self_, Tensor& tensor) {
void nonzero_template(const Tensor& self_, Tensor& out) {
Tensor self = self_.contiguous();

const int64_t num_dim = self.dim();
TORCH_CHECK(num_dim <= XPU_MAX_TENSORINFO_DIMS, "dim exceed max allowed dim");

int64_t N = self.numel();

if (N > 0) {
Tensor idx_flat = at::empty(
{N}, tensor.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));
Tensor range = at::empty(
{N}, tensor.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));

const scalar_t* self_begin = self.const_data_ptr<scalar_t>();
int64_t* idx_flat_begin = idx_flat.data_ptr<int64_t>();
int64_t* range_begin = nullptr;

CopyIfFunc<scalar_t> f(self_begin);
auto idx_flat_end =
pstl::copy_if<int64_t>(range_begin, range_begin + N, idx_flat_begin, f);

auto num_nonzeros = std::distance(idx_flat_begin, idx_flat_end);

bool need_to_copy = tensor.dim() == 2 &&
tensor.sizes()[0] == num_nonzeros && tensor.sizes()[1] == self_.dim() &&
!tensor.t().is_contiguous();
at::Tensor tensor_ = need_to_copy
? Tensor(at::detail::empty_xpu(
{self_.dim(), num_nonzeros}, tensor.options()))
: tensor.resize_({self_.dim(), num_nonzeros});

if (num_nonzeros > 0 && num_dim > 0) {
int64_t* tensor_begin = tensor_.data_ptr<int64_t>();

// preload sizes tensor for index calculation
int64_t sizes[XPU_MAX_TENSORINFO_DIMS];
int64_t divisor[XPU_MAX_TENSORINFO_DIMS];
sizes[num_dim - 1] = self.size(num_dim - 1);
divisor[num_dim - 1] = 1;
for (auto dim = num_dim - 2; dim >= 0; dim--) {
sizes[dim] = self.size(dim);
divisor[dim] = sizes[dim + 1] * divisor[dim + 1];
}

const int64_t N = num_nonzeros * num_dim;
// restore flatten idx to indices
FlattenIdxtoRealIdxKernelFunctor kfn(
N,
num_dim,
num_nonzeros,
tensor_begin,
idx_flat_begin,
divisor,
sizes);

const auto wg_sz = std::min(syclMaxWorkGroupSize(kfn), N);
const auto num_wg = (N + wg_sz - 1) / wg_sz;

sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), kfn);
}
if (need_to_copy) {
tensor.copy_(tensor_.t());
} else {
// transpose out so it is correct size
Tensor tensor_temp = tensor_.t();
tensor.set_(tensor_temp);
const int64_t N = self.numel();

Tensor idx_flat = at::empty(
{N}, out.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));

const scalar_t* self_begin = self.const_data_ptr<scalar_t>();
int64_t* idx_flat_begin = idx_flat.data_ptr<int64_t>();
int64_t* range_begin = nullptr;

CopyIfFunc<scalar_t> f(self_begin);
auto idx_flat_end =
pstl::copy_if<int64_t>(range_begin, range_begin + N, idx_flat_begin, f);

auto num_nonzeros = std::distance(idx_flat_begin, idx_flat_end);

bool need_to_copy = out.dim() == 2 &&
out.sizes()[0] == num_nonzeros && out.sizes()[1] == num_dim &&
!out.t().is_contiguous();
Tensor out_ = need_to_copy
? Tensor(at::detail::empty_xpu(
{num_dim, num_nonzeros}, out.options()))
: out.resize_({num_dim, num_nonzeros});

if (num_nonzeros > 0 && num_dim > 0) {
int64_t* out_begin = out_.data_ptr<int64_t>();

// preload sizes tensor for index calculation
int64_t sizes[XPU_MAX_TENSORINFO_DIMS];
int64_t divisor[XPU_MAX_TENSORINFO_DIMS];
sizes[num_dim - 1] = self.size(num_dim - 1);
divisor[num_dim - 1] = 1;
for (auto dim = num_dim - 2; dim >= 0; dim--) {
sizes[dim] = self.size(dim);
divisor[dim] = sizes[dim + 1] * divisor[dim + 1];
}

const int64_t N = num_nonzeros * num_dim;
// restore flatten idx to indices
FlattenIdxtoRealIdxKernelFunctor kfn(
N,
num_dim,
num_nonzeros,
out_begin,
idx_flat_begin,
divisor,
sizes);

const auto wg_sz = std::min(syclMaxWorkGroupSize(kfn), N);
const auto num_wg = (N + wg_sz - 1) / wg_sz;

sycl_kernel_submit(wg_sz * num_wg, wg_sz, getCurrentSYCLQueue(), kfn);
}
if (need_to_copy) {
out.copy_(out_.t());
} else {
tensor = tensor.resize_({num_dim, N}).contiguous().t();
// transpose out so it is correct size
Tensor out_temp = out_.t();
out.set_(out_temp);
}
}

Expand Down
Loading