@@ -11793,5 +11793,79 @@ ENTRY entry_computation {
1179311793 }
1179411794}
1179511795
11796+ TEST_F (ShardingPropagationTest, ShardAsWithShardBarrier) {
11797+ const char * const hlo_string = R"(
11798+ HloModule pjit_f
11799+
11800+ ENTRY main.11 {
11801+ Arg_0.1 = bf16[384,1408]{1,0} parameter(0), sharding={devices=[1,16,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate}
11802+ broadcast.4 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2}
11803+ custom-call.5 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.4), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11804+ broadcast.2 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2}
11805+ custom-call.3 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.2), custom_call_target="Sharding", sharding={devices=[8,1,1,1024]<=[8192] last_tile_dim_replicate}, backend_config="unspecified_dims=[1,2]"
11806+ custom-call.6 = bf16[8,384,1408]{2,1,0} custom-call(custom-call.3), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11807+ %shard-barrier-to = bf16[8,384,1408]{2,1,0} custom-call(%custom-call.6), custom_call_target="ShardBarrierTo", custom_call_has_side_effect=true
11808+ slice.7 = bf16[1,384,1408]{2,1,0} slice(shard-barrier-to), slice={[1:2], [0:384], [0:1408]}
11809+ reshape.8 = bf16[384,1408]{1,0} reshape(slice.7)
11810+ tuple.9 = (bf16[384,1408]{1,0}) tuple(reshape.8)
11811+ get-tuple-element.10 = bf16[384,1408]{1,0} get-tuple-element(tuple.9), index=0, sharding={devices=[16,1,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate}
11812+ ROOT tuple.13 = (bf16[384,1408]{1,0}, bf16[8,384,1408]{2,1,0}) tuple(get-tuple-element.10, custom-call.5)
11813+ })" ;
11814+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
11815+ ParseAndReturnVerifiedModule (hlo_string));
11816+ TF_ASSERT_OK_AND_ASSIGN (
11817+ bool changed,
11818+ ShardingPropagation (
11819+ /* is_spmd=*/ true , /* propagate_metadata=*/ true ,
11820+ /* allow_spmd_sharding_propagation_to_output=*/ {true },
11821+ /* allow_spmd_sharding_propagation_to_parameters=*/ {false , false })
11822+ .Run (module .get ()));
11823+ EXPECT_TRUE (changed);
11824+
11825+ XLA_VLOG_LINES (1 , module ->ToString ());
11826+ auto * broadcast_4 = FindInstruction (module .get (), " broadcast.4" );
11827+ ASSERT_NE (broadcast_4, nullptr );
11828+ EXPECT_THAT (
11829+ broadcast_4,
11830+ op::Sharding (" {devices=[8,1,16,64]<=[8192] last_tile_dim_replicate}" ));
11831+ auto * copy = FindInstruction (module .get (), " copy" );
11832+ ASSERT_NE (copy, nullptr );
11833+ EXPECT_THAT (
11834+ copy,
11835+ op::Sharding (" {devices=[8,1,16,64]<=[8192] last_tile_dim_replicate}" ));
11836+ }
11837+
11838+ TEST_F (ShardingPropagationTest, ShardAsWithShardBarrier2) {
11839+ const char * const hlo_string = R"(
11840+ HloModule module
11841+ ENTRY %elementwise {
11842+ %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0)
11843+ %custom-call.0 = f32[5,7,11,13]{3,2,1,0} custom-call(param0), custom_call_target="Sharding", sharding={devices=[2,1,1,1,4]<=[8] last_tile_dim_replicate}, backend_config="unspecified_dims=[1,2,3]"
11844+ %shard-barrier-from = f32[5,7,11,13]{3,2,1,0} custom-call(%custom-call.0), custom_call_target="ShardBarrierFrom", custom_call_has_side_effect=true
11845+ %custom-call.2 = f32[5,7,11,13]{3,2,1,0} custom-call(shard-barrier-from), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11846+ %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1)
11847+ %custom-call.1 = f32[5,7,11,13]{3,2,1,0} custom-call(param1), custom_call_target="Sharding", sharding={devices=[1,2,2,1,2]<=[2,4]T(1,0) last_tile_dim_replicate}, backend_config="unspecified_dims=[0]"
11848+ %custom-call.3 = f32[5,7,11,13]{3,2,1,0} custom-call(custom-call.1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
11849+ ROOT %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple(%custom-call.0, %custom-call.3)
11850+ })" ;
11851+ TF_ASSERT_OK_AND_ASSIGN (auto module ,
11852+ ParseAndReturnVerifiedModule (hlo_string));
11853+ TF_ASSERT_OK_AND_ASSIGN (
11854+ bool changed,
11855+ ShardingPropagation (
11856+ /* is_spmd=*/ true , /* propagate_metadata=*/ true ,
11857+ /* allow_spmd_sharding_propagation_to_output=*/ {true },
11858+ /* allow_spmd_sharding_propagation_to_parameters=*/ {false , false })
11859+ .Run (module .get ()));
11860+ EXPECT_TRUE (changed);
11861+
11862+ XLA_VLOG_LINES (1 , module ->ToString ());
11863+ EXPECT_THAT (
11864+ module ->entry_computation ()->root_instruction (),
11865+ op::Sharding (
11866+ " {{devices=[2,2,2,1]<=[8]}, {devices=[1,2,2,1,2]<=[2,4]T(1,0) "
11867+ " last_tile_dim_replicate}}" ));
11868+ }
11869+
1179611870} // namespace
1179711871} // namespace xla
0 commit comments