-
Notifications
You must be signed in to change notification settings - Fork 69
Closed
Copy link
Description
Follow-up issue from #4742 (#4742 (comment))
Repro:
TEST_F(ScatterTest, MappedLogicalAndLoop) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);
const int64_t m = 8;
auto tv0 = makeContigConcreteTensor({m}, DataType::Int);
fusion.addInput(tv0);
auto tv1 = makeContigConcreteTensor({m}, DataType::Int);
fusion.addInput(tv1);
auto tv2 = set(tv1);
auto tv3 = arange(IrBuilder::create<Val>(8));
auto tv4 = scatter(tv2, 0, tv0, tv3);
auto tv5 = set(tv4);
fusion.addOutput(tv5);
// Maps the iter domains of tv0 and tv1, which in turn maps the loop
// domain of tv4 with its logical domain
if (getenv("MAP")) {
auto tv6 = add(tv0, tv1);
fusion.addOutput(tv6);
}
for (auto tv : fusion.allTvs()) {
tv->axis(0)->parallelize(ParallelType::TIDx);
}
tv2->setMemoryType(MemoryType::Shared);
tv2->setAllocationDomain(tv2->getLogicalDomain(), true);
tv4->setMemoryType(MemoryType::Shared);
tv4->setAllocationDomain(tv4->getLogicalDomain(), true);
auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
auto t0 = at::randperm(m, options);
auto t1 = at::zeros({m}, options);
KernelExecutor ke;
ke.compile(&fusion, {t0, t1});
auto outputs = ke.run({t0, t1});
testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__);
}
If tv0
and tv1
are used together, it will end up mapping the logical and loop domains of tv4
, which means no sync is inserted.
Withtout mapping:
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<int64_t, 1, 1> T0, Tensor<int64_t, 1, 1> T1, Tensor<int64_t, 1, 1> T5) {
alignas(16) extern __shared__ char array[];
const unsigned smem_offset = 0;
int64_t* T2 = reinterpret_cast<int64_t*>(array + smem_offset + 0);
// Alias Allocation - shared
auto& T4 = T2;
int64_t i0;
i0 = 8LL - 0LL;
int64_t i1;
i1 = abs(i0);
int64_t i2;
i2 = abs(1LL);
int64_t i3;
i3 = ceilDiv(i1, i2);
T2[((nvfuser_index_t)threadIdx.x)]
= T1[((nvfuser_index_t)threadIdx.x)];
Array<int64_t, 1, 1> T3;
T3[0] = ((nvfuser_index_t)threadIdx.x);
__syncthreads();
T4[__to_index(T0[((nvfuser_index_t)threadIdx.x)])] = T3[0];
__syncthreads();
T5[((nvfuser_index_t)threadIdx.x)]
= T4[((nvfuser_index_t)threadIdx.x)];
}
With mapping:
// Codegen generated code
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<int64_t, 1, 1> T0, Tensor<int64_t, 1, 1> T1, Tensor<int64_t, 1, 1> T5, Tensor<int64_t, 1, 1> T6) {
alignas(16) extern __shared__ char array[];
const unsigned smem_offset = 0;
int64_t* T2 = reinterpret_cast<int64_t*>(array + smem_offset + 0);
// Alias Allocation - shared
auto& T4 = T2;
int64_t i0;
i0 = 8LL - 0LL;
int64_t i1;
i1 = abs(i0);
int64_t i2;
i2 = abs(1LL);
int64_t i3;
i3 = ceilDiv(i1, i2);
T6[((nvfuser_index_t)threadIdx.x)]
= T0[((nvfuser_index_t)threadIdx.x)]
+ T1[((nvfuser_index_t)threadIdx.x)];
T2[((nvfuser_index_t)threadIdx.x)]
= T1[((nvfuser_index_t)threadIdx.x)];
Array<int64_t, 1, 1> T3;
T3[0] = ((nvfuser_index_t)threadIdx.x);
T4[((nvfuser_index_t)threadIdx.x)] = T3[0];
T5[((nvfuser_index_t)threadIdx.x)]
= T4[((nvfuser_index_t)threadIdx.x)];
}
Notice the lack of the syncthreads.
The indexing of T4
is also wrong. Not sure why yet.
T4[((nvfuser_index_t)threadIdx.x)] = T3[0];