@@ -347,11 +347,14 @@ llvm::Value* IRBuilder::create_inbounds_gep(const std::string& var_name, llvm::V
347
347
348
348
// Since we index through the pointer, we need an extra 0 index in the indices list for GEP.
349
349
ValueVector indices{llvm::ConstantInt::get (get_i64_type (), 0 ), index};
350
- return builder.CreateInBoundsGEP (variable_ptr, indices);
350
+ llvm::Type* variable_type = variable_ptr->getType ()->getPointerElementType ();
351
+ return builder.CreateInBoundsGEP (variable_type, variable_ptr, indices);
351
352
}
352
353
353
354
llvm::Value* IRBuilder::create_inbounds_gep (llvm::Value* variable, llvm::Value* index) {
354
- return builder.CreateInBoundsGEP (variable, {index});
355
+ ValueVector indices{index};
356
+ llvm::Type* variable_type = variable->getType ()->getPointerElementType ();
357
+ return builder.CreateInBoundsGEP (variable_type, variable, indices);
355
358
}
356
359
357
360
llvm::Value* IRBuilder::create_index (llvm::Value* value) {
@@ -378,23 +381,25 @@ llvm::Value* IRBuilder::create_index(llvm::Value* value) {
378
381
379
382
llvm::Value* IRBuilder::create_load (const std::string& name, bool masked) {
380
383
llvm::Value* ptr = lookup_value (name);
384
+ llvm::Type* loaded_type = ptr->getType ()->getPointerElementType ();
381
385
382
386
// Check if the generated IR is vectorized and masked.
383
387
if (masked) {
384
- return builder.CreateMaskedLoad (ptr, llvm::Align (), mask);
388
+ builder.CreateMaskedLoad (loaded_type, ptr, llvm::Align (), mask);
385
389
}
386
- llvm::Type* loaded_type = ptr->getType ()->getPointerElementType ();
387
390
llvm::Value* loaded = builder.CreateLoad (loaded_type, ptr);
388
391
value_stack.push_back (loaded);
389
392
return loaded;
390
393
}
391
394
392
395
llvm::Value* IRBuilder::create_load (llvm::Value* ptr, bool masked) {
396
+ llvm::Type* loaded_type = ptr->getType ()->getPointerElementType ();
397
+
393
398
// Check if the generated IR is vectorized and masked.
394
399
if (masked) {
395
- return builder.CreateMaskedLoad (ptr, llvm::Align (), mask);
400
+ builder.CreateMaskedLoad (loaded_type, ptr, llvm::Align (), mask);
396
401
}
397
- llvm::Type* loaded_type = ptr-> getType ()-> getPointerElementType ();
402
+
398
403
llvm::Value* loaded = builder.CreateLoad (loaded_type, ptr);
399
404
value_stack.push_back (loaded);
400
405
return loaded;
@@ -466,7 +471,9 @@ llvm::Value* IRBuilder::get_struct_member_ptr(llvm::Value* struct_variable, int
466
471
ValueVector indices;
467
472
indices.push_back (llvm::ConstantInt::get (get_i32_type (), 0 ));
468
473
indices.push_back (llvm::ConstantInt::get (get_i32_type (), member_index));
469
- return builder.CreateInBoundsGEP (struct_variable, indices);
474
+
475
+ llvm::Type* type = struct_variable->getType ()->getPointerElementType ();
476
+ return builder.CreateInBoundsGEP (type, struct_variable, indices);
470
477
}
471
478
472
479
void IRBuilder::invert_mask () {
@@ -491,14 +498,23 @@ llvm::Value* IRBuilder::load_to_or_store_from_array(const std::string& id_name,
491
498
bool generating_vector_ir = vector_width > 1 && vectorize;
492
499
493
500
// If the vector code is generated, we need to distinguish between two cases. If the array is
494
- // indexed indirectly (i.e. not by an induction variable `kernel_id`), create a gather
495
- // instruction .
501
+ // indexed indirectly (i.e. not by an induction variable `kernel_id`), create gather/scatter
502
+ // instructions .
496
503
if (id_name != kernel_id && generating_vector_ir) {
497
- return maybe_value_to_store ? builder.CreateMaskedScatter (maybe_value_to_store,
498
- element_ptr,
499
- llvm::Align (),
500
- mask)
501
- : builder.CreateMaskedGather (element_ptr, llvm::Align (), mask);
504
+ if (maybe_value_to_store) {
505
+ return builder.CreateMaskedScatter (maybe_value_to_store,
506
+ element_ptr,
507
+ llvm::Align (),
508
+ mask);
509
+ } else {
510
+ // Construct the loaded vector type.
511
+ auto * ptrs = llvm::cast<llvm::VectorType>(element_ptr->getType ());
512
+ llvm::ElementCount element_count = ptrs->getElementCount ();
513
+ llvm::Type* element_type = ptrs->getElementType ()->getPointerElementType ();
514
+ llvm::Type* loaded_type = llvm::VectorType::get (element_type, element_count);
515
+
516
+ return builder.CreateMaskedGather (loaded_type, element_ptr, llvm::Align (), mask);
517
+ }
502
518
}
503
519
504
520
llvm::Value* ptr;
0 commit comments