Skip to content

Commit a5671d2

Browse files
jianyizhCopilot
andauthored
optimize adptive avg pool (#2012)
follows #1883, shape [4096,256,6,6] channel last with output shape [6,6] in torchbench alexnet can get ~4x improvement on bmg --------- Co-authored-by: Copilot <[email protected]>
1 parent fa9212b commit a5671d2

File tree

2 files changed

+223
-3
lines changed

2 files changed

+223
-3
lines changed

src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp

Lines changed: 222 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/native/AdaptivePooling.h>
55
#include <ATen/native/Pool.h>
66
#include <ATen/native/xpu/sycl/LaunchUtils.h>
7+
#include <ATen/native/xpu/sycl/MemoryAccess.h>
78
#include <comm/MemoryFormat.h>
89
#include <comm/xpu_aten.h>
910
#include <vector>
@@ -627,6 +628,220 @@ struct AdaptiveAvgPool2dKernelFunctor {
627628
PackedTensorAccessor64<scalar_t, 4> output_;
628629
};
629630

631+
template <typename scalar_t, typename opmath_t, typename vec_t, int vec_size>
632+
struct AdaptiveAvgPool2dKernelFunctor_cl {
633+
void operator()(sycl::nd_item<1> item) const {
634+
int64_t index = item.get_global_linear_id();
635+
if (index < numel_) {
636+
int _ow, _oh, _oc, _ob;
637+
int oc_vec_ = oc_ / vec_size;
638+
639+
_oc = index % oc_vec_;
640+
_ow = index / oc_vec_ % ow_;
641+
_oh = index / oc_vec_ / ow_ % oh_;
642+
_ob = index / oc_vec_ / ow_ / oh_;
643+
644+
int64_t _ih0 = native::start_index(_oh, oh_, ih_);
645+
int64_t _ih1 = native::end_index(_oh, oh_, ih_);
646+
int64_t _iw0 = native::start_index(_ow, ow_, iw_);
647+
int64_t _iw1 = native::end_index(_ow, ow_, iw_);
648+
int64_t kh = _ih1 - _ih0;
649+
int64_t kw = _iw1 - _iw0;
650+
int64_t _ib = _ob;
651+
int64_t _ic = _oc;
652+
653+
opmath_t sum[vec_size] = {static_cast<opmath_t>(0)};
654+
for (int _ih = _ih0; _ih < _ih1; _ih++) {
655+
for (int _iw = _iw0; _iw < _iw1; _iw++) {
656+
auto read = input_
657+
[_ic + _iw * oc_vec_ + _ih * oc_vec_ * iw_ +
658+
_ib * ih_ * iw_ * oc_vec_];
659+
#pragma unroll
660+
for (int v = 0; v < vec_size; v++) {
661+
sum[v] += opmath_t(read[v]);
662+
}
663+
}
664+
}
665+
#pragma unroll
666+
for (int v = 0; v < vec_size; v++) {
667+
sum[v] /= kh * kw;
668+
}
669+
vec_t output_value;
670+
#pragma unroll
671+
for (int v = 0; v < vec_size; v++) {
672+
output_value[v] = static_cast<scalar_t>(sum[v]);
673+
}
674+
output_[index] = output_value;
675+
}
676+
}
677+
AdaptiveAvgPool2dKernelFunctor_cl(
678+
vec_t* output,
679+
const vec_t* input,
680+
int ih,
681+
int iw,
682+
int ob,
683+
int oc,
684+
int oh,
685+
int ow,
686+
int64_t numel)
687+
: output_(output),
688+
input_(input),
689+
ih_(ih),
690+
iw_(iw),
691+
ob_(ob),
692+
oc_(oc),
693+
oh_(oh),
694+
ow_(ow),
695+
numel_(numel) {}
696+
697+
private:
698+
int ih_;
699+
int iw_;
700+
int ob_;
701+
int oc_;
702+
int oh_;
703+
int ow_;
704+
int64_t numel_;
705+
const vec_t* input_;
706+
vec_t* output_;
707+
};
708+
709+
#define LAUNCH_AVGPOOL_CHANNEL_LAST_VEC( \
710+
scalar_t, \
711+
opmath_t, \
712+
vec_size, \
713+
num_wg, \
714+
wg_size, \
715+
queue, \
716+
output, \
717+
input, \
718+
ih, \
719+
iw, \
720+
ob, \
721+
oc, \
722+
oh, \
723+
ow, \
724+
numel) \
725+
{ \
726+
using vec_t = memory::aligned_vector<scalar_t, vec_size>; \
727+
vec_t* output_vec = \
728+
reinterpret_cast<vec_t*>(output.mutable_data_ptr<scalar_t>()); \
729+
const vec_t* input_vec = \
730+
reinterpret_cast<const vec_t*>(input.const_data_ptr<scalar_t>()); \
731+
auto kfn = AdaptiveAvgPool2dKernelFunctor_cl< \
732+
scalar_t, \
733+
opmath_t, \
734+
vec_t, \
735+
vec_size>(output_vec, input_vec, ih, iw, ob, oc, oh, ow, numel); \
736+
sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \
737+
}
738+
739+
template <typename scalar_t, typename opmath_t>
740+
void launch_adaptive_avg_pool2d_kernel_cl(const Tensor& input, Tensor& output) {
741+
int ih = input.size(2);
742+
int iw = input.size(3);
743+
int ob = output.size(0);
744+
int oc = output.size(1);
745+
int oh = output.size(2);
746+
int ow = output.size(3);
747+
748+
int64_t numel = ob * oc * oh * ow;
749+
int vec_size = 1;
750+
for (vec_size = std::min(
751+
8,
752+
memory::can_vectorize_up_to<scalar_t>(
753+
(char*)output.mutable_data_ptr<scalar_t>()));
754+
vec_size > 1;
755+
vec_size /= 2) {
756+
if (oc % vec_size != 0)
757+
continue;
758+
if (2 * numel / vec_size > syclMaxWorkItemsPerTile()) {
759+
numel /= vec_size;
760+
break;
761+
}
762+
}
763+
764+
auto wg_size = syclDeviceMaxWorkGroupSize();
765+
int64_t num_wg = (numel + wg_size - 1) / wg_size;
766+
switch (vec_size) {
767+
case 8:
768+
LAUNCH_AVGPOOL_CHANNEL_LAST_VEC(
769+
scalar_t,
770+
opmath_t,
771+
8,
772+
num_wg,
773+
wg_size,
774+
at::xpu::getCurrentSYCLQueue(),
775+
output,
776+
input,
777+
ih,
778+
iw,
779+
ob,
780+
oc,
781+
oh,
782+
ow,
783+
numel);
784+
return;
785+
case 4:
786+
LAUNCH_AVGPOOL_CHANNEL_LAST_VEC(
787+
scalar_t,
788+
opmath_t,
789+
4,
790+
num_wg,
791+
wg_size,
792+
at::xpu::getCurrentSYCLQueue(),
793+
output,
794+
input,
795+
ih,
796+
iw,
797+
ob,
798+
oc,
799+
oh,
800+
ow,
801+
numel);
802+
return;
803+
case 2:
804+
LAUNCH_AVGPOOL_CHANNEL_LAST_VEC(
805+
scalar_t,
806+
opmath_t,
807+
2,
808+
num_wg,
809+
wg_size,
810+
at::xpu::getCurrentSYCLQueue(),
811+
output,
812+
input,
813+
ih,
814+
iw,
815+
ob,
816+
oc,
817+
oh,
818+
ow,
819+
numel);
820+
return;
821+
case 1:
822+
LAUNCH_AVGPOOL_CHANNEL_LAST_VEC(
823+
scalar_t,
824+
opmath_t,
825+
1,
826+
num_wg,
827+
wg_size,
828+
at::xpu::getCurrentSYCLQueue(),
829+
output,
830+
input,
831+
ih,
832+
iw,
833+
ob,
834+
oc,
835+
oh,
836+
ow,
837+
numel);
838+
return;
839+
default:
840+
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
841+
}
842+
}
843+
#undef LAUNCH_AVGPOOL_CHANNEL_LAST_VEC
844+
630845
template <typename scalar_t, typename opmath_t, bool is_channels_last>
631846
void launch_adaptive_avg_pool2d_kernel(
632847
PackedTensorAccessor64<const scalar_t, 4> input,
@@ -724,8 +939,13 @@ void adaptive_avg_pool2d_kernel(
724939
auto iacc = input_.packed_accessor64<const scalar_t, 4>();
725940
auto oacc = output.packed_accessor64<scalar_t, 4>();
726941
if (is_smf_channels_last(output)) {
727-
launch_adaptive_avg_pool2d_kernel<scalar_t, opmath_t, true>(
728-
iacc, oacc);
942+
if (input_.is_contiguous(at::MemoryFormat::ChannelsLast)) {
943+
launch_adaptive_avg_pool2d_kernel_cl<scalar_t, opmath_t>(
944+
input_, output);
945+
} else {
946+
launch_adaptive_avg_pool2d_kernel<scalar_t, opmath_t, true>(
947+
iacc, oacc);
948+
}
729949
} else {
730950
launch_adaptive_avg_pool2d_kernel<scalar_t, opmath_t, false>(
731951
iacc, oacc);

src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ void launch_max_pool2d_kernel(
545545
if constexpr (is_channels_last) {
546546
for (vec_size =
547547
std::min(8, memory::can_vectorize_up_to<scalar_t>((char*)input));
548-
vec_size >= 1;
548+
vec_size > 1;
549549
vec_size /= 2) {
550550
if (numPlane % vec_size != 0) {
551551
continue;

0 commit comments

Comments
 (0)