@@ -292,10 +292,10 @@ TensorView* takeAlongAxis(TensorView* inp, TensorView* index, int64_t dim) {
292292 return out_tensor->as <TensorView>();
293293}
294294
295- TensorView* groupedBlockSfLayout (
295+ TensorView* preprocessGroupedMatmulInputSf (
296296 TensorView* input,
297- TensorView* expert_offsets ,
298- TensorView* sf_offsets ,
297+ TensorView* input_offsets ,
298+ TensorView* output_offsets ,
299299 BlockScalingFactorLayout layout) {
300300 // only support input matrix;
301301 auto input_logical_dom =
@@ -312,9 +312,6 @@ TensorView* groupedBlockSfLayout(
312312 });
313313
314314 // Create the logical domain of output.
315- // Note: output logical domain handles potential padding required for the
316- // layout. Since the actual padding size is data-dependent, we allocate for
317- // the maximum padding (reflected on logical/allocation domain).
318315 std::vector<IterDomain*> out_logical;
319316 out_logical.reserve (input_logical_dom.size ());
320317
@@ -325,11 +322,15 @@ TensorView* groupedBlockSfLayout(
325322
326323 auto * one_val = input->fusion ()->oneVal (DataType::Index);
327324 std::vector<IterDomain*> offset_logical_dom =
328- TensorDomain::noReductions (expert_offsets ->getLogicalDomain ());
325+ TensorDomain::noReductions (input_offsets ->getLogicalDomain ());
329326 Val* num_groups =
330327 SimplifyingIrBuilder::subExpr (offset_logical_dom[0 ]->extent (), one_val);
331- // padded row size:
332- // num_groups * (row_multiple - 1) + row_size
328+
329+ // Note: output logical domain handles potential padding required for the
330+ // layout. Since the actual padding size is data-dependent, we allocate for
331+ // the maximum padding (reflected on logical/allocation domain).
332+
333+ // pad row size: num_groups * (row_multiple - 1) + row_size
333334 auto pad_to_max_extent = [&](IterDomain* id, int multiple) -> IterDomain* {
334335 auto * maximum_pad_value_per_group =
335336 IrBuilder::create<Val>(multiple - 1 , DataType::Index);
@@ -340,8 +341,7 @@ TensorView* groupedBlockSfLayout(
340341 };
341342 out_logical.push_back (pad_to_max_extent (out_root[0 ], row_multiple));
342343
343- // padded col size:
344- // (col_size + col_multiple - 1) / col_multiple * col_multiple
344+ // pad col size: (col_size + col_multiple - 1) / col_multiple * col_multiple
345345 auto pad_to_multiple = [&](IterDomain* id, int multiple) -> IterDomain* {
346346 Val* ext = id->extent ();
347347 auto * multiple_val = IrBuilder::create<Val>(multiple, DataType::Index);
@@ -370,11 +370,11 @@ TensorView* groupedBlockSfLayout(
370370 /* skip_checks=*/ true ),
371371 input->getDataType ().value ());
372372
373- IrBuilder::create<GroupedBlockScalingFactorLayoutOp >(
373+ IrBuilder::create<PreprocessGroupedMatmulInputSf >(
374374 out_tv,
375375 input,
376- expert_offsets ,
377- sf_offsets ,
376+ input_offsets ,
377+ output_offsets ,
378378 layout,
379379 input_logical_dom[1 ]->getMaybeExpandedExtent (),
380380 num_groups);
0 commit comments