@@ -127,6 +127,10 @@ llvm::cl::opt<bool>
127127llvm::cl::opt<bool >
128128 EnzymePrintDiffUse (" enzyme-print-diffuse" , cl::init(false ), cl::Hidden,
129129 cl::desc(" Print differential use analysis" ));
130+
131+ llvm::cl::opt<std::string>
132+ EnzymeRustDeallocName (" rust-dealloc-name" , cl::init(" " ), cl::Hidden,
133+ cl::desc(" Name of Rust deallocation function" ));
130134}
131135
132136SmallVector<unsigned int , 9 > MD_ToCopy = {
@@ -9474,6 +9478,7 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder,
94749478 GradientUtils *gutils) {
94759479 assert (isAllocationFunction (allocationfn, TLI));
94769480
9481+ #if LLVM_VERSION_MAJOR >= 17
94779482 std::string demangledName = llvm::demangle (allocationfn);
94789483 if (demangledName == " __rustc::__rust_alloc" || demangledName == " __rustc::__rust_alloc_zeroed" ) {
94799484 Type *VoidTy = Type::getVoidTy (tofree->getContext ());
@@ -9482,10 +9487,26 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder,
94829487 Type *inTys[3 ] = {IntPtrTy, RustSz, RustSz};
94839488
94849489 auto FT = FunctionType::get (VoidTy, inTys, false );
9490+ if (EnzymeRustDeallocName == " " ) {
9491+ // Rust's (de)alloc names aren't stable. We expect rustc to set them
9492+ // for us, but if it fails to do so we instead search for it here.
9493+ for (auto &F : *builder.GetInsertBlock ()->getParent ()->getParent ()) {
9494+ auto demangledName = llvm::demangle (F.getName ());
9495+ if (demangledName == " __rustc::__rust_dealloc" ) {
9496+ EnzymeRustDeallocName = F.getName ();
9497+ break ;
9498+ }
9499+ }
9500+ if (EnzymeRustDeallocName == " " ) {
9501+ // If we can't find it, use the raw __rust_dealloc as a fallback.
9502+ // FIXME: Make this a hard error once we pass the right name from rustc.
9503+ EnzymeRustDeallocName = " __rust_dealloc" ;
9504+ }
9505+ }
94859506 Value *freevalue = builder.GetInsertBlock ()
94869507 ->getParent ()
94879508 ->getParent ()
9488- ->getOrInsertFunction (" __rust_dealloc " , FT)
9509+ ->getOrInsertFunction (EnzymeRustDeallocName , FT)
94899510 .getCallee ();
94909511 Value *vals[3 ];
94919512 vals[0 ] = builder.CreatePointerCast (tofree, IntPtrTy);
@@ -9509,6 +9530,7 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder,
95099530 builder.Insert (freecall);
95109531 return freecall;
95119532 }
9533+ #endif
95129534 if (allocationfn == " julia.gc_alloc_obj" ||
95139535 allocationfn == " jl_gc_alloc_typed" ||
95149536 allocationfn == " ijl_gc_alloc_typed" ||
0 commit comments