diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 50796f78b4..77a321c885 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -1886,6 +1886,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } } + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + return true; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 07722155fd..6d3ca9d9dd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -1417,6 +1417,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 return false; } + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) <= TwoGB)) + { + return false; + } + // Gridwise GEMM size return true; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 1c3d8431b8..4fcf717e80 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -1359,6 +1359,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 } } + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(karg.M * karg.K * sizeof(ADataType) <= TwoGB && + karg.N * karg.K * sizeof(BDataType) <= TwoGB && + karg.M * karg.N * sizeof(CDataType) <= TwoGB)) + { + return false; + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index cb841c36ea..422b9afa61 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -581,6 +581,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight return false; } + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_b_k0_m_k1_grid_desc.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB && + b_b_k0_n_k1_grid_desc.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB && + c_m_n_grid_desc.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB)) + { + return false; + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; }