@@ -134,16 +134,17 @@ func.func @reduce_multiple_results_unreduced(
134
134
%arg0: tensor <2 x64 x13 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" , " y" }, {}, {}]>},
135
135
%arg1: tensor <2 x64 x13 xi32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" , " y" }, {}, {}]>})
136
136
-> (tensor <64 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{}], unreduced ={" x" }>},
137
- tensor <64 xi32 > {sdy.sharding = #sdy.sharding <@mesh , [{}], unreduced ={" y " }>}) {
137
+ tensor <64 xi32 > {sdy.sharding = #sdy.sharding <@mesh , [{}], unreduced ={" x " :( 1 ) 2 }>}) {
138
138
%0 = stablehlo.constant dense <0.000000e+00 > : tensor <f32 >
139
139
%1 = stablehlo.constant dense <0 > : tensor <i32 >
140
140
// CHECK: %[[REDUCE:.*]]:2 = stablehlo.reduce(%arg0 init: %cst), (%arg1 init: %c) across dimensions = [0, 2]
141
- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}], unreduced={"x"}>, <@mesh, [{}], unreduced={"y "}>]>}
141
+ // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}], unreduced={"x"}>, <@mesh, [{}], unreduced={"x "}>]>}
142
142
// CHECK: %[[ALL_REDUCE1:.*]] = sdy.all_reduce {"y"} %[[REDUCE]]#0 out_sharding=<@mesh, [{}], unreduced={"x"}> : tensor<64xf32>
143
- // CHECK-NEXT: %[[ALL_REDUCE2:.*]] = sdy.all_reduce {"x"} %[[REDUCE]]#1 out_sharding=<@mesh, [{}], unreduced={"y"}> : tensor<64xi32>
144
- // CHECK-NEXT: return %[[ALL_REDUCE1]], %[[ALL_REDUCE2]] : tensor<64xf32>, tensor<64xi32>
143
+ // CHECK-NEXT: %[[ALL_REDUCE2:.*]] = sdy.all_reduce {"y"} %[[REDUCE]]#1 out_sharding=<@mesh, [{}], unreduced={"x"}> : tensor<64xi32>
144
+ // CHECK-NEXT: %[[ALL_REDUCE3:.*]] = sdy.all_reduce {"x":(2)2} %[[ALL_REDUCE2]] out_sharding=<@mesh, [{}], unreduced={"x":(1)2}> : tensor<64xi32>
145
+ // CHECK-NEXT: return %[[ALL_REDUCE1]], %[[ALL_REDUCE3]] : tensor<64xf32>, tensor<64xi32>
145
146
%2:2 = stablehlo.reduce (%arg0 init : %0 ), (%arg1 init : %1 ) across dimensions = [0 , 2 ]
146
- {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{}], unreduced ={" x" }>, <@mesh , [{}], unreduced ={" y " }>]>} :
147
+ {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{}], unreduced ={" x" }>, <@mesh , [{}], unreduced ={" x " }>]>} :
147
148
(tensor <2 x64 x13 xf32 >, tensor <2 x64 x13 xi32 >, tensor <f32 >, tensor <i32 >) -> (tensor <64 xf32 >, tensor <64 xi32 >)
148
149
reducer (%arg2: tensor <f32 >, %arg4: tensor <f32 >) (%arg3: tensor <i32 >, %arg5: tensor <i32 >) {
149
150
%3 = stablehlo.add %arg2 , %arg4 : tensor <f32 >
0 commit comments