diff --git a/.gitignore b/.gitignore index 5e7285d5f0af..62c76d14f13d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ enzyme/benchmarks/ReverseMode/*/*.ll enzyme/benchmarks/ReverseMode/*/*.bc enzyme/benchmarks/ReverseMode/*/*.o enzyme/benchmarks/ReverseMode/*/*.exe +enzyme/benchmarks/ReverseMode/*/target/ enzyme/benchmarks/ReverseMode/*/results.txt enzyme/benchmarks/ReverseMode/*/results.json .cache diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 629cecdc578f..6d277511ca54 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -8,8 +8,10 @@ endif() get_target_property(TBL_LINKED_LIBS LLVMSupport INTERFACE_LINK_LIBRARIES) if (NOT TBL_LINKED_LIBS) +message(STATUS "No TBL_LINKED_LIBS found") else() -list(REMOVE_ITEM TBL_LINKED_LIBS "ZLIB::ZLIB") +message(STATUS "TBL_LINKED_LIBS (test): ${TBL_LINKED_LIBS}") +#list(REMOVE_ITEM TBL_LINKED_LIBS "ZLIB::ZLIB") set_property(TARGET LLVMSupport PROPERTY INTERFACE_LINK_LIBRARIES ${TBL_LINKED_LIBS}) endif() @@ -145,7 +147,33 @@ if (${ENZYME_EXTERNAL_SHARED_LIB}) add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasDeclarationsIncGen) add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasTAIncGen) add_dependencies(Enzyme-${LLVM_VERSION_MAJOR} BlasDiffUseIncGen) - target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} LLVM) + + # This would be the desired way to link against LLVM components, + # however this function is bugged and does not work with `all`, see: + # https://github.com/llvm/llvm-project/issues/46347 + #llvm_map_components_to_libnames(llvm_libraries Passes) + # Therefore, manually invoke llvm-config + if (EXISTS "${LLVM_TOOLS_BINARY_DIR}/llvm-config") + message(STATUS "Using llvm-config from ${LLVM_TOOLS_BINARY_DIR}") + else() + message(SEND_ERROR "llvm-config not found in ${LLVM_TOOLS_BINARY_DIR}") + endif() + execute_process(COMMAND ${LLVM_TOOLS_BINARY_DIR}/llvm-config --libs all + OUTPUT_VARIABLE llvm_libraries) + string(STRIP "${llvm_libraries}" llvm_libraries) + message(STATUS "Linking against LLVM libraries: ${llvm_libraries}") + # In theory, adding --libs should also add all the -l flags, + # but it isn't picked up correctly by clang, so we call target_link_libraries + execute_process(COMMAND ${LLVM_TOOLS_BINARY_DIR}/llvm-config --ldflags + OUTPUT_VARIABLE llvm_ldflags) + string(STRIP "${llvm_ldflags}" llvm_ldflags) + message(STATUS "Linking against LLVM ldflags: ${llvm_ldflags}") + set_target_properties(Enzyme-${LLVM_VERSION_MAJOR} PROPERTIES LINK_FLAGS ${llvm_ldflags}) + target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} ${llvm_libraries}) + + llvm_map_components_to_libnames(llvm_librariess)# Passes Support) + target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} ${llvm_librariess}) + install(TARGETS Enzyme-${LLVM_VERSION_MAJOR} EXPORT EnzymeTargets LIBRARY DESTINATION lib COMPONENT shlib diff --git a/enzyme/benchmarks/ReverseMode/adbench/Makefile.config b/enzyme/benchmarks/ReverseMode/adbench/Makefile.config new file mode 100644 index 000000000000..c620d4a3b710 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/adbench/Makefile.config @@ -0,0 +1,9 @@ +CLANG := /home/manuel/prog/rust-middle/build/x86_64-unknown-linux-gnu/llvm/build/bin/clang++ +OPT := /home/manuel/prog/rust-middle/build/x86_64-unknown-linux-gnu/llvm/build/bin/opt + +PASSES1 := verify,annotation2metadata,forceattrs,inferattrs,coro-early,function(ee-instrument<>,lower-expect,simplifycfg,sroa,early-cse<>,callsite-splitting),openmp-opt,ipsccp,called-value-propagation,globalopt,function(mem2reg,instcombine,simplifycfg),always-inline,require,function(invalidate),require,cgscc(devirt<4>(inline,function-attrs,argpromotion,openmp-opt-cgscc,function(sroa,early-cse,speculative-execution,jump-threading,correlated-propagation,simplifycfg,instcombine,aggressive-instcombine,libcalls-shrinkwrap,tailcallelim,simplifycfg,reassociate,constraint-elimination,loop-mssa(loop-instsimplify,loop-simplifycfg,licm,loop-rotate,licm,simple-loop-unswitch),simplifycfg,instcombine,loop(loop-idiom,indvars,extra-simple-loop-unswitch-passes,loop-deletion,loop-unroll-full),sroa,vector-combine,mldst-motion,gvn<>,sccp,bdce,instcombine,jump-threading,correlated-propagation,adce,memcpyopt,dse,move-auto-init,loop-mssa(licm),coro-elide,simplifycfg,instcombine),function-attrs,function(require),coro-split,coro-annotation-elide)),deadargelim,coro-cleanup,globalopt,globaldce,rpo-function-attrs,recompute-globalsaa,function(float2int,lower-constant-intrinsics,chr,loop(loop-rotate,loop-deletion),loop-distribute,inject-tli-mappings,loop-vectorize,infer-alignment,loop-load-elim,instcombine,simplifycfg,vector-combine,instcombine,loop-unroll,transform-warning,sroa,infer-alignment,instcombine,loop-mssa(licm),alignment-from-assumptions,loop-sink,instsimplify,div-rem-pairs,tailcallelim,simplifycfg),globaldce,constmerge,function(annotation-remarks),canonicalize-aliases,name-anon-globals,verify + +PASSES2 := cross-dso-cfi,openmp-opt,globaldce,inferattrs,function(callsite-splitting),pgo-icall-prom,cgscc(function-attrs,argpromotion,function(sroa)),ipsccp,called-value-propagation,rpo-function-attrs,globalsplit,wholeprogramdevirt,globalopt,function(mem2reg),constmerge,deadargelim,function(instcombine,aggressive-instcombine),expand-variadics,cgscc(inline,inline),globalopt,openmp-opt,globaldce,cgscc(argpromotion),function(instcombine,constraint-elimination,jump-threading,sroa,tailcallelim),cgscc(function-attrs),require,function(invalidate),cgscc(openmp-opt-cgscc),function(loop-mssa(licm),gvn<>,memcpyopt,dse,move-auto-init,mldst-motion,loop(indvars,loop-deletion,loop-unroll-full),loop-distribute,loop-vectorize,infer-alignment,loop-unroll,transform-warning,sroa,instcombine,simplifycfg,sccp,instcombine,bdce,vector-combine,infer-alignment,instcombine,loop-mssa(licm),alignment-from-assumptions,jump-threading),lowertypetests,lowertypetests,function(loop-sink,div-rem-pairs,simplifycfg),elim-avail-extern,globaldce,rel-lookup-table-converter,cg-profile,function(annotation-remarks),canonicalize-aliases,name-anon-globals +#PASSES2 := cross-dso-cfi,openmp-opt,globaldce,inferattrs,function(callsite-splitting),pgo-icall-prom,cgscc(function-attrs,argpromotion,function(sroa)),ipsccp,called-value-propagation,rpo-function-attrs,globalsplit,wholeprogramdevirt,globalopt,function(mem2reg),constmerge,deadargelim,function(instcombine,aggressive-instcombine),expand-variadics,cgscc(inline,inline),globalopt,openmp-opt,globaldce,cgscc(argpromotion),function(instcombine,constraint-elimination,jump-threading,sroa,tailcallelim),cgscc(function-attrs),require,function(invalidate),cgscc(openmp-opt-cgscc),function(loop-mssa(licm),gvn<>,memcpyopt,dse,move-auto-init,mldst-motion,loop(indvars,loop-deletion,loop-unroll-full),loop-distribute,loop-vectorize,infer-alignment,loop-unroll,transform-warning,sroa,instcombine,simplifycfg,sccp,instcombine,bdce,vector-combine,infer-alignment,instcombine,loop-mssa(licm),alignment-from-assumptions,jump-threading),lowertypetests,lowertypetests,function(loop-sink,div-rem-pairs,simplifycfg),elim-avail-extern,globaldce,rel-lookup-table-converter,cg-profile,function(annotation-remarks),canonicalize-aliases,name-anon-globals,EnzymeNewPM + +PASSES3 := cross-dso-cfi,openmp-opt,globaldce,inferattrs,function(callsite-splitting),pgo-icall-prom,cgscc(function-attrs,argpromotion,function(sroa)),ipsccp,called-value-propagation,rpo-function-attrs,globalsplit,wholeprogramdevirt,globalopt,function(mem2reg),constmerge,deadargelim,function(instcombine,aggressive-instcombine),expand-variadics,cgscc(inline,inline),globalopt,openmp-opt,globaldce,cgscc(argpromotion),function(instcombine,constraint-elimination,jump-threading,sroa,tailcallelim),cgscc(function-attrs),require,function(invalidate),cgscc(openmp-opt-cgscc),function(loop-mssa(licm),gvn<>,memcpyopt,dse,move-auto-init,mldst-motion,loop(indvars,loop-deletion,loop-unroll-full),loop-distribute,loop-vectorize,infer-alignment,loop-unroll,transform-warning,sroa,instcombine,simplifycfg,sccp,instcombine,bdce,slp-vectorizer,vector-combine,infer-alignment,instcombine,loop-mssa(licm),alignment-from-assumptions,jump-threading),lowertypetests,lowertypetests,function(loop-sink,div-rem-pairs,simplifycfg),elim-avail-extern,globaldce,mergefunc,rel-lookup-table-converter,cg-profile,function(annotation-remarks),canonicalize-aliases,name-anon-globals diff --git a/enzyme/benchmarks/ReverseMode/adbench/ba.h b/enzyme/benchmarks/ReverseMode/adbench/ba.h index 3ade86a0b7b2..131a5f8ae4d2 100644 --- a/enzyme/benchmarks/ReverseMode/adbench/ba.h +++ b/enzyme/benchmarks/ReverseMode/adbench/ba.h @@ -115,60 +115,68 @@ struct BAOutput { }; extern "C" { - void ba_objective( - int n, - int m, - int p, - double const* cams, - double const* X, - double const* w, - int const* obs, - double const* feats, - double* reproj_err, - double* w_err - ); - - void dcompute_reproj_error( - double const* cam, - double * dcam, - double const* X, - double * dX, - double const* w, - double * wb, - double const* feat, - double *err, - double *derr - ); - - void dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr); - - void compute_reproj_error_b( - double const* cam, - double * dcam, - double const* X, - double * dX, - double const* w, - double * wb, - double const* feat, - double *err, - double *derr - ); - - void compute_zach_weight_error_b(double const* w, double* dw, double* err, double* derr); - - void adept_compute_reproj_error( - double const* cam, - double * dcam, - double const* X, - double * dX, - double const* w, - double * wb, - double const* feat, - double *err, - double *derr - ); - - void adept_compute_zach_weight_error(double const* w, double* dw, double* err, double* derr); +void ba_objective_restrict(int n, int m, int p, double const *cams, + double const *X, double const *w, int const *obs, + double const *feats, double *reproj_err, + double *w_err); + +void ba_objective(int n, int m, int p, double const *cams, double const *X, + double const *w, int const *obs, double const *feats, + double *reproj_err, double *w_err); + +void rust2_unsafe_ba_objective(int n, int m, int p, double const *cams, + double const *X, double const *w, int const *obs, + double const *feats, double *reproj_err, + double *w_err); + +void rust2_ba_objective(int n, int m, int p, double const *cams, + double const *X, double const *w, int const *obs, + double const *feats, double *reproj_err, double *w_err); + +void dcompute_reproj_error_restrict(double const *cam, double *dcam, + double const *X, double *dX, + double const *w, double *wb, + double const *feat, double *err, + double *derr); + +void dcompute_zach_weight_error_restrict(double const *w, double *dw, + double *err, double *derr); + +void dcompute_reproj_error(double const *cam, double *dcam, double const *X, + double *dX, double const *w, double *wb, + double const *feat, double *err, double *derr); + +void dcompute_zach_weight_error(double const *w, double *dw, double *err, + double *derr); + +void compute_reproj_error_b(double const *cam, double *dcam, double const *X, + double *dX, double const *w, double *wb, + double const *feat, double *err, double *derr); + +void compute_zach_weight_error_b(double const *w, double *dw, double *err, + double *derr); + +void adept_compute_reproj_error(double const *cam, double *dcam, + double const *X, double *dX, double const *w, + double *wb, double const *feat, double *err, + double *derr); + +void adept_compute_zach_weight_error(double const *w, double *dw, double *err, + double *derr); + +void rust_unsafe_dcompute_reproj_error(double const *cam, double *dcam, + double const *X, double *dX, + double const *w, double *wb, + double const *feat, double *err, + double *derr); + +void rust_dcompute_reproj_error(double const *cam, double *dcam, + double const *X, double *dX, double const *w, + double *wb, double const *feat, double *err, + double *derr); + +void rust_dcompute_zach_weight_error(double const *w, double *dw, double *err, + double *derr); } void read_ba_instance(const string& fn, @@ -335,10 +343,22 @@ int main(const int argc, const char* argv[]) { std::string path = "/mnt/Data/git/Enzyme/apps/ADBench/data/ba/ba1_n49_m7776_p31843.txt"; std::vector paths = { - "ba10_n1197_m126327_p563734.txt", "ba14_n356_m226730_p1255268.txt", "ba18_n1936_m649673_p5213733.txt", "ba2_n21_m11315_p36455.txt", "ba6_n539_m65220_p277273.txt", "test.txt", - "ba11_n1723_m156502_p678718.txt", "ba15_n1102_m780462_p4052340.txt", "ba19_n4585_m1324582_p9125125.txt", "ba3_n161_m48126_p182072.txt", "ba7_n93_m61203_p287451.txt", - "ba12_n253_m163691_p899155.txt", "ba16_n1544_m942409_p4750193.txt", "ba1_n49_m7776_p31843.txt", "ba4_n372_m47423_p204472.txt", "ba8_n88_m64298_p383937.txt", - "ba13_n245_m198739_p1091386.txt", "ba17_n1778_m993923_p5001946.txt", "ba20_n13682_m4456117_p2987644.txt", "ba5_n257_m65132_p225911.txt", "ba9_n810_m88814_p393775.txt", + "ba10_n1197_m126327_p563734.txt", + "ba14_n356_m226730_p1255268.txt", // "ba18_n1936_m649673_p5213733.txt", + // "ba2_n21_m11315_p36455.txt", + // "ba6_n539_m65220_p277273.txt", + // "test.txt", + // "ba11_n1723_m156502_p678718.txt", + // "ba15_n1102_m780462_p4052340.txt", + // "ba19_n4585_m1324582_p9125125.txt", + // "ba3_n161_m48126_p182072.txt", "ba7_n93_m61203_p287451.txt", + // "ba12_n253_m163691_p899155.txt", + // "ba16_n1544_m942409_p4750193.txt", "ba1_n49_m7776_p31843.txt", + // "ba4_n372_m47423_p204472.txt", "ba8_n88_m64298_p383937.txt", + // "ba13_n245_m198739_p1091386.txt", + // "ba17_n1778_m993923_p5001946.txt", + // "ba20_n13682_m4456117_p2987644.txt", + // "ba5_n257_m65132_p225911.txt", "ba9_n810_m88814_p393775.txt", }; std::ofstream jsonfile("results.json", std::ofstream::trunc); @@ -358,27 +378,6 @@ int main(const int argc, const char* argv[]) { BASparseMat(input.n, input.m, input.p) }; - //BASparseMat(this->input.n, this->input.m, this->input.p) - - /* - ba_objective( - input.n, - input.m, - input.p, - input.cams.data(), - input.X.data(), - input.w.data(), - input.obs.data(), - input.feats.data(), - result.reproj_err.data(), - result.w_err.data() - ); - - for(unsigned i=0; iinput.n, this->input.m, this->input.p) - - /* - ba_objective( - input.n, - input.m, - input.p, - input.cams.data(), - input.X.data(), - input.w.data(), - input.obs.data(), - input.feats.data(), - result.reproj_err.data(), - result.w_err.data() - ); - - for(unsigned i=0; i(input, result); + calculate_jacobian(input, result); gettimeofday(&end, NULL); printf("Adept combined %0.6f\n", tdiff(&start, &end)); json adept; adept["name"] = "Adept combined"; adept["runtime"] = tdiff(&start, &end); - for(unsigned i=0; i<5; i++) { + for (unsigned i = 0; i < 5; i++) { printf("%f ", result.J.vals[i]); adept["result"].push_back(result.J.vals[i]); } printf("\n"); test_suite["tools"].push_back(adept); } + } + + for (int j=0;j<5;j++) { + + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, + input.X, input.w, input.obs, input.feats); + struct BAOutput result = {std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian(input, result); + gettimeofday(&end, NULL); + printf("Enzyme restrict c++ combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme restrict c++ combined"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = 0; i < 5; i++) { + printf("%f ", result.J.vals[i]); + enzyme["result"].push_back(result.J.vals[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } } { + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, + input.X, input.w, input.obs, input.feats); + + struct BAOutput result = {std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian( + input, result); + gettimeofday(&end, NULL); + printf("Enzyme aliasing c++ combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme c++ combined"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = 0; i < 5; i++) { + printf("%f ", result.J.vals[i]); + enzyme["result"].push_back(result.J.vals[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } + + { + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, + input.X, input.w, input.obs, input.feats); + + struct BAOutput result = {std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + ba_objective_restrict(input.n, input.m, input.p, input.cams.data(), + input.X.data(), input.w.data(), input.obs.data(), + input.feats.data(), result.reproj_err.data(), + result.w_err.data()); + gettimeofday(&end, NULL); + printf("primal restrict c++ t=%0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "primal restrict c++"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = 0; i < 5; i++) { + printf("%f ", result.reproj_err[i]); + enzyme["result"].push_back(result.reproj_err[i]); + } + for (unsigned i = 0; i < 5; i++) { + printf("%f ", result.w_err[i]); + enzyme["result"].push_back(result.w_err[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } + + { + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, + input.X, input.w, input.obs, input.feats); + + struct BAOutput result = {std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + ba_objective(input.n, input.m, input.p, input.cams.data(), input.X.data(), + input.w.data(), input.obs.data(), input.feats.data(), + result.reproj_err.data(), result.w_err.data()); + gettimeofday(&end, NULL); + printf("primal aliasing c++ t=%0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "primal aliasing c++"; + enzyme["runtime"] = tdiff(&start, &end); + for(unsigned i=0; i<5; i++) { + printf("%f ", result.reproj_err[i]); + enzyme["result"].push_back(result.reproj_err[i]); + } + for(unsigned i=0; i<5; i++) { + printf("%f ", result.w_err[i]); + enzyme["result"].push_back(result.w_err[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } + + { + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, + input.X, input.w, input.obs, input.feats); + + struct BAOutput result = {std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p)}; + { + + struct timeval start, end; + gettimeofday(&start, NULL); + rust2_unsafe_ba_objective(input.n, input.m, input.p, input.cams.data(), + input.X.data(), input.w.data(), + input.obs.data(), input.feats.data(), + result.reproj_err.data(), result.w_err.data()); + gettimeofday(&end, NULL); + printf("primal unsafe rust t=%0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "primal unsafe rust"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = 0; i < 5; i++) { + printf("%f ", result.reproj_err[i]); + enzyme["result"].push_back(result.reproj_err[i]); + } + for (unsigned i = 0; i < 5; i++) { + printf("%f ", result.w_err[i]); + enzyme["result"].push_back(result.w_err[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } + + { struct BAInput input; read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats); @@ -459,11 +595,11 @@ int main(const int argc, const char* argv[]) { std::vector(input.p), BASparseMat(input.n, input.m, input.p) }; + { - //BASparseMat(this->input.n, this->input.m, this->input.p) - - /* - ba_objective( + struct timeval start, end; + gettimeofday(&start, NULL); + rust2_ba_objective( input.n, input.m, input.p, @@ -475,20 +611,72 @@ int main(const int argc, const char* argv[]) { result.reproj_err.data(), result.w_err.data() ); + gettimeofday(&end, NULL); + printf("primal rust t=%0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "primal rust"; + enzyme["runtime"] = tdiff(&start, &end); + for(unsigned i=0; i<5; i++) { + printf("%f ", result.reproj_err[i]); + enzyme["result"].push_back(result.reproj_err[i]); + } + for(unsigned i=0; i<5; i++) { + printf("%f ", result.w_err[i]); + enzyme["result"].push_back(result.w_err[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } + + { + + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats); + + struct BAOutput result = { + std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p) + }; - for(unsigned i=0; i(input, result); + gettimeofday(&end, NULL); + printf("Enzyme unsafe rust combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme unsafe rust combined"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = 0; i < 5; i++) { + printf("%f ", result.J.vals[i]); + enzyme["result"].push_back(result.J.vals[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); } - */ + } + + for(int j=0;j<5;j++){ + + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, + input.X, input.w, input.obs, input.feats); + + struct BAOutput result = {std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p)}; { struct timeval start, end; gettimeofday(&start, NULL); - calculate_jacobian(input, result); + calculate_jacobian(input, result); gettimeofday(&end, NULL); - printf("Enzyme combined %0.6f\n", tdiff(&start, &end)); + printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end)); json enzyme; - enzyme["name"] = "Enzyme combined"; + enzyme["name"] = "Enzyme rust combined"; enzyme["runtime"] = tdiff(&start, &end); for(unsigned i=0; i<5; i++) { printf("%f ", result.J.vals[i]); @@ -497,8 +685,8 @@ int main(const int argc, const char* argv[]) { printf("\n"); test_suite["tools"].push_back(enzyme); } - } + test_suite["llvm-version"] = __clang_version__; test_suite["mode"] = "ReverseMode"; test_suite["batch-size"] = 1; diff --git a/enzyme/benchmarks/ReverseMode/adbench/gmm.h b/enzyme/benchmarks/ReverseMode/adbench/gmm.h index feef3a7d48c6..d994a6e17d66 100644 --- a/enzyme/benchmarks/ReverseMode/adbench/gmm.h +++ b/enzyme/benchmarks/ReverseMode/adbench/gmm.h @@ -18,7 +18,7 @@ using namespace std; using json = nlohmann::json; struct GMMInput { - int d, k, n; + size_t d, k, n; std::vector alphas, means, icf, x; Wishart wishart; }; @@ -33,24 +33,54 @@ struct GMMParameters { }; extern "C" { - void dgmm_objective(int d, int k, int n, const double *alphas, double * - alphasb, const double *means, double *meansb, const double *icf, - double *icfb, const double *x, Wishart wishart, double *err, double * - errb); - - void gmm_objective_b(int d, int k, int n, const double *alphas, double * - alphasb, const double *means, double *meansb, const double *icf, - double *icfb, const double *x, Wishart wishart, double *err, double * - errb); - - void adept_dgmm_objective(int d, int k, int n, const double *alphas, double * - alphasb, const double *means, double *meansb, const double *icf, - double *icfb, const double *x, Wishart wishart, double *err, double * - errb); +void gmm_objective(size_t d, size_t k, size_t n, double const *alphas, + double const *means, double const *icf, double const *x, + Wishart wishart, double *err); +void gmm_objective_restrict(size_t d, size_t k, size_t n, double const *alphas, + double const *means, double const *icf, + double const *x, Wishart wishart, double *err); +void dgmm_objective_restrict(size_t d, size_t k, size_t n, const double *alphas, + double *alphasb, const double *means, + double *meansb, const double *icf, double *icfb, + const double *x, Wishart wishart, double *err, + double *errb); +void dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, double *alphasb, + const double *means, double *meansb, const double *icf, + double *icfb, const double *x, Wishart wishart, double *err, + double *errb); + +void gmm_objective_b(size_t d, size_t k, size_t n, const double *alphas, double *alphasb, + const double *means, double *meansb, const double *icf, + double *icfb, const double *x, Wishart wishart, + double *err, double *errb); + +void adept_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, + double *alphasb, const double *means, double *meansb, + const double *icf, double *icfb, const double *x, + Wishart wishart, double *err, double *errb); + +void rust_unsafe_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, + double *alphasb, const double *means, + double *meansb, const double *icf, double *icfb, + const double *x, Wishart &wishart, double *err, + double *errb); + +void rust_unsafe_gmm_objective(size_t d, size_t k, size_t n, const double *alphas, + const double *means, const double *icf, + const double *x, Wishart &wishart, double *err); + +void rust_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, + double *alphasb, const double *means, double *meansb, + const double *icf, double *icfb, const double *x, + Wishart &wishart, double *err, double *errb); + +void rust_gmm_objective(size_t d, size_t k, size_t n, const double *alphas, + const double *means, const double *icf, const double *x, + Wishart &wishart, double *err); } void read_gmm_instance(const string& fn, - int* d, int* k, int* n, + size_t* d, size_t* k, size_t* n, vector& alphas, vector& means, vector& icf, @@ -65,32 +95,32 @@ void read_gmm_instance(const string& fn, exit(1); } - fscanf(fid, "%i %i %i", d, k, n); + fscanf(fid, "%zu %zu %zu", d, k, n); - int d_ = *d, k_ = *k, n_ = *n; + size_t d_ = *d, k_ = *k, n_ = *n; - int icf_sz = d_ * (d_ + 1) / 2; + size_t icf_sz = d_ * (d_ + 1) / 2; alphas.resize(k_); means.resize(d_ * k_); icf.resize(icf_sz * k_); x.resize(d_ * n_); - for (int i = 0; i < k_; i++) + for (size_t i = 0; i < k_; i++) { fscanf(fid, "%lf", &alphas[i]); } - for (int i = 0; i < k_; i++) + for (size_t i = 0; i < k_; i++) { - for (int j = 0; j < d_; j++) + for (size_t j = 0; j < d_; j++) { fscanf(fid, "%lf", &means[i * d_ + j]); } } - for (int i = 0; i < k_; i++) + for (size_t i = 0; i < k_; i++) { - for (int j = 0; j < icf_sz; j++) + for (size_t j = 0; j < icf_sz; j++) { fscanf(fid, "%lf", &icf[i * icf_sz + j]); } @@ -98,20 +128,20 @@ void read_gmm_instance(const string& fn, if (replicate_point) { - for (int j = 0; j < d_; j++) + for (size_t j = 0; j < d_; j++) { fscanf(fid, "%lf", &x[j]); } - for (int i = 0; i < n_; i++) + for (size_t i = 0; i < n_; i++) { memcpy(&x[i * d_], &x[0], d_ * sizeof(double)); } } else { - for (int i = 0; i < n_; i++) + for (size_t i = 0; i < n_; i++) { - for (int j = 0; j < d_; j++) + for (size_t j = 0; j < d_; j++) { fscanf(fid, "%lf", &x[i * d_ + j]); } @@ -123,10 +153,7 @@ void read_gmm_instance(const string& fn, fclose(fid); } -typedef void(*deriv_t)(int d, int k, int n, const double *alphas, double *alphasb, const double *means, double *meansb, const double *icf, - double *icfb, const double *x, Wishart wishart, double *err, double *errb); - -template +template void calculate_jacobian(struct GMMInput &input, struct GMMOutput &result) { double* alphas_gradient_part = result.gradient.data(); @@ -159,19 +186,38 @@ void calculate_jacobian(struct GMMInput &input, struct GMMOutput &result) ); } +template +double primal(struct GMMInput &input) +{ + double tmp = 0.0; // stores fictive result + // (Tapenade doesn't calculate an original function in reverse mode) + deriv( + input.d, + input.k, + input.n, + input.alphas.data(), + input.means.data(), + input.icf.data(), + input.x.data(), + input.wishart, + &tmp + ); + return tmp; +} + int main(const int argc, const char* argv[]) { printf("starting main\n"); const auto replicate_point = (argc > 9 && string(argv[9]) == "-rep"); const GMMParameters params = { replicate_point }; - std::vector paths;// = { "1k/gmm_d10_K100.txt" }; + std::vector paths = { "10k/gmm_d10_K200.txt" }; getTests(paths, "data/1k", "1k/"); - if (std::getenv("BENCH_LARGE")) { + //if (std::getenv("BENCH_LARGE")) { getTests(paths, "data/2.5k", "2.5k/"); getTests(paths, "data/10k", "10k/"); - } + //} std::ofstream jsonfile("results.json", std::ofstream::trunc); json test_results; @@ -188,7 +234,7 @@ int main(const int argc, const char* argv[]) { read_gmm_instance("data/" + path, &input.d, &input.k, &input.n, input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point); - int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; + size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; struct GMMOutput result = { 0, std::vector(Jcols) }; @@ -218,49 +264,141 @@ int main(const int argc, const char* argv[]) { read_gmm_instance("data/" + path, &input.d, &input.k, &input.n, input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point); - int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; + size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; + + struct GMMOutput result = { 0, std::vector(Jcols) }; + + if (1) { + try { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian(input, result); + gettimeofday(&end, NULL); + printf("Adept combined %0.6f\n", tdiff(&start, &end)); + json adept; + adept["name"] = "Adept combined"; + adept["runtime"] = tdiff(&start, &end); + for (unsigned i = result.gradient.size() - 5; + i < result.gradient.size(); i++) { + printf("%f ", result.gradient[i]); + adept["result"].push_back(result.gradient[i]); + } + printf("\n"); + test_suite["tools"].push_back(adept); + } catch (std::bad_alloc) { + printf("Adept combined 88888888 ooms\n"); + } + } + } + + for (size_t i = 0; i < 5; i++) + { + + struct GMMInput input; + read_gmm_instance("data/" + path, &input.d, &input.k, &input.n, + input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point); + + size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; struct GMMOutput result = { 0, std::vector(Jcols) }; - try { + { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian(input, result); + gettimeofday(&end, NULL); + printf("Enzyme c++ restrict combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme restrict combined"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = result.gradient.size() - 5; i < result.gradient.size(); + i++) { + printf("%f ", result.gradient[i]); + enzyme["result"].push_back(result.gradient[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } + + { + + struct GMMInput input; + read_gmm_instance("data/" + path, &input.d, &input.k, &input.n, + input.alphas, input.means, input.icf, input.x, + input.wishart, params.replicate_point); + + size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; + + struct GMMOutput result = {0, std::vector(Jcols)}; + + { struct timeval start, end; gettimeofday(&start, NULL); - calculate_jacobian(input, result); + calculate_jacobian(input, result); gettimeofday(&end, NULL); - printf("Adept combined %0.6f\n", tdiff(&start, &end)); - json adept; - adept["name"] = "Adept combined"; - adept["runtime"] = tdiff(&start, &end); + printf("Enzyme c++ mayalias combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme mayalias combined"; + enzyme["runtime"] = tdiff(&start, &end); for (unsigned i = result.gradient.size() - 5; i < result.gradient.size(); i++) { printf("%f ", result.gradient[i]); - adept["result"].push_back(result.gradient[i]); + enzyme["result"].push_back(result.gradient[i]); } printf("\n"); - test_suite["tools"].push_back(adept); - } catch(std::bad_alloc) { - printf("Adept combined 88888888 ooms\n"); + test_suite["tools"].push_back(enzyme); } - } { struct GMMInput input; read_gmm_instance("data/" + path, &input.d, &input.k, &input.n, - input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point); + input.alphas, input.means, input.icf, input.x, + input.wishart, params.replicate_point); - int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; + size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; - struct GMMOutput result = { 0, std::vector(Jcols) }; + struct GMMOutput result = {0, std::vector(Jcols)}; + { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian(input, result); + gettimeofday(&end, NULL); + printf("Enzyme unsafe rust combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Rust unsafe Enzyme combined"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = result.gradient.size() - 5; i < result.gradient.size(); + i++) { + printf("%f ", result.gradient[i]); + enzyme["result"].push_back(result.gradient[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } + + for (size_t i = 0; i < 5; i++) + { + + struct GMMInput input; + read_gmm_instance("data/" + path, &input.d, &input.k, &input.n, + input.alphas, input.means, input.icf, input.x, + input.wishart, params.replicate_point); + size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; + + struct GMMOutput result = {0, std::vector(Jcols)}; { struct timeval start, end; gettimeofday(&start, NULL); - calculate_jacobian(input, result); + calculate_jacobian(input, result); gettimeofday(&end, NULL); + printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end)); json enzyme; - enzyme["name"] = "Enzyme combined"; + enzyme["name"] = "Rust Enzyme combined"; enzyme["runtime"] = tdiff(&start, &end); for (unsigned i = result.gradient.size() - 5; i < result.gradient.size(); i++) { @@ -270,8 +408,73 @@ int main(const int argc, const char* argv[]) { printf("\n"); test_suite["tools"].push_back(enzyme); } + } + + { + struct GMMInput input; + read_gmm_instance("data/" + path, &input.d, &input.k, &input.n, + input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point); + + size_t Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; + + struct GMMOutput result = { 0, std::vector(Jcols) }; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + auto res = primal(input); + gettimeofday(&end, NULL); + printf("c++ primal mayalias combined t=%0.6f, err=%f\n", + tdiff(&start, &end), res); + + json primal; + primal["name"] = "C++ primal mayalias"; + primal["runtime"] = tdiff(&start, &end); + primal["result"].push_back(res); + test_suite["tools"].push_back(primal); + } + { + struct timeval start, end; + gettimeofday(&start, NULL); + auto res = primal(input); + gettimeofday(&end, NULL); + printf("c++ primal restrict combined t=%0.6f, err=%f\n", + tdiff(&start, &end), res); + + json primal; + primal["name"] = "C++ primal restrict"; + primal["runtime"] = tdiff(&start, &end); + primal["result"].push_back(res); + test_suite["tools"].push_back(primal); + } + { + struct timeval start, end; + gettimeofday(&start, NULL); + auto res = primal(input); + gettimeofday(&end, NULL); + printf("rust unsafe primal combined t=%0.6f, err=%f\n", + tdiff(&start, &end), res); + json primal; + primal["name"] = "Rust unsafe primal"; + primal["runtime"] = tdiff(&start, &end); + primal["result"].push_back(res); + test_suite["tools"].push_back(primal); } + { + struct timeval start, end; + gettimeofday(&start, NULL); + auto res = primal(input); + gettimeofday(&end, NULL); + printf("rust primal combined t=%0.6f, err=%f\n", tdiff(&start, &end), res); + json primal; + primal["name"] = "Rust primal"; + primal["runtime"] = tdiff(&start, &end); + primal["result"].push_back(res); + test_suite["tools"].push_back(primal); + } + } + test_suite["llvm-version"] = __clang_version__; test_suite["mode"] = "ReverseMode"; test_suite["batch-size"] = 1; diff --git a/enzyme/benchmarks/ReverseMode/adbench/lstm.h b/enzyme/benchmarks/ReverseMode/adbench/lstm.h index e6d13303d1f8..0648b6692803 100644 --- a/enzyme/benchmarks/ReverseMode/adbench/lstm.h +++ b/enzyme/benchmarks/ReverseMode/adbench/lstm.h @@ -34,37 +34,56 @@ struct LSTMOutput { }; extern "C" { - void dlstm_objective( - int l, - int c, - int b, - double const* main_params, - double* dmain_params, - double const* extra_params, - double* dextra_params, - double* state, - double const* sequence, - double* loss, - double* dloss - ); - - void lstm_objective_b(int l, int c, int b, const double *main_params, double * - main_paramsb, const double *extra_params, double *extra_paramsb, - double *state, const double *sequence, double *loss, double *lossb); - - void adept_dlstm_objective( - int l, - int c, - int b, - double const* main_params, - double* dmain_params, - double const* extra_params, - double* dextra_params, - double* state, - double const* sequence, - double* loss, - double* dloss - ); +void rust_unsafe_dlstm_objective(int l, int c, int b, double const *main_params, + double *dmain_params, + double const *extra_params, + double *dextra_params, double *state, + double const *sequence, double *loss, + double *dloss); + +void rust_unsafe_lstm_objective(int l, int c, int b, double const *main_params, + double const *extra_params, double *state, + double const *sequence, double *loss); + +void rust_safe_lstm_objective(int l, int c, int b, double const *main_params, + double const *extra_params, double *state, + double const *sequence, double *loss); + +void cxx_restrict_lstm_objective(int l, int c, int b, double const *main_params, + double const *extra_params, double *state, + double const *sequence, double *loss); + +void cxx_mayalias_lstm_objective(int l, int c, int b, double const *main_params, + double const *extra_params, double *state, + double const *sequence, double *loss); + +void rust_safe_dlstm_objective(int l, int c, int b, double const *main_params, + double *dmain_params, double const *extra_params, + double *dextra_params, double *state, + double const *sequence, double *loss, + double *dloss); + +void dlstm_objective_mayalias(int l, int c, int b, double const *main_params, + double *dmain_params, double const *extra_params, + double *dextra_params, double *state, + double const *sequence, double *loss, + double *dloss); + +void dlstm_objective_restrict(int l, int c, int b, double const *main_params, + double *dmain_params, double const *extra_params, + double *dextra_params, double *state, + double const *sequence, double *loss, + double *dloss); + +void lstm_objective_b(int l, int c, int b, const double *main_params, + double *main_paramsb, const double *extra_params, + double *extra_paramsb, double *state, + const double *sequence, double *loss, double *lossb); + +void adept_dlstm_objective(int l, int c, int b, double const *main_params, + double *dmain_params, double const *extra_params, + double *dextra_params, double *state, + double const *sequence, double *loss, double *dloss); } void read_lstm_instance(const string& fn, @@ -177,11 +196,56 @@ void calculate_jacobian(struct LSTMInput &input, struct LSTMOutput &result) } } +double calculate_mayalias_primal(struct LSTMInput &input) { + double loss = 0.0; + for (int i = 0; i < 100; i++) { + cxx_mayalias_lstm_objective( + input.l, input.c, input.b, input.main_params.data(), + input.extra_params.data(), input.state.data(), + input.sequence.data(), &loss); + } + return loss; +} + +double calculate_restrict_primal(struct LSTMInput &input) { + double loss = 0.0; + for (int i = 0; i < 100; i++) { + cxx_restrict_lstm_objective( + input.l, input.c, input.b, input.main_params.data(), + input.extra_params.data(), input.state.data(), + input.sequence.data(), &loss); + } + return loss; +} + +double calculate_unsafe_primal(struct LSTMInput &input) { + double loss = 0.0; + for (int i = 0; i < 100; i++) { + rust_unsafe_lstm_objective( + input.l, input.c, input.b, input.main_params.data(), + input.extra_params.data(), input.state.data(), + input.sequence.data(), &loss); + } + return loss; +} + +double calculate_safe_primal(struct LSTMInput &input) { + double loss = 0.0; + for (int i = 0; i < 100; i++) { + rust_safe_lstm_objective(input.l, input.c, input.b, + input.main_params.data(), + input.extra_params.data(), input.state.data(), + input.sequence.data(), &loss); + } + return loss; +} + int main(const int argc, const char* argv[]) { printf("starting main\n"); std::vector paths = { "lstm_l2_c1024.txt", "lstm_l4_c1024.txt", "lstm_l2_c4096.txt", "lstm_l4_c4096.txt" }; - + //std::vector paths = { "lstm_l4_c4096.txt" }; + std::ofstream jsonfile("results.json", std::ofstream::trunc); json test_results; @@ -225,18 +289,19 @@ int main(const int argc, const char* argv[]) { } - { + if (1){ - struct LSTMInput input = {}; + struct LSTMInput input = {}; // Read instance - read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, input.main_params, input.extra_params, input.state, - input.sequence); + read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, + input.main_params, input.extra_params, input.state, + input.sequence); - std::vector state = std::vector(input.state.size()); + std::vector state = std::vector(input.state.size()); - int Jcols = 8 * input.l * input.b + 3 * input.b; - struct LSTMOutput result = { 0, std::vector(Jcols) }; + int Jcols = 8 * input.l * input.b + 3 * input.b; + struct LSTMOutput result = { 0, std::vector(Jcols) }; { struct timeval start, end; @@ -258,10 +323,77 @@ int main(const int argc, const char* argv[]) { } + for (int j=0; j<5; j++){ + + struct LSTMInput input = {}; + + // Read instance + read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, input.main_params, input.extra_params, input.state, + input.sequence); + + std::vector state = std::vector(input.state.size()); + + int Jcols = 8 * input.l * input.b + 3 * input.b; + struct LSTMOutput result = { 0, std::vector(Jcols) }; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian(input, result); + gettimeofday(&end, NULL); + printf("Enzyme restrict combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme restrict combined"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = result.gradient.size() - 5; i < result.gradient.size(); + i++) { + printf("%f ", result.gradient[i]); + enzyme["result"].push_back(result.gradient[i]); + } + test_suite["tools"].push_back(enzyme); + + printf("\n"); + } + } + { struct LSTMInput input = {}; + // Read instance + read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, + input.main_params, input.extra_params, input.state, + input.sequence); + + std::vector state = std::vector(input.state.size()); + + int Jcols = 8 * input.l * input.b + 3 * input.b; + struct LSTMOutput result = {0, std::vector(Jcols)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian(input, result); + gettimeofday(&end, NULL); + printf("Enzyme mayalias combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme mayalias combined"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = result.gradient.size() - 5; i < result.gradient.size(); + i++) { + printf("%f ", result.gradient[i]); + enzyme["result"].push_back(result.gradient[i]); + } + test_suite["tools"].push_back(enzyme); + + printf("\n"); + } + } + + for (int j=0; j<5; j++){ + + struct LSTMInput input = {}; + // Read instance read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, input.main_params, input.extra_params, input.state, input.sequence); @@ -274,23 +406,178 @@ int main(const int argc, const char* argv[]) { { struct timeval start, end; gettimeofday(&start, NULL); - calculate_jacobian(input, result); + calculate_jacobian(input, result); gettimeofday(&end, NULL); - printf("Enzyme combined %0.6f\n", tdiff(&start, &end)); + printf("Enzyme (safe Rust) combined %0.6f\n", tdiff(&start, &end)); json enzyme; - enzyme["name"] = "Enzyme combined"; + enzyme["name"] = "Enzyme (safe Rust) combined"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = result.gradient.size() - 5; i < result.gradient.size(); + i++) { + printf("%f ", result.gradient[i]); + enzyme["result"].push_back(result.gradient[i]); + } + test_suite["tools"].push_back(enzyme); + + printf("\n"); + } + + } + + { + + struct LSTMInput input = {}; + + // Read instance + read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, + input.main_params, input.extra_params, input.state, + input.sequence); + + std::vector state = std::vector(input.state.size()); + + int Jcols = 8 * input.l * input.b + 3 * input.b; + struct LSTMOutput result = {0, std::vector(Jcols)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian(input, result); + gettimeofday(&end, NULL); + printf("Enzyme (unsafe Rust) combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme (unsafe Rust) combined"; enzyme["runtime"] = tdiff(&start, &end); - for (unsigned i = result.gradient.size() - 5; - i < result.gradient.size(); i++) { + for (unsigned i = result.gradient.size() - 5; i < result.gradient.size(); + i++) { printf("%f ", result.gradient[i]); enzyme["result"].push_back(result.gradient[i]); } test_suite["tools"].push_back(enzyme); - + + printf("\n"); + } + } + { + + struct LSTMInput input = {}; + + // Read instance + read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, + input.main_params, input.extra_params, input.state, + input.sequence); + + std::vector state = std::vector(input.state.size()); + + int Jcols = 8 * input.l * input.b + 3 * input.b; + struct LSTMOutput result = {0, std::vector(Jcols)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + double res = calculate_mayalias_primal(input); + gettimeofday(&end, NULL); + printf("C++ mayalias primal %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "C++ mayalias primal"; + enzyme["runtime"] = tdiff(&start, &end); + printf("%f ", res); + enzyme["result"].push_back(res); + test_suite["tools"].push_back(enzyme); + + printf("\n"); + } + } + { + + struct LSTMInput input = {}; + + // Read instance + read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, + input.main_params, input.extra_params, input.state, + input.sequence); + + std::vector state = std::vector(input.state.size()); + + int Jcols = 8 * input.l * input.b + 3 * input.b; + struct LSTMOutput result = {0, std::vector(Jcols)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + double res = calculate_restrict_primal(input); + gettimeofday(&end, NULL); + printf("C++ restrict primal %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "C++ restrict primal"; + enzyme["runtime"] = tdiff(&start, &end); + printf("%f ", res); + enzyme["result"].push_back(res); + test_suite["tools"].push_back(enzyme); + printf("\n"); } + } + { + + struct LSTMInput input = {}; + + // Read instance + read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, + input.main_params, input.extra_params, input.state, + input.sequence); + std::vector state = std::vector(input.state.size()); + + int Jcols = 8 * input.l * input.b + 3 * input.b; + struct LSTMOutput result = {0, std::vector(Jcols)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + double res =calculate_unsafe_primal(input); + gettimeofday(&end, NULL); + printf("Enzyme (unsafe Rust) primal %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme (unsafe Rust) primal"; + enzyme["runtime"] = tdiff(&start, &end); + printf("%f ", res); + enzyme["result"].push_back(res); + test_suite["tools"].push_back(enzyme); + + printf("\n"); } + } + { + + struct LSTMInput input = {}; + + // Read instance + read_lstm_instance("data/" + path, &input.l, &input.c, &input.b, + input.main_params, input.extra_params, input.state, + input.sequence); + + std::vector state = std::vector(input.state.size()); + + int Jcols = 8 * input.l * input.b + 3 * input.b; + struct LSTMOutput result = {0, std::vector(Jcols)}; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + double res = calculate_safe_primal(input); + gettimeofday(&end, NULL); + printf("Enzyme (safe Rust) primal %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme (safe Rust) primal"; + enzyme["runtime"] = tdiff(&start, &end); + printf("%f ", res); + enzyme["result"].push_back(res); + test_suite["tools"].push_back(enzyme); + + printf("\n"); + } + } + test_suite["llvm-version"] = __clang_version__; test_suite["mode"] = "ReverseMode"; test_suite["batch-size"] = 1; diff --git a/enzyme/benchmarks/ReverseMode/ba/Cargo.lock b/enzyme/benchmarks/ReverseMode/ba/Cargo.lock new file mode 100644 index 000000000000..74e2768e7cd4 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/Cargo.lock @@ -0,0 +1,16 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "bars" +version = "0.1.0" +dependencies = [ + "libm", +] + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" diff --git a/enzyme/benchmarks/ReverseMode/ba/Cargo.toml b/enzyme/benchmarks/ReverseMode/ba/Cargo.toml new file mode 100644 index 000000000000..9f577370661c --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "bars" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[profile.release] +lto = "fat" +opt-level = 3 +codegen-units = 1 +unwind = "abort" +strip = true +#overflow-checks = false + +[profile.dev] +lto = "fat" + +[dependencies] +libm = { version = "0.2.8", optional = true } + +[workspace] diff --git a/enzyme/benchmarks/ReverseMode/ba/Makefile.make b/enzyme/benchmarks/ReverseMode/ba/Makefile.make index b7f013dc4b57..c6182c0d03fb 100644 --- a/enzyme/benchmarks/ReverseMode/ba/Makefile.make +++ b/enzyme/benchmarks/ReverseMode/ba/Makefile.make @@ -1,23 +1,46 @@ -# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" PTR="%ptr" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" LOADCLANG="%loadClangEnzyme" ENZYME="%enzyme" make -B ba-raw.ll results.json -f %s +# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" PTR="%ptr" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" LOADCLANG="%loadClangEnzyme" ENZYME="%enzyme" make -B results.json -f %s .PHONY: clean dir := $(abspath $(lastword $(MAKEFILE_LIST))/../../../..) +include $(dir)/benchmarks/ReverseMode/adbench/Makefile.config + +ifeq ($(strip $(CLANG)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES1)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES2)),) +$(error PASSES2 is not set) +endif + +ifeq ($(strip $(PASSES3)),) +$(error PASSES3 is not set) +endif + +ifneq ($(strip $(PASSES4)),) +$(error PASSES4 is set) +endif + clean: rm -f *.ll *.o results.txt results.json + cargo +enzyme clean -%-unopt.ll: %.cpp - clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm +$(dir)/benchmarks/ReverseMode/ba/target/release/libbars.a: src/lib.rs Cargo.toml + RUSTFLAGS="-Z autodiff=Enable" cargo +enzyme rustc --release --lib --crate-type=staticlib --features=libm -%-raw.ll: %-unopt.ll - opt $^ $(LOAD) $(ENZYME) -o $@ -S +%-unopt.ll: %.cpp + $(CLANG) $(BENCH) $^ -pthread -O3 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -o $@ -S -emit-llvm -%-opt.ll: %-raw.ll - opt $^ -o $@ -S +%-opt.ll: %-unopt.ll + $(OPT) $^ $(LOAD) -passes="$(PASSES2),enzyme" -o $@ -S -ba.o: ba-opt.ll - clang++ $(BENCH) -pthread -O2 $^ -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o $@ $(BENCHLINK) -lpthread -lm -L /usr/lib/gcc/x86_64-linux-gnu/11 +ba.o: ba-opt.ll $(dir)/benchmarks/ReverseMode/ba/target/release/libbars.a + $(CLANG) -pthread -O3 -fno-math-errno $^ -o $@ $(BENCHLINK) -lm results.json: ba.o - ./$^ + numactl -C 1 ./$^ diff --git a/enzyme/benchmarks/ReverseMode/ba/ba.cpp b/enzyme/benchmarks/ReverseMode/ba/ba.cpp index b71e05a0a011..c9b29ec4cf78 100644 --- a/enzyme/benchmarks/ReverseMode/ba/ba.cpp +++ b/enzyme/benchmarks/ReverseMode/ba/ba.cpp @@ -43,17 +43,13 @@ double sqsum(int n, double const* x) return res; } - - -void cross(double const* a, double const* b, double* out) -{ +void cross_restrict(double const *__restrict a, double const *__restrict b, + double *__restrict out) { out[0] = a[1] * b[2] - a[2] * b[1]; out[1] = a[2] * b[0] - a[0] * b[2]; out[2] = a[0] * b[1] - a[1] * b[0]; } - - /* ===================================================================== */ /* MAIN LOGIC */ /* ===================================================================== */ @@ -68,8 +64,9 @@ void cross(double const* a, double const* b, double* out) // n = w / theta; // n_x = au_cross_matrix(n); // R = eye(3) + n_x*sin(theta) + n_x*n_x*(1 - cos(theta)); -void rodrigues_rotate_point(double const* __restrict rot, double const* __restrict pt, double *__restrict rotatedPt) -{ +void rodrigues_rotate_point_restrict(double const *__restrict rot, + double const *__restrict pt, + double *__restrict rotatedPt) { int i; double sqtheta = sqsum(3, rot); if (sqtheta != 0) @@ -87,7 +84,7 @@ void rodrigues_rotate_point(double const* __restrict rot, double const* __restri w[i] = rot[i] * theta_inverse; } - cross(w, pt, w_cross_pt); + cross_restrict(w, pt, w_cross_pt); tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta); @@ -100,7 +97,7 @@ void rodrigues_rotate_point(double const* __restrict rot, double const* __restri else { double rot_cross_pt[3]; - cross(rot, pt, rot_cross_pt); + cross_restrict(rot, pt, rot_cross_pt); for (i = 0; i < 3; i++) { @@ -109,8 +106,6 @@ void rodrigues_rotate_point(double const* __restrict rot, double const* __restri } } - - void radial_distort(double const* rad_params, double *proj) { double rsq, L; @@ -120,10 +115,17 @@ void radial_distort(double const* rad_params, double *proj) proj[1] = proj[1] * L; } - - -void project(double const* __restrict cam, double const* __restrict X, double* __restrict proj) +void radial_distort_restrict(double const *__restrict rad_params, double *__restrict proj) { + double rsq, L; + rsq = sqsum(2, proj); + L = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq; + proj[0] = proj[0] * L; + proj[1] = proj[1] * L; +} + +void project_restrict(double const *__restrict cam, double const *__restrict X, + double *__restrict proj) { double const* C = &cam[3]; double Xo[3], Xcam[3]; @@ -131,19 +133,17 @@ void project(double const* __restrict cam, double const* __restrict X, double* _ Xo[1] = X[1] - C[1]; Xo[2] = X[2] - C[2]; - rodrigues_rotate_point(&cam[0], Xo, Xcam); + rodrigues_rotate_point_restrict(&cam[0], Xo, Xcam); proj[0] = Xcam[0] / Xcam[2]; proj[1] = Xcam[1] / Xcam[2]; - radial_distort(&cam[9], proj); + radial_distort_restrict(&cam[9], proj); proj[0] = proj[0] * cam[6] + cam[7]; proj[1] = proj[1] * cam[6] + cam[8]; } - - // cam: 11 camera in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2] // r1, r2, r3 are angle - axis rotation parameters(Rodrigues) // [C1 C2 C3]' is the camera center @@ -158,30 +158,23 @@ void project(double const* __restrict cam, double const* __restrict X, double* _ // distorted = radial_distort(projective2euclidean(Xcam), radial_parameters) // proj = distorted * f + principal_point // err = sqsum(proj - measurement) -void compute_reproj_error( - double const* __restrict cam, - double const* __restrict X, - double const* __restrict w, - double const* __restrict feat, - double * __restrict err -) -{ +void compute_reproj_error_restrict(double const *__restrict cam, + double const *__restrict X, + double const *__restrict w, + double const *__restrict feat, + double *__restrict err) { double proj[2]; - project(cam, X, proj); + project_restrict(cam, X, proj); err[0] = (*w)*(proj[0] - feat[0]); err[1] = (*w)*(proj[1] - feat[1]); } - - -void compute_zach_weight_error(double const* w, double* err) -{ +void compute_zach_weight_error_restrict(double const *__restrict w, + double *__restrict err) { *err = 1 - (*w)*(*w); } - - // n number of cameras // m number of points // p number of observations @@ -196,36 +189,23 @@ void compute_zach_weight_error(double const* w, double* err) // feats: 2*p features (x,y coordinates corresponding to observations) // reproj_err: 2*p errors of observations // w_err: p weight "error" terms -void ba_objective( - int n, - int m, - int p, - double const* cams, - double const* X, - double const* w, - int const* obs, - double const* feats, - double* reproj_err, - double* w_err -) -{ +void ba_objective_restrict(int n, int m, int p, double const *cams, + double const *X, double const *w, int const *obs, + double const *feats, double *reproj_err, + double *w_err) { int i; for (i = 0; i < p; i++) { int camIdx = obs[i * 2 + 0]; int ptIdx = obs[i * 2 + 1]; - compute_reproj_error( - &cams[camIdx * BA_NCAMPARAMS], - &X[ptIdx * 3], - &w[i], - &feats[i * 2], - &reproj_err[2 * i] - ); + compute_reproj_error_restrict(&cams[camIdx * BA_NCAMPARAMS], + &X[ptIdx * 3], &w[i], &feats[i * 2], + &reproj_err[2 * i]); } for (i = 0; i < p; i++) { - compute_zach_weight_error(&w[i], &w_err[i]); + compute_zach_weight_error_restrict(&w[i], &w_err[i]); } } @@ -234,32 +214,21 @@ extern int enzyme_dup; extern int enzyme_dupnoneed; void __enzyme_autodiff(...) noexcept; -void dcompute_reproj_error( - double const* cam, - double * dcam, - double const* X, - double * dX, - double const* w, - double * wb, - double const* feat, - double *err, - double *derr -) -{ - __enzyme_autodiff(compute_reproj_error, - enzyme_dup, cam, dcam, - enzyme_dup, X, dX, - enzyme_dup, w, wb, - enzyme_const, feat, - enzyme_dupnoneed, err, derr); +void dcompute_reproj_error_restrict(double const *cam, double *dcam, + double const *X, double *dX, + double const *w, double *wb, + double const *feat, double *err, + double *derr) { + __enzyme_autodiff(compute_reproj_error_restrict, enzyme_dup, cam, dcam, + enzyme_dup, X, dX, enzyme_dup, w, wb, enzyme_const, feat, + enzyme_dupnoneed, err, derr); } -void dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr) { - __enzyme_autodiff(compute_zach_weight_error, - enzyme_dup, w, dw, - enzyme_dupnoneed, err, derr); +void dcompute_zach_weight_error_restrict(double const *w, double *dw, + double *err, double *derr) { + __enzyme_autodiff(compute_zach_weight_error_restrict, enzyme_dup, w, dw, + enzyme_dupnoneed, err, derr); } - } @@ -911,3 +880,5 @@ void adept_compute_zach_weight_error(double const* w, double* dw, double* err, d *dw = aw.get_gradient(); } + +#include "ba_mayalias.h" diff --git a/enzyme/benchmarks/ReverseMode/ba/ba_mayalias.h b/enzyme/benchmarks/ReverseMode/ba/ba_mayalias.h new file mode 100644 index 000000000000..25197b52d7b2 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/ba_mayalias.h @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +extern "C" { + +/* ===================================================================== */ +/* UTILS */ +/* ===================================================================== */ + +void cross(double const *a, double const *b, double *out) { + out[0] = a[1] * b[2] - a[2] * b[1]; + out[1] = a[2] * b[0] - a[0] * b[2]; + out[2] = a[0] * b[1] - a[1] * b[0]; +} + +/* ===================================================================== */ +/* MAIN LOGIC */ +/* ===================================================================== */ + +void compute_zach_weight_error(double const *w, double *err) { + *err = 1 - (*w) * (*w); +} + +// rot: 3 rotation parameters +// pt: 3 point to be rotated +// rotatedPt: 3 rotated point +// this is an efficient evaluation (part of +// the Ceres implementation) +// easy to understand calculation in matlab: +// theta = sqrt(sum(w. ^ 2)); +// n = w / theta; +// n_x = au_cross_matrix(n); +// R = eye(3) + n_x*sin(theta) + n_x*n_x*(1 - cos(theta)); +void rodrigues_rotate_point(double const *rot, double const *pt, + double *rotatedPt) { + int i; + double sqtheta = sqsum(3, rot); + if (sqtheta != 0) + { + double theta, costheta, sintheta, theta_inverse; + double w[3], w_cross_pt[3], tmp; + + theta = sqrt(sqtheta); + costheta = cos(theta); + sintheta = sin(theta); + theta_inverse = 1.0 / theta; + + for (i = 0; i < 3; i++) + { + w[i] = rot[i] * theta_inverse; + } + + cross(w, pt, w_cross_pt); + + tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * + (1. - costheta); + + for (i = 0; i < 3; i++) + { + rotatedPt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp; + } + } + else + { + double rot_cross_pt[3]; + cross(rot, pt, rot_cross_pt); + + for (i = 0; i < 3; i++) + { + rotatedPt[i] = pt[i] + rot_cross_pt[i]; + } + } +} + +void project(double const *cam, double const *X, double *proj) { + double const* C = &cam[3]; + double Xo[3], Xcam[3]; + + Xo[0] = X[0] - C[0]; + Xo[1] = X[1] - C[1]; + Xo[2] = X[2] - C[2]; + + rodrigues_rotate_point(&cam[0], Xo, Xcam); + + proj[0] = Xcam[0] / Xcam[2]; + proj[1] = Xcam[1] / Xcam[2]; + + radial_distort(&cam[9], proj); + + proj[0] = proj[0] * cam[6] + cam[7]; + proj[1] = proj[1] * cam[6] + cam[8]; +} + +// cam: 11 camera in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2] +// r1, r2, r3 are angle - axis rotation parameters(Rodrigues) +// [C1 C2 C3]' is the camera center +// f is the focal length in pixels +// [u0 v0]' is the principal point +// k1, k2 are radial distortion parameters +// X: 3 point +// feats: 2 feature (x,y coordinates) +// reproj_err: 2 +// projection function: +// Xcam = R * (X - C) +// distorted = radial_distort(projective2euclidean(Xcam), radial_parameters) +// proj = distorted * f + principal_point +// err = sqsum(proj - measurement) +void compute_reproj_error(double const *cam, double const *X, double const *w, + double const *feat, double *err) { + double proj[2]; + project(cam, X, proj); + + err[0] = (*w)*(proj[0] - feat[0]); + err[1] = (*w)*(proj[1] - feat[1]); +} + + + + +// n number of cameras +// m number of points +// p number of observations +// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2] +// r1, r2, r3 are angle - axis rotation parameters(Rodrigues) +// [C1 C2 C3]' is the camera center +// f is the focal length in pixels +// [u0 v0]' is the principal point +// k1, k2 are radial distortion parameters +// X: 3*m points +// obs: 2*p observations (pairs cameraIdx, pointIdx) +// feats: 2*p features (x,y coordinates corresponding to observations) +// reproj_err: 2*p errors of observations +// w_err: p weight "error" terms +void ba_objective( + int n, + int m, + int p, + double const* cams, + double const* X, + double const* w, + int const* obs, + double const* feats, + double* reproj_err, + double* w_err +) +{ + int i; + for (i = 0; i < p; i++) + { + int camIdx = obs[i * 2 + 0]; + int ptIdx = obs[i * 2 + 1]; + compute_reproj_error( + &cams[camIdx * BA_NCAMPARAMS], + &X[ptIdx * 3], + &w[i], + &feats[i * 2], + &reproj_err[2 * i] + ); + } + + for (i = 0; i < p; i++) + { + compute_zach_weight_error(&w[i], &w_err[i]); + } +} + +extern int enzyme_const; +extern int enzyme_dup; +extern int enzyme_dupnoneed; +void __enzyme_autodiff(...) noexcept; + +void dcompute_reproj_error( + double const* cam, + double * dcam, + double const* X, + double * dX, + double const* w, + double * wb, + double const* feat, + double *err, + double *derr +) +{ + __enzyme_autodiff(compute_reproj_error, + enzyme_dup, cam, dcam, + enzyme_dup, X, dX, + enzyme_dup, w, wb, + enzyme_const, feat, + enzyme_dupnoneed, err, derr); +} + +void dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr) { + __enzyme_autodiff(compute_zach_weight_error, + enzyme_dup, w, dw, + enzyme_dupnoneed, err, derr); +} + +} diff --git a/enzyme/benchmarks/ReverseMode/ba/src/lib.rs b/enzyme/benchmarks/ReverseMode/ba/src/lib.rs new file mode 100644 index 000000000000..7efd43ff2b28 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/src/lib.rs @@ -0,0 +1,25 @@ +#![feature(autodiff)] +#![allow(non_snake_case)] + +use std::autodiff::autodiff; +pub mod safe; +pub mod r#unsafe; + +static BA_NCAMPARAMS: usize = 11; + +#[no_mangle] +pub extern "C" fn rust_dcompute_zach_weight_error( + w: *const f64, + dw: *mut f64, + err: *mut f64, + derr: *mut f64, +) { + dcompute_zach_weight_error(w, dw, err, derr); +} + +#[autodiff(dcompute_zach_weight_error, Reverse, Duplicated, Duplicated)] +pub fn compute_zach_weight_error(w: *const f64, err: *mut f64) { + let w = unsafe { *w }; + unsafe { *err = 1. - w * w; } +} + diff --git a/enzyme/benchmarks/ReverseMode/ba/src/main.rs b/enzyme/benchmarks/ReverseMode/ba/src/main.rs new file mode 100644 index 000000000000..13f221be69c1 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/src/main.rs @@ -0,0 +1,26 @@ +use bars::{dcompute_reproj_error, dcompute_zach_weight_error}; +fn main() { + let cam = [0.0; 11]; + let mut dcam = [0.0; 11]; + let x = [0.0; 3]; + let mut dx = [0.0; 3]; + let w = [0.0; 1]; + let mut dw = [0.0; 1]; + let feat = [0.0; 2]; + let mut err = [0.0; 2]; + let mut derr = [0.0; 2]; + dcompute_reproj_error( + &cam as *const [f64;11], + &mut dcam as *mut [f64;11], + &x as *const [f64;3], + &mut dx as *mut [f64;3], + &w as *const [f64;1], + &mut dw as *mut [f64;1], + &feat as *const [f64;2], + &mut err as *mut [f64;2], + &mut derr as *mut [f64;2], + ); + + let mut wb = 0.0; + dcompute_zach_weight_error(&w as *const f64, &mut dw as *mut f64, &mut err as *mut f64, &mut derr as *mut f64); +} diff --git a/enzyme/benchmarks/ReverseMode/ba/src/safe.rs b/enzyme/benchmarks/ReverseMode/ba/src/safe.rs new file mode 100644 index 000000000000..dd8bf88b9265 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/src/safe.rs @@ -0,0 +1,207 @@ +use crate::BA_NCAMPARAMS; +use crate::compute_zach_weight_error; +use std::autodiff::autodiff; + +fn sqsum(x: &[f64]) -> f64 { + x.iter().map(|&v| v * v).sum() +} + +#[inline] +fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] { + [ + a[1] * b[2] - a[2] * b[1], + a[2] * b[0] - a[0] * b[2], + a[0] * b[1] - a[1] * b[0], + ] +} + +fn radial_distort(rad_params: &[f64], proj: &mut [f64]) { + let rsq = sqsum(proj); + let l = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq; + proj[0] = proj[0] * l; + proj[1] = proj[1] * l; +} + +fn rodrigues_rotate_point(rot: &[f64; 3], pt: &[f64; 3], rotated_pt: &mut [f64; 3]) { + let sqtheta = sqsum(rot); + if sqtheta != 0. { + let theta = sqtheta.sqrt(); + let costheta = theta.cos(); + let sintheta = theta.sin(); + let theta_inverse = 1. / theta; + let mut w = [0.; 3]; + for i in 0..3 { + w[i] = rot[i] * theta_inverse; + } + let w_cross_pt = cross(&w, &pt); + let tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta); + for i in 0..3 { + rotated_pt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp; + } + } else { + let rot_cross_pt = cross(&rot, &pt); + for i in 0..3 { + rotated_pt[i] = pt[i] + rot_cross_pt[i]; + } + } +} + +fn project(cam: &[f64; 11], X: &[f64; 3], proj: &mut [f64; 2]) { + let C = &cam[3..6]; + let mut Xo = [0.; 3]; + let mut Xcam = [0.; 3]; + + Xo[0] = X[0] - C[0]; + Xo[1] = X[1] - C[1]; + Xo[2] = X[2] - C[2]; + + rodrigues_rotate_point(cam.first_chunk::<3>().unwrap(), &Xo, &mut Xcam); + + proj[0] = Xcam[0] / Xcam[2]; + proj[1] = Xcam[1] / Xcam[2]; + + radial_distort(&cam[9..], proj); + + proj[0] = proj[0] * cam[6] + cam[7]; + proj[1] = proj[1] * cam[6] + cam[8]; +} + +#[no_mangle] +pub extern "C" fn rust_dcompute_reproj_error( + cam: *const [f64; 11], + dcam: *mut [f64; 11], + x: *const [f64; 3], + dx: *mut [f64; 3], + w: *const [f64; 1], + wb: *mut [f64; 1], + feat: *const [f64; 2], + err: *mut [f64; 2], + derr: *mut [f64; 2], +) { + unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)}; +} + +#[autodiff( + dcompute_reproj_error, + Reverse, + Duplicated, + Duplicated, + Duplicated, + Const, + DuplicatedOnly +)] +pub fn compute_reproj_error( + cam: *const [f64; 11], + x: *const [f64; 3], + w: *const [f64; 1], + feat: *const [f64; 2], + err: *mut [f64; 2], +) { + let cam = unsafe { &*cam }; + let w = unsafe { *(*w).get_unchecked(0) }; + let x = unsafe { &*x }; + let feat = unsafe { &*feat }; + let err = unsafe { &mut *err }; + let mut proj = [0.; 2]; + project(cam, x, &mut proj); + err[0] = w * (proj[0] - feat[0]); + err[1] = w * (proj[1] - feat[1]); +} + +// n number of cameras +// m number of points +// p number of observations +// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2] +// r1, r2, r3 are angle - axis rotation parameters(Rodrigues) +// [C1 C2 C3]' is the camera center +// f is the focal length in pixels +// [u0 v0]' is the principal point +// k1, k2 are radial distortion parameters +// X: 3*m points +// obs: 2*p observations (pairs cameraIdx, pointIdx) +// feats: 2*p features (x,y coordinates corresponding to observations) +// reproj_err: 2*p errors of observations +// w_err: p weight "error" terms +fn rust_ba_objective( + n: usize, + m: usize, + p: usize, + cams: &[f64], + x: &[f64], + w: &[f64], + obs: &[i32], + feats: &[f64], + reproj_err: &mut [f64], + w_err: &mut [f64], +) { + assert_eq!(cams.len(), n * 11); + assert_eq!(x.len(), m * 3); + assert_eq!(w.len(), p); + assert_eq!(obs.len(), p * 2); + assert_eq!(feats.len(), p * 2); + assert_eq!(reproj_err.len(), p * 2); + assert_eq!(w_err.len(), p); + + for i in 0..p { + let cam_idx = obs[i * 2 + 0] as usize; + let pt_idx = obs[i * 2 + 1] as usize; + let start = cam_idx * BA_NCAMPARAMS; + let cam: &[f64; 11] = unsafe { + cams[start..] + .get_unchecked(..11) + .try_into() + .unwrap_unchecked() + }; + let x: &[f64; 3] = unsafe { + x[pt_idx * 3..] + .get_unchecked(..3) + .try_into() + .unwrap_unchecked() + }; + let w: &[f64; 1] = unsafe { w[i..].get_unchecked(..1).try_into().unwrap_unchecked() }; + let feat: &[f64; 2] = unsafe { + feats[i * 2..] + .get_unchecked(..2) + .try_into() + .unwrap_unchecked() + }; + let reproj_err: &mut [f64; 2] = unsafe { + reproj_err[i * 2..] + .get_unchecked_mut(..2) + .try_into() + .unwrap_unchecked() + }; + compute_reproj_error(cam, x, w, feat, reproj_err); + } + + for i in 0..p { + let w_err: &mut f64 = unsafe { w_err.get_unchecked_mut(i) }; + compute_zach_weight_error(w[i..].as_ptr(), w_err as *mut f64); + } +} + +#[no_mangle] +extern "C" fn rust2_ba_objective( + n: i32, + m: i32, + p: i32, + cams: *const f64, + x: *const f64, + w: *const f64, + obs: *const i32, + feats: *const f64, + reproj_err: *mut f64, + w_err: *mut f64, +) { + let n = n as usize; + let m = m as usize; + let p = p as usize; + let cams = unsafe { std::slice::from_raw_parts(cams, n * 11) }; + let x = unsafe { std::slice::from_raw_parts(x, m * 3) }; + let w = unsafe { std::slice::from_raw_parts(w, p) }; + let obs = unsafe { std::slice::from_raw_parts(obs, p * 2) }; + let feats = unsafe { std::slice::from_raw_parts(feats, p * 2) }; + let reproj_err = unsafe { std::slice::from_raw_parts_mut(reproj_err, p * 2) }; + let w_err = unsafe { std::slice::from_raw_parts_mut(w_err, p) }; + rust_ba_objective(n, m, p, cams, x, w, obs, feats, reproj_err, w_err); +} diff --git a/enzyme/benchmarks/ReverseMode/ba/src/unsafe.rs b/enzyme/benchmarks/ReverseMode/ba/src/unsafe.rs new file mode 100644 index 000000000000..467a7cb27d7d --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/src/unsafe.rs @@ -0,0 +1,143 @@ +use crate::BA_NCAMPARAMS; +use crate::compute_zach_weight_error; +use std::autodiff::autodiff; + +unsafe fn sqsum(x: *const f64, n: usize) -> f64 { + let mut sum = 0.; + for i in 0..n { + let v = unsafe { *x.add(i) }; + sum += v * v; + } + sum +} + +#[inline] +unsafe fn cross(a: *const f64, b: *const f64, out: *mut f64) { + *out.add(0) = *a.add(1) * *b.add(2) - *a.add(2) * *b.add(1); + *out.add(1) = *a.add(2) * *b.add(0) - *a.add(0) * *b.add(2); + *out.add(2) = *a.add(0) * *b.add(1) - *a.add(1) * *b.add(0); +} + +unsafe fn radial_distort(rad_params: *const f64, proj: *mut f64) { + let rsq = sqsum(proj, 2); + let l = 1. + *rad_params.add(0) * rsq + *rad_params.add(1) * rsq * rsq; + *proj.add(0) = *proj.add(0) * l; + *proj.add(1) = *proj.add(1) * l; +} + +unsafe fn rodrigues_rotate_point(rot: *const f64, pt: *const f64, rotated_pt: *mut f64) { + let sqtheta = sqsum(rot, 3); + if sqtheta != 0. { + let theta = sqtheta.sqrt(); + let costheta = theta.cos(); + let sintheta = theta.sin(); + let theta_inverse = 1. / theta; + let mut w = [0.; 3]; + for i in 0..3 { + w[i] = *rot.add(i) * theta_inverse; + } + let mut w_cross_pt = [0.; 3]; + cross(w.as_ptr(), pt, w_cross_pt.as_mut_ptr()); + let tmp = (w[0] * *pt.add(0) + w[1] * *pt.add(1) + w[2] * *pt.add(2)) * (1. - costheta); + for i in 0..3 { + *rotated_pt.add(i) = *pt.add(i) * costheta + w_cross_pt[i] * sintheta + w[i] * tmp; + } + } else { + let mut rot_cross_pt = [0.; 3]; + cross(rot, pt, rot_cross_pt.as_mut_ptr()); + for i in 0..3 { + *rotated_pt.add(i) = *pt.add(i) + rot_cross_pt[i]; + } + } +} + +unsafe fn project(cam: *const f64, X: *const f64, proj: *mut f64) { + let C = cam.add(3); + let mut Xo = [0.; 3]; + let mut Xcam = [0.; 3]; + + Xo[0] = *X.add(0) - *C.add(0); + Xo[1] = *X.add(1) - *C.add(1); + Xo[2] = *X.add(2) - *C.add(2); + + rodrigues_rotate_point(cam, Xo.as_ptr(), Xcam.as_mut_ptr()); + + *proj.add(0) = Xcam[0] / Xcam[2]; + *proj.add(1) = Xcam[1] / Xcam[2]; + + radial_distort(cam.add(9), proj); + *proj.add(0) = *proj.add(0) * *cam.add(6) + *cam.add(7); + *proj.add(1) = *proj.add(1) * *cam.add(6) + *cam.add(8); +} + +#[no_mangle] +pub unsafe extern "C" fn rust_unsafe_dcompute_reproj_error( + cam: *const f64, + dcam: *mut f64, + x: *const f64, + dx: *mut f64, + w: *const f64, + wb: *mut f64, + feat: *const f64, + err: *mut f64, + derr: *mut f64, +) { + unsafe {dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr)}; +} + + +#[autodiff( + dcompute_reproj_error, + Reverse, + Duplicated, + Duplicated, + Duplicated, + Const, + DuplicatedOnly +)] +pub unsafe fn compute_reproj_error( + cam: *const f64, + x: *const f64, + w: *const f64, + feat: *const f64, + err: *mut f64, +) { + let mut proj = [0.; 2]; + project(cam, x, proj.as_mut_ptr()); + *err.add(0) = *w * (proj[0] - *feat.add(0)); + *err.add(1) = *w * (proj[1] - *feat.add(1)); +} + +#[no_mangle] +unsafe extern "C" fn rust2_unsafe_ba_objective( + n: i32, + m: i32, + p: i32, + cams: *const f64, + x: *const f64, + w: *const f64, + obs: *const i32, + feats: *const f64, + reproj_err: *mut f64, + w_err: *mut f64, +) { + let n = n as usize; + let m = m as usize; + let p = p as usize; + for i in 0..p { + let cam_idx = *obs.add(i * 2 + 0) as usize; + let pt_idx = *obs.add(i * 2 + 1) as usize; + let start = cam_idx * BA_NCAMPARAMS; + + let cam: *const f64 = cams.add(start); + let x: *const f64 = x.add(pt_idx * 3); + let w: *const f64 = w.add(i); + let feat: *const f64 = feats.add(i * 2); + let reproj_err: *mut f64 = reproj_err.add(i * 2); + compute_reproj_error(cam, x, w, feat, reproj_err); + } + + for i in 0..p { + compute_zach_weight_error(w.add(i), w_err.add(i)); + } +} diff --git a/enzyme/benchmarks/ReverseMode/fft/Cargo.lock b/enzyme/benchmarks/ReverseMode/fft/Cargo.lock new file mode 100644 index 000000000000..44847eca60f6 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/fft/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "fft" +version = "0.1.0" diff --git a/enzyme/benchmarks/ReverseMode/fft/Cargo.toml b/enzyme/benchmarks/ReverseMode/fft/Cargo.toml new file mode 100644 index 000000000000..630506d03fed --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/fft/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "fft" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +[lib] +crate-type = ["lib"] + +[profile.release] +lto = "fat" +opt-level = 3 +codegen-units = 1 +unwind = "abort" +strip = true +#overflow-checks = false + +[profile.dev] +lto = "fat" + +[workspace] diff --git a/enzyme/benchmarks/ReverseMode/fft/Makefile.make b/enzyme/benchmarks/ReverseMode/fft/Makefile.make index 17ea03aaa5ae..9ed3daaa26b6 100644 --- a/enzyme/benchmarks/ReverseMode/fft/Makefile.make +++ b/enzyme/benchmarks/ReverseMode/fft/Makefile.make @@ -4,21 +4,50 @@ dir := $(abspath $(lastword $(MAKEFILE_LIST))/../../../..) +include $(dir)/benchmarks/ReverseMode/adbench/Makefile.config + +ifeq ($(strip $(CLANG)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES1)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES2)),) +$(error PASSES2 is not set) +endif + +ifeq ($(strip $(PASSES3)),) +$(error PASSES3 is not set) +endif + +ifneq ($(strip $(PASSES4)),) +$(error PASSES4 is set) +endif + clean: rm -f *.ll *.o results.txt results.json +$(dir)/benchmarks/ReverseMode/fft/target/release/libfft.a: src/lib.rs Cargo.toml + RUSTFLAGS="-Z autodiff=Enable" cargo +enzyme rustc --release --lib --crate-type=staticlib + %-unopt.ll: %.cpp - clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-use-cxa-atexit -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm + $(CLANG) $(BENCH) $^ -DCPP=1 -fno-math-errno -fno-plt -pthread -O3 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -o $@ -S -emit-llvm #-fno-use-cxa-atexit +%-unoptr.ll: %.cpp + $(CLANG) $(BENCH) $^ -fno-math-errno -fno-plt -pthread -O3 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -o $@ -S -emit-llvm #-fno-use-cxa-atexit -%-raw.ll: %-unopt.ll - opt $^ $(LOAD) $(ENZYME) -o $@ -S -%-opt.ll: %-raw.ll - opt $^ -o $@ -S +%-opt.ll: %-unopt.ll + $(OPT) $^ $(LOAD) -passes="$(PASSES2),enzyme" -o $@ -S +%-optr.ll: %-unoptr.ll + $(OPT) $^ $(LOAD) -passes="$(PASSES2),enzyme" -o $@ -S -fft.o: fft-opt.ll - clang++ $(BENCH) -pthread -O2 $^ -o $@ $(BENCHLINK) -lpthread -lm -L /usr/lib/gcc/x86_64-linux-gnu/11 - #clang++ $(LOAD) $(BENCH) fft.cpp -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o fft.o -lpthread $(BENCHLINK) -lm -L /usr/lib/gcc/x86_64-linux-gnu/11 +fft.o: fft-opt.ll $(dir)/benchmarks/ReverseMode/fft/target/release/libfft.a + $(CLANG) -DCPP=1 -pthread -O3 -fno-math-errno -fno-plt -lpthread -lm $^ -o $@ $(BENCHLINK) -lm +fftr.o: fft-optr.ll $(dir)/benchmarks/ReverseMode/fft/target/release/libfft.a + $(CLANG) -pthread -O3 -fno-math-errno -fno-plt -lpthread -lm $^ -o $@ $(BENCHLINK) -lm -results.json: fft.o - ./$^ 1048576 | tee $@ +results.json: fftr.o fft.o + numactl -C 1 ./fft.o 1048576 | tee results.json + numactl -C 1 ./fftr.o 1048576 | tee resultsr.json diff --git a/enzyme/benchmarks/ReverseMode/fft/fft.cpp b/enzyme/benchmarks/ReverseMode/fft/fft.cpp index cf9459b9597a..799b7b16c1b9 100644 --- a/enzyme/benchmarks/ReverseMode/fft/fft.cpp +++ b/enzyme/benchmarks/ReverseMode/fft/fft.cpp @@ -1,237 +1,368 @@ +#include +#include +#include #include #include #include -#include -#include -#include -#include #include -#include -#include +#include using adept::adouble; -template -Return __enzyme_autodiff(T...); +template Return __enzyme_autodiff(T...); float tdiff(struct timeval *start, struct timeval *end) { - return (end->tv_sec-start->tv_sec) + 1e-6*(end->tv_usec-start->tv_usec); + return (end->tv_sec - start->tv_sec) + 1e-6 * (end->tv_usec - start->tv_usec); } #include "fft.h" -void foobar(double* data, unsigned len) { +void foobar(double *data, size_t len) { fft(data, len); ifft(data, len); } -void afoobar(aVector& data, unsigned len) { +void afoobar(aVector &data, size_t len) { fft(data, len); ifft(data, len); } extern "C" { - int enzyme_dupnoneed; +int enzyme_dupnoneed; } -static double foobar_and_gradient(unsigned len) { - double *inp = new double[2*len]; - for(int i=0; i<2*len; i++) inp[i] = 2.0; - double *dinp = new double[2*len]; - for(int i=0; i<2*len; i++) dinp[i] = 1.0; - __enzyme_autodiff(foobar, enzyme_dupnoneed, inp, dinp, len); - double res = dinp[0]; - delete[] dinp; - delete[] inp; - return res; +extern "C" void rust_unsafe_dfoobar(size_t n, double *data, double *ddata); +extern "C" void rust_unsafe_foobar(size_t n, double *data); +extern "C" void rust_dfoobar(size_t n, double *data, double *ddata); +extern "C" void rust_foobar(size_t n, double *data); + +static double rust_unsafe_foobar_and_gradient(size_t len) { + double *inp = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + inp[i] = 2.0; + double *dinp = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + dinp[i] = 1.0; + rust_unsafe_dfoobar(len, inp, dinp); + double res = dinp[0]; + delete[] dinp; + delete[] inp; + return res; } -static double afoobar_and_gradient(unsigned len) { - adept::Stack stack; - - aVector x(2*len); - for(int i=0; i<2*len; i++) x(i) = 2.0; - stack.new_recording(); - afoobar(x, len); - for(int i=0; i<2*len; i++) - x(i).set_gradient(1.0); - stack.compute_adjoint(); +static double rust_foobar_and_gradient(size_t len) { + double *inp = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + inp[i] = 2.0; + double *dinp = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + dinp[i] = 1.0; + rust_dfoobar(len, inp, dinp); + double res = dinp[0]; + delete[] dinp; + delete[] inp; + return res; +} - double *dinp = new double[2*len]; - for(int i=0; i<2*len; i++) - dinp[i] = x(i).get_gradient(); - double res = dinp[0]; - delete[] dinp; - return res; +__attribute__((noinline)) static double foobar_and_gradient(size_t len) { + double *inp = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + inp[i] = 2.0; + double *dinp = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + dinp[i] = 1.0; + __enzyme_autodiff(foobar, enzyme_dupnoneed, inp, dinp, len); + double res = dinp[0]; + delete[] dinp; + delete[] inp; + return res; } +static double afoobar_and_gradient(size_t len) { + adept::Stack stack; -static double tfoobar_and_gradient(unsigned len) { - double *inp = new double[2*len]; - for(int i=0; i<2*len; i++) inp[i] = 2.0; - double *dinp = new double[2*len]; - for(int i=0; i<2*len; i++) dinp[i] = 1.0; - foobar_b(inp, dinp, len); - double res = dinp[0]; - delete[] dinp; - delete[] inp; - return res; + aVector x(2 * len); + for (size_t i = 0; i < 2 * len; i++) + x(i) = 2.0; + stack.new_recording(); + afoobar(x, len); + for (size_t i = 0; i < 2 * len; i++) + x(i).set_gradient(1.0); + stack.compute_adjoint(); + + double *dinp = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + dinp[i] = x(i).get_gradient(); + double res = dinp[0]; + delete[] dinp; + return res; } -static void adept_sincos(double inp, unsigned len) { - { - struct timeval start, end; - gettimeofday(&start, NULL); - - double *x = new double[2*len]; - for(int i=0; i<2*len; i++) x[i] = 2.0; - foobar(x, len); - double res = x[0]; +static double tfoobar_and_gradient(size_t len) { + double *inp = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + inp[i] = 2.0; + double *dinp = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + dinp[i] = 1.0; + foobar_b(inp, dinp, len); + double res = dinp[0]; + delete[] dinp; + delete[] inp; + return res; +} - gettimeofday(&end, NULL); - printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res); - delete[] x; +static void adept_sincos(double inp, size_t len) { + { + struct timeval start, end; + gettimeofday(&start, NULL); + + double *x = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + x[i] = 2.0; + foobar(x, len); + double res = x[0]; + + gettimeofday(&end, NULL); + printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res); + delete[] x; } { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); - adept::Stack stack; + adept::Stack stack; - aVector x(2*len); - for(int i=0; i<2*len; i++) x[i] = 2.0; - // stack.new_recording(); - afoobar(x, len); - double res = x(0).value(); + aVector x(2 * len); + for (size_t i = 0; i < 2 * len; i++) + x[i] = 2.0; + // stack.new_recording(); + afoobar(x, len); + double res = x(0).value(); - gettimeofday(&end, NULL); - printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res); + gettimeofday(&end, NULL); + printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res); } { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); - double res2 = afoobar_and_gradient(len); + double res2 = afoobar_and_gradient(len); - gettimeofday(&end, NULL); - printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), res2); + gettimeofday(&end, NULL); + printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), res2); } } +static void tapenade_sincos(double inp, size_t len) { -static void tapenade_sincos(double inp, unsigned len) { + // { + // struct timeval start, end; + // gettimeofday(&start, NULL); - { - struct timeval start, end; - gettimeofday(&start, NULL); + // double *x = new double[2*len]; + // for(size_t i=0; i<2*len; i++) x[i] = 2.0; + // foobar(x, len); + // double res = x[0]; - double *x = new double[2*len]; - for(int i=0; i<2*len; i++) x[i] = 2.0; - foobar(x, len); - double res = x[0]; + // gettimeofday(&end, NULL); + // printf("Tapenade real %0.6f res=%f\n", tdiff(&start, &end), res); + // delete[] x; + // } - gettimeofday(&end, NULL); - printf("Tapenade real %0.6f res=%f\n", tdiff(&start, &end), res); - delete[] x; - } + // { + // struct timeval start, end; + // gettimeofday(&start, NULL); - { - struct timeval start, end; - gettimeofday(&start, NULL); + // double* x = new double[2*len]; + // for(size_t i=0; i<2*len; i++) x[i] = 2.0; + // foobar(x, len); + // double res = x[0]; + + // gettimeofday(&end, NULL); + // printf("Tapenade forward %0.6f res=%f\n", tdiff(&start, &end), res); + // delete[] x; + // } + + // { + // struct timeval start, end; + // gettimeofday(&start, NULL); + + // double res2 = tfoobar_and_gradient(len); - double* x = new double[2*len]; - for(int i=0; i<2*len; i++) x[i] = 2.0; - foobar(x, len); - double res = x[0]; + // gettimeofday(&end, NULL); + // printf("Tapenade combined %0.6f res'=%f\n", tdiff(&start, &end), res2); + // } +} + +static void enzyme_sincos(double inp, size_t len) { + + { + struct timeval start, end; + gettimeofday(&start, NULL); + + double *x = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + x[i] = 2.0; + foobar(x, len); + double res = x[0]; + + gettimeofday(&end, NULL); + printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res); + delete[] x; + } - gettimeofday(&end, NULL); - printf("Tapenade forward %0.6f res=%f\n", tdiff(&start, &end), res); - delete[] x; + { + struct timeval start, end; + gettimeofday(&start, NULL); + + double *x = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + x[i] = 2.0; + foobar(x, len); + double res = x[0]; + + gettimeofday(&end, NULL); + printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res); + delete[] x; } { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); - double res2 = tfoobar_and_gradient(len); + double res2 = foobar_and_gradient(len); - gettimeofday(&end, NULL); - printf("Tapenade combined %0.6f res'=%f\n", tdiff(&start, &end), res2); + gettimeofday(&end, NULL); + printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), res2); } } -static void enzyme_sincos(double inp, unsigned len) { +static void enzyme_unsafe_rust_sincos(double inp, size_t len) { { - struct timeval start, end; - gettimeofday(&start, NULL); - - double *x = new double[2*len]; - for(int i=0; i<2*len; i++) x[i] = 2.0; - foobar(x, len); - double res = x[0]; + struct timeval start, end; + gettimeofday(&start, NULL); + + double *x = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + x[i] = 2.0; + rust_unsafe_foobar(len, x); + double res = x[0]; + + gettimeofday(&end, NULL); + printf("Enzyme (unsafe Rust) real %0.6f res=%f\n", tdiff(&start, &end), + res); + delete[] x; + } - gettimeofday(&end, NULL); - printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res); - delete[] x; + { + struct timeval start, end; + gettimeofday(&start, NULL); + + double *x = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + x[i] = 2.0; + rust_unsafe_foobar(len, x); + double res = x[0]; + + gettimeofday(&end, NULL); + printf("Enzyme (unsafe Rust) forward %0.6f res=%f\n", tdiff(&start, &end), + res); + delete[] x; } { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); + + double res2 = rust_unsafe_foobar_and_gradient(len); + + gettimeofday(&end, NULL); + printf("Enzyme (unsafe Rust) combined %0.6f res'=%f\n", tdiff(&start, &end), + res2); + } +} + +static void enzyme_rust_sincos(double inp, size_t len) { - double *x = new double[2*len]; - for(int i=0; i<2*len; i++) x[i] = 2.0; - foobar(x, len); - double res = x[0]; + { + struct timeval start, end; + gettimeofday(&start, NULL); + + double *x = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + x[i] = 2.0; + rust_foobar(len, x); + double res = x[0]; + + gettimeofday(&end, NULL); + printf("Enzyme (Rust) real %0.6f res=%f\n", tdiff(&start, &end), res); + delete[] x; + } - gettimeofday(&end, NULL); - printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res); - delete[] x; + { + struct timeval start, end; + gettimeofday(&start, NULL); + + double *x = new double[2 * len]; + for (size_t i = 0; i < 2 * len; i++) + x[i] = 2.0; + rust_foobar(len, x); + double res = x[0]; + + gettimeofday(&end, NULL); + printf("Enzyme (Rust) forward %0.6f res=%f\n", tdiff(&start, &end), res); + delete[] x; } { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); - double res2 = foobar_and_gradient(len); + double res2 = rust_foobar_and_gradient(len); - gettimeofday(&end, NULL); - printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), res2); + gettimeofday(&end, NULL); + printf("Enzyme (Rust) combined %0.6f res'=%f\n", tdiff(&start, &end), res2); } } - /* Function to check if x is power of 2*/ -bool isPowerOfTwo (int x) -{ - /* First x in the below expression is for the case when x is 0 */ - return x && (!(x&(x-1))); +bool isPowerOfTwo(size_t x) { + /* First x in the below expression is for the case when x is 0 */ + return x && (!(x & (x - 1))); } -unsigned max(unsigned A, unsigned B){ - if (A>B) return A; +size_t max(size_t A, size_t B) { + if (A > B) + return A; return B; } -int main(int argc, char** argv) { +int main(int argc, char **argv) { if (argc < 2) { printf("usage %s n [must be power of 2]\n", argv[0]); return 1; } - unsigned N = atoi(argv[1]); + size_t N = atol(argv[1]); if (!isPowerOfTwo(N)) { printf("usage %s n [must be power of 2]\n", argv[0]); return 1; } double inp = -2.1; - for(unsigned iters=max(1, N>>5); iters <= N; iters*=2) { - printf("iters=%d\n", iters); + size_t iters = max(1, N >> 0); + for (size_t i = 0; i < 5; i++) { + printf("iters=%zu\n", iters); +#if CPP adept_sincos(inp, iters); tapenade_sincos(inp, iters); enzyme_sincos(inp, iters); +#else + enzyme_rust_sincos(inp, iters); + enzyme_unsafe_rust_sincos(inp, iters); +#endif } } diff --git a/enzyme/benchmarks/ReverseMode/fft/fft.h b/enzyme/benchmarks/ReverseMode/fft/fft.h index 809196b76cc3..fad3c7dad145 100644 --- a/enzyme/benchmarks/ReverseMode/fft/fft.h +++ b/enzyme/benchmarks/ReverseMode/fft/fft.h @@ -1,71 +1,75 @@ #ifndef _fft_h_ #define _fft_h_ -#include #include #include +#include using adept::adouble; using adept::aVector; - /* A classy FFT and Inverse FFT C++ class library Author: Tim Molteno, tim@physics.otago.ac.nz - Based on the article "A Simple and Efficient FFT Implementation in C++" by Volodymyr Myrnyy - with just a simple Inverse FFT modification. + Based on the article "A Simple and Efficient FFT Implementation in C++" by + Volodymyr Myrnyy with just a simple Inverse FFT modification. Licensed under the GPL v3. */ - #include -inline void swap(double* a, double* b) { - double temp=*a; +inline void swap(double *a, double *b) { + double temp = *a; *a = *b; *b = temp; } -static void recursiveApply(double* data, int iSign, unsigned N) { - if (N == 1) return; - recursiveApply(data, iSign, N/2); - recursiveApply(data+N, iSign, N/2); +static void recursiveApply(double *__restrict data, size_t N, int iSign) { + if (N == 1) + return; + recursiveApply(data, N / 2, iSign); + recursiveApply(data + N, N / 2, iSign); - double wtemp = iSign*sin(M_PI/N); - double wpi = -iSign*sin(2*M_PI/N); - double wpr = -2.0*wtemp*wtemp; + double wtemp = iSign * sin(M_PI / N); + double wpi = -iSign * sin(2 * (M_PI / N)); + double wpr = -2.0 * wtemp * wtemp; double wr = 1.0; double wi = 0.0; - for (unsigned i=0; ii) { - swap(&data[j-1], &data[i-1]); +static void scramble(double *data, size_t N) { + size_t j = 1; + for (size_t ii = 0; ii < N; ii++) { + size_t i = 2 * ii + 1; + if (j > i) { + swap(&data[j - 1], &data[i - 1]); swap(&data[j], &data[i]); } - int m = N; - while (m>=2 && j>m) { + size_t m = N; + while (m >= 2 && j > m) { j -= m; m >>= 1; } @@ -73,69 +77,71 @@ static void scramble(double* data, unsigned N) { } } -static void rescale(double* data, unsigned N) { - double scale = ((double)1)/N; - for (unsigned i=0; i<2*N; i++) { +static void rescale(double *data, size_t N) { + double scale = ((double)1) / N; + for (size_t i = 0; i < 2 * N; i++) { data[i] *= scale; } } -static void fft(double* data, unsigned N) { +static void fft(double *data, size_t N) { scramble(data, N); - recursiveApply(data,1, N); + recursiveApply(data, N, 1); } -static void ifft(double* data, unsigned N) { +static void ifft(double *data, size_t N) { scramble(data, N); - recursiveApply(data,-1, N); + recursiveApply(data, N, -1); rescale(data, N); } - - -inline void swapad(adept::ActiveReference a, adept::ActiveReference b) { - adouble temp=a; +inline void swapad(adept::ActiveReference a, + adept::ActiveReference b) { + adouble temp = a; a = b; b = temp; } -static void recursiveApply(aVector data, int iSign, unsigned N) { - if (N == 1) return; - recursiveApply(data, iSign, N/2); - recursiveApply(data(adept::range(N,adept::end)), iSign, N/2); +static void recursiveApply(aVector data, size_t N, int iSign) { + if (N == 1) + return; + recursiveApply(data, N / 2, iSign); + recursiveApply(data(adept::range(N, adept::end)), N / 2, iSign); - adouble wtemp = iSign*std::sin(M_PI/N); - adouble wpi = -iSign*std::sin(2*M_PI/N); - adouble wpr = -2.0*wtemp*wtemp; + adouble wtemp = iSign * std::sin(M_PI / N); + adouble wpi = -iSign * std::sin(2 * (M_PI / N)); + adouble wpr = -2.0 * wtemp * wtemp; adouble wr = 1.0; adouble wi = 0.0; - for (unsigned i=0; ii) { - swapad(data(j-1), data(i-1)); +static void scramble(aVector data, size_t N) { + size_t j = 1; + for (size_t ii = 0; ii < N; ii++) { + size_t i = 2 * ii + 1; + if (j > i) { + swapad(data(j - 1), data(i - 1)); swapad(data(j), data(i)); } - int m = N; - while (m>=2 && j>m) { + size_t m = N; + while (m >= 2 && j > m) { j -= m; m >>= 1; } @@ -143,21 +149,21 @@ static void scramble(aVector data, unsigned N) { } } -static void rescale(aVector data, unsigned N) { - adouble scale = ((double)1)/N; - for (unsigned i=0; i<2*N; i++) { +static void rescale(aVector data, size_t N) { + adouble scale = ((double)1) / N; + for (size_t i = 0; i < 2 * N; i++) { data[i] *= scale; } } -static void fft(aVector data, unsigned N) { +static void fft(aVector data, size_t N) { scramble(data, N); - recursiveApply(data,1, N); + recursiveApply(data, N, 1); } -static void ifft(aVector data, unsigned N) { +static void ifft(aVector data, size_t N) { scramble(data, N); - recursiveApply(data,-1, N); + recursiveApply(data, N, -1); rescale(data, N); } @@ -165,260 +171,308 @@ static void ifft(aVector data, unsigned N) { extern "C" { /* Generated by TAPENADE (INRIA, Ecuador team) - Tapenade 3.15 (master) - 15 Apr 2020 11:54 + Tapenade 3.16 (bugfix_servletAD) - 4 Jan 2024 17:44 */ #include +#include +#include /* - Differentiation of recursiveApply in reverse (adjoint) mode (with options context): - gradient of useful results: *data - with respect to varying inputs: *data - Plus diff mem management of: data:in + Differentiation of swap in reverse (adjoint) mode: + gradient of useful results: *a *b + with respect to varying inputs: *a *b + Plus diff mem management of: a:in b:in */ -static void recursiveApply_b(double *data, double *datab, int iSign, unsigned - int N) { - int arg1; - double *arg10; - double *arg10b; - int arg2; - if (N != 1) { - arg1 = N/2; - arg10b = datab + N; - arg10 = data + N; - arg2 = N/2; - double wtemp = iSign*sin(3.1415926536/N); - double wpi = -iSign*sin(2*3.1415926536/N); - double wpr = -2.0*wtemp*wtemp; - double wr = 1.0; - double wi = 0.0; - for (int i = 0; i <= N-1; i += 2) { - int iN = i + N; - double tempr = data[iN]*wr - data[iN+1]*wi; - double tempi = data[iN]*wi + data[iN+1]*wr; - double tmp; - double tmp0; - wtemp = wr; - pushReal8(wr); - wr = wr + (wr*wpr - wi*wpi); - pushReal8(wi); - wi = wi + (wi*wpr + wtemp*wpi); - pushInteger4(iN); - } - pushPointer8(arg10b); - pushInteger4(arg2); - pushInteger4(arg1); - popInteger4(&arg1); - popInteger4(&arg2); - popPointer8((void **)&arg10b); - for (int i = N-(N-1)%2-1; i >= 0; i -= 2) { - int iN; - double tempr; - double temprb = 0.0; - double tempi; - double tempib = 0.0; - double tmpb; - double tmpb0; - popInteger4(&iN); - tmpb0 = datab[iN + 1]; - popReal8(&wi); - popReal8(&wr); - tempib = datab[i + 1] - tmpb0; - temprb = datab[i]; - datab[iN + 1] = 0.0; - datab[i + 1] = datab[i + 1] + tmpb0; - tmpb = datab[iN]; - datab[iN] = 0.0; - datab[i] = datab[i] + tmpb; - temprb = temprb - tmpb; - datab[iN + 1] = datab[iN + 1] + wr*tempib - wi*temprb; - datab[iN] = datab[iN] + wi*tempib + wr*temprb; - } - recursiveApply_b(arg10, arg10b, iSign, arg2); - recursiveApply_b(data, datab, iSign, arg1); - } +inline void swap_b(double *a, double *ab, double *b, double *bb) { + double temp = *a; + double tempb = 0.0; + *a = *b; + *b = temp; + tempb = *bb; + *bb = *ab; + *ab = tempb; +} + +inline void swap_c(double *a, double *b) { + double temp = *a; + *a = *b; + *b = temp; } -static void recursiveApply_nodiff(double *data, int iSign, unsigned int N) { - int arg1; - double *arg10; - int arg2; - if (N == 1) - return; - else { - arg1 = N/2; - recursiveApply_nodiff(data, iSign, arg1); - arg10 = data + N; - arg2 = N/2; - recursiveApply_nodiff(arg10, iSign, arg2); - double wtemp = iSign*sin(3.1415926536/N); - double wpi = -iSign*sin(2*3.1415926536/N); - double wpr = -2.0*wtemp*wtemp; - double wr = 1.0; - double wi = 0.0; - for (int i = 0; i <= N-1; i += 2) { - int iN = i + N; - double tempr = data[iN]*wr - data[iN+1]*wi; - double tempi = data[iN]*wi + data[iN+1]*wr; - data[iN] = data[i] - tempr; - data[iN + 1] = data[i + 1] - tempi; - data[i] += tempr; - data[i + 1] += tempi; - wtemp = wr; - wr += wr*wpr - wi*wpi; - wi += wi*wpr + wtemp*wpi; - } +static void recursiveApply_c(double *data, int iSign, size_t N) { + size_t arg1; + double *arg10; + size_t arg2; + if (N == 1) + return; + else { + arg1 = N / 2; + recursiveApply_c(data, iSign, arg1); + arg10 = data + N; + arg2 = N / 2; + recursiveApply_c(arg10, iSign, arg2); + double wtemp = iSign * sin(3.14 / N); + double wpi = -iSign * sin(2 * 3.14 / N); + double wpr = -2.0 * wtemp * wtemp; + double wr = 1.0; + double wi = 0.0; + for (size_t ii = 0; ii < N / 2; ii++) { + size_t i = 2 * ii; + size_t iN = i + N; + double tempr = data[iN] * wr - data[iN + 1] * wi; + double tempi = data[iN] * wi + data[iN + 1] * wr; + data[iN] = data[i] - tempr; + data[iN + 1] = data[i + 1] - tempi; + data[i] += tempr; + data[i + 1] += tempi; + wtemp = wr; + wr += wr * wpr - wi * wpi; + wi += wi * wpr + wtemp * wpi; } + } } /* - Differentiation of swap in reverse (adjoint) mode (with options context): - gradient of useful results: *a *b - with respect to varying inputs: *a *b - Plus diff mem management of: a:in b:in + Differentiation of recursiveApply in reverse (adjoint) mode: + gradient of useful results: *data + with respect to varying inputs: *data + Plus diff mem management of: data:in */ -static void swap_b(double *a, double *ab, double *b, double *bb) { - double temp = *a; - double tempb = 0.0; - tempb = *bb; - *bb = *ab; - *ab = tempb; -} - -static void swap_nodiff(double *a, double *b) { - double temp = *a; - *a = *b; - *b = temp; +static void recursiveApply_b(double *data, double *datab, int iSign, size_t N) { + size_t arg1; + double *arg10; + double *arg10b; + size_t arg2; + int branch; + if (N != 1) { + arg1 = N / 2; + pushReal8(*data); + recursiveApply_c(data, iSign, arg1); + arg10b = datab + N; + arg10 = data + N; + arg2 = N / 2; + if (arg10) { + pushReal8(*arg10); + pushControl1b(1); + } else + pushControl1b(0); + recursiveApply_c(arg10, iSign, arg2); + double wtemp = iSign * sin(3.14 / N); + double wpi = -iSign * sin(2 * 3.14 / N); + double wpr = -2.0 * wtemp * wtemp; + double wr = 1.0; + double wi = 0.0; + for (size_t ii = 0; ii < N / 2; ii++) { + size_t i = 2 * ii; + int iN = i + N; + double tempr = data[iN] * wr - data[iN + 1] * wi; + double tempi = data[iN] * wi + data[iN + 1] * wr; + double temprb; + double tempib; + double tmp; + double tmp0; + tmp = data[i] - tempr; + data[iN] = tmp; + tmp0 = data[i + 1] - tempi; + data[iN + 1] = tmp0; + data[i] = data[i] + tempr; + data[i + 1] = data[i + 1] + tempi; + wtemp = wr; + pushReal8(wr); + wr = wr + (wr * wpr - wi * wpi); + pushReal8(wi); + wi = wi + (wi * wpr + wtemp * wpi); + pushInteger4(iN); + } + for (size_t i = N - (N - 1) % 2 - 1; i >= 0; i -= 2) { + int iN; + double tempr; + double temprb = 0.0; + double tempi; + double tempib = 0.0; + double tmpb; + double tmpb0; + popInteger4(&iN); + tmpb0 = datab[iN + 1]; + popReal8(&wi); + popReal8(&wr); + tempib = datab[i + 1] - tmpb0; + temprb = datab[i]; + datab[iN + 1] = 0.0; + datab[i + 1] = datab[i + 1] + tmpb0; + tmpb = datab[iN]; + datab[iN] = 0.0; + datab[i] = datab[i] + tmpb; + temprb = temprb - tmpb; + datab[iN + 1] = datab[iN + 1] + wr * tempib - wi * temprb; + datab[iN] = datab[iN] + wi * tempib + wr * temprb; + } + popControl1b(&branch); + if (branch == 1) + popReal8(arg10); + recursiveApply_b(arg10, arg10b, iSign, arg2); + popReal8(data); + recursiveApply_b(data, datab, iSign, arg1); + } } /* - Differentiation of scramble in reverse (adjoint) mode (with options context): + Differentiation of scramble in reverse (adjoint) mode: gradient of useful results: *data with respect to varying inputs: *data Plus diff mem management of: data:in */ -static void scramble_b(double *data, double *datab, unsigned int N) { - int j = 1; - int branch; - for (int i = 1; i <= 2*N-1; i += 2) { - int adCount; - if (j > i) - pushControl1b(0); - else - pushControl1b(1); - int m = N; - adCount = 0; - while(m >= 2 && j > m) { - pushInteger4(j); - j = j - m; - m = m >> 1; - adCount = adCount + 1; - } - pushInteger4(adCount); - pushInteger4(j); - j = j + m; +static void scramble_b(double *data, double *datab, size_t N) { + int j = 1; + int branch; + for (size_t ii = 0; ii < N; ii++) { + size_t i = 2 * ii + 1; + int adCount; + if (j > i) { + pushReal8(data[i - 1]); + pushReal8(data[j - 1]); + swap_c(&(data[j - 1]), &(data[i - 1])); + pushReal8(data[i]); + pushReal8(data[j]); + swap_c(&(data[j]), &(data[i])); + pushControl1b(0); + } else + pushControl1b(1); + size_t m = N; + adCount = 0; + while (m >= 2 && j > m) { + pushInteger4(j); + j = j - m; + m = m >> 1; + adCount = adCount + 1; } - for (int i = 2*N-(2*N-2)%2-1; i >= 1; i -= 2) { - int m; - int adCount; - int i0; - popInteger4(&j); - popInteger4(&adCount); - for (i0 = 1; i0 < adCount+1; ++i0) - popInteger4(&j); - popControl1b(&branch); - if (branch == 0) { - swap_b(&(data[j]), &(datab[j]), &(data[i]), &(datab[i])); - swap_b(&(data[j - 1]), &(datab[j - 1]), &(data[i - 1]), &(datab[i - - 1])); - } + pushInteger4(adCount); + pushInteger4(j); + j = j + m; + } + for (size_t i = 2 * N - (2 * N - 2) % 2 - 1; i >= 1; i -= 2) { + size_t m; + int adCount; + size_t i0; + popInteger4(&j); + popInteger4(&adCount); + for (i0 = 1; i0 < adCount + 1; ++i0) + popInteger4(&j); + popControl1b(&branch); + if (branch == 0) { + popReal8(&(data[j])); + popReal8(&(data[i])); + swap_b(&(data[j]), &(datab[j]), &(data[i]), &(datab[i])); + popReal8(&(data[j - 1])); + popReal8(&(data[i - 1])); + swap_b(&(data[j - 1]), &(datab[j - 1]), &(data[i - 1]), &(datab[i - 1])); } + } } -static void scramble_nodiff(double *data, unsigned int N) { - int j = 1; - for (int i = 1; i <= 2*N-1; i += 2) { - if (j > i) { - swap_nodiff(&(data[j - 1]), &(data[i - 1])); - swap_nodiff(&(data[j]), &(data[i])); - } - int m = N; - while(m >= 2 && j > m) { - j -= m; - m >>= 1; - } - j += m; +static void scramble_c(double *data, size_t N) { + size_t j = 1; + for (size_t ii = 0; ii < N; ii++) { + size_t i = 2 * ii + 1; + if (j > i) { + swap_c(&(data[j - 1]), &(data[i - 1])); + swap_c(&(data[j]), &(data[i])); + } + size_t m = N; + while (m >= 2 && j > m) { + j -= m; + m >>= 1; } + j += m; + } } /* - Differentiation of rescale in reverse (adjoint) mode (with options context): + Differentiation of rescale in reverse (adjoint) mode: gradient of useful results: *data with respect to varying inputs: *data Plus diff mem management of: data:in */ -static void rescale_b(double *data, double *datab, unsigned int N) { - double scale = (double)1/N; - pushReal8(scale); - popReal8(&scale); - for (int i = 2*N-1; i > -1; --i) - datab[i] = scale*datab[i]; +static void rescale_b(double *data, double *datab, size_t N) { + double scale = (double)1 / N; + for (size_t i = 0; i < 2 * N; ++i) + data[i] = data[i] * scale; + for (size_t i = 2 * N - 1; i > -1; --i) + datab[i] = scale * datab[i]; } -static void rescale_nodiff(double *data, unsigned int N) { - double scale = (double)1/N; - for (int i = 0; i < 2*N; ++i) - data[i] *= scale; +static void rescale_c(double *data, size_t N) { + double scale = (double)1 / N; + for (size_t i = 0; i < 2 * N; ++i) + data[i] *= scale; } /* - Differentiation of fft in reverse (adjoint) mode (with options context): + Differentiation of fiveft in reverse (adjoint) mode: gradient of useful results: *data with respect to varying inputs: *data Plus diff mem management of: data:in */ -static void fft_b(double *data, double *datab, unsigned int N) { - recursiveApply_b(data, datab, 1, N); - scramble_b(data, datab, N); +void fiveft_b(double *data, double *datab, size_t N) { + pushReal8(*data); + scramble_c(data, N); + pushReal8(*data); + recursiveApply_c(data, 1, N); + popReal8(data); + recursiveApply_b(data, datab, 1, N); + popReal8(data); + scramble_b(data, datab, N); } -static void fft_nodiff(double *data, unsigned int N) { - scramble_nodiff(data, N); - recursiveApply_nodiff(data, 1, N); +void fiveft_c(double *data, size_t N) { + scramble_c(data, N); + recursiveApply_c(data, 1, N); } /* - Differentiation of ifft in reverse (adjoint) mode (with options context): + Differentiation of ifiveft in reverse (adjoint) mode: gradient of useful results: *data with respect to varying inputs: *data Plus diff mem management of: data:in */ -static void ifft_b(double *data, double *datab, unsigned int N) { - rescale_b(data, datab, N); - recursiveApply_b(data, datab, -1, N); - scramble_b(data, datab, N); +void ifiveft_b(double *data, double *datab, size_t N) { + pushReal8(*data); + scramble_c(data, N); + pushReal8(*data); + recursiveApply_c(data, -1, N); + pushReal8(*data); + rescale_c(data, N); + popReal8(data); + rescale_b(data, datab, N); + popReal8(data); + recursiveApply_b(data, datab, -1, N); + popReal8(data); + scramble_b(data, datab, N); } -static void ifft_nodiff(double *data, unsigned int N) { - scramble_nodiff(data, N); - recursiveApply_nodiff(data, -1, N); - rescale_nodiff(data, N); +void ifiveft_c(double *data, size_t N) { + scramble_c(data, N); + recursiveApply_c(data, -1, N); + rescale_c(data, N); } /* - Differentiation of foobar in reverse (adjoint) mode (with options context): + Differentiation of foobar in reverse (adjoint) mode: gradient of useful results: *data with respect to varying inputs: *data - RW status of diff variables: *data:in-out + RW status of diff variables: data:(loc) *data:in-out Plus diff mem management of: data:in */ -void foobar_b(double *data, double *datab, unsigned int len) { - double chksum = 0.0; - int i; - ifft_b(data, datab, len); - fft_b(data, datab, len); +void foobar_b(double *data, double *datab, size_t len) { + pushReal8(*data); + fiveft_c(data, len); + pushReal8(*data); + ifiveft_c(data, len); + popReal8(data); + ifiveft_b(data, datab, len); + popReal8(data); + fiveft_b(data, datab, len); } - } - #endif /* _fft_h_ */ diff --git a/enzyme/benchmarks/ReverseMode/fft/src/lib.rs b/enzyme/benchmarks/ReverseMode/fft/src/lib.rs new file mode 100644 index 000000000000..3b49cb61fd47 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/fft/src/lib.rs @@ -0,0 +1,6 @@ +#![feature(slice_swap_unchecked)] +#![feature(autodiff)] +#![feature(slice_as_chunks)] + +pub mod safe; +pub mod unsf; diff --git a/enzyme/benchmarks/ReverseMode/fft/src/main.rs b/enzyme/benchmarks/ReverseMode/fft/src/main.rs new file mode 100644 index 000000000000..5f76ad96243e --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/fft/src/main.rs @@ -0,0 +1,22 @@ +use core::mem; +use fft::safe;//::dfoobar; +use fft::unsf;//::dfoobar; + +fn main() { + let len = 16; + let mut data = vec![1.0; 2*len]; + for i in 0..len { + data[i] = 2.0; + } + let mut data_d = vec![1.0; 2*len]; + + //unsafe {safe::rust_dfoobar(len, data.as_mut_ptr(), data_d.as_mut_ptr());} + //unsafe {safe::rust_foobar(len, data.as_mut_ptr());} + unsafe {unsf::unsafe_dfoobar(len, data.as_mut_ptr(), data_d.as_mut_ptr());} + unsafe {unsf::unsafe_foobar(len, data.as_mut_ptr());} + + dbg!(&data_d); + dbg!(&data); + //mem::forget(data); + //mem::forget(data_d); +} diff --git a/enzyme/benchmarks/ReverseMode/fft/src/safe.rs b/enzyme/benchmarks/ReverseMode/fft/src/safe.rs new file mode 100644 index 000000000000..cbca5abb8484 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/fft/src/safe.rs @@ -0,0 +1,104 @@ +use std::autodiff::autodiff; +use std::f64::consts::PI; +use std::slice; + +fn bitreversal_perm(data: &mut [T]) { + let len = data.len() / 2; + let mut j = 1; + + for i in (1..data.len()).step_by(2) { + if j > i { + //dbg!(&i, &j); + data.swap(j-1, i-1); + data.swap(j, i); + //unsafe { + // data.swap_unchecked(j - 1, i - 1); + // data.swap_unchecked(j, i); + //} + } + + let mut m = len; + while m >= 2 && j > m { + j -= m; + m >>= 1; + } + + j += m; + } +} + +fn radix2(data: &mut [f64], i_sign: i32) { + let n = data.len() / 2; + if n == 1 { + return; + } + + let (a, b) = data.split_at_mut(n); + // assert_eq!(a.len(), b.len()); + radix2(a, i_sign); + radix2(b, i_sign); + + let wtemp = i_sign as f64 * (PI / n as f64).sin(); + let wpi = -i_sign as f64 * (2.0 * (PI / n as f64)).sin(); + let wpr = -2.0 * wtemp * wtemp; + let mut wr = 1.0; + let mut wi = 0.0; + + let (achunks, _) = a.as_chunks_mut(); + let (bchunks, _) = b.as_chunks_mut(); + for ([ax, ay], [bx, by]) in achunks.iter_mut().zip(bchunks.iter_mut()) { + let tempr = *bx * wr - *by * wi; + let tempi = *bx * wi + *by * wr; + + *bx = *ax - tempr; + *by = *ay - tempi; + *ax += tempr; + *ay += tempi; + + let wtemp_new = wr; + wr = wr * (wpr + 1.0) - wi * wpi; + wi = wi * (wpr + 1.0) + wtemp_new * wpi; + } +} + +fn rescale(data: &mut [f64], scale: usize) { + let scale = 1. / scale as f64; + for elm in data { + *elm *= scale; + } +} + +fn fft(data: &mut [f64]) { + bitreversal_perm(data); + radix2(data, 1); +} + +fn ifft(data: &mut [f64]) { + bitreversal_perm(data); + radix2(data, -1); + rescale(data, data.len() / 2); +} + +#[autodiff(dfoobar, Reverse, DuplicatedOnly)] +pub fn foobar(data: &mut [f64]) { + fft(data); + ifft(data); +} + +#[no_mangle] +pub extern "C" fn rust_dfoobar(n: usize, data: *mut f64, ddata: *mut f64) { + let (data, ddata) = unsafe { + ( + slice::from_raw_parts_mut(data, n * 2), + slice::from_raw_parts_mut(ddata, n * 2), + ) + }; + + unsafe { dfoobar(data, ddata) }; +} + +#[no_mangle] +pub extern "C" fn rust_foobar(n: usize, data: *mut f64) { + let data = unsafe { slice::from_raw_parts_mut(data, n * 2) }; + foobar(data); +} diff --git a/enzyme/benchmarks/ReverseMode/fft/src/unsf.rs b/enzyme/benchmarks/ReverseMode/fft/src/unsf.rs new file mode 100644 index 000000000000..29c8ceb1187d --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/fft/src/unsf.rs @@ -0,0 +1,92 @@ +use std::autodiff::autodiff; +use std::f64::consts::PI; + +unsafe fn bitreversal_perm(data: *mut f64, len: usize) { + let mut j = 1; + + for i in (1..2 * len).step_by(2) { + if j > i { + std::ptr::swap(data.add(j - 1), data.add(i - 1)); + std::ptr::swap(data.add(j), data.add(i)); + } + + let mut m = len; + while m >= 2 && j > m { + j -= m; + m >>= 1; + } + + j += m; + } +} + +unsafe fn radix2(data: *mut f64, n: usize, i_sign: i32) { + if n == 1 { + return; + } + radix2(data, n / 2, i_sign); + radix2(data.add(n), n / 2, i_sign); + + let wtemp = i_sign as f64 * (PI / n as f64).sin(); + let wpi = -i_sign as f64 * (2.0 * (PI / n as f64)).sin(); + let wpr = -2.0 * wtemp * wtemp; + let mut wr = 1.0; + let mut wi = 0.0; + + for i in (0..n).step_by(2) { + let in_n = i + n; + let ax = &mut *data.add(i); + let ay = &mut *data.add(i + 1); + let bx = &mut *data.add(in_n); + let by = &mut *data.add(in_n + 1); + let tempr = *bx * wr - *by * wi; + let tempi = *bx * wi + *by * wr; + + *bx = *ax - tempr; + *by = *ay - tempi; + *ax += tempr; + *ay += tempi; + + let wtemp_new = wr; + wr = wr * (wpr + 1.0) - wi * wpi; + wi = wi * (wpr + 1.0) + wtemp_new * wpi; + } +} + +unsafe fn rescale(data: *mut f64, n: usize) { + let scale = 1. / n as f64; + for i in 0..2 * n { + *data.add(i) = *data.add(i) * scale; + } +} + +unsafe fn fft(data: *mut f64, n: usize) { + bitreversal_perm(data, n); + radix2(data, n, 1); +} + +unsafe fn ifft(data: *mut f64, n: usize) { + bitreversal_perm(data, n); + radix2(data, n, -1); + rescale(data, n); +} + +#[autodiff(unsafe_dfoobar, Reverse, Const, DuplicatedOnly)] +pub unsafe fn unsafe_foobar(n: usize, data: *mut f64) { + fft(data, n); + ifft(data, n); +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_dfoobar(n: usize, data: *mut f64, ddata: *mut f64) { + unsafe { + unsafe_dfoobar(n, data, ddata); + } +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_foobar(n: usize, data: *mut f64) { + unsafe { + unsafe_foobar(n, data); + } +} diff --git a/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock b/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock new file mode 100644 index 000000000000..cfdab95b3d9c --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock @@ -0,0 +1,16 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "gmmrs" +version = "0.1.0" +dependencies = [ + "libm", +] + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" diff --git a/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml b/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml new file mode 100644 index 000000000000..eaaf749b996f --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "gmmrs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +crate-type = ["lib"] + +[features] +libm = ["dep:libm"] + +[profile.release] +lto = "fat" +opt-level = 3 +codegen-units = 1 +panic = "abort" +strip = true +#overflow-checks = false + +[profile.dev] +lto = "fat" + +[dependencies] +libm = { version = "0.2.8", optional = true } + +[workspace] diff --git a/enzyme/benchmarks/ReverseMode/gmm/Makefile.make b/enzyme/benchmarks/ReverseMode/gmm/Makefile.make index 1e8e711da1ba..f5f6de4fb06a 100644 --- a/enzyme/benchmarks/ReverseMode/gmm/Makefile.make +++ b/enzyme/benchmarks/ReverseMode/gmm/Makefile.make @@ -1,24 +1,46 @@ -# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" PTR="%ptr" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" LOADCLANG="%loadClangEnzyme" ENZYME="%enzyme" make -B gmm-raw.ll results.json -f %s +# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" LOADCLANG="%loadClangEnzyme" ENZYME="%enzyme" make -B gmm-raw.ll results.json -f %s .PHONY: clean dir := $(abspath $(lastword $(MAKEFILE_LIST))/../../../..) +include $(dir)/benchmarks/ReverseMode/adbench/Makefile.config + +ifeq ($(strip $(CLANG)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES1)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES2)),) +$(error PASSES2 is not set) +endif + +ifeq ($(strip $(PASSES3)),) +$(error PASSES3 is not set) +endif + +ifneq ($(strip $(PASSES4)),) +$(error PASSES4 is set) +endif + clean: rm -f *.ll *.o results.txt results.json + cargo +enzyme clean -%-unopt.ll: %.cpp - clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm +$(dir)/benchmarks/ReverseMode/gmm/target/release/libgmmrs.a: src/lib.rs Cargo.toml + RUSTFLAGS="-Z autodiff=Enable,PrintPasses,LooseTypes" cargo +enzyme rustc --release --lib --crate-type=staticlib -%-raw.ll: %-unopt.ll - opt $^ $(LOAD) $(ENZYME) -o $@ -S +%-unopt.ll: %.cpp + $(CLANG) $(BENCH) $^ -pthread -O3 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -o $@ -S -emit-llvm -%-opt.ll: %-raw.ll - opt $^ -o $@ -S +%-opt.ll: %-unopt.ll + $(OPT) $^ $(LOAD) -passes="$(PASSES2),enzyme" -o $@ -S -gmm.o: gmm-opt.ll - clang++ -pthread -O2 $^ -o $@ $(BENCHLINK) -lm - #clang++ $(LOADCLANG) $(BENCH) gmm.cpp -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o gmm.o -lpthread $(BENCHLINK) -lm -L /usr/lib/gcc/x86_64-linux-gnu/11 +gmm.o: gmm-opt.ll $(dir)/benchmarks/ReverseMode/gmm/target/release/libgmmrs.a + $(CLANG) -pthread -O3 -fno-math-errno $^ -o $@ $(BENCHLINK) -lm results.json: gmm.o - ./$^ + numactl -C 1 ./$^ diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp b/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp index 866059217b96..cb7e864eca48 100644 --- a/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp +++ b/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp @@ -13,7 +13,7 @@ * typedef struct * { * double gamma; - * int m; + * size_t m; * } Wishart; * * After Tapenade CLI installing use the next command to generate a file: @@ -39,9 +39,9 @@ extern "C" { /* ==================================================================== */ // This throws error on n<1 -double arr_max(int n, double const* x) +double arr_max(size_t n, double const* x) { - int i; + size_t i; double m = x[0]; for (i = 1; i < n; i++) { @@ -57,9 +57,9 @@ double arr_max(int n, double const* x) // sum of component squares -double sqnorm(int n, double const* x) +double sqnorm(size_t n, double const* x) { - int i; + size_t i; double res = x[0] * x[0]; for (i = 1; i < n; i++) { @@ -73,13 +73,13 @@ double sqnorm(int n, double const* x) // out = a - b void subtract( - int d, + size_t d, double const* x, double const* y, double* out ) { - int id; + size_t id; for (id = 0; id < d; id++) { out[id] = x[id] - y[id]; @@ -87,9 +87,9 @@ void subtract( } -double log_sum_exp(int n, double const* x) +double log_sum_exp(size_t n, double const* x) { - int i; + size_t i; double mx = arr_max(n, x); double semx = 0.0; @@ -105,7 +105,7 @@ double log_sum_exp(int n, double const* x) __attribute__((const)) double log_gamma_distrib(double a, double p) { - int j; + int64_t j; double out = 0.25 * p * (p - 1) * log(PI); for (j = 1; j <= p; j++) @@ -123,17 +123,17 @@ double log_gamma_distrib(double a, double p) /* ======================================================================== */ double log_wishart_prior( - int p, - int k, + size_t p, + size_t k, Wishart wishart, double const* sum_qs, double const* Qdiags, double const* icf ) { - int ik; - int n = p + wishart.m + 1; - int icf_sz = p * (p + 1) / 2; + size_t ik; + size_t n = p + wishart.m + 1; + size_t icf_sz = p * (p + 1) / 2; double C = n * p * (log(wishart.gamma) - 0.5 * log(2)) - log_gamma_distrib(0.5 * n, p); @@ -150,15 +150,15 @@ double log_wishart_prior( void preprocess_qs( - int d, - int k, + size_t d, + size_t k, double const* icf, double* sum_qs, double* Qdiags ) { - int ik, id; - int icf_sz = d * (d + 1) / 2; + size_t ik, id; + size_t icf_sz = d * (d + 1) / 2; for (ik = 0; ik < k; ik++) { sum_qs[ik] = 0.; @@ -174,14 +174,14 @@ void preprocess_qs( void Qtimesx( - int d, + size_t d, double const* Qdiag, double const* ltri, // strictly lower triangular part double const* x, double* out ) { - int i, j; + size_t i, j; for (i = 0; i < d; i++) { out[i] = Qdiag[i] * x[i]; @@ -189,10 +189,10 @@ void Qtimesx( //caching lparams as scev doesn't replicate index calculation // todo note changing to strengthened form - //int Lparamsidx = 0; + //size_t Lparamsidx = 0; for (i = 0; i < d; i++) { - int Lparamsidx = i*(2*d-i-1)/2; + size_t Lparamsidx = i*(2*d-i-1)/2; for (j = i + 1; j < d; j++) { // and this x @@ -202,24 +202,15 @@ void Qtimesx( } } - - -void gmm_objective( - int d, - int k, - int n, - double const* __restrict alphas, - double const* __restrict means, - double const* __restrict icf, - double const* __restrict x, - Wishart wishart, - double* __restrict err -) -{ - #define int int64_t - int ix, ik; - const double CONSTANT = -n * d * 0.5 * log(2 * PI); - int icf_sz = d * (d + 1) / 2; +void gmm_objective_restrict(size_t d, size_t k, size_t n, + double const *__restrict alphas, + double const *__restrict means, + double const *__restrict icf, + double const *__restrict x, Wishart wishart, + double *__restrict err) { + int64_t ix, ik; + const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI); + int64_t icf_sz = d * (d + 1) / 2; double* Qdiags = (double*)malloc(d * k * sizeof(double)); double* sum_qs = (double*)malloc(k * sizeof(double)); @@ -256,7 +247,6 @@ void gmm_objective( free(xcentered); free(Qxcentered); free(main_term); - #undef int } extern int enzyme_const; @@ -265,23 +255,16 @@ extern int enzyme_dupnoneed; void __enzyme_autodiff(...) noexcept; // * tapenade -b -o gmm_tapenade -head "gmm_objective(err)/(alphas means icf)" gmm.c -void dgmm_objective(int d, int k, int n, const double *alphas, double * - alphasb, const double *means, double *meansb, const double *icf, - double *icfb, const double *x, Wishart wishart, double *err, double * - errb) { - __enzyme_autodiff( - gmm_objective, - enzyme_const, d, - enzyme_const, k, - enzyme_const, n, - enzyme_dup, alphas, alphasb, - enzyme_dup, means, meansb, - enzyme_dup, icf, icfb, - enzyme_const, x, - enzyme_const, wishart, - enzyme_dupnoneed, err, errb); +void dgmm_objective_restrict(size_t d, size_t k, size_t n, const double *alphas, + double *alphasb, const double *means, + double *meansb, const double *icf, double *icfb, + const double *x, Wishart wishart, double *err, + double *errb) { + __enzyme_autodiff(gmm_objective_restrict, enzyme_const, d, enzyme_const, k, + enzyme_const, n, enzyme_dup, alphas, alphasb, enzyme_dup, + means, meansb, enzyme_dup, icf, icfb, enzyme_const, x, + enzyme_const, wishart, enzyme_dupnoneed, err, errb); } - } @@ -300,20 +283,19 @@ extern "C" { UTILS ==================================================================== */ // This throws error on n<1 -void arr_max_b(int n, const double *x, double *xb, double arr_maxb) { - int i; +void arr_max_b(size_t n, const double *x, double *xb, double arr_maxb) { double m = x[0]; double mb = 0.0; int branch; double arr_max; - for (i = 1; i < n; ++i) + for (int64_t i = 1; i < n; ++i) if (m < x[i]) { m = x[i]; pushControl1b(1); } else pushControl1b(0); mb = arr_maxb; - for (i = n-1; i > 0; --i) { + for (int64_t i = (int64_t)n-1; i > 0; --i) { popControl1b(&branch); if (branch != 0) { xb[i] = xb[i] + mb; @@ -327,8 +309,8 @@ void arr_max_b(int n, const double *x, double *xb, double arr_maxb) { UTILS ==================================================================== */ // This throws error on n<1 -double arr_max_nodiff(int n, const double *x) { - int i; +double arr_max_nodiff(size_t n, const double *x) { + size_t i; double m = x[0]; for (i = 1; i < n; ++i) if (m < x[i]) @@ -343,20 +325,19 @@ double arr_max_nodiff(int n, const double *x) { Plus diff mem management of: x:in */ // sum of component squares -void sqnorm_b(int n, const double *x, double *xb, double sqnormb) { - int i; +void sqnorm_b(size_t n, const double *x, double *xb, double sqnormb) { double res = x[0]*x[0]; double resb = 0.0; double sqnorm; resb = sqnormb; - for (i = n-1; i > 0; --i) + for (int64_t i = (int64_t)n-1; i > 0; --i) xb[i] = xb[i] + 2*x[i]*resb; xb[0] = xb[0] + 2*x[0]*resb; } // sum of component squares -double sqnorm_nodiff(int n, const double *x) { - int i; +double sqnorm_nodiff(size_t n, const double *x) { + size_t i; double res = x[0]*x[0]; for (i = 1; i < n; ++i) res = res + x[i]*x[i]; @@ -370,18 +351,17 @@ double sqnorm_nodiff(int n, const double *x) { Plus diff mem management of: out:in y:in */ // out = a - b -void subtract_b(int d, const double *x, const double *y, double *yb, double * +void subtract_b(size_t d, const double *x, const double *y, double *yb, double * out, double *outb) { - int id; - for (id = d-1; id > -1; --id) { + for (int64_t id = (int64_t)d-1; id > -1; --id) { yb[id] = yb[id] - outb[id]; outb[id] = 0.0; } } // out = a - b -void subtract_nodiff(int d, const double *x, const double *y, double *out) { - int id; +void subtract_nodiff(size_t d, const double *x, const double *y, double *out) { + size_t id; for (id = 0; id < d; ++id) out[id] = x[id] - y[id]; } @@ -392,8 +372,7 @@ void subtract_nodiff(int d, const double *x, const double *y, double *out) { with respect to varying inputs: *x Plus diff mem management of: x:in */ -void log_sum_exp_b(int n, const double *x, double *xb, double log_sum_expb) { - int i; +void log_sum_exp_b(size_t n, const double *x, double *xb, double log_sum_expb) { double mx; double mxb; double tempb; @@ -401,11 +380,11 @@ void log_sum_exp_b(int n, const double *x, double *xb, double log_sum_expb) { mx = arr_max_nodiff(n, x); double semx = 0.0; double semxb = 0.0; - for (i = 0; i < n; ++i) + for (int64_t i = 0; i < n; ++i) semx = semx + exp(x[i] - mx); semxb = log_sum_expb/semx; mxb = log_sum_expb; - for (i = n-1; i > -1; --i) { + for (int64_t i = (int64_t)n-1; i > -1; --i) { tempb = exp(x[i]-mx)*semxb; xb[i] = xb[i] + tempb; mxb = mxb - tempb; @@ -413,8 +392,8 @@ void log_sum_exp_b(int n, const double *x, double *xb, double log_sum_expb) { arr_max_b(n, x, xb, mxb); } -double log_sum_exp_nodiff(int n, const double *x) { - int i; +double log_sum_exp_nodiff(size_t n, const double *x) { + size_t i; double mx; mx = arr_max_nodiff(n, x); double semx = 0.0; @@ -424,7 +403,7 @@ double log_sum_exp_nodiff(int n, const double *x) { } double log_gamma_distrib_nodiff(double a, double p) { - int j; + size_t j; /* TFIX */ double out = 0.25*p*(p-1)*log(PI); double arg1; @@ -446,12 +425,12 @@ double log_gamma_distrib_nodiff(double a, double p) { ======================================================================== MAIN LOGIC ======================================================================== */ -void log_wishart_prior_b(int p, int k, Wishart wishart, const double *sum_qs, +void log_wishart_prior_b(size_t p, size_t k, Wishart wishart, const double *sum_qs, double *sum_qsb, const double *Qdiags, double *Qdiagsb, const double * icf, double *icfb, double log_wishart_priorb) { - int ik; - int n = p + wishart.m + 1; - int icf_sz = p*(p+1)/2; + int64_t ik; + size_t n = p + wishart.m + 1; + size_t icf_sz = p*(p+1)/2; double C; float arg1; double result1; @@ -461,7 +440,7 @@ void log_wishart_prior_b(int p, int k, Wishart wishart, const double *sum_qs, for (ik = 0; ik < k; ++ik) { double frobenius; double result1; - int arg1; + size_t arg1; double result2; } outb = log_wishart_priorb; @@ -471,12 +450,12 @@ void log_wishart_prior_b(int p, int k, Wishart wishart, const double *sum_qs, sum_qsb[ik] = 0.0; for (ik = 0; ik < k * icf_sz; ik++) /* TFIX */ icfb[ik] = 0.0; - for (ik = k-1; ik > -1; --ik) { + for (ik = (int64_t)k-1; ik > -1; --ik) { double frobenius; double frobeniusb; double result1; double result1b; - int arg1; + size_t arg1; double result2; double result2b; frobeniusb = wishart.gamma*wishart.gamma*0.5*outb; @@ -493,11 +472,11 @@ void log_wishart_prior_b(int p, int k, Wishart wishart, const double *sum_qs, /* ======================================================================== MAIN LOGIC ======================================================================== */ -double log_wishart_prior_nodiff(int p, int k, Wishart wishart, const double * +double log_wishart_prior_nodiff(size_t p, size_t k, Wishart wishart, const double * sum_qs, const double *Qdiags, const double *icf) { - int ik; - int n = p + wishart.m + 1; - int icf_sz = p*(p+1)/2; + size_t ik; + size_t n = p + wishart.m + 1; + size_t icf_sz = p*(p+1)/2; double C; float arg1; double result1; @@ -508,7 +487,7 @@ double log_wishart_prior_nodiff(int p, int k, Wishart wishart, const double * for (ik = 0; ik < k; ++ik) { double frobenius; double result1; - int arg1; + size_t arg1; double result2; result1 = sqnorm_nodiff(p, &(Qdiags[ik*p])); arg1 = icf_sz - p; @@ -526,17 +505,17 @@ double log_wishart_prior_nodiff(int p, int k, Wishart wishart, const double * with respect to varying inputs: *icf Plus diff mem management of: Qdiags:in sum_qs:in icf:in */ -void preprocess_qs_b(int d, int k, const double *icf, double *icfb, double * +void preprocess_qs_b(size_t d, size_t k, const double *icf, double *icfb, double * sum_qs, double *sum_qsb, double *Qdiags, double *Qdiagsb) { - int ik, id; - int icf_sz = d*(d+1)/2; + int64_t ik, id; + size_t icf_sz = d*(d+1)/2; for (ik = 0; ik < k; ++ik) for (id = 0; id < d; ++id) { double q = icf[ik*icf_sz + id]; pushReal8(q); } - for (ik = k-1; ik > -1; --ik) { - for (id = d-1; id > -1; --id) { + for (ik = (int64_t)k-1; ik > -1; --ik) { + for (id = (int64_t)d-1; id > -1; --id) { double q; double qb = 0.0; popReal8(&q); @@ -549,13 +528,12 @@ void preprocess_qs_b(int d, int k, const double *icf, double *icfb, double * } } -void preprocess_qs_nodiff(int d, int k, const double *icf, double *sum_qs, +void preprocess_qs_nodiff(size_t d, size_t k, const double *icf, double *sum_qs, double *Qdiags) { - int ik, id; - int icf_sz = d*(d+1)/2; - for (ik = 0; ik < k; ++ik) { + size_t icf_sz = d*(d+1)/2; + for (size_t ik = 0; ik < k; ++ik) { sum_qs[ik] = 0.; - for (id = 0; id < d; ++id) { + for (size_t id = 0; id < d; ++id) { double q = icf[ik*icf_sz + id]; sum_qs[ik] = sum_qs[ik] + q; Qdiags[ik*d + id] = exp(q); @@ -569,41 +547,41 @@ void preprocess_qs_nodiff(int d, int k, const double *icf, double *sum_qs, with respect to varying inputs: *out *Qdiag *x *ltri Plus diff mem management of: out:in Qdiag:in x:in ltri:in */ -void Qtimesx_b(int d, const double *Qdiag, double *Qdiagb, const double *ltri, +void Qtimesx_b(size_t d, const double *Qdiag, double *Qdiagb, const double *ltri, double *ltrib, const double *x, double *xb, double *out, double *outb) { // strictly lower triangular part - int i, j; + int64_t i, j; int adFrom; - int Lparamsidx = 0; + size_t Lparamsidx = 0; for (i = 0; i < d; ++i) { adFrom = i + 1; for (j = adFrom; j < d; ++j) Lparamsidx++; pushInteger4(adFrom); } - for (i = d-1; i > -1; --i) { + for (i = (int64_t)d-1; i > -1; --i) { popInteger4(&adFrom); - for (j = d-1; j > adFrom-1; --j) { + for (j = (int64_t)d-1; j > adFrom-1; --j) { --Lparamsidx; ltrib[Lparamsidx] = ltrib[Lparamsidx] + x[i]*outb[j]; xb[i] = xb[i] + ltri[Lparamsidx]*outb[j]; } } - for (i = d-1; i > -1; --i) { + for (i = (int64_t)d-1; i > -1; --i) { Qdiagb[i] = Qdiagb[i] + x[i]*outb[i]; xb[i] = xb[i] + Qdiag[i]*outb[i]; outb[i] = 0.0; } } -void Qtimesx_nodiff(int d, const double *Qdiag, const double *ltri, const +void Qtimesx_nodiff(size_t d, const double *Qdiag, const double *ltri, const double *x, double *out) { // strictly lower triangular part - int i, j; + size_t i, j; for (i = 0; i < d; ++i) out[i] = Qdiag[i]*x[i]; - int Lparamsidx = 0; + size_t Lparamsidx = 0; for (i = 0; i < d; ++i) for (j = i+1; j < d; ++j) { out[j] = out[j] + ltri[Lparamsidx]*x[i]; @@ -619,19 +597,19 @@ void Qtimesx_nodiff(int d, const double *Qdiag, const double *ltri, const *alphas:out Plus diff mem management of: err:in means:in icf:in alphas:in */ -void gmm_objective_b(int d, int k, int n, const double *alphas, double * +void gmm_objective_b(size_t d, size_t k, size_t n, const double *alphas, double * alphasb, const double *means, double *meansb, const double *icf, double *icfb, const double *x, Wishart wishart, double *err, double * errb) { - int ix, ik; + int64_t ix, ik; /* TFIX */ - const double CONSTANT = -n*d*0.5*log(2*PI); - int icf_sz = d*(d+1)/2; + const double CONSTANT = -(double)n*d*0.5*log(2*PI); + size_t icf_sz = d*(d+1)/2; double *Qdiags; double *Qdiagsb; double result1; double result1b; - int ii1; + size_t ii1; Qdiagsb = (double *)malloc(d*k*sizeof(double)); for (ii1 = 0; ii1 < d*k; ++ii1) Qdiagsb[ii1] = 0.0; @@ -687,10 +665,10 @@ void gmm_objective_b(int d, int k, int n, const double *alphas, double * log_sum_exp_b(k, alphas, alphasb, lse_alphasb); for (ii1 = 0; ii1 < d * k; ii1++) /* TFIX */ meansb[ii1] = 0.0; - for (ix = n-1; ix > -1; --ix) { + for (ix = (int64_t)n-1; ix > -1; --ix) { result1b = slseb; log_sum_exp_b(k, &(main_term[0]), &(main_termb[0]), result1b); - for (ik = k-1; ik > -1; --ik) { + for (ik = (int64_t)k-1; ik > -1; --ik) { popReal8(&(main_term[ik])); alphasb[ik] = alphasb[ik] + main_termb[ik]; sum_qsb[ik] = sum_qsb[ik] + main_termb[ik]; @@ -733,32 +711,32 @@ namespace adeptTest { // out = a - b template -void subtract(int d, +void subtract(size_t d, const T1* const x, const T2* const y, T3* out) { - for (int id = 0; id < d; id++) + for (size_t id = 0; id < d; id++) { out[id] = x[id] - y[id]; } } template -T sqnorm(int n, const T* const x) +T sqnorm(size_t n, const T* const x) { T res = x[0] * x[0]; - for (int i = 1; i < n; i++) + for (size_t i = 1; i < n; i++) res = res + x[i] * x[i]; return res; } // This throws error on n<1 template -T arr_max(int n, const T* const x) +T arr_max(size_t n, const T* const x) { T m = x[0]; - for (int i = 1; i < n; i++) + for (size_t i = 1; i < n; i++) { if (m < x[i]) m = x[i]; @@ -767,12 +745,12 @@ T arr_max(int n, const T* const x) } template -void gmm_objective(int d, int k, int n, const T* const alphas, const T* const means, +void gmm_objective(size_t d, size_t k, size_t n, const T* const alphas, const T* const means, const T* const icf, const double* const x, Wishart wishart, T* err); // split of the outer loop over points template -void gmm_objective_split_inner(int d, int k, +void gmm_objective_split_inner(size_t d, size_t k, const T* const alphas, const T* const means, const T* const icf, @@ -781,7 +759,7 @@ void gmm_objective_split_inner(int d, int k, T* err); // other terms which are outside the loop template -void gmm_objective_split_other(int d, int k, int n, +void gmm_objective_split_other(size_t d, size_t k, size_t n, const T* const alphas, const T* const means, const T* const icf, @@ -789,7 +767,7 @@ void gmm_objective_split_other(int d, int k, int n, T* err); template -T logsumexp(int n, const T* const x); +T logsumexp(size_t n, const T* const x); // p: dim // k: number of components @@ -798,20 +776,20 @@ T logsumexp(int n, const T* const x); // Qdiags: d*k // icf: (p*(p+1)/2)*k inverse covariance factors template -T log_wishart_prior(int p, int k, +T log_wishart_prior(size_t p, size_t k, Wishart wishart, const T* const sum_qs, const T* const Qdiags, const T* const icf); template -void preprocess_qs(int d, int k, +void preprocess_qs(size_t d, size_t k, const T* const icf, T* sum_qs, T* Qdiags); template -void Qtimesx(int d, +void Qtimesx(size_t d, const T* const Qdiag, const T* const ltri, // strictly lower triangular part const T* const x, @@ -822,11 +800,11 @@ void Qtimesx(int d, //////////////////////////////////////////////////////////// template -T logsumexp(int n, const T* const x) +T logsumexp(size_t n, const T* const x) { T mx = arr_max(n, x); T semx = 0.; - for (int i = 0; i < n; i++) + for (size_t i = 0; i < n; i++) { semx = semx + exp(x[i] - mx); } @@ -834,19 +812,19 @@ T logsumexp(int n, const T* const x) } template -T log_wishart_prior(int p, int k, +T log_wishart_prior(size_t p, size_t k, Wishart wishart, const T* const sum_qs, const T* const Qdiags, const T* const icf) { - int n = p + wishart.m + 1; - int icf_sz = p * (p + 1) / 2; + size_t n = p + wishart.m + 1; + size_t icf_sz = p * (p + 1) / 2; double C = n * p * (log(wishart.gamma) - 0.5 * log(2)) - log_gamma_distrib(0.5 * n, p); T out = 0; - for (int ik = 0; ik < k; ik++) + for (size_t ik = 0; ik < k; ik++) { T frobenius = sqnorm(p, &Qdiags[ik * p]) + sqnorm(icf_sz - p, &icf[ik * icf_sz + p]); out = out + 0.5 * wishart.gamma * wishart.gamma * (frobenius) @@ -857,16 +835,16 @@ T log_wishart_prior(int p, int k, } template -void preprocess_qs(int d, int k, +void preprocess_qs(size_t d, size_t k, const T* const icf, T* sum_qs, T* Qdiags) { - int icf_sz = d * (d + 1) / 2; - for (int ik = 0; ik < k; ik++) + size_t icf_sz = d * (d + 1) / 2; + for (size_t ik = 0; ik < k; ik++) { sum_qs[ik] = 0.; - for (int id = 0; id < d; id++) + for (size_t id = 0; id < d; id++) { T q = icf[ik * icf_sz + id]; sum_qs[ik] = sum_qs[ik] + q; @@ -876,19 +854,19 @@ void preprocess_qs(int d, int k, } template -void Qtimesx(int d, +void Qtimesx(size_t d, const T* const Qdiag, const T* const ltri, // strictly lower triangular part const T* const x, T* out) { - for (int id = 0; id < d; id++) + for (size_t id = 0; id < d; id++) out[id] = Qdiag[id] * x[id]; - int Lparamsidx = 0; - for (int i = 0; i < d; i++) + size_t Lparamsidx = 0; + for (size_t i = 0; i < d; i++) { - for (int j = i + 1; j < d; j++) + for (size_t j = i + 1; j < d; j++) { out[j] = out[j] + ltri[Lparamsidx] * x[i]; Lparamsidx++; @@ -897,7 +875,7 @@ void Qtimesx(int d, } template -void gmm_objective(int d, int k, int n, +void gmm_objective(size_t d, size_t k, size_t n, const T* const alphas, const T* const means, const T* const icf, @@ -905,8 +883,8 @@ void gmm_objective(int d, int k, int n, Wishart wishart, T* err) { - const double CONSTANT = -n * d * 0.5 * log(2 * PI); - int icf_sz = d * (d + 1) / 2; + const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI); + size_t icf_sz = d * (d + 1) / 2; vector Qdiags(d * k); vector sum_qs(k); @@ -917,9 +895,9 @@ void gmm_objective(int d, int k, int n, preprocess_qs(d, k, icf, &sum_qs[0], &Qdiags[0]); T slse = 0.; - for (int ix = 0; ix < n; ix++) + for (size_t ix = 0; ix < n; ix++) { - for (int ik = 0; ik < k; ik++) + for (size_t ik = 0; ik < k; ik++) { subtract(d, &x[ix * d], &means[ik * d], &xcentered[0]); Qtimesx(d, &Qdiags[ik * d], &icf[ik * icf_sz + d], &xcentered[0], &Qxcentered[0]); @@ -937,7 +915,7 @@ void gmm_objective(int d, int k, int n, } template -void gmm_objective_split_inner(int d, int k, +void gmm_objective_split_inner(size_t d, size_t k, const T* const alphas, const T* const means, const T* const icf, @@ -945,39 +923,39 @@ void gmm_objective_split_inner(int d, int k, Wishart wishart, T* err) { - int icf_sz = d * (d + 1) / 2; + size_t icf_sz = d * (d + 1) / 2; T* Ldiag = new T[d]; T* xcentered = new T[d]; T* mahal = new T[d]; T* lse = new T[k]; - for (int ik = 0; ik < k; ik++) + for (size_t ik = 0; ik < k; ik++) { - int icf_off = ik * icf_sz; + size_t icf_off = ik * icf_sz; T sumlog_Ldiag(0.); - for (int id = 0; id < d; id++) + for (size_t id = 0; id < d; id++) { sumlog_Ldiag = sumlog_Ldiag + icf[icf_off + id]; Ldiag[id] = exp(icf[icf_off + id]); } - for (int id = 0; id < d; id++) + for (size_t id = 0; id < d; id++) { xcentered[id] = x[id] - means[ik * d + id]; mahal[id] = Ldiag[id] * xcentered[id]; } - int Lparamsidx = d; - for (int i = 0; i < d; i++) + size_t Lparamsidx = d; + for (size_t i = 0; i < d; i++) { - for (int j = i + 1; j < d; j++) + for (size_t j = i + 1; j < d; j++) { mahal[j] = mahal[j] + icf[icf_off + Lparamsidx] * xcentered[i]; Lparamsidx++; } } T sqsum_mahal(0.); - for (int id = 0; id < d; id++) + for (size_t id = 0; id < d; id++) { sqsum_mahal = sqsum_mahal + mahal[id] * mahal[id]; } @@ -994,14 +972,14 @@ void gmm_objective_split_inner(int d, int k, } template -void gmm_objective_split_other(int d, int k, int n, +void gmm_objective_split_other(size_t d, size_t k, size_t n, const T* const alphas, const T* const means, const T* const icf, Wishart wishart, T* err) { - const double CONSTANT = -n * d * 0.5 * log(2 * PI); + const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI); T lse_alphas = logsumexp(k, alphas); @@ -1015,14 +993,14 @@ void gmm_objective_split_other(int d, int k, int n, }; -void adept_dgmm_objective(int d, int k, int n, const double *alphas, double * +void adept_dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, double * alphasb, const double *means, double *meansb, const double *icf, double *icfb, const double *x, Wishart wishart, double *err, double * errb) { - int icf_sz = d*(d + 1) / 2; - int Jrows = 1; - int Jcols = (k*(d + 1)*(d + 2)) / 2; + size_t icf_sz = d*(d + 1) / 2; + size_t Jrows = 1; + size_t Jcols = (k*(d + 1)*(d + 2)) / 2; adept::Stack stack; adouble *aalphas = new adouble[k]; @@ -1050,3 +1028,5 @@ void adept_dgmm_objective(int d, int k, int n, const double *alphas, double * delete[] ameans; delete[] aicf; } + +#include "gmm_mayalias.h" diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm.h b/enzyme/benchmarks/ReverseMode/gmm/gmm.h index eb189afed44b..5dc5bc0b1edb 100644 --- a/enzyme/benchmarks/ReverseMode/gmm/gmm.h +++ b/enzyme/benchmarks/ReverseMode/gmm/gmm.h @@ -28,9 +28,9 @@ extern "C" { // wishart: wishart distribution parameters // err: 1 output void gmm_objective( - int d, - int k, - int n, + size_t d, + size_t k, + size_t n, double const* alphas, double const* means, double const* icf, diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm_mayalias.h b/enzyme/benchmarks/ReverseMode/gmm/gmm_mayalias.h new file mode 100644 index 000000000000..4bcba4fb0900 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/gmm_mayalias.h @@ -0,0 +1,62 @@ +void gmm_objective(size_t d, size_t k, size_t n, double const *alphas, + double const *means, double const *icf, double const *x, + Wishart wishart, double *err) { + size_t ix, ik; + const double CONSTANT = -(double)n * d * 0.5 * log(2 * PI); + size_t icf_sz = d * (d + 1) / 2; + + double *Qdiags = (double *)malloc(d * k * sizeof(double)); + double *sum_qs = (double *)malloc(k * sizeof(double)); + double *xcentered = (double *)malloc(d * sizeof(double)); + double *Qxcentered = (double *)malloc(d * sizeof(double)); + double *main_term = (double *)malloc(k * sizeof(double)); + + preprocess_qs(d, k, icf, &sum_qs[0], &Qdiags[0]); + + double slse = 0.; + for (ix = 0; ix < n; ix++) { + for (ik = 0; ik < k; ik++) { + subtract(d, &x[ix * d], &means[ik * d], &xcentered[0]); + Qtimesx(d, &Qdiags[ik * d], &icf[ik * icf_sz + d], &xcentered[0], + &Qxcentered[0]); + // two caches for qxcentered at idx 0 and at arbitrary index + main_term[ik] = alphas[ik] + sum_qs[ik] - 0.5 * sqnorm(d, &Qxcentered[0]); + } + + // storing cmp for max of main_term + // 2 x (0 and arbitrary) storing sub to exp + // storing sum for use in log + slse = slse + log_sum_exp(k, &main_term[0]); + } + + // storing cmp of alphas + double lse_alphas = log_sum_exp(k, alphas); + + *err = CONSTANT + slse - n * lse_alphas + + log_wishart_prior(d, k, wishart, &sum_qs[0], &Qdiags[0], icf); + + free(Qdiags); + free(sum_qs); + free(xcentered); + free(Qxcentered); + free(main_term); +} + +// * tapenade -b -o gmm_tapenade -head "gmm_objective(err)/(alphas means icf)" gmm.c +void dgmm_objective(size_t d, size_t k, size_t n, const double *alphas, double * + alphasb, const double *means, double *meansb, const double *icf, + double *icfb, const double *x, Wishart wishart, double *err, double * + errb) { + __enzyme_autodiff( + gmm_objective, + enzyme_const, d, + enzyme_const, k, + enzyme_const, n, + enzyme_dup, alphas, alphasb, + enzyme_dup, means, meansb, + enzyme_dup, icf, icfb, + enzyme_const, x, + enzyme_const, wishart, + enzyme_dupnoneed, err, errb); +} + diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs b/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs new file mode 100644 index 000000000000..4f9fc5336e8e --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs @@ -0,0 +1,10 @@ +#![feature(autodiff)] +pub mod safe; +pub mod r#unsafe; + +#[derive(Clone, Copy)] +#[repr(C)] +pub struct Wishart { + pub gamma: f64, + pub m: i32, +} diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/main.rs b/enzyme/benchmarks/ReverseMode/gmm/src/main.rs new file mode 100644 index 000000000000..e7ebf74d0aa2 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/src/main.rs @@ -0,0 +1,24 @@ +#![feature(autodiff)] +use gmmrs::{Wishart, r#unsafe::dgmm_objective}; + +fn main() { + let d = 2; + let k = 2; + let n = 2; + let alphas = vec![0.5, 0.5]; + let means = vec![0., 0., 1., 1.]; + let icf = vec![1., 0., 1.]; + let x = vec![0., 0., 1., 1.]; + let wishart = Wishart { gamma: 1., m: 1 }; + let mut err = 0.; + let mut d_alphas = vec![0.; alphas.len()]; + let mut d_means = vec![0.; means.len()]; + let mut d_icf = vec![0.; icf.len()]; + let mut d_x = vec![0.; x.len()]; + let mut d_err = 0.; + let mut err2 = &mut err; + let mut d_err2 = &mut d_err; + let wishart2 = &wishart; + // pass as raw ptr: + unsafe {dgmm_objective(d, k, n, alphas.as_ptr(), d_alphas.as_mut_ptr(), means.as_ptr(), d_means.as_mut_ptr(), icf.as_ptr(), d_icf.as_mut_ptr(), x.as_ptr(), wishart2 as *const Wishart, err2 as *mut f64, d_err2 as *mut f64);} +} diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/safe.rs b/enzyme/benchmarks/ReverseMode/gmm/src/safe.rs new file mode 100644 index 000000000000..e365246908e9 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/src/safe.rs @@ -0,0 +1,297 @@ +use crate::Wishart; +use std::f64::consts::PI; +use std::autodiff::autodiff; + +#[cfg(feature = "libm")] +use libm::lgamma; + +#[cfg(not(feature = "libm"))] +mod cmath { + extern "C" { + pub fn lgamma(x: f64) -> f64; + } +} +#[cfg(not(feature = "libm"))] +#[inline] +fn lgamma(x: f64) -> f64 { + unsafe { cmath::lgamma(x) } +} + +#[no_mangle] +pub extern "C" fn rust_dgmm_objective( + d: usize, + k: usize, + n: usize, + alphas: *const f64, + dalphas: *mut f64, + means: *const f64, + dmeans: *mut f64, + icf: *const f64, + dicf: *mut f64, + x: *const f64, + wishart: *const Wishart, + err: *mut f64, + derr: *mut f64, +) { + let alphas = unsafe { std::slice::from_raw_parts(alphas, k) }; + let means = unsafe { std::slice::from_raw_parts(means, k * d) }; + let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) }; + let x = unsafe { std::slice::from_raw_parts(x, n * d) }; + let wishart: Wishart = unsafe { *wishart }; + let mut my_err = unsafe { *err }; + + let d_alphas = unsafe { std::slice::from_raw_parts_mut(dalphas, k) }; + let d_means = unsafe { std::slice::from_raw_parts_mut(dmeans, k * d) }; + let d_icf = unsafe { std::slice::from_raw_parts_mut(dicf, k * d * (d + 1) / 2) }; + let mut my_derr = unsafe { *derr }; + let (mut qdiags, mut sum_qs, mut xcentered, mut qxcentered, mut main_term) = + get_workspace(d, k); + let (mut bqdiags, mut bsum_qs, mut bxcentered, mut bqxcentered, mut bmain_term) = + get_workspace(d, k); + + unsafe { dgmm_objective( + d, + k, + n, + alphas, + d_alphas, + means, + d_means, + icf, + d_icf, + x, + wishart.gamma, + wishart.m, + &mut my_err, + &mut my_derr, + &mut qdiags, + &mut bqdiags, + &mut sum_qs, + &mut bsum_qs, + &mut xcentered, + &mut bxcentered, + &mut qxcentered, + &mut bqxcentered, + &mut main_term, + &mut bmain_term, + )}; + + unsafe { *err = my_err }; + unsafe { *derr = my_derr }; +} + +#[no_mangle] +pub extern "C" fn rust_gmm_objective( + d: usize, + k: usize, + n: usize, + alphas: *const f64, + means: *const f64, + icf: *const f64, + x: *const f64, + wishart: *const Wishart, + err: *mut f64, +) { + let alphas = unsafe { std::slice::from_raw_parts(alphas, k) }; + let means = unsafe { std::slice::from_raw_parts(means, k * d) }; + let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) }; + let x = unsafe { std::slice::from_raw_parts(x, n * d) }; + let wishart: Wishart = unsafe { *wishart }; + let mut my_err = unsafe { *err }; + let (mut qdiags, mut sum_qs, mut xcentered, mut qxcentered, mut main_term) = + get_workspace(d, k); + gmm_objective( + d, + k, + n, + alphas, + means, + icf, + x, + wishart.gamma, + wishart.m, + &mut my_err, + &mut qdiags, + &mut sum_qs, + &mut xcentered, + &mut qxcentered, + &mut main_term, + ); + unsafe { *err = my_err }; +} + +fn get_workspace(d: usize, k: usize) -> (Vec, Vec, Vec, Vec, Vec) { + let qdiags = vec![0.; d * k]; + let sum_qs = vec![0.; k]; + let xcentered = vec![0.; d]; + let qxcentered = vec![0.; d]; + let main_term = vec![0.; k]; + (qdiags, sum_qs, xcentered, qxcentered, main_term) +} + +#[autodiff( + dgmm_objective, + Reverse, + Const, + Const, + Const, + Duplicated, + Duplicated, + Duplicated, + Const, + Const, + Const, + DuplicatedOnly, + Duplicated, + Duplicated, + Duplicated, + Duplicated, + Duplicated +)] +pub fn gmm_objective( + d: usize, + k: usize, + n: usize, + alphas: &[f64], + means: &[f64], + icf: &[f64], + x: &[f64], + gamma: f64, + m: i32, + err: &mut f64, + qdiags: &mut [f64], + sum_qs: &mut [f64], + xcentered: &mut [f64], + qxcentered: &mut [f64], + main_term: &mut [f64], +) { + let wishart: Wishart = Wishart { gamma, m }; + let constant = -(n as f64) * d as f64 * 0.5 * (2.0 * PI).ln(); + let icf_sz = d * (d + 1) / 2; + + // Let the compiler know sizes so it can eliminate bounds checks + assert_eq!(qdiags.len(), d * k); + assert_eq!(sum_qs.len(), k); + assert_eq!(xcentered.len(), d); + assert_eq!(qxcentered.len(), d); + assert_eq!(main_term.len(), k); + + preprocess_qs(d, k, icf, sum_qs, qdiags); + + let mut slse = 0.; + for ix in 0..n { + for ik in 0..k { + subtract( + d, + &x[ix as usize * d as usize..], + &means[ik as usize * d as usize..], + xcentered, + ); + qtimesx( + d, + &qdiags[ik as usize * d as usize..], + &icf[ik as usize * icf_sz as usize + d as usize..], + &*xcentered, + qxcentered, + ); + main_term[ik as usize] = + alphas[ik as usize] + sum_qs[ik as usize] - 0.5 * sqnorm(&*qxcentered); + } + + slse = slse + log_sum_exp(k, &main_term); + } + + let lse_alphas = log_sum_exp(k, alphas); + + *err = constant + slse - n as f64 * lse_alphas + + log_wishart_prior(d, k, wishart, &sum_qs, &*qdiags, icf); +} + +fn arr_max(n: usize, x: &[f64]) -> f64 { + let mut max = f64::NEG_INFINITY; + for i in 0..n { + if max < x[i] { + max = x[i]; + } + } + max +} + +fn preprocess_qs(d: usize, k: usize, icf: &[f64], sum_qs: &mut [f64], qdiags: &mut [f64]) { + let icf_sz = d * (d + 1) / 2; + for ik in 0..k { + sum_qs[ik as usize] = 0.; + for id in 0..d { + let q = icf[ik as usize * icf_sz as usize + id as usize]; + sum_qs[ik as usize] = sum_qs[ik as usize] + q; + qdiags[ik as usize * d as usize + id as usize] = q.exp(); + } + } +} +fn subtract(d: usize, x: &[f64], y: &[f64], out: &mut [f64]) { + assert!(x.len() >= d); + assert!(y.len() >= d); + assert!(out.len() >= d); + for i in 0..d { + out[i] = x[i] - y[i]; + } +} + +fn qtimesx(d: usize, q_diag: &[f64], ltri: &[f64], x: &[f64], out: &mut [f64]) { + assert!(out.len() >= d); + assert!(q_diag.len() >= d); + assert!(x.len() >= d); + for i in 0..d { + out[i] = q_diag[i] * x[i]; + } + + for i in 0..d { + let mut lparamsidx = i * (2 * d - i - 1) / 2; + for j in i + 1..d { + out[j] = out[j] + ltri[lparamsidx] * x[i]; + lparamsidx += 1; + } + } +} + +fn log_sum_exp(n: usize, x: &[f64]) -> f64 { + let mx = arr_max(n, x); + let semx: f64 = x.iter().map(|x| (x - mx).exp()).sum(); + semx.ln() + mx +} +fn log_gamma_distrib(a: f64, p: f64) -> f64 { + 0.25 * p * (p - 1.) * PI.ln() + + (1..=p as usize) + .map(|j| lgamma(a + 0.5 * (1. - j as f64))) + .sum::() +} + +fn log_wishart_prior( + p: usize, + k: usize, + wishart: Wishart, + sum_qs: &[f64], + qdiags: &[f64], + icf: &[f64], +) -> f64 { + let n = p + wishart.m as usize + 1; + let icf_sz = p * (p + 1) / 2; + + let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln()) + - log_gamma_distrib(0.5 * n as f64, p as f64); + + let out = (0..k) + .map(|ik| { + let frobenius = sqnorm(&qdiags[ik * p as usize..][..p]) + + sqnorm(&icf[ik * icf_sz as usize + p as usize..][..icf_sz - p]); + 0.5 * wishart.gamma * wishart.gamma * (frobenius) + - (wishart.m as f64) * sum_qs[ik as usize] + }) + .sum::(); + + out - k as f64 * c +} + +fn sqnorm(x: &[f64]) -> f64 { + x.iter().map(|x| x * x).sum() +} diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/unsafe.rs b/enzyme/benchmarks/ReverseMode/gmm/src/unsafe.rs new file mode 100644 index 000000000000..aa91938565ab --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/src/unsafe.rs @@ -0,0 +1,148 @@ +use std::f64::consts::PI; +use crate::Wishart; +use std::autodiff::autodiff; + +#[cfg(feature = "libm")] +use libm::lgamma; + +#[cfg(not(feature = "libm"))] +mod cmath { + extern "C" { + pub fn lgamma(x: f64) -> f64; + } +} +#[cfg(not(feature = "libm"))] +#[inline] +fn lgamma(x: f64) -> f64 { + unsafe { cmath::lgamma(x) } +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_dgmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, dalphas: *mut f64, means: *const f64, dmeans: *mut f64, icf: *const f64, dicf: *mut f64, x: *const f64, wishart: *const Wishart, err: *mut f64, derr: *mut f64) { + let k = k as usize; + let n = n as usize; + let d = d as usize; + unsafe { dgmm_objective(d, k, n, alphas, dalphas, means, dmeans, icf, dicf, x, wishart, err, derr); } +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_gmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) { + let k = k as usize; + let n = n as usize; + let d = d as usize; + unsafe {gmm_objective(d, k, n, alphas, means, icf, x, wishart, err); } +} + +//#[autodiff(dgmm_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Duplicated, Const, Const, Duplicated)] +//pub unsafe fn gmm_objective(d: usize, k: usize, n: usize, alphas: &[f64], means: &[f64], icf: &[f64], x: &[f64], gamma: f64, m: i32, err: &mut f64) { +// gmm_objective(d, k, n, alphas, means, icf, x, wishart, &mut my_err); +//} + +#[autodiff(dgmm_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Duplicated, Const, Const, DuplicatedOnly)] +pub unsafe fn gmm_objective(d: usize, k: usize, n: usize, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) { + let constant = -(n as f64) * d as f64 * 0.5 * (2.0 * PI).ln(); + let icf_sz = d * (d + 1) / 2; + let mut qdiags = vec![0.; d * k]; + let mut sum_qs = vec![0.; k]; + let mut xcentered = vec![0.; d]; + let mut qxcentered = vec![0.; d]; + let mut main_term = vec![0.; k]; + + preprocess_qs(d, k, icf, sum_qs.as_mut_ptr(), qdiags.as_mut_ptr()); + + let mut slse = 0.; + for ix in 0..n { + for ik in 0..k { + subtract(d, x.add(ix * d), means.add(ik * d), xcentered.as_mut_ptr()); + qtimesx(d, qdiags.as_mut_ptr().add(ik * d), icf.add(ik * icf_sz + d), xcentered.as_ptr(), qxcentered.as_mut_ptr()); + main_term[ik] = *alphas.add(ik) + sum_qs[ik] - 0.5 * sqnorm(d, qxcentered.as_ptr()); + //main_term[ik] = alphas[ik] + sum_qs[ik] - 0.5 * sqnorm(d, &Qxcentered[0]); + } + + slse = slse + log_sum_exp(k, main_term.as_ptr()); + } + + let lse_alphas = log_sum_exp(k, alphas); + + *err = constant + slse - n as f64 * lse_alphas + log_wishart_prior(d, k, *wishart, sum_qs.as_ptr(), qdiags.as_ptr(), icf); +} + +unsafe fn arr_max(n: usize, x: *const f64) -> f64 { + let mut max = f64::NEG_INFINITY; + for i in 0..n { + if max < *x.add(i) { + max = *x.add(i); + } + } + max +} + +unsafe fn preprocess_qs(d: usize, k: usize, icf: *const f64, sum_qs: *mut f64, qdiags: *mut f64) { + let icf_sz = d * (d + 1) / 2; + for ik in 0..k { + *sum_qs.add(ik) = 0.; + for id in 0..d { + let q = *icf.add(ik * icf_sz + id); + *sum_qs.add(ik) = *sum_qs.add(ik) + q; + *qdiags.add(ik * d + id) = q.exp(); + } + } +} + +unsafe fn subtract(d: usize, x: *const f64, y: *const f64, out: *mut f64) { + for i in 0..d { + *out.add(i) = *x.add(i) - *y.add(i); + } +} + +unsafe fn qtimesx(d: usize, q_diag: *const f64, ltri: *const f64, x: *const f64, out: *mut f64) { + for i in 0..d { + *out.add(i) = *q_diag.add(i) * *x.add(i); + } + + for i in 0..d { + let mut lparamsidx = i*(2*d-i-1)/2; + for j in i + 1..d { + *out.add(j) = *out.add(j) + *ltri.add(lparamsidx) * *x.add(i); + lparamsidx += 1; + } + } +} + +unsafe fn log_sum_exp(n: usize, x: *const f64) -> f64 { + let mx = arr_max(n, x); + let mut semx: f64 = 0.0; + + for i in 0..n { + semx = semx + (*x.add(i) - mx).exp(); + } + semx.ln() + mx +} + +fn log_gamma_distrib(a: f64, p: f64) -> f64 { + 0.25 * p * (p - 1.) * PI.ln() + (1..=p as usize).map(|j| lgamma(a + 0.5 * (1. - j as f64))).sum::() +} + +unsafe fn log_wishart_prior(p: usize, k: usize, wishart: Wishart, sum_qs: *const f64, qdiags: *const f64, icf: *const f64) -> f64 { + let n = p + wishart.m as usize + 1; + let icf_sz = p * (p + 1) / 2; + + let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln()) - log_gamma_distrib(0.5 * n as f64, p as f64); + + let mut out = 0.; + + for ik in 0..k { + let frobenius = sqnorm(p, qdiags.add(ik * p)) + sqnorm(icf_sz - p, icf.add(ik * icf_sz + p)); + out = out + 0.5 * wishart.gamma * wishart.gamma * (frobenius) - wishart.m as f64 * *sum_qs.add(ik); + } + + out - k as f64 * c +} + +unsafe fn sqnorm(n: usize, x: *const f64) -> f64 { + let mut sum = 0.; + for i in 0..n { + sum += *x.add(i) * *x.add(i); + } + sum +} diff --git a/enzyme/benchmarks/ReverseMode/lstm/Cargo.lock b/enzyme/benchmarks/ReverseMode/lstm/Cargo.lock new file mode 100644 index 000000000000..270bf4367433 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/lstm/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "lstm" +version = "0.1.0" diff --git a/enzyme/benchmarks/ReverseMode/lstm/Cargo.toml b/enzyme/benchmarks/ReverseMode/lstm/Cargo.toml new file mode 100644 index 000000000000..0b5fd981b7cf --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/lstm/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "lstm" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +[lib] +crate-type = ["lib"] + +[profile.release] +lto = "fat" +opt-level = 3 +codegen-units = 1 +unwind = "abort" +strip = true +#overflow-checks = false + +[profile.dev] +lto = "fat" + +[workspace] diff --git a/enzyme/benchmarks/ReverseMode/lstm/Makefile.make b/enzyme/benchmarks/ReverseMode/lstm/Makefile.make index 276c5df7b450..65a1e930eeed 100644 --- a/enzyme/benchmarks/ReverseMode/lstm/Makefile.make +++ b/enzyme/benchmarks/ReverseMode/lstm/Makefile.make @@ -1,23 +1,46 @@ -# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" PTR="%ptr" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" LOADCLANG="%loadClangEnzyme" ENZYME="%enzyme" make -B lstm-raw.ll results.json -f %s +# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" PTR="%ptr" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" LOADCLANG="%loadClangEnzyme" ENZYME="%enzyme" make -B results.json -f %s .PHONY: clean dir := $(abspath $(lastword $(MAKEFILE_LIST))/../../../..) +include $(dir)/benchmarks/ReverseMode/adbench/Makefile.config + +ifeq ($(strip $(CLANG)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES1)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES2)),) +$(error PASSES2 is not set) +endif + +ifeq ($(strip $(PASSES3)),) +$(error PASSES3 is not set) +endif + +ifneq ($(strip $(PASSES4)),) +$(error PASSES4 is set) +endif + clean: rm -f *.ll *.o results.txt results.json + cargo +enzyme clean -%-unopt.ll: %.cpp - clang++ $(BENCH) $(PTR) $^ -pthread -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm +$(dir)/benchmarks/ReverseMode/lstm/target/release/liblstm.a: src/lib.rs Cargo.toml + RUSTFLAGS="-Z autodiff=Enable,PrintPasses" cargo +enzyme rustc --release --lib --crate-type=staticlib -%-raw.ll: %-unopt.ll - opt $^ $(LOAD) $(ENZYME) -o $@ -S +%-unopt.ll: %.cpp + $(CLANG) $(BENCH) $^ -pthread -O3 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -o $@ -S -emit-llvm -%-opt.ll: %-raw.ll - opt $^ -o $@ -S +%-opt.ll: %-unopt.ll + $(OPT) $^ $(LOAD) -passes="$(PASSES2),enzyme" -o $@ -S -lstm.o: lstm-opt.ll - clang++ -pthread -O2 $^ -o $@ $(BENCHLINK) -lm +lstm.o: lstm-opt.ll $(dir)/benchmarks/ReverseMode/lstm/target/release/liblstm.a + $(CLANG) -pthread -O3 $^ -o $@ $(BENCHLINK) -lm results.json: lstm.o - ./$^ + numactl -C 1 ./$^ diff --git a/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp b/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp index dbbc9929a7cc..ade0b2237510 100644 --- a/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp +++ b/enzyme/benchmarks/ReverseMode/lstm/lstm.cpp @@ -50,15 +50,10 @@ double logsumexp(double const* vect, int sz) // LSTM OBJECTIVE // The LSTM model -void lstm_model( - int hsize, - double const* __restrict weight, - double const* __restrict bias, - double* __restrict hidden, - double* __restrict cell, - double const* __restrict input -) -{ +void lstm_model_restrict(int hsize, double const *__restrict weight, + double const *__restrict bias, + double *__restrict hidden, double *__restrict cell, + double const *__restrict input) { // TODO NOTE THIS //__builtin_assume(hsize > 0); @@ -94,16 +89,9 @@ void lstm_model( } // Predict LSTM output given an input -void lstm_predict( - int l, - int b, - double const* __restrict w, - double const* __restrict w2, - double* __restrict s, - double const* __restrict x, - double* __restrict x2 -) -{ +void lstm_predict_restrict(int l, int b, double const *__restrict w, + double const *__restrict w2, double *__restrict s, + double const *__restrict x, double *__restrict x2) { int i; for (i = 0; i < b; i++) { @@ -113,7 +101,8 @@ void lstm_predict( double* xp = x2; for (i = 0; i <= 2 * l * b - 1; i += 2 * b) { - lstm_model(b, &(w[i * 4]), &(w[(i + b) * 4]), &(s[i]), &(s[i + b]), xp); + lstm_model_restrict(b, &(w[i * 4]), &(w[(i + b) * 4]), &(s[i]), + &(s[i + b]), xp); xp = &(s[i]); } @@ -124,17 +113,12 @@ void lstm_predict( } // LSTM objective (loss function) -void lstm_objective( - int l, - int c, - int b, - double const* __restrict main_params, - double const* __restrict extra_params, - double* __restrict state, - double const* __restrict sequence, - double* __restrict loss -) -{ +void cxx_restrict_lstm_objective(int l, int c, int b, + double const *__restrict main_params, + double const *__restrict extra_params, + double *__restrict state, + double const *__restrict sequence, + double *__restrict loss) { int i, t; double total = 0.0; int count = 0; @@ -147,7 +131,8 @@ void lstm_objective( __builtin_assume(b>0); for (t = 0; t <= (c - 1) * b - 1; t += b) { - lstm_predict(l, b, main_params, extra_params, state, input, ypred); + lstm_predict_restrict(l, b, main_params, extra_params, state, input, + ypred); lse = logsumexp(ypred, b); for (i = 0; i < b; i++) { @@ -177,32 +162,17 @@ void __enzyme_autodiff(...) noexcept; // * tapenade -b -o lstm_tapenade -head "lstm_objective(loss)/(main_params extra_params)" lstm.c -void dlstm_objective( - int l, - int c, - int b, - double const* main_params, - double* dmain_params, - double const* extra_params, - double* dextra_params, - double* state, - double const* sequence, - double* loss, - double* dloss -) -{ - __enzyme_autodiff(lstm_objective, - enzyme_const, l, - enzyme_const, c, - enzyme_const, b, - enzyme_dup, main_params, dmain_params, - enzyme_dup, extra_params, dextra_params, - enzyme_const, state, - enzyme_const, sequence, - enzyme_dupnoneed, loss, dloss - ); +void dlstm_objective_restrict(int l, int c, int b, double const *main_params, + double *dmain_params, double const *extra_params, + double *dextra_params, double *state, + double const *sequence, double *loss, + double *dloss) { + __enzyme_autodiff(cxx_restrict_lstm_objective, enzyme_const, l, + enzyme_const, c, enzyme_const, b, enzyme_dup, main_params, + dmain_params, enzyme_dup, extra_params, dextra_params, + enzyme_const, state, enzyme_const, sequence, + enzyme_dupnoneed, loss, dloss); } - } @@ -728,3 +698,5 @@ void adept_dlstm_objective(int l, int c, int b, const double *main_params, doubl } #endif + +#include "lstm_mayalias.h" diff --git a/enzyme/benchmarks/ReverseMode/lstm/lstm_mayalias.h b/enzyme/benchmarks/ReverseMode/lstm/lstm_mayalias.h new file mode 100644 index 000000000000..06401ff35a66 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/lstm/lstm_mayalias.h @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +/* + * File "lstm_b_tapenade_generated.c" is generated by Tapenade 3.14 (r7259) from this file. + * To reproduce such a generation you can use Tapenade CLI + * (can be downloaded from http://www-sop.inria.fr/tropics/tapenade/downloading.html) + * + * After installing use the next command to generate a file: + * + * tapenade -b -o lstm_tapenade -head "lstm_objective(loss)/(main_params extra_params)" lstm.c + * + * This will produce a file "lstm_tapenade_b.c" which content will be the same as the content of the file "lstm_b_tapenade_generated.c", + * except one-line header. Moreover a log-file "lstm_tapenade_b.msg" will be produced. + * + * NOTE: the code in "lstm_b_tapenade_generated.c" is wrong and won't work. + * REPAIRED SOURCE IS STORED IN THE FILE "lstm_b.c". + * You can either use diff tool or read "lstm_b.c" header to figure out what changes was performed to fix the code. + * + * NOTE: you can also use Tapenade web server (http://tapenade.inria.fr:8080/tapenade/index.jsp) + * for generating but the result can be slightly different. + */ + +// #include "../adbench/lstm.h" + +extern "C" { +// #include "lstm.h" + +// UTILS +// Sigmoid on scalar +// double sigmoid(double x) +//{ +// return 1.0 / (1.0 + exp(-x)); +//} +// +//// log(sum(exp(x), 2)) +// double logsumexp(double const* vect, int sz) +//{ +// double sum = 0.0; +// int i; +// +// for (i = 0; i < sz; i++) +// { +// sum += exp(vect[i]); +// } +// +// sum += 2; +// return log(sum); +// } + +// LSTM OBJECTIVE +// The LSTM model +void lstm_model(int hsize, double const *weight, double const *bias, + double *hidden, double *cell, double const *input) { + // TODO NOTE THIS + //__builtin_assume(hsize > 0); + + double *gates = (double *)malloc(4 * hsize * sizeof(double)); + double *forget = &(gates[0]); + double *ingate = &(gates[hsize]); + double *outgate = &(gates[2 * hsize]); + double *change = &(gates[3 * hsize]); + + int i; + // caching input + // hidden (needed) + for (i = 0; i < hsize; i++) { + forget[i] = sigmoid(input[i] * weight[i] + bias[i]); + ingate[i] = sigmoid(hidden[i] * weight[hsize + i] + bias[hsize + i]); + outgate[i] = + sigmoid(input[i] * weight[2 * hsize + i] + bias[2 * hsize + i]); + change[i] = tanh(hidden[i] * weight[3 * hsize + i] + bias[3 * hsize + i]); + } + + // caching cell (needed) + for (i = 0; i < hsize; i++) { + cell[i] = cell[i] * forget[i] + ingate[i] * change[i]; + } + + for (i = 0; i < hsize; i++) { + hidden[i] = outgate[i] * tanh(cell[i]); + } + + free(gates); +} + +// Predict LSTM output given an input +void lstm_predict(int l, int b, double const *w, double const *w2, double *s, + double const *x, double *x2) { + int i; + for (i = 0; i < b; i++) { + x2[i] = x[i] * w2[i]; + } + + double *xp = x2; + for (i = 0; i <= 2 * l * b - 1; i += 2 * b) { + lstm_model(b, &(w[i * 4]), &(w[(i + b) * 4]), &(s[i]), &(s[i + b]), xp); + xp = &(s[i]); + } + + for (i = 0; i < b; i++) { + x2[i] = xp[i] * w2[b + i] + w2[2 * b + i]; + } +} + +// LSTM objective (loss function) +void cxx_mayalias_lstm_objective(int l, int c, int b, double const *main_params, + double const *extra_params, double *state, + double const *sequence, double *loss) { + int i, t; + double total = 0.0; + int count = 0; + const double *input = &(sequence[0]); + double *ypred = (double *)malloc(b * sizeof(double)); + double *ynorm = (double *)malloc(b * sizeof(double)); + const double *ygold; + double lse; + + __builtin_assume(b > 0); + for (t = 0; t <= (c - 1) * b - 1; t += b) { + lstm_predict(l, b, main_params, extra_params, state, input, ypred); + lse = logsumexp(ypred, b); + for (i = 0; i < b; i++) { + ynorm[i] = ypred[i] - lse; + } + + ygold = &(sequence[t + b]); + for (i = 0; i < b; i++) { + total += ygold[i] * ynorm[i]; + } + + count += b; + input = ygold; + } + + *loss = -total / count; + + free(ypred); + free(ynorm); +} + +extern int enzyme_const; +extern int enzyme_dup; +extern int enzyme_dupnoneed; +void __enzyme_autodiff(...) noexcept; + +// * tapenade -b -o lstm_tapenade -head "lstm_objective(loss)/(main_params extra_params)" lstm.c + +void dlstm_objective_mayalias(int l, int c, int b, double const *main_params, + double *dmain_params, double const *extra_params, + double *dextra_params, double *state, + double const *sequence, double *loss, + double *dloss) { + __enzyme_autodiff(cxx_mayalias_lstm_objective, enzyme_const, l, enzyme_const, + c, enzyme_const, b, enzyme_dup, main_params, dmain_params, + enzyme_dup, extra_params, dextra_params, enzyme_const, + state, enzyme_const, sequence, enzyme_dupnoneed, loss, + dloss); +} +} diff --git a/enzyme/benchmarks/ReverseMode/lstm/src/lib.rs b/enzyme/benchmarks/ReverseMode/lstm/src/lib.rs new file mode 100644 index 000000000000..937460f3cee3 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/lstm/src/lib.rs @@ -0,0 +1,56 @@ +#![feature(autodiff)] + +pub (crate) mod unsf; +pub (crate) mod safe; +use std::slice; + + +#[no_mangle] +pub extern "C" fn rust_unsafe_lstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + unsafe {unsf::lstm_unsafe_objective(l,c,b,main_params,extra_params,state,sequence, loss);} +} +#[no_mangle] +pub extern "C" fn rust_safe_lstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + let (main_params, extra_params, state, sequence) = unsafe {( + slice::from_raw_parts(main_params, 2*l*4*b), + slice::from_raw_parts(extra_params, 3*b), + slice::from_raw_parts_mut(state, 2*l*b), + slice::from_raw_parts(sequence, c*b) + )}; + + unsafe { + safe::lstm_objective(l,c,b,main_params,extra_params,state,sequence, &mut *loss); + } +} + +#[no_mangle] +pub extern "C" fn rust_unsafe_dlstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, d_main_params: *mut f64, extra_params: *const f64, d_extra_params: *mut f64, state: *mut f64, sequence: *const f64, res: *mut f64, d_res: *mut f64) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + unsafe {unsf::d_lstm_unsafe_objective(l,c,b,main_params,d_main_params, extra_params,d_extra_params, state,sequence, res, d_res);} +} +#[no_mangle] +pub extern "C" fn rust_safe_dlstm_objective(l: i32, c: i32, b: i32, main_params: *const f64, d_main_params: *mut f64, extra_params: *const f64, d_extra_params: *mut f64, state: *mut f64, sequence: *const f64, res: *mut f64, d_res: *mut f64) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + let (main_params, d_main_params, extra_params, d_extra_params, state, sequence) = unsafe {( + slice::from_raw_parts(main_params, 2*l*4*b), + slice::from_raw_parts_mut(d_main_params, 2*l*4*b), + slice::from_raw_parts(extra_params, 3*b), + slice::from_raw_parts_mut(d_extra_params, 3*b), + slice::from_raw_parts_mut(state, 2*l*b), + slice::from_raw_parts(sequence, c*b) + )}; + + unsafe { + safe::d_lstm_objective(l,c,b,main_params,d_main_params, extra_params,d_extra_params, state,sequence, &mut *res, &mut *d_res); + } +} diff --git a/enzyme/benchmarks/ReverseMode/lstm/src/safe.rs b/enzyme/benchmarks/ReverseMode/lstm/src/safe.rs new file mode 100644 index 000000000000..3329ebb2c6ae --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/lstm/src/safe.rs @@ -0,0 +1,238 @@ +use std::slice; +use std::autodiff::autodiff; +use std::hint::assert_unchecked; + +// Sigmoid on scalar +fn sigmoid(x: f64) -> f64 { + 1.0 / (1.0 + (-x).exp()) +} + +// log(sum(exp(x), 2)) +#[inline] +fn logsumexp(vect: &[f64]) -> f64 { + let mut sum = 0.0; + for &val in vect { + sum += val.exp(); + } + sum += 2.0; // Adding 2 to sum + sum.ln() +} + +// LSTM OBJECTIVE +// The LSTM model +fn lstm_model( + hsize: usize, + weight: &[f64], + bias: &[f64], + hidden: &mut [f64], + cell: &mut [f64], + input: &[f64], +) { + let mut gates = vec![0.0; 4 * hsize]; + let gates = &mut gates[..4 * hsize]; + let (a, b) = gates.split_at_mut(2 * hsize); + let ((forget, ingate), (outgate, change)) = (a.split_at_mut(hsize), b.split_at_mut(hsize)); + + // unsafe {assert_unchecked(weight.len()== 4 * hsize)}; + // unsafe {assert_unchecked(bias.len()== 4 * hsize)}; + // unsafe {assert_unchecked(hidden.len()== hsize)}; + // unsafe {assert_unchecked(cell.len() >= hsize)}; + // unsafe {assert_unchecked(input.len() >= hsize)}; + // caching input + for i in 0..hsize { + forget[i] = sigmoid(input[i] * weight[i] + bias[i]); + ingate[i] = sigmoid(hidden[i] * weight[hsize + i] + bias[hsize + i]); + outgate[i] = sigmoid(input[i] * weight[2 * hsize + i] + bias[2 * hsize + i]); + change[i] = (hidden[i] * weight[3 * hsize + i] + bias[3 * hsize + i]).tanh(); + } + + // caching cell + for i in 0..hsize { + cell[i] = cell[i] * forget[i] + ingate[i] * change[i]; + } + + for i in 0..hsize { + hidden[i] = outgate[i] * cell[i].tanh(); + } +} + +// Predict LSTM output given an input +fn lstm_predict( + l: usize, + b: usize, + w: &[f64], + w2: &[f64], + s: &mut [f64], + x: &[f64], + x2: &mut [f64], +) { + for i in 0..b { + x2[i] = x[i] * w2[i]; + } + + let mut i = 0; + while i <= 2 * l * b - 1 { + // make borrow-checker happy with non-overlapping mutable references + let (xp, s1, s2) = if i == 0 { + let (s1, s2) = s.split_at_mut(b); + (x2.as_mut(), s1, s2) + } else { + let tmp = &mut s[i - 2 * b..]; + let (a, d) = tmp.split_at_mut(2 * b); + let (d, c) = d.split_at_mut(b); + + (a, d, c) + }; + + lstm_model( + b, + &w[i * 4..(i + b) * 4], + &w[(i + b) * 4..(i + 2 * b) * 4], + s1, + s2, + xp, + ); + + i += 2 * b; + } + + let xp = &s[i - 2 * b..]; + + for i in 0..b { + x2[i] = xp[i] * w2[b + i] + w2[2 * b + i]; + } +} + +// LSTM objective (loss function) +#[autodiff( + d_lstm_objective, + Reverse, + Const, + Const, + Const, + Duplicated, + Duplicated, + Const, + Const, + DuplicatedOnly +)] +pub(crate) fn lstm_objective( + l: usize, + c: usize, + b: usize, + main_params: &[f64], + extra_params: &[f64], + state: &mut [f64], + sequence: &[f64], + loss: &mut f64, +) { + let mut total = 0.0; + + let mut input = &sequence[..b]; + let mut ypred = vec![0.0; b]; + let mut ynorm = vec![0.0; b]; + + // unsafe{assert_unchecked(b > 0)}; + + let limit = (c - 1) * b; + for j in 0..(c - 1) { + let t = j * b; + lstm_predict(l, b, main_params, extra_params, state, input, &mut ypred); + let lse = logsumexp(&ypred); + for i in 0..b { + ynorm[i] = ypred[i] - lse; + } + + let ygold = &sequence[t + b..]; + for i in 0..b { + total += ygold[i] * ynorm[i]; + } + + input = ygold; + } + let count = (c - 1) * b; + + *loss = -total / count as f64; +} + +#[no_mangle] +pub extern "C" fn rust_lstm_objective( + l: i32, + c: i32, + b: i32, + main_params: *const f64, + extra_params: *const f64, + state: *mut f64, + sequence: *const f64, + loss: *mut f64, +) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + let (main_params, extra_params, state, sequence) = unsafe { + ( + slice::from_raw_parts(main_params, 2 * l * 4 * b), + slice::from_raw_parts(extra_params, 3 * b), + slice::from_raw_parts_mut(state, 2 * l * b), + slice::from_raw_parts(sequence, c * b), + ) + }; + + unsafe { + lstm_objective( + l, + c, + b, + main_params, + extra_params, + state, + sequence, + &mut *loss, + ); + } +} + +#[no_mangle] +pub extern "C" fn rust_dlstm_objective( + l: i32, + c: i32, + b: i32, + main_params: *const f64, + d_main_params: *mut f64, + extra_params: *const f64, + d_extra_params: *mut f64, + state: *mut f64, + sequence: *const f64, + res: *mut f64, + d_res: *mut f64, +) { + let l = l as usize; + let c = c as usize; + let b = b as usize; + let (main_params, d_main_params, extra_params, d_extra_params, state, sequence) = unsafe { + ( + slice::from_raw_parts(main_params, 2 * l * 4 * b), + slice::from_raw_parts_mut(d_main_params, 2 * l * 4 * b), + slice::from_raw_parts(extra_params, 3 * b), + slice::from_raw_parts_mut(d_extra_params, 3 * b), + slice::from_raw_parts_mut(state, 2 * l * b), + slice::from_raw_parts(sequence, c * b), + ) + }; + + unsafe { + d_lstm_objective( + l, + c, + b, + main_params, + d_main_params, + extra_params, + d_extra_params, + state, + sequence, + &mut *res, + &mut *d_res, + ); + } +} diff --git a/enzyme/benchmarks/ReverseMode/lstm/src/unsf.rs b/enzyme/benchmarks/ReverseMode/lstm/src/unsf.rs new file mode 100644 index 000000000000..498bf96a9983 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/lstm/src/unsf.rs @@ -0,0 +1,116 @@ +use std::autodiff::autodiff; + +// Sigmoid on scalar +fn sigmoid(x: f64) -> f64 { + 1.0 / (1.0 + (-x).exp()) +} + +// log(sum(exp(x), 2)) +unsafe fn logsumexp(vect: *const f64, sz: usize) -> f64 { + let mut sum: f64 = 0.0; + for i in 0..sz { + sum += (*vect.add(i)).exp(); + } + sum += 2.0; // Adding 2 to sum + sum.ln() +} + +// LSTM OBJECTIVE +// The LSTM model +unsafe fn lstm_model( + hsize: usize, + weight: *const f64, + bias: *const f64, + hidden: *mut f64, + cell: *mut f64, + input: *const f64, +) { +// // TODO NOTE THIS +// //__builtin_assume(hsize > 0); + let mut gates = vec![0.0; 4 * hsize]; + let forget: *mut f64 = gates.as_mut_ptr(); + let ingate: *mut f64 = gates[hsize..].as_mut_ptr(); + let outgate: *mut f64 = gates[2 * hsize..].as_mut_ptr(); + let change: *mut f64 = gates[3 * hsize..].as_mut_ptr(); + //let (a,b) = gates.split_at_mut(2*hsize); + //let ((forget, ingate), (outgate, change)) = ( + // a.split_at_mut(hsize), b.split_at_mut(hsize)); + + // caching input + for i in 0..hsize { + *forget.add(i) = sigmoid(*input.add(i) * *weight.add(i) + *bias.add(i)); + *ingate.add(i) = sigmoid(*hidden.add(i) * *weight.add(hsize + i) + *bias.add(hsize + i)); + *outgate.add(i) = sigmoid(*input.add(i) * *weight.add(2 * hsize + i) + *bias.add(2 * hsize + i)); + *change.add(i) = (*hidden.add(i) * *weight.add(3 * hsize + i) + *bias.add(3 * hsize + i)).tanh(); + } + + // caching cell + for i in 0..hsize { + *cell.add(i) = *cell.add(i) * *forget.add(i) + *ingate.add(i) * *change.add(i); + } + + for i in 0..hsize { + *hidden.add(i) = *outgate.add(i) * (*cell.add(i)).tanh(); + } +} + +// Predict LSTM output given an input +unsafe fn lstm_predict( + l: usize, + b: usize, + w: *const f64, + w2: *const f64, + s: *mut f64, + x: *const f64, + x2: *mut f64, +) { + for i in 0..b { + *x2.add(i) = *x.add(i) * *w2.add(i); + } + + let mut xp = x2; + let stop = 2 * l * b; + for i in (0..=stop - 1).step_by(2 * b) { + lstm_model(b, w.add(i * 4), w.add((i + b) * 4), s.add(i), s.add(i + b), xp); + xp = s.add(i); + } + + for i in 0..b { + *x2.add(i) = *xp.add(i) * *w2.add(b + i) + *w2.add(2 * b + i); + } +} + +// LSTM objective (loss function) +#[autodiff(d_lstm_unsafe_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Const, Const, DuplicatedOnly)] +pub (crate) unsafe fn lstm_unsafe_objective(l: usize, c: usize, b: usize, main_params: *const f64, extra_params: *const f64, state: *mut f64, sequence: *const f64, loss: *mut f64) { + let mut total = 0.0; + let mut count = 0; + + //const double* input = &(sequence[0]); + let mut input = sequence; + let mut ypred = vec![0.0; b]; + let mut ynorm = vec![0.0; b]; + let mut lse; + + assert!(b > 0); + + let stop = (c - 1) * b; + for t in (0..=stop - 1).step_by(b) { + lstm_predict(l, b, main_params, extra_params, state, input, ypred.as_mut_ptr()); + lse = logsumexp(ypred.as_mut_ptr(), b); + for i in 0..b { + ynorm[i] = ypred[i] - lse; + } + + //let ygold = &sequence[t + b..]; + let ygold = sequence.add(t + b); + for i in 0..b { + total += *ygold.add(i) * ynorm[i]; + } + + count += b; + input = ygold; + } + + *loss = -total / count as f64; +} diff --git a/enzyme/benchmarks/ReverseMode/ode-real/Cargo.lock b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.lock new file mode 100644 index 000000000000..93dcf6a53b60 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "ode" +version = "0.1.0" diff --git a/enzyme/benchmarks/ReverseMode/ode-real/Cargo.toml b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.toml new file mode 100644 index 000000000000..96952f9d9e08 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ode-real/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "ode" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +[lib] +crate-type = ["lib"] + +[profile.release] +lto = "fat" +opt-level = 3 +codegen-units = 1 +unwind = "abort" +strip = true +#overflow-checks = false + +[profile.dev] +lto = "fat" + +[workspace] diff --git a/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make b/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make index 5abb283600e4..4f484097271d 100644 --- a/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make +++ b/enzyme/benchmarks/ReverseMode/ode-real/Makefile.make @@ -1,32 +1,55 @@ -# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" PTR="%ptr" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" ENZYME="%enzyme" make -B ode-raw.ll ode-opt.ll results.json VERBOSE=1 -f %s +# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" PTR="%ptr" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" ENZYME="%enzyme" make -B results.json VERBOSE=1 -f %s .PHONY: clean dir := $(abspath $(lastword $(MAKEFILE_LIST))/../../../..) +include $(dir)/benchmarks/ReverseMode/adbench/Makefile.config + +ifeq ($(strip $(CLANG)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES1)),) +$(error PASSES1 is not set) +endif + +ifeq ($(strip $(PASSES2)),) +$(error PASSES2 is not set) +endif + +ifeq ($(strip $(PASSES3)),) +$(error PASSES3 is not set) +endif + +ifneq ($(strip $(PASSES4)),) +$(error PASSES4 is set) +endif + clean: rm -f *.ll *.o results.txt results.json + cargo +enzyme clean -%-unopt.ll: %.cpp - clang++ $(BENCH) $(PTR) $^ -O2 -fno-use-cxa-atexit -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm +$(dir)/benchmarks/ReverseMode/ode-real/target/release/libode.a: src/lib.rs Cargo.toml + RUSTFLAGS="-Z autodiff=Enable,LooseTypes" cargo +enzyme rustc --release --lib --crate-type=staticlib -%-raw.ll: %-unopt.ll - opt $^ $(LOAD) $(ENZYME) -o $@ -S +%-unopt.ll: %.cpp + $(CLANG) $(BENCH) $^ -pthread -O3 -fno-use-cxa-atexit -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -o $@ -S -emit-llvm -%-opt.ll: %-raw.ll - opt $^ -o $@ -S +%-opt.ll: %-unopt.ll + $(OPT) $^ $(LOAD) -passes="$(PASSES2),enzyme" -o $@ -S -ode.o: ode-opt.ll - clang++ $(BENCH) -O2 $^ -o $@ $(BENCHLINK) +ode.o: ode-opt.ll $(dir)/benchmarks/ReverseMode/ode-real/target/release/libode.a + $(CLANG) -pthread -O3 -fno-math-errno $^ -o $@ $(BENCHLINK) results.json: ode.o - ./$^ 1000 | tee $@ - ./$^ 1000 >> $@ - ./$^ 1000 >> $@ - ./$^ 1000 >> $@ - ./$^ 1000 >> $@ - ./$^ 1000 >> $@ - ./$^ 1000 >> $@ - ./$^ 1000 >> $@ - ./$^ 1000 >> $@ - ./$^ 1000 >> $@ + numactl -C 1 ./$^ 1000 | tee $@ + numactl -C 1 ./$^ 1000 >> $@ + numactl -C 1 ./$^ 1000 >> $@ + numactl -C 1 ./$^ 1000 >> $@ + numactl -C 1 ./$^ 1000 >> $@ + numactl -C 1 ./$^ 1000 >> $@ + numactl -C 1 ./$^ 1000 >> $@ + numactl -C 1 ./$^ 1000 >> $@ + numactl -C 1 ./$^ 1000 >> $@ + numactl -C 1 ./$^ 1000 >> $@ diff --git a/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp b/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp index 7c7113df9641..17007c8de727 100644 --- a/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp +++ b/enzyme/benchmarks/ReverseMode/ode-real/ode.cpp @@ -24,20 +24,8 @@ float tdiff(struct timeval *start, struct timeval *end) { return (end->tv_sec-start->tv_sec) + 1e-6*(end->tv_usec-start->tv_usec); } -#define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS -#define BOOST_NO_EXCEPTIONS #include -#include - -#include - -#include -void boost::throw_exception(std::exception const & e){ - //do nothing -} - using namespace std; -using namespace boost::numeric::odeint; #define N 32 #define xmin 0. @@ -76,7 +64,39 @@ void init_brusselator(double* __restrict u, double* __restrict v) { } __attribute__((noinline)) -void brusselator_2d_loop(double* __restrict du, double* __restrict dv, const double* __restrict u, const double* __restrict v, const double* __restrict p, double t) { +void brusselator_2d_loop_restrict(double* __restrict du, double* __restrict dv, const double* __restrict u, const double* __restrict v, const double* __restrict p, double t) { + double A = p[0]; + double B = p[1]; + double alpha = p[2]; + double dx = (double)1/(N-1); + + alpha = alpha/(dx*dx); + + for(int i=0; i state_type; +typedef double state_type[2*N*N]; + +void lorenz_norestrict( const state_type &x, state_type &dxdt, double t ) +{ + // Extract the parameters + double p[3] = { /*A*/ 3.4, /*B*/ 1, /*alpha*/10. }; + brusselator_2d_loop_norestrict(dxdt, dxdt + N * N, x, x + N * N, p, t); +} -void lorenz( const state_type &x , state_type &dxdt , double t ) +void lorenz_restrict( const state_type &x, state_type &dxdt, double t ) { // Extract the parameters double p[3] = { /*A*/ 3.4, /*B*/ 1, /*alpha*/10. }; - brusselator_2d_loop(dxdt.c_array(), dxdt.c_array() + N * N, x.data(), x.data() + N * N, p, t); + brusselator_2d_loop_restrict(dxdt, dxdt + N * N, x, x + N * N, p, t); +} + +extern "C" void rust_lorenz_safe(const double* x, double* dxdt, double t); +extern "C" void rust_dbrusselator_2d_loop_safe(double* adjoint, const double* x, double* dx, const double* p, double* dp, double t); +extern "C" void rust_lorenz_unsf(const double* x, double* dxdt, double t); +extern "C" void rust_dbrusselator_2d_loop_unsf(double* adjoint, const double* x, double* dx, const double* p, double* dp, double t); + +double rustfoobar_unsf(const double *p, const state_type x, const state_type adjoint, double t) { + double dp[3] = { 0. }; + + state_type dx = { 0. }; + + state_type dadjoint_inp;// = adjoint + for (int i = 0; i < N * N; i++) { + dadjoint_inp[i] = adjoint[i]; + } + + rust_dbrusselator_2d_loop_unsf(dadjoint_inp, x, dx, p, dp, t); + return dx[0]; } -// init_brusselator(x.c_array(), x.c_array() + N*N) +double rustfoobar_safe(const double *p, const state_type x, const state_type adjoint, double t) { + double dp[3] = { 0. }; -double foobar(const double* p, const state_type x, const state_type adjoint, double t) { + state_type dx = { 0. }; + + state_type dadjoint_inp;// = adjoint + for (int i = 0; i < N * N; i++) { + dadjoint_inp[i] = adjoint[i]; + } + + rust_dbrusselator_2d_loop_safe(dadjoint_inp, x, dx, p, dp, t); + return dx[0]; +} + +double foobar_restrict(const double* p, const state_type x, const state_type adjoint, double t) { double dp[3] = { 0. }; state_type dx = { 0. }; - state_type dadjoint_inp = adjoint; + state_type dadjoint_inp;// = adjoint + for (int i = 0; i < N * N; i++) { + dadjoint_inp[i] = adjoint[i]; + } state_type dxdu; - __enzyme_autodiff(brusselator_2d_loop, -// enzyme_dup, dxdu.c_array(), dadjoint_inp.c_array(), -// enzyme_dup, dxdu.c_array() + N * N, dadjoint_inp.c_array() + N * N, - enzyme_dupnoneed, nullptr, dadjoint_inp.data(), - enzyme_dupnoneed, nullptr, dadjoint_inp.data() + N * N, - enzyme_dup, x.data(), dx.data(), - enzyme_dup, x.data() + N * N, dx.data() + N * N, + __enzyme_autodiff(brusselator_2d_loop_restrict, + enzyme_dup, dxdu, dadjoint_inp, + enzyme_dup, dxdu + N * N, dadjoint_inp + N * N, + // enzyme_dupnoneed, nullptr, dadjoint_inp, + // enzyme_dupnoneed, nullptr, dadjoint_inp + N * N, + enzyme_dup, x, dx, + enzyme_dup, x + N * N, dx + N * N, + enzyme_dup, p, dp, + enzyme_const, t); + + return dx[0]; +} + +double foobar_norestrict(const double* p, const state_type x, const state_type adjoint, double t) { + double dp[3] = { 0. }; + + state_type dx = { 0. }; + + state_type dadjoint_inp;// = adjoint + for (int i = 0; i < N * N; i++) { + dadjoint_inp[i] = adjoint[i]; + } + + state_type dxdu; + + __enzyme_autodiff(brusselator_2d_loop_norestrict, + enzyme_dup, dxdu, dadjoint_inp, + enzyme_dup, dxdu + N * N, dadjoint_inp + N * N, + // enzyme_dupnoneed, nullptr, dadjoint_inp, + // enzyme_dupnoneed, nullptr, dadjoint_inp + N * N, + enzyme_dup, x, dx, + enzyme_dup, x + N * N, dx + N * N, enzyme_dup, p, dp, enzyme_const, t); @@ -486,14 +572,17 @@ double tfoobar(const double* p, const state_type x, const state_type adjoint, do state_type dx = { 0. }; - state_type dadjoint_inp = adjoint; + state_type dadjoint_inp;// = adjoint + for (int i = 0; i < N * N; i++) { + dadjoint_inp[i] = adjoint[i]; + } state_type dxdu; - brusselator_2d_loop_b(nullptr, dadjoint_inp.data(), - nullptr, dadjoint_inp.data() + N * N, - x.data(), dx.data(), - x.data() + N * N, dx.data() + N * N, + brusselator_2d_loop_b(nullptr, dadjoint_inp, + nullptr, dadjoint_inp + N * N, + x, dx, + x + N * N, dx + N * N, p, dp, t); @@ -505,10 +594,10 @@ int main(int argc, char** argv) { const double p[3] = { /*A*/ 3.4, /*B*/ 1, /*alpha*/10. }; state_type x; - init_brusselator(x.data(), x.data() + N * N); + init_brusselator(x, x + N * N); state_type adjoint; - init_brusselator(adjoint.data(), adjoint.data() + N * N); + init_brusselator(adjoint, adjoint + N * N); double t = 2.1; @@ -542,174 +631,97 @@ int main(int argc, char** argv) { double res; for(int i=0; i<10000; i++) - res = foobar(p, x, adjoint, t); + res = foobar_norestrict(p, x, adjoint, t); gettimeofday(&end, NULL); - printf("Enzyme combined %0.6f res=%f\n", tdiff(&start, &end), res); + printf("C++ Enzyme combined mayalias %0.6f res=%f\n", tdiff(&start, &end), res); } - //printf("res=%f\n", foobar(1000)); -} - - -#if 0 - -typedef boost::array< double , 6 > state_type; - -void lorenz( const state_type &x , state_type &dxdt , double t ) -{ - // Extract the parameters - double k1 = x[3]; - double k2 = x[4]; - double k3 = x[5]; - - dxdt[0] = -k1 * x[0] + k3 * x[1] * x[2]; - dxdt[1] = k1 * x[0] - k2 * x[1] * x[1] - k3 * x[1] * x[2]; - dxdt[2] = k2 * x[1] * x[1]; - - // Don't change the parameters p - dxdt[3] = 0; - dxdt[4] = 0; - dxdt[5] = 0; -} - -double foobar(double* p, uint64_t iters) { - state_type x = { 1.0, 0, 0, p[0], p[1], p[2] }; // initial conditions - double t = 1e5; - typedef controlled_runge_kutta< runge_kutta_dopri5< state_type , typename state_type::value_type , state_type , double > > stepper_type; - //typedef euler< state_type , typename state_type::value_type , state_type , double > stepper_type; - integrate_const( stepper_type(), lorenz , x , 0.0 , t, t/iters ); - - return x[0]; -} - -typedef boost::array< adouble , 6 > astate_type; - -void alorenz( const astate_type &x , astate_type &dxdt , adouble t ) -{ - // Extract the parameters - adouble k1 = x[3]; - adouble k2 = x[4]; - adouble k3 = x[5]; - - dxdt[0] = -k1 * x[0] + k3 * x[1] * x[2]; - dxdt[1] = k1 * x[0] - k2 * x[1] * x[1] - k3 * x[1] * x[2]; - dxdt[2] = k2 * x[1] * x[1]; - - // Don't change the parameters p - dxdt[3] = 0; - dxdt[4] = 0; - dxdt[5] = 0; -} - -adouble afoobar(adouble* p, uint64_t iters) { - astate_type x = { 1.0, 0, 0, p[0], p[1], p[2] }; // initial conditions - double t = 1e5; - typedef controlled_runge_kutta< runge_kutta_dopri5< astate_type , typename astate_type::value_type , astate_type , adouble > > stepper_type; - //typedef euler< astate_type , typename astate_type::value_type , astate_type , adouble > stepper_type; - integrate_const( stepper_type(), alorenz , x , 0.0 , t, t/iters ); - - return x[0]; -} - -static -double afoobar_and_gradient(double* p_in, double* dp_out, uint64_t iters) { - adept::Stack stack; - adouble x[3] = { p_in[0], p_in[1], p_in[2] }; - stack.new_recording(); - adouble y = afoobar(x, iters); - y.set_gradient(1.0); - stack.compute_adjoint(); - for(int i=0; i<3; i++) - dp_out[i] = x[i].get_gradient(); - return y.value(); -} - -static void adept_sincos(uint64_t iters) { + { struct timeval start, end; gettimeofday(&start, NULL); - double p[3] = { 0.04,3e7,1e4 }; - double res = foobar(p, iters); + double res; + for(int i=0; i<10000; i++) + res = foobar_restrict(p, x, adjoint, t); gettimeofday(&end, NULL); - printf("Adept real %0.6f res=%f\n", tdiff(&start, &end), res); + printf("C++ Enzyme combined restrict %0.6f res=%f\n", tdiff(&start, &end), res); } - + { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); - adept::Stack stack; - adouble p[3] = { 0.04,3e7,1e4 }; - // stack.new_recording(); - adouble resa = afoobar(p, iters); - double res = resa.value(); + double res; + for(int i=0; i<10000; i++) + res = rustfoobar_safe(p, x, adjoint, t); - gettimeofday(&end, NULL); - printf("Adept forward %0.6f res=%f\n", tdiff(&start, &end), res); + gettimeofday(&end, NULL); + printf("Rust Enzyme combined safe %0.6f res=%f\n", tdiff(&start, &end), res); } { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); - double p[3] = { 0.04,3e7,1e4 }; - double dp[3] = { 0 }; - afoobar_and_gradient(p, dp, iters); + double res; + for(int i=0; i<10000; i++) + res = rustfoobar_unsf(p, x, adjoint, t); - gettimeofday(&end, NULL); - printf("Adept combined %0.6f res'=%f\n", tdiff(&start, &end), dp[0]); + gettimeofday(&end, NULL); + printf("Rust Enzyme combined unsf %0.6f res=%f\n", tdiff(&start, &end), res); } -} - -static void enzyme_sincos(double inp, uint64_t iters) { { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); + state_type x2; - double p[3] = { 0.04,3e7,1e4 }; - double res = foobar(p, iters); + for(int i=0; i<10000; i++) { + lorenz_norestrict(x, x2, t); + } - gettimeofday(&end, NULL); - printf("Enzyme real %0.6f res=%f\n", tdiff(&start, &end), res); + gettimeofday(&end, NULL); + printf("C++ fwd mayalias %0.6f res=%f\n", tdiff(&start, &end), x2[0]); } { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); + state_type x2; - double p[3] = { 0.04,3e7,1e4 }; - double res = foobar(p, iters); + for(int i=0; i<10000; i++) { + lorenz_restrict(x, x2, t); + } - gettimeofday(&end, NULL); - printf("Enzyme forward %0.6f res=%f\n", tdiff(&start, &end), res); + gettimeofday(&end, NULL); + printf("C++ fwd restrict %0.6f res=%f\n", tdiff(&start, &end), x2[0]); } { - struct timeval start, end; - gettimeofday(&start, NULL); + struct timeval start, end; + gettimeofday(&start, NULL); + state_type x2; - double p[3] = { 0.04,3e7,1e4 }; - double dp[3] = { 0 }; - __enzyme_autodiff(foobar, p, dp, iters); + for(int i=0; i<10000; i++) + rust_lorenz_safe(x, x2, t); - gettimeofday(&end, NULL); - printf("Enzyme combined %0.6f res'=%f\n", tdiff(&start, &end), dp[0]); + gettimeofday(&end, NULL); + printf("Rust fwd safe %0.6f res=%f\n\n", tdiff(&start, &end), x2[0]); } -} -int main(int argc, char** argv) { + { + struct timeval start, end; + gettimeofday(&start, NULL); + state_type x2; - int max_iters = atoi(argv[1]) ; - double inp = 2.1; + for(int i=0; i<10000; i++) + rust_lorenz_unsf(x, x2, t); - //for(int iters=max_iters/20; iters<=max_iters; iters+=max_iters/20) { - auto iters = max_iters; - printf("iters=%d\n", iters); - adept_sincos(inp, iters); - enzyme_sincos(inp, iters); - //} + gettimeofday(&end, NULL); + printf("Rust fwd unsf %0.6f res=%f\n\n", tdiff(&start, &end), x2[0]); + } + + //printf("res=%f\n", foobar(1000)); } -#endif diff --git a/enzyme/benchmarks/ReverseMode/ode-real/src/lib.rs b/enzyme/benchmarks/ReverseMode/ode-real/src/lib.rs new file mode 100644 index 000000000000..4fbc7e75f054 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ode-real/src/lib.rs @@ -0,0 +1,100 @@ +#![feature(autodiff)] +#![feature(slice_as_chunks)] +#![feature(iter_next_chunk)] +#![feature(array_ptr_get)] +#![allow(non_snake_case)] +#![allow(non_camel_case_types)] +#![allow(non_upper_case_globals)] + +pub mod safe; +pub mod unsf; + +type StateType = [f64; 2 * N * N]; + +const N: usize = 32; + + +#[no_mangle] +pub extern "C" fn rust_lorenz_unsf(x: *const StateType, dxdt: *mut StateType, t: f64) { + let x: &StateType = unsafe { &*x }; + let dxdt: &mut StateType = unsafe { &mut *dxdt }; + unsafe {unsf::lorenz(x, dxdt, t)}; +} + + +#[no_mangle] +pub extern "C" fn rust_lorenz_safe(x: *const StateType, dxdt: *mut StateType, t: f64) { + let x: &StateType = unsafe { &*x }; + let dxdt: &mut StateType = unsafe { &mut *dxdt }; + safe::lorenz(x, dxdt, t); +} + +#[no_mangle] +pub extern "C" fn rust_dbrusselator_2d_loop_unsf(adjoint: *mut StateType, x: *const StateType, dx: *mut StateType, p: *const [f64;3], dp: *mut [f64;3], t: f64) { + let mut null1 = [0.; 1 * N * N]; + let mut null2 = [0.; 1 * N * N]; + let dx1: *mut f64 = dx.as_mut_ptr(); + let dx2: *mut f64 = unsafe { dx.as_mut_ptr().add(N*N) }; + let dadj1: *mut f64 = adjoint.as_mut_ptr(); + let dadj2: *mut f64 = unsafe { adjoint.as_mut_ptr().add(N*N) }; + let x1: *const f64 = x.as_ptr(); + let x2: *const f64 = unsafe { x.as_ptr().add(N*N) }; + + unsafe {unsf::dbrusselator_2d_loop_unsf(null1.as_mut_ptr(), dadj1, + null2.as_mut_ptr(), dadj2, + x1, dx1, + x2, dx2, + p as *mut f64, dp as *mut f64, t)}; +} + +#[no_mangle] +pub extern "C" fn rust_dbrusselator_2d_loop_safe(adjoint: *mut StateType, x: *const StateType, dx: *mut StateType, p: *const [f64;3], dp: *mut [f64;3], t: f64) { + let x: &StateType = unsafe { &*x }; + let dx: &mut StateType = unsafe { &mut *dx }; + let adjoint: &mut StateType = unsafe { &mut *adjoint }; + + let p: &[f64;3] = unsafe { &*p }; + let dp: &mut [f64;3] = unsafe { &mut *dp }; + + assert!(p[0] == 3.4); + assert!(p[1] == 1.); + assert!(p[2] == 10.); + assert!(t == 2.1); + + //let mut x1 = [0.; 2 * N * N]; + //let mut dx1 = [0.; 2 *N * N]; + //let (tmp1, tmp2) = x1.split_at_mut(N * N); + //let mut x1: [f64; N * N] = tmp1.try_into().unwrap(); + //let mut x2: [f64; N * N] = tmp2.try_into().unwrap(); + //init_brusselator(&mut x1, &mut x2); + //for i in 0..N*N { + // let tmp = (x1[i] - x[i]).abs(); + // if (tmp / x[i] > 1e-5) { + // dbg!(tmp); + // dbg!(tmp / x[i]); + // dbg!(i); + // dbg!(x1[i]); + // dbg!(x[i]); + // println!("x1[{}] = {} != x[{}] = {}", i, x1[i], i, x[i]); + // panic!(); + // } + //} + + // Alternative ways to split the inputs + //let [ mut dx1, mut dx2]: [[f64; N*N]; 2] = unsafe { *std::mem::transmute::<*mut StateType, &mut [[f64; N*N]; 2]>(dx) }; + //let [dx1, dx2]: &mut [[f64; N*N];2] = unsafe { dx.cast::<[[f64; N*N]; 2]>().as_mut().unwrap() }; + + // https://discord.com/channels/273534239310479360/273541522815713281/1236945105601040446 + let ([dx1, dx2], []): (&mut [[f64; N*N]], &mut [f64]) = dx.as_chunks_mut() else { unreachable!() }; + let ([dadj1, dadj2], []): (&mut [[f64; N*N]], &mut [f64])= adjoint.as_chunks_mut() else { unreachable!() }; + let ([x1, x2], []): (&[[f64; N*N]], &[f64])= x.as_chunks() else { unreachable!() }; + + let mut null1 = [0.; 1 * N * N]; + let mut null2 = [0.; 1 * N * N]; + safe::dbrusselator_2d_loop(&mut null1, dadj1, + &mut null2, dadj2, + x1, dx1, + x2, dx2, + p, dp, t); + return; +} diff --git a/enzyme/benchmarks/ReverseMode/ode-real/src/safe.rs b/enzyme/benchmarks/ReverseMode/ode-real/src/safe.rs new file mode 100644 index 000000000000..ddf36851b09c --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ode-real/src/safe.rs @@ -0,0 +1,75 @@ +use std::autodiff::autodiff; + +const N: usize = 32; +const xmin: f64 = 0.; +const xmax: f64 = 1.; +const ymin: f64 = 0.; +const ymax: f64 = 1.; + +#[inline(always)] +fn range(min: f64, max: f64, i: usize, N_var: usize) -> f64 { + (max - min) / (N_var as f64 - 1.) * i as f64 + min +} + +fn brusselator_f(x: f64, y: f64, t: f64) -> f64 { + let eq1 = (x - 0.3) * (x - 0.3) + (y - 0.6) * (y - 0.6) <= 0.1 * 0.1; + let eq2 = t >= 1.1; + if eq1 && eq2 { + 5.0 + } else { + 0.0 + } +} + +#[expect(unused)] +fn init_brusselator(u: &mut [f64], v: &mut [f64]) { + assert!(u.len() == N * N); + assert!(v.len() == N * N); + for i in 0..N { + for j in 0..N { + let x = range(xmin, xmax, i, N); + let y = range(ymin, ymax, j, N); + u[N * i + j] = 22.0 * (y * (1.0 - y)) * (y * (1.0 - y)).sqrt(); + v[N * i + j] = 27.0 * (x * (1.0 - x)) * (x * (1.0 - x)).sqrt(); + } + } +} + +#[no_mangle] +#[autodiff(dbrusselator_2d_loop, Reverse, Duplicated, Duplicated, Duplicated, Duplicated, Duplicated, Const)] +pub fn brusselator_2d_loop(d_u: &mut [f64;N*N], d_v: &mut [f64;N*N], u: &[f64;N*N], v: &[f64;N*N], p: &[f64;3], t: f64) { + let A = p[0]; + let B = p[1]; + let alpha = p[2]; + let dx = 1. / (N - 1) as f64; + let alpha = alpha / (dx * dx); + for i in 0..N { + for j in 0..N { + let x = range(xmin, xmax, i, N); + let y = range(ymin, ymax, j, N); + let ip1 = if i == N - 1 { i } else { i + 1 }; + let im1 = if i == 0 { i } else { i - 1 }; + let jp1 = if j == N - 1 { j } else { j + 1 }; + let jm1 = if j == 0 { j } else { j - 1 }; + let u2v = u[N * i + j] * u[N * i + j] * v[N * i + j]; + d_u[N * i + j] = alpha * (u[N * im1 + j] + u[N * ip1 + j] + u[N * i + jp1] + u[N * i + jm1] - 4. * u[N * i + j]) + + B + u2v - (A + 1.) * u[N * i + j] + brusselator_f(x, y, t); + d_v[N * i + j] = alpha * (v[N * im1 + j] + v[N * ip1 + j] + v[N * i + jp1] + v[N * i + jm1] - 4. * v[N * i + j]) + + A * u[N * i + j] - u2v; + } + } +} + +pub type StateType = [f64; 2 * N * N]; + +pub fn lorenz(x: &StateType, dxdt: &mut StateType, t: f64) { + let p = [3.4, 1., 10.]; + let (tmp1, tmp2) = dxdt.split_at_mut(N * N); + let mut dxdt1: [f64; N * N] = tmp1.try_into().unwrap(); + let mut dxdt2: [f64; N * N] = tmp2.try_into().unwrap(); + let (tmp1, tmp2) = x.split_at(N * N); + let u: [f64; N * N] = tmp1.try_into().unwrap(); + let v: [f64; N * N] = tmp2.try_into().unwrap(); + brusselator_2d_loop(&mut dxdt1, &mut dxdt2, &u, &v, &p, t); +} + diff --git a/enzyme/benchmarks/ReverseMode/ode-real/src/unsf.rs b/enzyme/benchmarks/ReverseMode/ode-real/src/unsf.rs new file mode 100644 index 000000000000..9f1e4006b80e --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ode-real/src/unsf.rs @@ -0,0 +1,79 @@ +use std::autodiff::autodiff; + +const N: usize = 32; +const xmin: f64 = 0.; +const xmax: f64 = 1.; +const ymin: f64 = 0.; +const ymax: f64 = 1.; + +#[inline(always)] +fn range(min: f64, max: f64, i: usize, N_var: usize) -> f64 { + (max - min) / (N_var as f64 - 1.) * i as f64 + min +} + +fn brusselator_f(x: f64, y: f64, t: f64) -> f64 { + let eq1 = (x - 0.3) * (x - 0.3) + (y - 0.6) * (y - 0.6) <= 0.1 * 0.1; + let eq2 = t >= 1.1; + if eq1 && eq2 { + 5.0 + } else { + 0.0 + } +} + +#[expect(unused)] +unsafe fn init_brusselator(u: *mut f64, v: *mut f64) { + for i in 0..N { + for j in 0..N { + let x = range(xmin, xmax, i, N); + let y = range(ymin, ymax, j, N); + *u.add(N * i + j) = 22.0 * (y * (1.0 - y)) * (y * (1.0 - y)).sqrt(); + *v.add(N * i + j) = 27.0 * (x * (1.0 - x)) * (x * (1.0 - x)).sqrt(); + } + } +} + +#[no_mangle] +#[autodiff(dbrusselator_2d_loop_unsf, Reverse, Duplicated, Duplicated, Duplicated, Duplicated, Duplicated, Const)] +pub unsafe fn brusselator_2d_loop_unsf(d_u: *mut f64, d_v: *mut f64, u: *const f64, v: *const f64, p: *const f64, t: f64) { + let A = *p.add(0); + let B = *p.add(1); + let alpha = *p.add(2); + let dx = 1. / (N - 1) as f64; + let alpha = alpha / (dx * dx); + for i in 0..N { + for j in 0..N { + let x = range(xmin, xmax, i, N); + let y = range(ymin, ymax, j, N); + let ip1 = if i == N - 1 { i } else { i + 1 }; + let im1 = if i == 0 { i } else { i - 1 }; + let jp1 = if j == N - 1 { j } else { j + 1 }; + let jm1 = if j == 0 { j } else { j - 1 }; + let u2v = *u.add(N * i + j) * *u.add(N * i + j) * *v.add(N * i + j); + *d_u.add(N * i + j) = alpha * (*u.add(N * im1 + j) + *u.add(N * ip1 + j) + *u.add(N * i + jp1) + *u.add(N * i + jm1) - 4. * *u.add(N * i + j)) + + B + u2v - (A + 1.) * *u.add(N * i + j) + brusselator_f(x, y, t); + *d_v.add(N * i + j) = alpha * (*v.add(N * im1 + j) + *v.add(N * ip1 + j) + *v.add(N * i + jp1) + *v.add(N * i + jm1) - 4. * *v.add(N * i + j)) + + A * *u.add(N * i + j) - u2v; + } + } +} + +type StateType = [f64; 2 * N * N]; + +pub unsafe fn lorenz(x: *const StateType, dxdt: *mut StateType, t: f64) { + let p = [3.4, 1., 10.]; + let x = x as *const f64; + let dxdt = dxdt as *mut f64; + let dxdt1: *mut f64 = dxdt as *mut f64; + let dxdt2: *mut f64 = unsafe {dxdt.add(N * N)} as *mut f64; + //let (tmp1, tmp2) = dxdt.split_at_mut(N * N); + //let mut dxdt1: [f64; N * N] = tmp1.try_into().unwrap(); + //let mut dxdt2: [f64; N * N] = tmp2.try_into().unwrap(); + let u: *const f64 = x as *const f64; + let v: *const f64 = unsafe{x.add(N * N)} as *const f64; + //let (tmp1, tmp2) = x.split_at(N * N); + //let u: [f64; N * N] = tmp1.try_into().unwrap(); + //let v: [f64; N * N] = tmp2.try_into().unwrap(); + unsafe {brusselator_2d_loop_unsf(dxdt1 as *mut f64, dxdt2 as *mut f64, u as *const f64, v as *const f64, p.as_ptr(), t)}; +} +