|
4 | 4 | #include <ATen/native/AdaptivePooling.h>
|
5 | 5 | #include <ATen/native/Pool.h>
|
6 | 6 | #include <ATen/native/xpu/sycl/LaunchUtils.h>
|
| 7 | +#include <ATen/native/xpu/sycl/MemoryAccess.h> |
7 | 8 | #include <comm/MemoryFormat.h>
|
8 | 9 | #include <comm/xpu_aten.h>
|
9 | 10 | #include <vector>
|
@@ -627,6 +628,220 @@ struct AdaptiveAvgPool2dKernelFunctor {
|
627 | 628 | PackedTensorAccessor64<scalar_t, 4> output_;
|
628 | 629 | };
|
629 | 630 |
|
| 631 | +template <typename scalar_t, typename opmath_t, typename vec_t, int vec_size> |
| 632 | +struct AdaptiveAvgPool2dKernelFunctor_cl { |
| 633 | + void operator()(sycl::nd_item<1> item) const { |
| 634 | + int64_t index = item.get_global_linear_id(); |
| 635 | + if (index < numel_) { |
| 636 | + int _ow, _oh, _oc, _ob; |
| 637 | + int oc_vec_ = oc_ / vec_size; |
| 638 | + |
| 639 | + _oc = index % oc_vec_; |
| 640 | + _ow = index / oc_vec_ % ow_; |
| 641 | + _oh = index / oc_vec_ / ow_ % oh_; |
| 642 | + _ob = index / oc_vec_ / ow_ / oh_; |
| 643 | + |
| 644 | + int64_t _ih0 = native::start_index(_oh, oh_, ih_); |
| 645 | + int64_t _ih1 = native::end_index(_oh, oh_, ih_); |
| 646 | + int64_t _iw0 = native::start_index(_ow, ow_, iw_); |
| 647 | + int64_t _iw1 = native::end_index(_ow, ow_, iw_); |
| 648 | + int64_t kh = _ih1 - _ih0; |
| 649 | + int64_t kw = _iw1 - _iw0; |
| 650 | + int64_t _ib = _ob; |
| 651 | + int64_t _ic = _oc; |
| 652 | + |
| 653 | + opmath_t sum[vec_size] = {static_cast<opmath_t>(0)}; |
| 654 | + for (int _ih = _ih0; _ih < _ih1; _ih++) { |
| 655 | + for (int _iw = _iw0; _iw < _iw1; _iw++) { |
| 656 | + auto read = input_ |
| 657 | + [_ic + _iw * oc_vec_ + _ih * oc_vec_ * iw_ + |
| 658 | + _ib * ih_ * iw_ * oc_vec_]; |
| 659 | +#pragma unroll |
| 660 | + for (int v = 0; v < vec_size; v++) { |
| 661 | + sum[v] += opmath_t(read[v]); |
| 662 | + } |
| 663 | + } |
| 664 | + } |
| 665 | +#pragma unroll |
| 666 | + for (int v = 0; v < vec_size; v++) { |
| 667 | + sum[v] /= kh * kw; |
| 668 | + } |
| 669 | + vec_t output_value; |
| 670 | +#pragma unroll |
| 671 | + for (int v = 0; v < vec_size; v++) { |
| 672 | + output_value[v] = static_cast<scalar_t>(sum[v]); |
| 673 | + } |
| 674 | + output_[index] = output_value; |
| 675 | + } |
| 676 | + } |
| 677 | + AdaptiveAvgPool2dKernelFunctor_cl( |
| 678 | + vec_t* output, |
| 679 | + const vec_t* input, |
| 680 | + int ih, |
| 681 | + int iw, |
| 682 | + int ob, |
| 683 | + int oc, |
| 684 | + int oh, |
| 685 | + int ow, |
| 686 | + int64_t numel) |
| 687 | + : output_(output), |
| 688 | + input_(input), |
| 689 | + ih_(ih), |
| 690 | + iw_(iw), |
| 691 | + ob_(ob), |
| 692 | + oc_(oc), |
| 693 | + oh_(oh), |
| 694 | + ow_(ow), |
| 695 | + numel_(numel) {} |
| 696 | + |
| 697 | + private: |
| 698 | + int ih_; |
| 699 | + int iw_; |
| 700 | + int ob_; |
| 701 | + int oc_; |
| 702 | + int oh_; |
| 703 | + int ow_; |
| 704 | + int64_t numel_; |
| 705 | + const vec_t* input_; |
| 706 | + vec_t* output_; |
| 707 | +}; |
| 708 | + |
| 709 | +#define LAUNCH_AVGPOOL_CHANNEL_LAST_VEC( \ |
| 710 | + scalar_t, \ |
| 711 | + opmath_t, \ |
| 712 | + vec_size, \ |
| 713 | + num_wg, \ |
| 714 | + wg_size, \ |
| 715 | + queue, \ |
| 716 | + output, \ |
| 717 | + input, \ |
| 718 | + ih, \ |
| 719 | + iw, \ |
| 720 | + ob, \ |
| 721 | + oc, \ |
| 722 | + oh, \ |
| 723 | + ow, \ |
| 724 | + numel) \ |
| 725 | + { \ |
| 726 | + using vec_t = memory::aligned_vector<scalar_t, vec_size>; \ |
| 727 | + vec_t* output_vec = \ |
| 728 | + reinterpret_cast<vec_t*>(output.mutable_data_ptr<scalar_t>()); \ |
| 729 | + const vec_t* input_vec = \ |
| 730 | + reinterpret_cast<const vec_t*>(input.const_data_ptr<scalar_t>()); \ |
| 731 | + auto kfn = AdaptiveAvgPool2dKernelFunctor_cl< \ |
| 732 | + scalar_t, \ |
| 733 | + opmath_t, \ |
| 734 | + vec_t, \ |
| 735 | + vec_size>(output_vec, input_vec, ih, iw, ob, oc, oh, ow, numel); \ |
| 736 | + sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \ |
| 737 | + } |
| 738 | + |
| 739 | +template <typename scalar_t, typename opmath_t> |
| 740 | +void launch_adaptive_avg_pool2d_kernel_cl(const Tensor& input, Tensor& output) { |
| 741 | + int ih = input.size(2); |
| 742 | + int iw = input.size(3); |
| 743 | + int ob = output.size(0); |
| 744 | + int oc = output.size(1); |
| 745 | + int oh = output.size(2); |
| 746 | + int ow = output.size(3); |
| 747 | + |
| 748 | + int64_t numel = ob * oc * oh * ow; |
| 749 | + int vec_size = 1; |
| 750 | + for (vec_size = std::min( |
| 751 | + 8, |
| 752 | + memory::can_vectorize_up_to<scalar_t>( |
| 753 | + (char*)output.mutable_data_ptr<scalar_t>())); |
| 754 | + vec_size > 1; |
| 755 | + vec_size /= 2) { |
| 756 | + if (oc % vec_size != 0) |
| 757 | + continue; |
| 758 | + if (2 * numel / vec_size > syclMaxWorkItemsPerTile()) { |
| 759 | + numel /= vec_size; |
| 760 | + break; |
| 761 | + } |
| 762 | + } |
| 763 | + |
| 764 | + auto wg_size = syclDeviceMaxWorkGroupSize(); |
| 765 | + int64_t num_wg = (numel + wg_size - 1) / wg_size; |
| 766 | + switch (vec_size) { |
| 767 | + case 8: |
| 768 | + LAUNCH_AVGPOOL_CHANNEL_LAST_VEC( |
| 769 | + scalar_t, |
| 770 | + opmath_t, |
| 771 | + 8, |
| 772 | + num_wg, |
| 773 | + wg_size, |
| 774 | + at::xpu::getCurrentSYCLQueue(), |
| 775 | + output, |
| 776 | + input, |
| 777 | + ih, |
| 778 | + iw, |
| 779 | + ob, |
| 780 | + oc, |
| 781 | + oh, |
| 782 | + ow, |
| 783 | + numel); |
| 784 | + return; |
| 785 | + case 4: |
| 786 | + LAUNCH_AVGPOOL_CHANNEL_LAST_VEC( |
| 787 | + scalar_t, |
| 788 | + opmath_t, |
| 789 | + 4, |
| 790 | + num_wg, |
| 791 | + wg_size, |
| 792 | + at::xpu::getCurrentSYCLQueue(), |
| 793 | + output, |
| 794 | + input, |
| 795 | + ih, |
| 796 | + iw, |
| 797 | + ob, |
| 798 | + oc, |
| 799 | + oh, |
| 800 | + ow, |
| 801 | + numel); |
| 802 | + return; |
| 803 | + case 2: |
| 804 | + LAUNCH_AVGPOOL_CHANNEL_LAST_VEC( |
| 805 | + scalar_t, |
| 806 | + opmath_t, |
| 807 | + 2, |
| 808 | + num_wg, |
| 809 | + wg_size, |
| 810 | + at::xpu::getCurrentSYCLQueue(), |
| 811 | + output, |
| 812 | + input, |
| 813 | + ih, |
| 814 | + iw, |
| 815 | + ob, |
| 816 | + oc, |
| 817 | + oh, |
| 818 | + ow, |
| 819 | + numel); |
| 820 | + return; |
| 821 | + case 1: |
| 822 | + LAUNCH_AVGPOOL_CHANNEL_LAST_VEC( |
| 823 | + scalar_t, |
| 824 | + opmath_t, |
| 825 | + 1, |
| 826 | + num_wg, |
| 827 | + wg_size, |
| 828 | + at::xpu::getCurrentSYCLQueue(), |
| 829 | + output, |
| 830 | + input, |
| 831 | + ih, |
| 832 | + iw, |
| 833 | + ob, |
| 834 | + oc, |
| 835 | + oh, |
| 836 | + ow, |
| 837 | + numel); |
| 838 | + return; |
| 839 | + default: |
| 840 | + TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); |
| 841 | + } |
| 842 | +} |
| 843 | +#undef LAUNCH_AVGPOOL_CHANNEL_LAST_VEC |
| 844 | + |
630 | 845 | template <typename scalar_t, typename opmath_t, bool is_channels_last>
|
631 | 846 | void launch_adaptive_avg_pool2d_kernel(
|
632 | 847 | PackedTensorAccessor64<const scalar_t, 4> input,
|
@@ -724,8 +939,13 @@ void adaptive_avg_pool2d_kernel(
|
724 | 939 | auto iacc = input_.packed_accessor64<const scalar_t, 4>();
|
725 | 940 | auto oacc = output.packed_accessor64<scalar_t, 4>();
|
726 | 941 | if (is_smf_channels_last(output)) {
|
727 |
| - launch_adaptive_avg_pool2d_kernel<scalar_t, opmath_t, true>( |
728 |
| - iacc, oacc); |
| 942 | + if (input_.is_contiguous(at::MemoryFormat::ChannelsLast)) { |
| 943 | + launch_adaptive_avg_pool2d_kernel_cl<scalar_t, opmath_t>( |
| 944 | + input_, output); |
| 945 | + } else { |
| 946 | + launch_adaptive_avg_pool2d_kernel<scalar_t, opmath_t, true>( |
| 947 | + iacc, oacc); |
| 948 | + } |
729 | 949 | } else {
|
730 | 950 | launch_adaptive_avg_pool2d_kernel<scalar_t, opmath_t, false>(
|
731 | 951 | iacc, oacc);
|
|
0 commit comments