diff --git a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp index 34619786d7e..6cb70049041 100644 --- a/src/cpu/x64/jit_brgemm_inner_product_utils.cpp +++ b/src/cpu/x64/jit_brgemm_inner_product_utils.cpp @@ -1437,7 +1437,11 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, if (jbgp.weights_decompression) { jbgp.src_quant_group_size = jbgp.ic; jbgp.src_sum_group_size = jbgp.ic; + const size_t simd_width = 16; if (!attr.src_dyn_quant_params_.has_default_values()) { + if (jbgp.ic < static_cast(simd_width)) { + return status::unimplemented; + } jbgp.with_src_dynamic_quant = true; } @@ -1490,15 +1494,18 @@ status_t jit_brgemm_ip_conf_t::init_conf_base(cpu_isa_t isa, one_of(jbgp.wei_decomp_zero_points_dt, u8, data_type::undef))) return status::unimplemented; - const size_t simd_width = 16; if (jbgp.src_quant_group_size == 0 || jbgp.src_quant_group_size % simd_width) return status::unimplemented; jbgp.orig_src_dt = jbgp.src_dt; jbgp.src_dt = s8; - size_t rd_unroll = jbgp.src_quant_group_size; - jbgp.src_sum_group_size = nstl::min(rd_unroll, min_group_size); + if (jbgp.src_quant_group_size < min_group_size) + min_group_size = jbgp.src_quant_group_size; + jbgp.src_sum_group_size = min_group_size; + + if (jbgp.src_quant_group_size == 0 || jbgp.src_quant_group_size % simd_width) + return status::unimplemented; if (jbgp.wei_scales_ic_group_size != static_cast(jbgp.ic) && jbgp.wei_scales_ic_group_size % jbgp.src_sum_group_size) return status::unimplemented;