diff --git a/.cargo/config.toml b/.cargo/config.toml index f26d8449a..b693b0887 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,3 +1,11 @@ [build] # Configuration for these lints should be placed in `.clippy.toml` at the crate root. rustflags = ["-Dwarnings"] + +[target.wasm32-unknown-unknown] +rustflags = [ + "-C", "target-feature=+simd128,+atomics,+bulk-memory,+mutable-globals", + "-C", "link-arg=--max-memory=2576941056", +] +[unstable] +build-std = ["panic_abort", "std"] diff --git a/Cargo.lock b/Cargo.lock index 90dae99c1..29630acae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,15 @@ dependencies = [ "as-slice", ] +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anes" version = "0.1.6" @@ -161,6 +170,15 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "ash" +version = "0.38.0+1.3.281" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bb44936d800fea8f016d7f2311c6a4f97aebd5dc86f09906139ec848cf3a46f" +dependencies = [ + "libloading", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -179,6 +197,36 @@ dependencies = [ "serde", ] +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +dependencies = [ + "serde", +] + [[package]] name = "blake2" version = "0.10.6" @@ -201,6 +249,12 @@ dependencies = [ "constant_time_eq", ] +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "block-buffer" version = "0.10.4" @@ -263,6 +317,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "ciborium" version = "0.2.2" @@ -315,6 +375,16 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "colorchoice" version = "1.0.3" @@ -327,6 +397,33 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "cpufeatures" version = "0.2.16" @@ -446,6 +543,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "document-features" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95249b50c6c185bee49034bcb378a49dc2b5dff0be90ff6616d31d64febab05d" +dependencies = [ + "litrs", +] + [[package]] name = "educe" version = "0.5.11" @@ -511,6 +617,140 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "flume" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "spin", +] + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -534,6 +774,89 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "gl_generator" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a95dfc23a2b4a9a2f5ab41d194f8bfda3cabec42af4e39f08c339eb2a0c124d" +dependencies = [ + "khronos_api", + "log", + "xml-rs", +] + +[[package]] +name = "glow" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5e5ea60d70410161c8bf5da3fdfeaa1c72ed2c15f8bbb9d19fe3a4fad085f08" +dependencies = [ + "js-sys", + "slotmap", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "glutin_wgl_sys" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c4ee00b289aba7a9e5306d57c2d05499b2e5dc427f84ac708bd2c090212cf3e" +dependencies = [ + "gl_generator", +] + +[[package]] +name = "gpu-alloc" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbcd2dba93594b227a1f57ee09b8b9da8892c34d55aa332e034a228d0fe6a171" +dependencies = [ + "bitflags 2.9.1", + "gpu-alloc-types", +] + +[[package]] +name = "gpu-alloc-types" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98ff03b468aa837d70984d55f5d3f846f6ec31fe34bbb97c4f85219caeee1ca4" +dependencies = [ + "bitflags 2.9.1", +] + +[[package]] +name = "gpu-allocator" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c151a2a5ef800297b4e79efa4f4bec035c5f51d5ae587287c9b952bdf734cacd" +dependencies = [ + "log", + "presser", + "thiserror 1.0.69", + "windows", +] + +[[package]] +name = "gpu-descriptor" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b89c83349105e3732062a895becfc71a8f921bb71ecbbdd8ff99263e3b53a0ca" +dependencies = [ + "bitflags 2.9.1", + "gpu-descriptor-types", + "hashbrown", +] + +[[package]] +name = "gpu-descriptor-types" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdf242682df893b86f33a73828fb09ca4b2d3bb6cc95249707fc684d27484b91" +dependencies = [ + "bitflags 2.9.1", +] + [[package]] name = "half" version = "2.4.1" @@ -549,6 +872,15 @@ name = "hashbrown" version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" +dependencies = [ + "foldhash", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" @@ -562,6 +894,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hexf-parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" + [[package]] name = "hmac" version = "0.12.1" @@ -631,6 +969,12 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "js-sys" version = "0.3.77" @@ -641,6 +985,23 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "khronos-egl" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6aae1df220ece3c0ada96b8153459b67eebe9ae9212258bb0134ae60416fdf76" +dependencies = [ + "libc", + "libloading", + "pkg-config", +] + +[[package]] +name = "khronos_api" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc" + [[package]] name = "lazy_static" version = "1.5.0" @@ -653,12 +1014,47 @@ version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +[[package]] +name = "libloading" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +dependencies = [ + "cfg-if", + "windows-targets", +] + +[[package]] +name = "litrs" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ce301924b7887e9d637144fdade93f9dfff9b60981d4ac161db09720d39aa5" + +[[package]] +name = "lock_api" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "matchers" version = "0.1.0" @@ -684,6 +1080,21 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "metal" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f569fb946490b5743ad69813cb19629130ce9374034abe31614a36402d18f99e" +dependencies = [ + "bitflags 2.9.1", + "block", + "core-graphics-types", + "foreign-types", + "log", + "objc", + "paste", +] + [[package]] name = "minicov" version = "0.3.7" @@ -694,6 +1105,37 @@ dependencies = [ "walkdir", ] +[[package]] +name = "naga" +version = "24.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e380993072e52eef724eddfcde0ed013b0c023c3f0417336ed041aa9f076994e" +dependencies = [ + "arrayvec", + "bit-set", + "bitflags 2.9.1", + "cfg_aliases", + "codespan-reporting", + "hexf-parse", + "indexmap", + "log", + "rustc-hash", + "spirv", + "strum", + "termcolor", + "thiserror 2.0.12", + "unicode-xid", +] + +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "ndarray" version = "0.16.1" @@ -709,6 +1151,15 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "ndk-sys" +version = "0.5.0+25.2.9519653" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691" +dependencies = [ + "jni-sys", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -756,6 +1207,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -768,12 +1228,44 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "ordered-float" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +dependencies = [ + "num-traits", +] + [[package]] name = "overload" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking_lot" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "paste" version = "1.0.15" @@ -786,6 +1278,18 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "plotters" version = "0.3.7" @@ -814,6 +1318,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "pollster" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f3a9f18d041e6d0e102a0a46750538147e5e8992d3b4873aaafee2520b00ce3" + [[package]] name = "portable-atomic" version = "1.11.1" @@ -838,6 +1348,12 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "presser" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa" + [[package]] name = "proc-macro2" version = "1.0.93" @@ -847,6 +1363,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "profiling" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773" + [[package]] name = "quote" version = "1.0.38" @@ -882,6 +1404,18 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +[[package]] +name = "range-alloc" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde" + +[[package]] +name = "raw-window-handle" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" + [[package]] name = "rawpointer" version = "0.2.1" @@ -908,6 +1442,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6" +dependencies = [ + "bitflags 2.9.1", +] + [[package]] name = "regex" version = "1.11.1" @@ -952,6 +1495,12 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +[[package]] +name = "renderdoc-sys" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" + [[package]] name = "rfc6979" version = "0.4.0" @@ -962,6 +1511,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.1" @@ -992,6 +1547,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "semver" version = "1.0.25" @@ -1056,12 +1617,45 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "slab" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" + +[[package]] +name = "slotmap" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" +dependencies = [ + "version_check", +] + [[package]] name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spirv" +version = "0.3.0+sdk-1.3.268.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" +dependencies = [ + "bitflags 2.9.1", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -1122,6 +1716,34 @@ dependencies = [ "serde", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.96", +] + [[package]] name = "stwo-air-utils" version = "0.1.1" @@ -1153,6 +1775,7 @@ dependencies = [ "rayon", "stwo-prover", "tracing", + "web-sys", ] [[package]] @@ -1161,9 +1784,11 @@ version = "0.1.1" dependencies = [ "criterion", "educe", + "flume", "itertools 0.12.1", "ndarray", "num-traits", + "pollster", "rand", "rayon", "serde", @@ -1172,6 +1797,12 @@ dependencies = [ "test-log", "tracing", "tracing-subscriber", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-bindgen-test", + "wasm_thread", + "web-sys", + "wgpu", ] [[package]] @@ -1185,20 +1816,23 @@ dependencies = [ "cfg-if", "criterion", "educe", + "flume", "hex", "indexmap", "itertools 0.12.1", "num-traits", + "pollster", "rand", "rayon", "serde", "starknet-crypto", "starknet-ff", "test-log", - "thiserror", + "thiserror 1.0.69", "tracing", "tracing-subscriber", "wasm-bindgen-test", + "wgpu", ] [[package]] @@ -1229,6 +1863,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "test-log" version = "0.2.17" @@ -1257,7 +1900,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl 2.0.12", ] [[package]] @@ -1271,6 +1923,17 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -1364,6 +2027,18 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "utf8parse" version = "0.2.2" @@ -1493,6 +2168,17 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "wasm_thread" +version = "0.3.3" +source = "git+https://github.com/jaehunkim/wasm_thread?branch=main#36cddbcc73c87c3851a7a3fedd80f8ae64bad271" +dependencies = [ + "futures", + "js-sys", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.77" @@ -1503,6 +2189,115 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "wgpu" +version = "24.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b0b3436f0729f6cdf2e6e9201f3d39dc95813fad61d826c1ed07918b4539353" +dependencies = [ + "arrayvec", + "bitflags 2.9.1", + "cfg_aliases", + "document-features", + "js-sys", + "log", + "naga", + "parking_lot", + "profiling", + "raw-window-handle", + "smallvec", + "static_assertions", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "wgpu-core", + "wgpu-hal", + "wgpu-types", +] + +[[package]] +name = "wgpu-core" +version = "24.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f0aa306497a238d169b9dc70659105b4a096859a34894544ca81719242e1499" +dependencies = [ + "arrayvec", + "bit-vec", + "bitflags 2.9.1", + "cfg_aliases", + "document-features", + "indexmap", + "log", + "naga", + "once_cell", + "parking_lot", + "profiling", + "raw-window-handle", + "rustc-hash", + "smallvec", + "thiserror 2.0.12", + "wgpu-hal", + "wgpu-types", +] + +[[package]] +name = "wgpu-hal" +version = "24.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f112f464674ca69f3533248508ee30cb84c67cf06c25ff6800685f5e0294e259" +dependencies = [ + "android_system_properties", + "arrayvec", + "ash", + "bit-set", + "bitflags 2.9.1", + "block", + "bytemuck", + "cfg_aliases", + "core-graphics-types", + "glow", + "glutin_wgl_sys", + "gpu-alloc", + "gpu-allocator", + "gpu-descriptor", + "js-sys", + "khronos-egl", + "libc", + "libloading", + "log", + "metal", + "naga", + "ndk-sys", + "objc", + "once_cell", + "ordered-float", + "parking_lot", + "profiling", + "range-alloc", + "raw-window-handle", + "renderdoc-sys", + "rustc-hash", + "smallvec", + "thiserror 2.0.12", + "wasm-bindgen", + "web-sys", + "wgpu-types", + "windows", + "windows-core", +] + +[[package]] +name = "wgpu-types" +version = "24.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50ac044c0e76c03a0378e7786ac505d010a873665e2d51383dcff8dd227dc69c" +dependencies = [ + "bitflags 2.9.1", + "js-sys", + "log", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1534,6 +2329,70 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core", + "windows-targets", +] + +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-result", + "windows-strings", + "windows-targets", +] + +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -1616,6 +2475,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "xml-rs" +version = "0.8.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a62ce76d9b56901b19a74f19431b0d8b3bc7ca4ad685a746dfd78ca8f4fc6bda" + [[package]] name = "zerocopy" version = "0.7.35" diff --git a/Cargo.toml b/Cargo.toml index 3cd638623..006b8195a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,9 @@ num-traits = "0.2.17" thiserror = "1.0.56" bytemuck = "1.14.3" tracing = "0.1.40" +wgpu = "24.0" +flume = "0.11.1" +pollster = "0.4.0" tracing-subscriber = "0.3.18" rayon = { version = "1.10.0", optional = false } rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } diff --git a/crates/constraint_framework/Cargo.toml b/crates/constraint_framework/Cargo.toml index 0832f8f80..6baef17c4 100644 --- a/crates/constraint_framework/Cargo.toml +++ b/crates/constraint_framework/Cargo.toml @@ -4,7 +4,10 @@ version.workspace = true edition.workspace = true [features] -parallel = ["dep:rayon"] +parallel = [ + "dep:rayon", + "stwo-prover/parallel" +] [dependencies] rayon = { workspace = true, optional = true } @@ -13,3 +16,5 @@ itertools.workspace = true tracing.workspace = true stwo-prover = { path = "../prover" } rand.workspace = true +[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dependencies] +web-sys = { version = "0.3", features = ["console", "Performance", "Window"] } \ No newline at end of file diff --git a/crates/constraint_framework/src/component.rs b/crates/constraint_framework/src/component.rs index 4408b436e..0dcf99b89 100644 --- a/crates/constraint_framework/src/component.rs +++ b/crates/constraint_framework/src/component.rs @@ -18,6 +18,7 @@ use stwo_prover::core::backend::simd::very_packed_m31::{ VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS, }; use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::web::WebBackend; use stwo_prover::core::circle::CirclePoint; use stwo_prover::core::constraints::coset_vanishing; use stwo_prover::core::fields::m31::BaseField; @@ -29,9 +30,12 @@ use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::secure_column::SecureColumnByCoords; use stwo_prover::core::ColumnVec; use tracing::{span, Level}; +#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] +use web_sys::console; use super::cpu_domain::CpuDomainEvaluator; use super::preprocessed_columns::PreProcessedColumnId; +use super::web_domain::WebDomainEvaluator; use super::{ EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX, }; @@ -306,6 +310,8 @@ impl ComponentProver for FrameworkComponen .map(|idx| &trace.evals[PREPROCESSED_TRACE_IDX][*idx]) .collect(); + #[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] + console::time_with_label("simd-work-timer"); // Extend trace if necessary. // TODO: Don't extend when eval_size < committed_size. Instead, pick a good // subdomain. (For larger blowup factors). @@ -368,7 +374,7 @@ impl ComponentProver for FrameworkComponen let denom_inv = denom_inv[row >> trace_domain.log_size()]; col.set(row, col.at(row) + row_res * denom_inv) } - let col = SecureColumnByCoords::from_cpu(col); + let col = SecureColumnByCoords::::from_cpu(col); *accum.col = col; return; } @@ -421,6 +427,80 @@ impl ComponentProver for FrameworkComponen } } }); + + #[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] + console::time_end_with_label("simd-work-timer"); + } +} + +impl ComponentProver for FrameworkComponent { + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &Trace<'_, WebBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + ) { + if self.n_constraints() == 0 { + return; + } + + let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); + let trace_domain = CanonicCoset::new(self.eval.log_size()); + + let mut component_polys = trace.polys.sub_tree(&self.trace_locations); + component_polys[PREPROCESSED_TRACE_IDX] = self + .preprocessed_column_indices + .iter() + .map(|idx| &trace.polys[PREPROCESSED_TRACE_IDX][*idx]) + .collect(); + + let mut component_evals = trace.evals.sub_tree(&self.trace_locations); + component_evals[PREPROCESSED_TRACE_IDX] = self + .preprocessed_column_indices + .iter() + .map(|idx| &trace.evals[PREPROCESSED_TRACE_IDX][*idx]) + .collect(); + + // Extend trace if necessary. + // TODO: Don't extend when eval_size < committed_size. Instead, pick a good + // subdomain. (For larger blowup factors). + let need_to_extend = component_evals + .iter() + .flatten() + .any(|c| c.domain != eval_domain); + + // Denom inverses. + let log_expand = eval_domain.log_size() - trace_domain.log_size(); + let mut denom_inv = (0..1 << log_expand) + .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) + .collect_vec(); + bit_reverse(&mut denom_inv); + + // Accumulator. + let [mut accum] = + evaluation_accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); + accum.random_coeff_powers.reverse(); + + let _span = span!(Level::INFO, "Constraint point-wise eval").entered(); + + let component_polys = component_polys.as_cols_ref().map_cols(|c| c.as_ref()); + let component_evals = component_evals.as_cols_ref().map_cols(|c| c.as_ref()); + let col = + unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col.as_mut()) }; + + // Use WebGPU for heavy computations + let eval = WebDomainEvaluator::new( + &component_polys, + &component_evals, + need_to_extend, + col, + accum.random_coeff_powers.clone(), + eval_domain, + trace_domain.log_size(), + denom_inv, + self.eval.log_size(), + self.claimed_sum, + ); + self.eval.evaluate(eval); } } diff --git a/crates/constraint_framework/src/lib.rs b/crates/constraint_framework/src/lib.rs index 9fa8d4cde..47ac734e9 100644 --- a/crates/constraint_framework/src/lib.rs +++ b/crates/constraint_framework/src/lib.rs @@ -10,6 +10,7 @@ mod point; pub mod preprocessed_columns; pub mod relation_tracker; mod simd_domain; +mod web_domain; use std::array; use std::fmt::Debug; @@ -26,6 +27,7 @@ use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::qm31::{SecureField, SECURE_EXTENSION_DEGREE}; use stwo_prover::core::fields::FieldExpOps; use stwo_prover::core::lookups::utils::Fraction; +pub use web_domain::WebDomainEvaluator; #[rustfmt::skip] pub use stwo_prover::core::prover::PREPROCESSED_TRACE_IDX; @@ -295,7 +297,7 @@ impl<'a, F: Clone, EF: RelationEFTraitBound, R: Relation> RelationEntr macro_rules! relation { ($name:tt, $size:tt) => { #[derive(Clone, Debug, PartialEq)] - pub struct $name($crate::logup::LookupElements<$size>); + pub struct $name(pub $crate::logup::LookupElements<$size>); #[allow(dead_code)] impl $name { diff --git a/crates/constraint_framework/src/web_domain.rs b/crates/constraint_framework/src/web_domain.rs new file mode 100644 index 000000000..cfc509502 --- /dev/null +++ b/crates/constraint_framework/src/web_domain.rs @@ -0,0 +1,88 @@ +use std::ops::Mul; + +use stwo_prover::core::backend::simd::column::VeryPackedSecureColumnByCoords; +use stwo_prover::core::backend::simd::very_packed_m31::{ + VeryPackedBaseField, VeryPackedSecureField, +}; +use stwo_prover::core::backend::web::WebBackend; +use stwo_prover::core::fields::m31::{BaseField, M31}; +use stwo_prover::core::fields::qm31::{SecureField, SECURE_EXTENSION_DEGREE}; +use stwo_prover::core::lookups::utils::Fraction; +use stwo_prover::core::pcs::TreeVec; +use stwo_prover::core::poly::circle::{CircleDomain, CircleEvaluation, CirclePoly}; +use stwo_prover::core::poly::BitReversedOrder; + +use super::logup::LogupAtRow; +use super::{EvalAtRow, INTERACTION_TRACE_IDX}; + +/// Dummy evaluator for WebGPU. +pub struct WebDomainEvaluator<'a> { + pub trace_poly: &'a TreeVec>>, + pub trace_eval: &'a TreeVec>>, + pub needs_to_extend: bool, + pub col: &'a mut VeryPackedSecureColumnByCoords, + pub random_coeff_powers: Vec, + pub eval_domain: CircleDomain, + pub trace_domain_log_size: u32, + pub denom_inv: Vec, + pub claimed_sum: SecureField, + pub log_size: u32, + pub logup: LogupAtRow, +} + +impl<'a> WebDomainEvaluator<'a> { + pub fn new( + trace_poly: &'a TreeVec>>, + trace_eval: &'a TreeVec>>, + needs_to_extend: bool, + col: &'a mut VeryPackedSecureColumnByCoords, + random_coeff_powers: Vec, + eval_domain: CircleDomain, + trace_domain_log_size: u32, + denom_inv: Vec, + log_size: u32, + claimed_sum: SecureField, + ) -> Self { + Self { + trace_poly, + trace_eval, + needs_to_extend, + col, + random_coeff_powers, + eval_domain, + trace_domain_log_size, + denom_inv, + claimed_sum, + log_size, + logup: LogupAtRow::new(INTERACTION_TRACE_IDX, claimed_sum, log_size), + } + } +} + +/// Dummy implementation for WebGPU. These methods will be implemented as WGSL code, so they don't +/// need to be implemented here. +#[allow(unused_variables)] +impl EvalAtRow for WebDomainEvaluator<'_> { + type F = VeryPackedBaseField; + type EF = VeryPackedSecureField; + + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + unimplemented!() + } + fn add_constraint(&mut self, constraint: G) + where + Self::EF: Mul + From, + { + unimplemented!() + } + + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + unimplemented!() + } + + super::logup_proxy!(); +} diff --git a/crates/examples/Cargo.toml b/crates/examples/Cargo.toml index ebb5beb40..0798d0c00 100644 --- a/crates/examples/Cargo.toml +++ b/crates/examples/Cargo.toml @@ -4,7 +4,10 @@ version.workspace = true edition.workspace = true [features] -parallel = ["dep:rayon"] +parallel = [ + "dep:rayon", + "stwo-constraint-framework/parallel" +] tracing = [] slow-tests = [] @@ -22,6 +25,14 @@ stwo-prover = { path = "../prover" } stwo-constraint-framework = { path = "../constraint_framework" } rand.workspace = true educe.workspace = true +wgpu.workspace = true +flume.workspace = true +pollster.workspace = true +[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dependencies] +web-sys = { version = "0.3", features = ["console", "Performance", "Window"] } +wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4.50" +wasm_thread = { git = "https://github.com/jaehunkim/wasm_thread", branch = "main" } [dev-dependencies] criterion = { default-features = false, features = [ @@ -29,6 +40,8 @@ criterion = { default-features = false, features = [ ], version = "0.5.1" } ndarray = "0.16.1" test-log = { version = "0.2.15", features = ["trace"] } +[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dev-dependencies] +wasm-bindgen-test = "0.3.43" [[bench]] harness = false diff --git a/crates/examples/src/poseidon/mod.rs b/crates/examples/src/poseidon/mod.rs index 7409f3ad8..fc1a044fb 100644 --- a/crates/examples/src/poseidon/mod.rs +++ b/crates/examples/src/poseidon/mod.rs @@ -1,5 +1,6 @@ //! AIR for Poseidon2 hash function from . +use std::any::type_name; use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; @@ -9,12 +10,13 @@ use rayon::prelude::*; use stwo_constraint_framework::logup::LogupTraceGenerator; use stwo_constraint_framework::{ relation, EvalAtRow, FrameworkComponent, FrameworkEval, Relation, RelationEntry, - TraceLocationAllocator, + TraceLocationAllocator, WebDomainEvaluator, }; use stwo_prover::core::backend::simd::column::BaseColumn; -use stwo_prover::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use stwo_prover::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use stwo_prover::core::backend::simd::qm31::PackedSecureField; use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::web::WebBackend; use stwo_prover::core::backend::{Col, Column}; use stwo_prover::core::channel::Blake2sChannel; use stwo_prover::core::fields::m31::BaseField; @@ -27,6 +29,17 @@ use stwo_prover::core::prover::{prove, StarkProof}; use stwo_prover::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; use stwo_prover::core::ColumnVec; use tracing::{info, span, Level}; +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] +use {crate::poseidon::web::gpu_channels, web_sys::console}; + +use crate::poseidon::web::eval_composition_poly::build_gpu_input; +#[cfg(not(target_family = "wasm"))] +use crate::poseidon::web::eval_composition_poly::{ + compute_composition_polynomial_wgpu, GpuContext, +}; +use crate::poseidon::web::ComputeCompositionPolynomialOutput; + +mod web; const N_LOG_INSTANCES_PER_ROW: usize = 3; const N_INSTANCES_PER_ROW: usize = 1 << N_LOG_INSTANCES_PER_ROW; @@ -60,12 +73,35 @@ impl FrameworkEval for PoseidonEval { fn max_constraint_log_degree_bound(&self) -> u32 { self.log_n_rows + LOG_EXPAND } - fn evaluate(&self, mut eval: E) -> E { - eval_poseidon_constraints(&mut eval, &self.lookup_elements); + fn evaluate(&self, mut eval: E) -> E + where + E: HasDomainTypeName, + { + let web_name = type_name::>(); + + if eval.type_name() == web_name { + eval_poseidon_constraints_web(&mut eval, &self.lookup_elements); + } else { + eval_poseidon_constraints(&mut eval, &self.lookup_elements); + } eval } } +pub trait HasDomainTypeName { + fn type_name(&self) -> &str; + fn as_ptr(&self) -> *const (); +} + +impl HasDomainTypeName for T { + fn type_name(&self) -> &str { + type_name::() + } + fn as_ptr(&self) -> *const () { + self as *const T as *const () + } +} + #[inline(always)] /// Applies the M4 MDS matrix described in 5.1. fn apply_m4(x: [F; 4]) -> [F; 4] @@ -139,6 +175,64 @@ fn pow5(x: F) -> F { x4 * x.clone() } +pub fn eval_poseidon_constraints_web( + eval: &mut E, + lookup_elements: &PoseidonElements, +) { + let web: &mut WebDomainEvaluator<'_> = + unsafe { &mut *(eval as *mut E as *mut WebDomainEvaluator<'_>) }; + + #[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] + console::time_with_label("wgpu-work-timer"); + + #[cfg(not(all(target_arch = "wasm32", not(target_os = "wasi"))))] + let start = std::time::Instant::now(); + + let output = compute_composition_polynomial_gpu(web, lookup_elements); + + for (chunk_idx, chunk) in output.poly.iter().enumerate() { + for (lane_idx, &qm) in chunk.iter().enumerate() { + let idx = chunk_idx * N_LANES as usize + lane_idx; + web.col.columns[0].set(idx, qm.0[0].into()); + web.col.columns[1].set(idx, qm.0[1].into()); + web.col.columns[2].set(idx, qm.0[2].into()); + web.col.columns[3].set(idx, qm.0[3].into()); + } + } + + #[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] + console::time_end_with_label("wgpu-work-timer"); + + #[cfg(not(all(target_arch = "wasm32", not(target_os = "wasi"))))] + println!("wgpu-work-timer: {:?}", start.elapsed()); +} + +#[cfg(not(target_family = "wasm"))] +fn compute_composition_polynomial_gpu( + web: &mut WebDomainEvaluator<'_>, + lookup_elements: &PoseidonElements, +) -> Box { + let gpu = pollster::block_on(GpuContext::new()); + let web_input = build_gpu_input(web, lookup_elements); + let out = pollster::block_on(compute_composition_polynomial_wgpu(web_input, &gpu)); + + out +} + +#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] +fn compute_composition_polynomial_gpu( + web: &mut WebDomainEvaluator<'_>, + lookup_elements: &PoseidonElements, +) -> Box { + let web_input = build_gpu_input(web, lookup_elements); + + gpu_channels::with(|c| c.tx.send(web_input).unwrap()); + + let output = gpu_channels::with(|c| c.rx.recv().unwrap()); + + output +} + pub fn eval_poseidon_constraints(eval: &mut E, lookup_elements: &PoseidonElements) { for _ in 0..N_INSTANCES_PER_ROW { let mut state: [_; N_STATE] = std::array::from_fn(|_| eval.next_trace_mask()); @@ -392,6 +486,74 @@ pub fn prove_poseidon( (component, proof) } +pub fn prove_poseidon_web( + log_n_instances: u32, + config: PcsConfig, +) -> (PoseidonComponent, StarkProof) { + assert!(log_n_instances >= N_LOG_INSTANCES_PER_ROW as u32); + let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32; + + // Precompute twiddles. + let span = span!(Level::INFO, "Precompute twiddles").entered(); + let twiddles = WebBackend::precompute_twiddles( + CanonicCoset::new(log_n_rows + LOG_EXPAND + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + span.exit(); + + // Setup protocol. + let channel = &mut Blake2sChannel::default(); + let mut commitment_scheme = + CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + + // Preprocessed trace. + let span = span!(Level::INFO, "Constant").entered(); + let mut tree_builder = commitment_scheme.tree_builder(); + let constant_trace = vec![]; + tree_builder.extend_evals(constant_trace); + tree_builder.commit(channel); + span.exit(); + + // Trace. + let span = span!(Level::INFO, "Trace").entered(); + let (trace, lookup_data) = gen_trace(log_n_rows); + let trace: Vec> = + trace.into_iter().map(|eval| eval.into()).collect(); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(channel); + span.exit(); + + // Draw lookup elements. + let lookup_elements = PoseidonElements::draw(channel); + + // Interaction trace. + let span = span!(Level::INFO, "Interaction").entered(); + let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, lookup_data, &lookup_elements); + let trace: Vec> = + trace.into_iter().map(|eval| eval.into()).collect(); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(channel); + span.exit(); + + // Prove constraints. + let component = PoseidonComponent::new( + &mut TraceLocationAllocator::default(), + PoseidonEval { + log_n_rows, + lookup_elements, + claimed_sum, + }, + claimed_sum, + ); + info!("Poseidon component info:\n{}", component); + let proof = prove(&[&component], channel, commitment_scheme).unwrap(); + + (component, proof) +} + #[cfg(test)] mod tests { use std::{array, env}; @@ -406,23 +568,39 @@ mod tests { use stwo_prover::core::poly::circle::CanonicCoset; use stwo_prover::core::prover::verify; use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; + #[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))] + use { + crate::poseidon::web::gpu_channels, + crate::poseidon::web::runner::runner_eval_composition_polynomial, + crate::poseidon::web::{ + ComputeCompositionPolynomialInput, ComputeCompositionPolynomialOutput, + }, + wasm_bindgen_futures::spawn_local, + wasm_thread as thread, + web_sys::console, + }; use crate::poseidon::{ apply_internal_round_matrix, apply_m4, eval_poseidon_constraints, gen_interaction_trace, - gen_trace, prove_poseidon, PoseidonElements, + gen_trace, prove_poseidon, prove_poseidon_web, PoseidonElements, }; + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] #[wasm_bindgen_test::wasm_bindgen_test] fn test_poseidon_prove_wasm() { - const LOG_N_INSTANCES: u32 = 10; + const LOG_N_INSTANCES: u32 = 17; let config = PcsConfig { pow_bits: 10, fri_config: FriConfig::new(5, 1, 64), }; // Prove. + console::time_with_label("simd_poseidon_prove_wasm"); prove_poseidon(LOG_N_INSTANCES, config); + console::time_end_with_label("simd_poseidon_prove_wasm"); } #[test] @@ -556,4 +734,82 @@ mod tests { println!("{}", csv); } + + fn web_poseidon_prove(log_n_instances: u32, config: PcsConfig) { + // Prove.; + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] + console::time_with_label("web_poseidon_prove"); + let (component, proof) = prove_poseidon_web(log_n_instances, config); + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] + console::time_end_with_label("web_poseidon_prove"); + + // Verify. + // TODO: Create Air instance independently. + let channel = &mut Blake2sChannel::default(); + let commitment_scheme = + &mut CommitmentSchemeVerifier::::new(proof.config); + + // Decommit. + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = component.trace_log_degree_bounds(); + + // Preprocessed columns. + commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); + // Trace columns. + commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); + // Draw lookup element. + let lookup_elements = PoseidonElements::draw(channel); + assert_eq!(lookup_elements, component.lookup_elements); + // Interaction columns. + commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); + + verify(&[&component], channel, commitment_scheme, proof).unwrap(); + } + + #[test_log::test] + fn test_web_poseidon_prove() { + let log_n_instances = 17; + let config = PcsConfig { + pow_bits: 10, + fri_config: FriConfig::new(5, 1, 64), + }; + + web_poseidon_prove(log_n_instances, config); + } + + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] + #[wasm_bindgen_test::wasm_bindgen_test] + #[allow(dead_code)] + async fn test_web_poseidon_prove_runner() { + let (request_tx, request_rx) = flume::bounded::>(1); + let (response_tx, response_rx) = + flume::bounded::>(1); + let (job_tx, job_rx) = flume::bounded(1); + + // spawn runner + thread::spawn(move || { + spawn_local(async move { + runner_eval_composition_polynomial(request_rx, response_tx).await; + }); + }); + + // spawn caller + thread::spawn(move || { + gpu_channels::init_gpu_channels(request_tx, response_rx); + + let log_n_instances = 17; + let config = PcsConfig { + pow_bits: 10, + fri_config: FriConfig::new(5, 1, 64), + }; + + web_poseidon_prove(log_n_instances, config); + + job_tx.send(()).unwrap(); + }); + + // wait for worker to finish + let _ = job_rx.recv_async().await.unwrap(); + wasm_thread::terminate_all_workers(); + } } diff --git a/crates/examples/src/poseidon/web/constants.rs b/crates/examples/src/poseidon/web/constants.rs new file mode 100644 index 000000000..81532f65c --- /dev/null +++ b/crates/examples/src/poseidon/web/constants.rs @@ -0,0 +1,21 @@ +// TODO: refactor this to const generics +pub const N_ROWS: u32 = 1024; +pub const N_CONSTRAINTS: u32 = 1144; + +pub const N_STATE: u32 = 16; +pub const N_LOG_INSTANCES_PER_ROW: u32 = 3; +pub const N_INSTANCES_PER_ROW: u32 = 1 << N_LOG_INSTANCES_PER_ROW; +pub const N_LANES: u32 = 16; +pub const N_EXTENDED_ROWS: u32 = N_ROWS * 4; +pub const N_ORIGINAL_ROWS: u32 = N_ROWS; +pub const N_COLUMNS: u32 = 1264; +pub const N_INTERACTION_COLUMNS: u32 = N_INSTANCES_PER_ROW * 4; +#[allow(dead_code)] +pub const N_WORKGROUPS: u32 = N_EXTENDED_ROWS * N_LANES / THREADS_PER_WORKGROUP; +#[allow(dead_code)] +pub const THREADS_PER_WORKGROUP: u32 = 256; + +pub const N_LINE_TWIDDLES_SIZE: u32 = 32; +pub const N_LINE_TWIDDLES_FLAT_SIZE: u32 = N_EXTENDED_ROWS * N_LANES / 2; +pub const N_CIRCLE_TWIDDLES_SIZE: u32 = N_LINE_TWIDDLES_FLAT_SIZE + 1; +pub const N_ORIGINAL_TRACE_COLUMNS: u32 = N_COLUMNS + N_INTERACTION_COLUMNS; diff --git a/crates/examples/src/poseidon/web/constants.wgsl b/crates/examples/src/poseidon/web/constants.wgsl new file mode 100644 index 000000000..f56c6e1ca --- /dev/null +++ b/crates/examples/src/poseidon/web/constants.wgsl @@ -0,0 +1,32 @@ +const N_ROWS: u32 = ${N_ROWS}; +const N_CONSTRAINTS: u32 = ${N_CONSTRAINTS}; + +const N_EXTENDED_ROWS: u32 = N_ROWS * 4; +const N_STATE: u32 = 16; +const N_INSTANCES_PER_ROW: u32 = 8; +const N_TOTAL_FRACS: u32 = N_INSTANCES_PER_ROW * 2; +const N_COLUMNS: u32 = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; +const N_INTERACTION_COLUMNS: u32 = N_INSTANCES_PER_ROW * 4; +const N_HALF_FULL_ROUNDS: u32 = 4; +const FULL_ROUNDS: u32 = 2u * N_HALF_FULL_ROUNDS; +const N_PARTIAL_ROUNDS: u32 = 14; +const N_LANES: u32 = 16; +const N_COLUMNS_PER_REP: u32 = N_STATE * (1 + FULL_ROUNDS) + N_PARTIAL_ROUNDS; +const N_WORKGROUPS: u32 = N_EXTENDED_ROWS * N_LANES / THREADS_PER_WORKGROUP; +const THREADS_PER_WORKGROUP: u32 = 256; + +const R: CM31 = CM31(M31(2u), M31(1u)); +const ONE = QM31(CM31(M31(1u), M31(0u)), CM31(M31(0u), M31(0u))); +const N_ORIGINAL_COLUMN_SIZE: u32 = N_LANES * N_ROWS; +const N_EXTENDED_COLUMN_SIZE: u32 = N_LANES * N_EXTENDED_ROWS; + +const N_LINE_TWIDDLES_SIZE: u32 = 32; +const N_LINE_TWIDDLES_FLAT_SIZE: u32 = N_EXTENDED_ROWS * N_LANES / 2; +const N_CIRCLE_TWIDDLES_SIZE: u32 = N_LINE_TWIDDLES_FLAT_SIZE + 1; +const N_ORIGINAL_TRACE_COLUMNS: u32 = N_COLUMNS + N_INTERACTION_COLUMNS; + +const N_PREPROCESSED_TRACE_OFFSET: u32 = 0u; +const N_EXTENDED_TRACE_OFFSET: u32 = N_PREPROCESSED_TRACE_OFFSET; +const N_INTERACTION_TRACE_OFFSET: u32 = N_EXTENDED_TRACE_OFFSET + N_COLUMNS; + +const N_MAX_WORKGROUP_STORAGE_SIZE: u32 = 32 << 10; \ No newline at end of file diff --git a/crates/examples/src/poseidon/web/eval_composition_poly.rs b/crates/examples/src/poseidon/web/eval_composition_poly.rs new file mode 100644 index 000000000..1e6dc84cf --- /dev/null +++ b/crates/examples/src/poseidon/web/eval_composition_poly.rs @@ -0,0 +1,339 @@ +use std::collections::HashMap; +use std::mem::MaybeUninit; +use std::ptr; + +use itertools::Itertools; +use stwo_constraint_framework::WebDomainEvaluator; +use stwo_prover::core::backend::cpu::circle::circle_twiddles_from_line_twiddles; +use stwo_prover::core::backend::web::webgpu::qm31::{GpuM31, GpuQM31}; +use stwo_prover::core::backend::web::webgpu::ByteSerialize; +use stwo_prover::core::backend::{Column, CpuBackend}; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::poly::circle::PolyOps; +use stwo_prover::core::poly::utils::domain_line_twiddles_from_tree; + +use crate::poseidon::web::*; +use crate::poseidon::PoseidonElements; + +#[allow(dead_code)] +pub struct GpuContext { + pub instance: wgpu::Instance, + pub adapter: wgpu::Adapter, + pub device: wgpu::Device, + pub queue: wgpu::Queue, + + pub input_buffer: wgpu::Buffer, + pub output_buffer: wgpu::Buffer, + pub staging_buffer: wgpu::Buffer, + + pub bind_group: wgpu::BindGroup, + + pub extend_trace_pipeline: wgpu::ComputePipeline, + pub composition_polynomial_pipeline: wgpu::ComputePipeline, +} + +impl GpuContext { + /// Create a fully‑initialised GPU context ready for compute work. + pub async fn new() -> Self { + // ── adapter / device ──────────────────────────────────────────────── + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let mut limits = wgpu::Limits::default(); + let ext_size = std::mem::size_of::() as u64; + limits.max_storage_buffer_binding_size = limits + .max_storage_buffer_binding_size + .max(ext_size as u32 + 1); + limits.max_buffer_size = limits.max_buffer_size.max(ext_size + 1); + limits.max_compute_workgroup_storage_size = 32 << 10; // 32 KiB. + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("stwo‑device"), + required_features: wgpu::Features::empty(), + required_limits: limits, + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + // ── shaders ───────────────────────────────────────────────────────── + let constants = include_str!("constants.wgsl") + .replace("${N_ROWS}", &N_ROWS.to_string()) + .replace("${N_CONSTRAINTS}", &N_CONSTRAINTS.to_string()); + let qm31 = include_str!("../../../../prover/src/core/backend/web/webgpu/qm31.wgsl"); + let utils = include_str!("../../../../prover/src/core/backend/web/webgpu/utils.wgsl"); + let extend = include_str!("extend_trace.wgsl"); + let comp_poly = include_str!("eval_composition_poly.wgsl"); + + let mk_shader = |label: &str, src: String| -> wgpu::ShaderModule { + device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some(label), + source: wgpu::ShaderSource::Wgsl(src.into()), + }) + }; + + let extend_mod = mk_shader( + "extend-trace-shader", + format!("{constants}\n{qm31}\n{utils}\n{extend}"), + ); + let comp_mod = mk_shader( + "comp-poly-shader", + format!("{constants}\n{qm31}\n{utils}\n{comp_poly}"), + ); + + // ── buffers ───────────────────────────────────────────────────────── + let input = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("input-buffer"), + size: std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + let output = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("output-buffer"), + size: std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + let staging = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("staging-buffer"), + size: output.size(), + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + let extend_buf = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("extend-trace-buffer"), + size: ext_size, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // ── bind group & layout ───────────────────────────────────────────── + let layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("compute-layout"), + entries: &[ + Self::storage_entry(0, true), // input + Self::storage_entry(1, false), // output + Self::storage_entry(2, false), // extend trace + ], + }); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("compute-bind-group"), + layout: &layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: input.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: output.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: extend_buf.as_entire_binding(), + }, + ], + }); + + // ── pipelines ─────────────────────────────────────────────────────── + let pl_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("compute-pl-layout"), + bind_group_layouts: &[&layout], + push_constant_ranges: &[], + }); + + let opts = wgpu::PipelineCompilationOptions { + constants: &HashMap::new(), + zero_initialize_workgroup_memory: true, + }; + let extend_trace_pipeline = + device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("extend-trace-pipeline"), + layout: Some(&pl_layout), + module: &extend_mod, + entry_point: Some("evaluate_line_twiddle_per_poly32"), + cache: None, + compilation_options: opts.clone(), + }); + let composition_polynomial_pipeline = + device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("composition-polynomial-pipeline"), + layout: Some(&pl_layout), + module: &comp_mod, + entry_point: Some("compute_composition_polynomial"), + cache: None, + compilation_options: opts, + }); + + Self { + instance, + adapter, + device, + queue, + input_buffer: input, + output_buffer: output, + staging_buffer: staging, + bind_group, + extend_trace_pipeline, + composition_polynomial_pipeline, + } + } + + #[inline] + fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry { + wgpu::BindGroupLayoutEntry { + binding, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + } + } + + /// Encode both compute passes and copy result into CPU‑visible staging buffer. + fn encode_compute(&self) -> wgpu::CommandEncoder { + let mut enc = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("comp-poly-encoder"), + }); + + { + let mut pass = enc.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("comp-poly-pass"), + timestamp_writes: None, + }); + pass.set_bind_group(0, &self.bind_group, &[]); + + // Pass 1: extend trace + pass.set_pipeline(&self.extend_trace_pipeline); + pass.dispatch_workgroups(1, (N_ORIGINAL_TRACE_COLUMNS + 15) / 16, 1); + + // Pass 2: composition polynomial + pass.set_pipeline(&self.composition_polynomial_pipeline); + pass.dispatch_workgroups(N_WORKGROUPS, 1, 1); + } + + enc.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + enc + } +} + +#[allow(dead_code)] +pub async fn compute_composition_polynomial_wgpu( + input: Box, + gpu: &GpuContext, +) -> Box { + // Upload input + gpu.queue + .write_buffer(&gpu.input_buffer, 0, &input.as_bytes()); + + let encoder = gpu.encode_compute(); + gpu.queue.submit(Some(encoder.finish())); + + // Wait for GPU completion and map the staging buffer for readback. + let slice = gpu.staging_buffer.slice(..); + let (tx, rx) = flume::bounded(1); + slice.map_async(wgpu::MapMode::Read, move |v| tx.send(v).unwrap()); + gpu.device.poll(wgpu::Maintain::wait()).panic_on_timeout(); + + let _ = rx.recv_async().await.unwrap(); + let data = slice.get_mapped_range(); + let output = ComputeCompositionPolynomialOutput::from_bytes_box(&data); + drop(data); + gpu.staging_buffer.unmap(); + + output +} + +fn alloc_default_gpu_input() -> Box { + let mut boxed = Box::>::new_uninit(); + let out: *mut ComputeCompositionPolynomialInput = boxed.as_mut_ptr().cast(); + unsafe { + ptr::write_bytes(out, 0, 1); + Box::from_raw(Box::into_raw(boxed).cast()) + } +} + +pub fn build_gpu_input( + eval: &mut WebDomainEvaluator<'_>, + lookup_elements: &PoseidonElements, +) -> Box { + let mut inp = alloc_default_gpu_input(); + + eval.trace_poly + .iter() + .flatten() + .enumerate() + .for_each(|(col, poly)| { + inp.original_trace[col] + .coeffs + .iter_mut() + .zip(poly.coeffs.to_cpu()) + .for_each(|(dst, src)| *dst = src.into()); + }); + + let tw = CpuBackend::precompute_twiddles(eval.eval_domain.half_coset); + let line_tw = domain_line_twiddles_from_tree(eval.eval_domain, &tw.twiddles); + + inp.twiddles.line_twiddles_layer_count = line_tw.len() as u32; + let mut offset = 0usize; + for (layer_idx, layer) in line_tw.iter().enumerate() { + inp.twiddles.line_twiddles_sizes[layer_idx] = layer.len() as u32; + inp.twiddles.line_twiddles_offsets[layer_idx] = offset as u32; + + for (j, &el) in layer.iter().enumerate() { + inp.twiddles.line_twiddles_flat[offset + j] = GpuM31::from(el); + } + offset += layer.len(); + } + + let circle = circle_twiddles_from_line_twiddles(line_tw[0]); + let circle_len = circle.try_len().unwrap(); + for (i, tw) in circle.enumerate() { + inp.twiddles.circle_twiddles[i] = GpuM31::from(tw); + } + inp.twiddles.circle_twiddles_size = circle_len as u32; + + for i in 0..4 { + inp.denom_inv[i] = GpuM31::from(eval.denom_inv[i]); + } + for (i, &p) in eval + .random_coeff_powers + .iter() + .enumerate() + .take(N_CONSTRAINTS as usize) + { + inp.random_coeff_powers[i] = GpuQM31::from(p); + } + + inp.lookup_elements = GpuLookupElements::from(lookup_elements); + inp.trace_domain_log_size = eval.trace_domain_log_size; + inp.eval_domain_log_size = eval.eval_domain.log_size(); + inp.cumsum_shift = + (eval.claimed_sum / BaseField::from_u32_unchecked(1 << eval.log_size)).into(); + + inp +} diff --git a/crates/examples/src/poseidon/web/eval_composition_poly.wgsl b/crates/examples/src/poseidon/web/eval_composition_poly.wgsl new file mode 100644 index 000000000..31db9c634 --- /dev/null +++ b/crates/examples/src/poseidon/web/eval_composition_poly.wgsl @@ -0,0 +1,332 @@ +// Note: depends on qm31.wgsl, utils.wgsl + +// Initialize EXTERNAL_ROUND_CONSTS with explicit values +const EXTERNAL_ROUND_CONSTS: array, FULL_ROUNDS> = array, FULL_ROUNDS>( + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), +); + +// Initialize INTERNAL_ROUND_CONSTS with explicit values +const INTERNAL_ROUND_CONSTS: array = array( + 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234 +); + +struct OriginalColumn { + data: array, +} + +struct Extended1DColumn { + data: array, +} + +struct LookupElements { + z: QM31, + alpha: QM31, + alpha_powers: array, +} + +struct Twiddles { + circle_twiddles: array, + circle_twiddles_size: u32, + line_twiddles_flat: array, + line_twiddles_layer_count: u32, + line_twiddles_sizes: array, + line_twiddles_offsets: array, +} + +struct ComputeCompositionPolynomialInput { + original_trace: array, + twiddles: Twiddles, + denom_inv: array, + random_coeff_powers: array, + lookup_elements: LookupElements, + trace_domain_log_size: u32, + eval_domain_log_size: u32, + cumsum_shift: QM31, +} + +struct ComputeCompositionPolynomialOutput { + poly: array, N_EXTENDED_ROWS>, +} + +struct RelationEntry { + multiplicity: QM31, + values: array, +} + +struct ExtendTraceOutput { + extended_trace: array, +} + +struct State16 { + data: array +} + +@group(0) @binding(0) +var input: ComputeCompositionPolynomialInput; + +@group(0) @binding(1) +var output: ComputeCompositionPolynomialOutput; + +@group(0) @binding(2) +var extend_trace_output: ExtendTraceOutput; + +var constraint_index: u32 = 0u; + +var fracs_index: u32 = 0u; + +var fracs: array = array( + ZERO_FRACTION, ZERO_FRACTION, ZERO_FRACTION, ZERO_FRACTION, + ZERO_FRACTION, ZERO_FRACTION, ZERO_FRACTION, ZERO_FRACTION, + ZERO_FRACTION, ZERO_FRACTION, ZERO_FRACTION, ZERO_FRACTION, + ZERO_FRACTION, ZERO_FRACTION, ZERO_FRACTION, ZERO_FRACTION +); + +var is_finalized: bool = false; + +@compute @workgroup_size(THREADS_PER_WORKGROUP) +fn compute_composition_polynomial( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_invocation_id: vec3, + @builtin(global_invocation_id) global_invocation_id: vec3, + @builtin(local_invocation_index) local_invocation_index: u32, + @builtin(num_workgroups) num_workgroups: vec3, +) { + let workgroup_index = + workgroup_id.x + + workgroup_id.y * num_workgroups.x + + workgroup_id.z * num_workgroups.x * num_workgroups.y; + + let global_invocation_index = workgroup_index * THREADS_PER_WORKGROUP + local_invocation_index; // [0, 512) + + var vec_index = global_invocation_index / N_LANES; + var inner_vec_index = global_invocation_index % N_LANES; + var col_index = 0u; + + for (var rep_i = 0u; rep_i < N_INSTANCES_PER_ROW; rep_i++) { + var state: State16 = State16(array()); + for (var j = 0u; j < N_STATE; j++) { + state.data[j] = next_trace_mask(col_index, vec_index, inner_vec_index); + col_index += 1u; + } + var initial_state = state; + + // 4 full rounds + for (var i = 0u; i < N_HALF_FULL_ROUNDS; i++) { + for (var j = 0u; j < N_STATE; j++) { + state.data[j] = m31_add(state.data[j], M31(EXTERNAL_ROUND_CONSTS[i][j])); + } + state = apply_external_round_matrix_state16(state); + for (var j = 0u; j < N_STATE; j++) { + state.data[j] = m31_pow5(state.data[j]); + } + for (var j = 0u; j < N_STATE; j++) { + var m_1 = next_trace_mask(col_index, vec_index, inner_vec_index); + let constraint = m31_sub(state.data[j], m_1); + add_constraint(constraint, vec_index, inner_vec_index); + + state.data[j] = m_1; + col_index += 1u; + } + } + // Partial rounds + for (var i = 0u; i < N_PARTIAL_ROUNDS; i++) { + state.data[0] = m31_add(state.data[0], M31(INTERNAL_ROUND_CONSTS[i])); + state = apply_internal_round_matrix_state16(state); + state.data[0] = m31_pow5(state.data[0]); + var m_1 = next_trace_mask(col_index, vec_index, inner_vec_index); + let constraint = m31_sub(state.data[0], m_1); + add_constraint(constraint, vec_index, inner_vec_index); + + state.data[0] = m_1; + col_index += 1u; + } + // 4 full rounds + for (var i = 0u; i < N_HALF_FULL_ROUNDS; i++) { + for (var j = 0u; j < N_STATE; j++) { + state.data[j] = m31_add(state.data[j], M31(EXTERNAL_ROUND_CONSTS[i + N_HALF_FULL_ROUNDS][j])); + } + state = apply_external_round_matrix_state16(state); + for (var j = 0u; j < N_STATE; j++) { + state.data[j] = m31_pow5(state.data[j]); + } + for (var j = 0u; j < N_STATE; j++) { + var m_1 = next_trace_mask(col_index, vec_index, inner_vec_index); + let constraint = m31_sub(state.data[j], m_1); + add_constraint(constraint, vec_index, inner_vec_index); + state.data[j] = m_1; + col_index += 1u; + } + } + add_to_relation_single(RelationEntry(ONE, initial_state.data)); + add_to_relation_single(RelationEntry(qm31_neg(ONE), state.data)); + } + finalize_logup_in_pairs(vec_index, inner_vec_index); + + let row = vec_index * N_STATE + inner_vec_index; + let denom_inv = input.denom_inv[row >> input.trace_domain_log_size]; + output.poly[vec_index][inner_vec_index] = qm31_mul(output.poly[vec_index][inner_vec_index], QM31(CM31(denom_inv, M31(0u)), CM31(M31(0u), M31(0u)))); +} + +fn flatten_idx(vec_index: u32, inner_vec_index: u32) -> u32 { + return vec_index * N_LANES + inner_vec_index; +} + +fn add_constraint(constraint: M31, vec_index: u32, inner_vec_index: u32) { + add_constraint_qm31(QM31(CM31(constraint, M31(0u)), CM31(M31(0u), M31(0u))), vec_index, inner_vec_index); +} + +fn add_constraint_qm31(constraint: QM31, vec_index: u32, inner_vec_index: u32) { + var new_add = qm31_mul(constraint, input.random_coeff_powers[constraint_index]); + output.poly[vec_index][inner_vec_index] = qm31_add(output.poly[vec_index][inner_vec_index], new_add); + constraint_index += 1u; +} + +fn add_to_relation_single(entry: RelationEntry) { + var combined_value = QM31(CM31(M31(0u), M31(0u)), CM31(M31(0u), M31(0u))); + for (var j = 0u; j < N_STATE; j++) { + let value = QM31(CM31(entry.values[j], M31(0u)), CM31(M31(0u), M31(0u))); + combined_value = qm31_add(combined_value, qm31_mul(input.lookup_elements.alpha_powers[j], value)); + } + + combined_value = qm31_sub(combined_value, input.lookup_elements.z); + var frac = Fraction(entry.multiplicity, combined_value); + write_logup_frac_single(frac); +} + +fn write_logup_frac_single(frac: Fraction) { + if (fracs_index == 0u) { + is_finalized = false; + } + fracs[fracs_index] = frac; + fracs_index += 1u; +} + +fn finalize_logup_in_pairs(vec_index: u32, inner_vec_index: u32) { + if (is_finalized) { + return; + } + + var prev_col_cumsum = QM31(CM31(M31(0u), M31(0u)), CM31(M31(0u), M31(0u))); + var last_interaction_col_index = 0u; + + // All batches except the last are cumulatively summed in new interaction columns. + for (var i = 0u; i < fracs_index - 2u; i += 2u) { + var cur_frac = fraction_add(fracs[i], fracs[i + 1u]); + + var cur_cumsum = next_interaction_trace_mask(last_interaction_col_index, vec_index, inner_vec_index); + var diff = qm31_sub(cur_cumsum, prev_col_cumsum); + prev_col_cumsum = cur_cumsum; + var constraint = qm31_sub(qm31_mul(diff, cur_frac.denominator), cur_frac.numerator); + add_constraint_qm31(constraint, vec_index, inner_vec_index); + last_interaction_col_index += 4u; + } + + // last batch + let frac = fraction_add(fracs[fracs_index - 2u], fracs[fracs_index - 1u]); + + var cur_cumsum = next_interaction_trace_mask(last_interaction_col_index, vec_index, inner_vec_index); + var prev_row_cumsum = next_interaction_trace_mask_offset(last_interaction_col_index, vec_index, inner_vec_index, -1); + + var diff = qm31_sub(qm31_sub(cur_cumsum, prev_row_cumsum), prev_col_cumsum); + var fixed_diff = qm31_add(diff, input.cumsum_shift); + + var constraint = qm31_sub(qm31_mul(fixed_diff, frac.denominator), frac.numerator); + add_constraint_qm31(constraint, vec_index, inner_vec_index); + is_finalized = true; +} + +fn next_trace_mask(col_index: u32, vec_index: u32, inner_vec_index: u32) -> M31 { + let v0: M31 = extend_trace_output.extended_trace[col_index + N_EXTENDED_TRACE_OFFSET].data[flatten_idx(vec_index, inner_vec_index)]; + + return v0; +} + +fn next_interaction_trace_mask(col_index: u32, vec_index: u32, inner_vec_index: u32) -> QM31 { + let base = col_index + N_INTERACTION_TRACE_OFFSET; + let i = flatten_idx(vec_index, inner_vec_index); + let v0: M31 = extend_trace_output.extended_trace[base].data[i]; + let v1: M31 = extend_trace_output.extended_trace[base + 1].data[i]; + let v2: M31 = extend_trace_output.extended_trace[base + 2].data[i]; + let v3: M31 = extend_trace_output.extended_trace[base + 3].data[i]; + + return qm31_4(v0, v1, v2, v3); +} + +fn next_interaction_trace_mask_offset(col_index: u32, vec_index: u32, inner_vec_index: u32, offset: i32) -> QM31 { + var curr_row = vec_index * N_STATE + inner_vec_index; + + var row = offset_bit_reversed_circle_domain_index(curr_row, input.trace_domain_log_size, input.eval_domain_log_size, offset); + + var new_vec_index = row / N_LANES; + var new_inner_vec_index = row % N_LANES; + + let base = col_index + N_INTERACTION_TRACE_OFFSET; + let i = flatten_idx(new_vec_index, new_inner_vec_index); + let v0: M31 = extend_trace_output.extended_trace[base].data[i]; + let v1: M31 = extend_trace_output.extended_trace[base + 1].data[i]; + let v2: M31 = extend_trace_output.extended_trace[base + 2].data[i]; + let v3: M31 = extend_trace_output.extended_trace[base + 3].data[i]; + + let ret_val = QM31(CM31(v0, v1), CM31(v2, v3)); + return ret_val; +} + +fn apply_external_round_matrix_state16(state: State16) -> State16 { + var modified_state = state.data; + for (var i = 0u; i < 4u; i++) { + var x = array( + state.data[4 * i], + state.data[4 * i + 1], + state.data[4 * i + 2], + state.data[4 * i + 3], + ); + + let t0 = m31_add(x[0], x[1]); + let t02 = m31_add(t0, t0); + let t1 = m31_add(x[2], x[3]); + let t12 = m31_add(t1, t1); + let t2 = m31_add(m31_add(x[1], x[1]), t1); + let t3 = m31_add(m31_add(x[3], x[3]), t0); + let t4 = m31_add(m31_add(t12, t12), t3); + let t5 = m31_add(m31_add(t02, t02), t2); + let t6 = m31_add(t3, t5); + let t7 = m31_add(t2, t4); + + modified_state[4 * i] = t6; + modified_state[4 * i + 1] = t5; + modified_state[4 * i + 2] = t7; + modified_state[4 * i + 3] = t4; + } + for (var j = 0u; j < 4u; j++) { + let s = m31_add(m31_add(modified_state[j], modified_state[j + 4]), m31_add(modified_state[j + 8], modified_state[j + 12])); + for (var i = 0u; i < 4u; i++) { + modified_state[4 * i + j] = m31_add(modified_state[4 * i + j], s); + } + } + return State16(modified_state); +} + +// Applies the internal round matrix. +// mu_i = 2^{i+1} + 1. +// See 5.2. +fn apply_internal_round_matrix_state16(state: State16) -> State16 { + var sum = state.data[0]; + for (var i = 1u; i < N_STATE; i++) { + sum = m31_add(sum, state.data[i]); + } + + var result = State16(array()); + for (var i = 0u; i < N_STATE; i++) { + let factor = partial_reduce(1u << (i + 1)); + result.data[i] = m31_add(m31_mul(M31(factor), state.data[i]), sum); + } + + return result; +} diff --git a/crates/examples/src/poseidon/web/extend_trace.wgsl b/crates/examples/src/poseidon/web/extend_trace.wgsl new file mode 100644 index 000000000..ac71a9350 --- /dev/null +++ b/crates/examples/src/poseidon/web/extend_trace.wgsl @@ -0,0 +1,206 @@ +// Note: depends on qm31.wgsl, utils.wgsl constants.wgsl + +fn butterfly(v0: ptr, v1: ptr, twid: M31) { + let tmp = m31_mul(*v1, twid); + * v1 = m31_sub(*v0, tmp); + * v0 = m31_add(*v0, tmp); +} + +struct OriginalColumn { + data: array, +} + +struct Extended1DColumn { + data: array, +} + +struct LookupElements { + z: QM31, + alpha: QM31, + alpha_powers: array, +} + +struct ComputeCompositionPolynomialInput { + original_trace: array, + twiddles: Twiddles, + denom_inv: array, + random_coeff_powers: array, + lookup_elements: LookupElements, + trace_domain_log_size: u32, + eval_domain_log_size: u32, + total_sum: QM31, +} + +struct ComputeCompositionPolynomialOutput { + poly: array, N_EXTENDED_ROWS>, +} + +struct Twiddles { + circle_twiddles: array, + circle_twiddles_size: u32, + line_twiddles_flat: array, + line_twiddles_layer_count: u32, + line_twiddles_sizes: array, + line_twiddles_offsets: array, +} + +struct ExtendTraceOutput { + extended_trace: array, +} + +@group(0) @binding(0) +var trace_input: ComputeCompositionPolynomialInput; + +@group(0) @binding(1) +var composition_polynomial_output: ComputeCompositionPolynomialOutput; + +@group(0) @binding(2) +var trace_output: ExtendTraceOutput; + +//------------------------------------------------------------------------------ +// evaluate_line_twiddle_per_poly32 +// +// One work-item (thread) processes **one polynomial column**. +// +// Phase 1 – “large layers” : operate directly in device storage +// Phase 2 – “small layers + circle” : copy 256-u32 chunks into a +// thread-local scratch array, +// finish the remaining line-twiddle +// layers *and* the final circle-twiddle +// (step = 1), then write the chunk back. +//------------------------------------------------------------------------------ + +// Size of a thread-local scratch tile (256 u32 = 1 KiB). +// 32 threads × 1 KiB ≈ 32 KiB, matching Metal's per-threadgroup LDS budget. +const CHUNK_SIZE: u32 = 128u; + +@compute @workgroup_size(16) +fn evaluate_line_twiddle_per_poly32(@builtin(global_invocation_id) global_id: vec3) { + let poly_id = global_id.x + global_id.y * 16u; + if (poly_id >= N_ORIGINAL_TRACE_COLUMNS) { + return; + } + + //------------------------------------------------------------------------ + // 1. Copy coeffs → output evals (storage → storage, 1 : 1) + //------------------------------------------------------------------------ + for (var j: u32 = 0u; j < N_ORIGINAL_COLUMN_SIZE; j = j + 1u) { + trace_output.extended_trace[poly_id].data[j] = trace_input.original_trace[poly_id].data[j]; + } + + //------------------------------------------------------------------------ + // 2. Line-twiddle “large layers” – operate in place in storage + // Stop when a single butterfly block (step*2) fits in CHUNK_SIZE. + //------------------------------------------------------------------------ + let num_layers = trace_input.twiddles.line_twiddles_layer_count; + var layer = num_layers - 1u; + loop { + let step = 1u << (layer + 1u); + if (step * 2u <= CHUNK_SIZE) { + break; + } + let layer_size = trace_input.twiddles.line_twiddles_sizes[layer]; + let layer_offset = trace_input.twiddles.line_twiddles_offsets[layer]; + + // Iterate over all butterfly blocks in this layer + for (var h: u32 = 0u; h < layer_size; h = h + 1u) { + let t = trace_input.twiddles.line_twiddles_flat[layer_offset + h]; + let base_idx = h << (layer + 2u); + + // Plain Cooley–Tukey butterfly within the block + for (var l: u32 = 0u; l < step; l = l + 1u) { + let idx0 = base_idx + l; + let idx1 = idx0 + step; + + var v0 = trace_output.extended_trace[poly_id].data[idx0]; + var v1 = trace_output.extended_trace[poly_id].data[idx1]; + butterfly(&v0, &v1, t); + trace_output.extended_trace[poly_id].data[idx0] = v0; + trace_output.extended_trace[poly_id].data[idx1] = v1; + } + } + if (layer == 0u) { + break; + } + layer = layer - 1u; + } + + //------------------------------------------------------------------------ + // 3. Scratch-tile phase – finish remaining line layers + circle twiddle + //------------------------------------------------------------------------ + let num_chunks = (N_EXTENDED_COLUMN_SIZE + CHUNK_SIZE - 1u) / CHUNK_SIZE; + var scratch: array; + for (var chunk_id: u32 = 0u; chunk_id < num_chunks; chunk_id = chunk_id + 1u) { + let base = chunk_id * CHUNK_SIZE; + let real_size = min(CHUNK_SIZE, N_EXTENDED_COLUMN_SIZE - base); + + //-------------------------------------------------------------------- + // 3-A. Copy current chunk from storage → scratch + //-------------------------------------------------------------------- + for (var i: u32 = 0u; i < real_size; i = i + 1u) { + scratch[i] = trace_output.extended_trace[poly_id].data[base + i]; + } + + //-------------------------------------------------------------------- + // 3-B. Remaining (small) line-twiddle layers in scratch + //-------------------------------------------------------------------- + var l = layer; + loop { + let step = 1u << (l + 1u); + let layer_size = trace_input.twiddles.line_twiddles_sizes[l]; + let layer_offset = trace_input.twiddles.line_twiddles_offsets[l]; + + let h_start = base >> (l + 2u); + let h_end = (base + real_size - 1u) >> (l + 2u); + let h_count = h_end - h_start + 1u; + + for (var h_local: u32 = 0u; h_local < h_count; h_local = h_local + 1u) { + let h_global = h_start + h_local; + if (h_global >= layer_size) { + continue; + } + let t = trace_input.twiddles.line_twiddles_flat[layer_offset + h_global]; + + // Convert global base index → scratch-local index + let base_idx_local = (h_global << (l + 2u)) - base; + + for (var s: u32 = 0u; s < step; s = s + 1u) { + let idx0 = base_idx_local + s; + let idx1 = idx0 + step; + var v0 = scratch[idx0]; + var v1 = scratch[idx1]; + butterfly(&v0, &v1, t); + scratch[idx0] = v0; + scratch[idx1] = v1; + } + } + if (l == 0u) { + break; + } + l = l - 1u; + } + + //-------------------------------------------------------------------- + // 3-C. Circle-twiddle (step = 1) in scratch + //-------------------------------------------------------------------- + // Treat (idx0, idx1) as even/odd pair; global pair index = (base+idx0)/2 + for (var idx0: u32 = 0u; idx0 + 1u < real_size; idx0 = idx0 + 2u) { + let idx1 = idx0 + 1u; + let pair_global = (base + idx0) >> 1u; + let t = trace_input.twiddles.circle_twiddles[pair_global]; + + var v0 = scratch[idx0]; + var v1 = scratch[idx1]; + butterfly(&v0, &v1, t); + scratch[idx0] = v0; + scratch[idx1] = v1; + } + + //-------------------------------------------------------------------- + // 3-D. Copy scratch → storage + //-------------------------------------------------------------------- + for (var i: u32 = 0u; i < real_size; i = i + 1u) { + trace_output.extended_trace[poly_id].data[base + i] = scratch[i]; + } + } +} diff --git a/crates/examples/src/poseidon/web/gpu_channels.rs b/crates/examples/src/poseidon/web/gpu_channels.rs new file mode 100644 index 000000000..f5c6eca01 --- /dev/null +++ b/crates/examples/src/poseidon/web/gpu_channels.rs @@ -0,0 +1,37 @@ +use std::sync::OnceLock; // ← 표준 라이브러리만! + +use flume::{Receiver, Sender}; + +use crate::poseidon::web::{ComputeCompositionPolynomialInput, ComputeCompositionPolynomialOutput}; + +#[derive(Debug)] +pub struct Channels { + pub tx: Sender>, + pub rx: Receiver>, +} + +thread_local! { + static CHANNELS: OnceLock = OnceLock::new(); +} + +#[allow(dead_code)] +pub fn init_gpu_channels( + tx: Sender>, + rx: Receiver>, +) { + CHANNELS.with(|cell| { + cell.set(Channels { tx, rx }) + .expect("gpu_channels::init called twice"); + }); +} + +#[allow(dead_code)] +pub fn with(f: F) -> R +where + F: FnOnce(&Channels) -> R, +{ + CHANNELS.with(|cell| { + let chans = cell.get().expect("gpu channels not initialised"); + f(chans) + }) +} diff --git a/crates/examples/src/poseidon/web/gpu_types.rs b/crates/examples/src/poseidon/web/gpu_types.rs new file mode 100644 index 000000000..70962b19e --- /dev/null +++ b/crates/examples/src/poseidon/web/gpu_types.rs @@ -0,0 +1,59 @@ +use stwo_prover::core::backend::web::webgpu::qm31::{GpuM31, GpuQM31}; + +use super::constants::*; + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct GpuOriginalColumn { + pub coeffs: [GpuM31; (N_LANES * N_ORIGINAL_ROWS) as usize], +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct GpuExtendedColumn { + pub data: [GpuM31; (N_LANES * N_EXTENDED_ROWS) as usize], +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct Twiddles { + pub circle_twiddles: [GpuM31; N_CIRCLE_TWIDDLES_SIZE as usize], + pub circle_twiddles_size: u32, + pub line_twiddles_flat: [GpuM31; N_LINE_TWIDDLES_FLAT_SIZE as usize], + pub line_twiddles_layer_count: u32, + pub line_twiddles_sizes: [u32; N_LINE_TWIDDLES_SIZE as usize], + pub line_twiddles_offsets: [u32; N_LINE_TWIDDLES_SIZE as usize], +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct GpuLookupElements { + pub z: GpuQM31, + pub alpha: GpuQM31, + pub alpha_powers: [GpuQM31; N_STATE as usize], +} + +#[derive(Debug, Clone, Copy)] +#[repr(C, align(16))] +pub struct ComputeCompositionPolynomialInput { + pub original_trace: [GpuOriginalColumn; N_ORIGINAL_TRACE_COLUMNS as usize], + pub twiddles: Twiddles, + pub denom_inv: [GpuM31; 4], + pub random_coeff_powers: [GpuQM31; N_CONSTRAINTS as usize], + pub lookup_elements: GpuLookupElements, + pub trace_domain_log_size: u32, + pub eval_domain_log_size: u32, + pub cumsum_shift: GpuQM31, +} + +#[derive(Debug, Clone, Copy)] +#[repr(C, align(16))] +pub struct ComputeCompositionPolynomialOutput { + pub poly: [[GpuQM31; N_LANES as usize]; N_EXTENDED_ROWS as usize], +} + +#[allow(dead_code)] +#[derive(Debug, Clone, Copy)] +pub struct ExtendTraceOutput { + pub extended_trace: [GpuExtendedColumn; N_ORIGINAL_TRACE_COLUMNS as usize], +} diff --git a/crates/examples/src/poseidon/web/mod.rs b/crates/examples/src/poseidon/web/mod.rs new file mode 100644 index 000000000..66972ad75 --- /dev/null +++ b/crates/examples/src/poseidon/web/mod.rs @@ -0,0 +1,12 @@ +pub mod constants; +pub mod eval_composition_poly; + +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] +pub mod gpu_channels; +pub mod gpu_types; +#[cfg(all(target_family = "wasm", not(target_os = "wasi")))] +pub mod runner; +pub mod serialization; + +use constants::*; +pub use gpu_types::*; diff --git a/crates/examples/src/poseidon/web/runner.rs b/crates/examples/src/poseidon/web/runner.rs new file mode 100644 index 000000000..0648de9a5 --- /dev/null +++ b/crates/examples/src/poseidon/web/runner.rs @@ -0,0 +1,20 @@ +use web_sys::console; + +use super::eval_composition_poly::{compute_composition_polynomial_wgpu, GpuContext}; +use super::{ComputeCompositionPolynomialInput, ComputeCompositionPolynomialOutput}; + +#[allow(dead_code)] +pub async fn runner_eval_composition_polynomial( + request_rx: flume::Receiver>, + response_tx: flume::Sender>, +) { + let gpu = GpuContext::new().await; + + let input_data = request_rx.recv_async().await.unwrap(); + + console::time_with_label("wgpu-runner-timer"); + let output_data = compute_composition_polynomial_wgpu(input_data, &gpu).await; + + response_tx.send(output_data).unwrap(); + console::time_end_with_label("wgpu-runner-timer"); +} diff --git a/crates/examples/src/poseidon/web/serialization.rs b/crates/examples/src/poseidon/web/serialization.rs new file mode 100644 index 000000000..55f174220 --- /dev/null +++ b/crates/examples/src/poseidon/web/serialization.rs @@ -0,0 +1,55 @@ +use std::mem::MaybeUninit; +use std::ptr; + +use stwo_prover::core::backend::web::webgpu::ByteSerialize; + +use crate::poseidon::web::{ + ComputeCompositionPolynomialInput, ComputeCompositionPolynomialOutput, GpuExtendedColumn, + GpuLookupElements, GpuOriginalColumn, +}; +use crate::poseidon::PoseidonElements; + +impl ByteSerialize for GpuExtendedColumn {} +impl ByteSerialize for GpuOriginalColumn {} +impl ByteSerialize for ComputeCompositionPolynomialOutput {} +impl ByteSerialize for ComputeCompositionPolynomialInput {} + +#[allow(dead_code)] +impl ComputeCompositionPolynomialInput { + pub fn from_bytes(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), std::mem::size_of::()); + unsafe { std::ptr::read_unaligned(bytes.as_ptr() as *const Self) } + } +} + +#[allow(dead_code)] +impl ComputeCompositionPolynomialOutput { + /// Create directly on the heap without an intermediate stack copy. + pub fn from_bytes_box(bytes: &[u8]) -> Box { + assert_eq!(bytes.len(), core::mem::size_of::()); + + let boxed_uninit = Box::>::new_uninit(); + let raw_ptr = Box::into_raw(boxed_uninit) as *mut Self as *mut u8; + unsafe { + ptr::copy_nonoverlapping(bytes.as_ptr(), raw_ptr, bytes.len()); + Box::from_raw(raw_ptr as *mut Self) + } + } +} + +impl From<&PoseidonElements> for GpuLookupElements { + fn from(value: &PoseidonElements) -> Self { + GpuLookupElements { + z: value.0.z.into(), + alpha: value.0.alpha.into(), + alpha_powers: value + .0 + .alpha_powers + .iter() + .map(|&x| x.into()) + .collect::>() + .try_into() + .unwrap(), + } + } +} diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 48e57d514..b5b6e8218 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -28,6 +28,9 @@ tracing.workspace = true rayon = { workspace = true, optional = true } serde.workspace = true tracing-subscriber.workspace = true +wgpu.workspace = true +flume.workspace = true +pollster.workspace = true [dev-dependencies] aligned = "0.4.2" diff --git a/crates/prover/src/core/backend/cpu/circle.rs b/crates/prover/src/core/backend/cpu/circle.rs index d28180bf7..46d3b91f9 100644 --- a/crates/prover/src/core/backend/cpu/circle.rs +++ b/crates/prover/src/core/backend/cpu/circle.rs @@ -206,7 +206,7 @@ fn fft_layer_loop( /// Computes the circle twiddles layer (layer 0) from the first line twiddles layer (layer 1). /// /// Only works for line twiddles generated from a domain with size `>4`. -fn circle_twiddles_from_line_twiddles( +pub fn circle_twiddles_from_line_twiddles( first_line_twiddles: &[BaseField], ) -> impl Iterator + '_ { // The twiddles for layer 0 can be computed from the twiddles for layer 1. diff --git a/crates/prover/src/core/backend/mod.rs b/crates/prover/src/core/backend/mod.rs index a8f5f717f..6837abfd3 100644 --- a/crates/prover/src/core/backend/mod.rs +++ b/crates/prover/src/core/backend/mod.rs @@ -15,6 +15,7 @@ use super::vcs::ops::MerkleOps; pub mod cpu; pub mod simd; +pub mod web; pub trait Backend: Copy diff --git a/crates/prover/src/core/backend/simd/bit_reverse.rs b/crates/prover/src/core/backend/simd/bit_reverse.rs index fe27a149c..4c487a530 100644 --- a/crates/prover/src/core/backend/simd/bit_reverse.rs +++ b/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -110,7 +110,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { } /// Bit reverses 256 M31 values, packed in 16 words of 16 elements each. -fn bit_reverse16(mut data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { +pub fn bit_reverse16(mut data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { // Denote the index of each element in the 16 packed M31 words as abcd:0123, // where abcd is the index of the packed word and 0123 is the index of the element in the word. // Bit reversal is achieved by applying the following permutation to the index for 4 times: diff --git a/crates/prover/src/core/backend/simd/blake2s.rs b/crates/prover/src/core/backend/simd/blake2s.rs index 38cf77ad8..d767b6f19 100644 --- a/crates/prover/src/core/backend/simd/blake2s.rs +++ b/crates/prover/src/core/backend/simd/blake2s.rs @@ -307,7 +307,7 @@ pub fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) { /// Transposes input chunks (16 chunks of 16 `u32`s each), to get 16 `u32x16`, each /// representing 16 packed instances of a message word. -fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] { +pub fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] { // Index abcd:xyzw, refers to a specific word in data as follows: // abcd - chunk index (in base 2) // xyzw - word offset (in base 2) @@ -331,7 +331,7 @@ fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] { data } -fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { +pub fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { // Index abc:xyzw, refers to a specific word in data as follows: // abc - chunk index (in base 2) // xyzw - word offset (in base 2) diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index 0538a0f95..a46f8c6bd 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -386,7 +386,7 @@ fn compute_coset_twiddles(coset: Coset, twiddles: &mut Vec) { } } -fn slow_eval_at_point( +pub fn slow_eval_at_point( poly: &CirclePoly, point: CirclePoint, ) -> SecureField { diff --git a/crates/prover/src/core/backend/simd/fri.rs b/crates/prover/src/core/backend/simd/fri.rs index 7b1fa2923..0da7b3404 100644 --- a/crates/prover/src/core/backend/simd/fri.rs +++ b/crates/prover/src/core/backend/simd/fri.rs @@ -71,7 +71,7 @@ impl FriOps for SimdBackend { fold_circle_into_line(&mut cpu_dst, &src.to_cpu(), alpha); *dst = LineEvaluation::new( cpu_dst.domain(), - SecureColumnByCoords::from_cpu(cpu_dst.values), + SecureColumnByCoords::::from_cpu(cpu_dst.values), ); return; } diff --git a/crates/prover/src/core/backend/web/accumulation.rs b/crates/prover/src/core/backend/web/accumulation.rs new file mode 100644 index 000000000..67935270e --- /dev/null +++ b/crates/prover/src/core/backend/web/accumulation.rs @@ -0,0 +1,39 @@ +use super::WebBackend; +use crate::core::air::accumulation::AccumulationOps; +use crate::core::backend::simd::SimdBackend; +use crate::core::fields::qm31::SecureField; +use crate::core::secure_column::SecureColumnByCoords; + +impl AccumulationOps for WebBackend { + fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords) { + SimdBackend::accumulate(column.as_mut(), other.as_ref()); + } + + fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { + SimdBackend::generate_secure_powers(felt, n_powers) + } +} + +#[cfg(test)] +mod tests { + use crate::core::air::accumulation::AccumulationOps; + use crate::core::backend::cpu::CpuBackend; + use crate::core::backend::simd::SimdBackend; + use crate::qm31; + + #[test] + fn test_generate_secure_powers_simd() { + let felt = qm31!(1, 2, 3, 4); + let n_powers_vec = [0, 16, 100]; + + n_powers_vec.iter().for_each(|&n_powers| { + let expected = ::generate_secure_powers(felt, n_powers); + let actual = ::generate_secure_powers(felt, n_powers); + assert_eq!( + expected, actual, + "Error generating secure powers in n_powers = {}.", + n_powers + ); + }); + } +} diff --git a/crates/prover/src/core/backend/web/bit_reverse.rs b/crates/prover/src/core/backend/web/bit_reverse.rs new file mode 100644 index 000000000..683815843 --- /dev/null +++ b/crates/prover/src/core/backend/web/bit_reverse.rs @@ -0,0 +1,99 @@ +use super::WebBackend; +use crate::core::backend::cpu::bit_reverse as cpu_bit_reverse; +use crate::core::backend::simd::bit_reverse::bit_reverse_m31; +use crate::core::backend::simd::column::{BaseColumn, SecureColumn}; +use crate::core::backend::ColumnOps; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; + +const VEC_BITS: u32 = 4; + +const W_BITS: u32 = 3; + +pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS; + +impl ColumnOps for WebBackend { + type Column = BaseColumn; + + fn bit_reverse_column(column: &mut Self::Column) { + // Fallback to cpu bit_reverse. + if column.data.len().ilog2() < MIN_LOG_SIZE { + cpu_bit_reverse(column.as_mut_slice()); + return; + } + + bit_reverse_m31(&mut column.data); + } +} + +impl ColumnOps for WebBackend { + type Column = SecureColumn; + + fn bit_reverse_column(_column: &mut SecureColumn) { + todo!() + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::MIN_LOG_SIZE; + use crate::core::backend::cpu::bit_reverse as cpu_bit_reverse; + use crate::core::backend::simd::bit_reverse::{bit_reverse16, bit_reverse_m31}; + use crate::core::backend::simd::column::BaseColumn; + use crate::core::backend::simd::m31::{PackedM31, N_LANES}; + use crate::core::backend::web::WebBackend; + use crate::core::backend::{Column, ColumnOps}; + use crate::core::fields::m31::BaseField; + + #[test] + fn test_bit_reverse16() { + let values: BaseColumn = (0..N_LANES * 16).map(BaseField::from).collect(); + let mut expected = values.to_cpu(); + cpu_bit_reverse(&mut expected); + + let res = bit_reverse16(values.data.try_into().unwrap()); + + assert_eq!(res.map(PackedM31::to_array).as_flattened(), expected); + } + + #[test] + fn bit_reverse_m31_works() { + const SIZE: usize = 1 << 15; + let data: Vec<_> = (0..SIZE).map(BaseField::from).collect(); + let mut expected = data.clone(); + cpu_bit_reverse(&mut expected); + + let mut res: BaseColumn = data.into_iter().collect(); + bit_reverse_m31(&mut res.data[..]); + + assert_eq!(res.to_cpu(), expected); + } + + #[test] + fn bit_reverse_small_column_works() { + const LOG_SIZE: u32 = MIN_LOG_SIZE - 1; + let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec(); + let mut expected = column.clone(); + cpu_bit_reverse(&mut expected); + + let mut res = column.iter().copied().collect::(); + >::bit_reverse_column(&mut res); + + assert_eq!(res.to_cpu(), expected); + } + + #[test] + fn bit_reverse_large_column_works() { + const LOG_SIZE: u32 = MIN_LOG_SIZE; + let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec(); + let mut expected = column.clone(); + cpu_bit_reverse(&mut expected); + + let mut res = column.iter().copied().collect::(); + >::bit_reverse_column(&mut res); + + assert_eq!(res.to_cpu(), expected); + } +} diff --git a/crates/prover/src/core/backend/web/blake2s.rs b/crates/prover/src/core/backend/web/blake2s.rs new file mode 100644 index 000000000..d49c23e95 --- /dev/null +++ b/crates/prover/src/core/backend/web/blake2s.rs @@ -0,0 +1,105 @@ +use super::utils::transmute_col_refs; +use super::WebBackend; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, ColumnOps}; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::blake2_hash::Blake2sHash; +use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; +use crate::core::vcs::ops::MerkleOps; + +impl ColumnOps for WebBackend { + type Column = Vec; + + fn bit_reverse_column(_column: &mut Self::Column) { + unimplemented!() + } +} + +impl MerkleOps for WebBackend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Col], + ) -> Vec { + >::commit_on_layer( + log_size, + prev_layer, + transmute_col_refs(columns), + ) + } +} + +#[cfg(test)] +mod tests { + use std::array; + use std::mem::transmute; + use std::simd::u32x16; + + use aligned::{Aligned, A64}; + + use crate::core::backend::simd::blake2s::{compress16, transpose_msgs, untranspose_states}; + use crate::core::vcs::blake2s_ref::compress; + + #[test] + fn compress16_works() { + let states: Aligned = + Aligned(array::from_fn(|i| array::from_fn(|j| (i + j) as u32))); + let msgs: Aligned = + Aligned(array::from_fn(|i| array::from_fn(|j| (i + j + 20) as u32))); + let count_low = 1; + let count_high = 2; + let lastblock = 3; + let lastnode = 4; + let res_unvectorized = array::from_fn(|i| { + compress( + states[i], msgs[i], count_low, count_high, lastblock, lastnode, + ) + }); + + let res_vectorized: [[u32; 8]; 16] = unsafe { + transmute(untranspose_states(compress16( + transpose_states(transmute::, [u32x16; 8]>( + states, + )), + transpose_msgs(transmute::, [u32x16; 16]>( + msgs, + )), + u32x16::splat(count_low), + u32x16::splat(count_high), + u32x16::splat(lastblock), + u32x16::splat(lastnode), + ))) + }; + + assert_eq!(res_vectorized, res_unvectorized); + } + + #[test] + fn untranspose_states_is_transpose_states_inverse() { + let states = array::from_fn(|i| u32x16::from(array::from_fn(|j| (i + j) as u32))); + let transposed_states = transpose_states(states); + + let untrasponsed_transposed_states = untranspose_states(transposed_states); + + assert_eq!(untrasponsed_transposed_states, states) + } + + /// Transposes states, from 8 packed words, to get 16 results, each of size 32B. + fn transpose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { + // Index abc:xyzw, refers to a specific word in data as follows: + // abc - chunk index (in base 2) + // xyzw - word offset (in base 2) + // Transpose by applying 3 times the index permutation: + // abc:xyzw => wab:cxyz + // In other words, rotate the index to the right by 1. + for _ in 0..3 { + let (s0, s4) = states[0].deinterleave(states[1]); + let (s1, s5) = states[2].deinterleave(states[3]); + let (s2, s6) = states[4].deinterleave(states[5]); + let (s3, s7) = states[6].deinterleave(states[7]); + states = [s0, s1, s2, s3, s4, s5, s6, s7]; + } + + states + } +} diff --git a/crates/prover/src/core/backend/web/circle.rs b/crates/prover/src/core/backend/web/circle.rs new file mode 100644 index 000000000..2564d9b43 --- /dev/null +++ b/crates/prover/src/core/backend/web/circle.rs @@ -0,0 +1,172 @@ +use super::WebBackend; +use crate::core::backend::simd::SimdBackend; +use crate::core::circle::{CirclePoint, Coset}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::poly::circle::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps}; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::BitReversedOrder; + +impl PolyOps for WebBackend { + // The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation + // requires one of the numbers to be shifted left by 1 bit. This is not a reduced + // representation of the field. + type Twiddles = Vec; + + fn interpolate( + eval: CircleEvaluation, + twiddles: &TwiddleTree, + ) -> CirclePoly { + SimdBackend::interpolate(eval.into(), twiddles.as_ref()).into() + } + + fn eval_at_point(poly: &CirclePoly, point: CirclePoint) -> SecureField { + SimdBackend::eval_at_point(poly.as_ref(), point) + } + + fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { + SimdBackend::extend(poly.as_ref(), log_size).into() + } + + fn evaluate( + poly: &CirclePoly, + domain: CircleDomain, + twiddles: &TwiddleTree, + ) -> CircleEvaluation { + SimdBackend::evaluate(poly.as_ref(), domain, twiddles.as_ref()).into() + } + + fn precompute_twiddles(coset: Coset) -> TwiddleTree { + SimdBackend::precompute_twiddles(coset).into() + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::WebBackend; + use crate::core::backend::simd::circle::slow_eval_at_point; + use crate::core::backend::simd::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::circle::CirclePoint; + use crate::core::fields::m31::BaseField; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, PolyOps}; + use crate::core::poly::{BitReversedOrder, NaturalOrder}; + + #[test] + fn test_interpolate_and_eval() { + for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 4 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let evaluation = CircleEvaluation::::new( + domain, + (0..1 << log_size).map(BaseField::from).collect(), + ); + + let poly = evaluation.clone().interpolate(); + let evaluation2 = poly.evaluate(domain); + + assert_eq!(evaluation.values.to_cpu(), evaluation2.values.to_cpu()); + } + } + + #[test] + fn test_eval_extension() { + for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let domain_ext = CanonicCoset::new(log_size + 2).circle_domain(); + let evaluation = CircleEvaluation::::new( + domain, + (0..1 << log_size).map(BaseField::from).collect(), + ); + let poly = evaluation.clone().interpolate(); + + let evaluation2 = poly.evaluate(domain_ext); + + assert_eq!( + poly.extend(log_size + 2).coeffs.to_cpu(), + evaluation2.interpolate().coeffs.to_cpu() + ); + } + } + + #[test] + fn test_eval_at_point() { + for log_size in MIN_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 4 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let evaluation = CircleEvaluation::::new( + domain, + (0..1 << log_size).map(BaseField::from).collect(), + ); + let poly = evaluation.bit_reverse().interpolate(); + for i in [0, 1, 3, 1 << (log_size - 1), 1 << (log_size - 2)] { + let p = domain.at(i); + + let eval = poly.eval_at_point(p.into_ef()); + + assert_eq!( + eval, + BaseField::from(i).into(), + "log_size={log_size}, i={i}" + ); + } + } + } + + #[test] + fn test_circle_poly_extend() { + for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 { + let poly = + CirclePoly::::new((0..1 << log_size).map(BaseField::from).collect()); + let eval0 = poly.evaluate(CanonicCoset::new(log_size + 2).circle_domain()); + + let eval1 = poly + .extend(log_size + 2) + .evaluate(CanonicCoset::new(log_size + 2).circle_domain()); + + assert_eq!(eval0.values.to_cpu(), eval1.values.to_cpu()); + } + } + + #[test] + fn test_eval_securefield() { + let mut rng = SmallRng::seed_from_u64(0); + for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let evaluation = CircleEvaluation::::new( + domain, + (0..1 << log_size).map(BaseField::from).collect(), + ); + let poly = evaluation.bit_reverse().interpolate(); + let x = rng.gen(); + let y = rng.gen(); + let p = CirclePoint { x, y }; + + let eval = PolyOps::eval_at_point(&poly, p); + + assert_eq!( + eval, + slow_eval_at_point(&poly.as_ref(), p), + "log_size = {log_size}" + ); + } + } + + #[test] + fn test_optimized_precompute_twiddles() { + let coset = CanonicCoset::new(10).half_coset(); + let twiddles = WebBackend::precompute_twiddles(coset); + let expected_twiddles = CpuBackend::precompute_twiddles(coset); + + assert_eq!( + twiddles.twiddles, + expected_twiddles + .twiddles + .iter() + .map(|x| x.0 * 2) + .collect_vec() + ); + } +} diff --git a/crates/prover/src/core/backend/web/column.rs b/crates/prover/src/core/backend/web/column.rs new file mode 100644 index 000000000..370173b88 --- /dev/null +++ b/crates/prover/src/core/backend/web/column.rs @@ -0,0 +1,12 @@ +use super::WebBackend; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::CpuBackend; +use crate::core::secure_column::SecureColumnByCoords; + +impl SecureColumnByCoords { + pub fn from_cpu(cpu: SecureColumnByCoords) -> Self { + Self { + columns: cpu.columns.map(BaseColumn::from_cpu), + } + } +} diff --git a/crates/prover/src/core/backend/web/fri.rs b/crates/prover/src/core/backend/web/fri.rs new file mode 100644 index 000000000..418afc271 --- /dev/null +++ b/crates/prover/src/core/backend/web/fri.rs @@ -0,0 +1,141 @@ +use super::WebBackend; +use crate::core::backend::simd::SimdBackend; +use crate::core::fields::qm31::SecureField; +use crate::core::fri::FriOps; +use crate::core::poly::circle::SecureEvaluation; +use crate::core::poly::line::LineEvaluation; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::BitReversedOrder; + +impl FriOps for WebBackend { + fn fold_line( + eval: &LineEvaluation, + alpha: SecureField, + twiddles: &TwiddleTree, + ) -> LineEvaluation { + SimdBackend::fold_line(eval.as_ref(), alpha, twiddles.as_ref()).into() + } + + fn fold_circle_into_line( + dst: &mut LineEvaluation, + src: &SecureEvaluation, + alpha: SecureField, + twiddles: &TwiddleTree, + ) { + SimdBackend::fold_circle_into_line(dst.as_mut(), src.as_ref(), alpha, twiddles.as_ref()); + } + + fn decompose( + eval: &SecureEvaluation, + ) -> (SecureEvaluation, SecureField) { + let (g, lambda) = SimdBackend::decompose(eval.as_ref()); + (g.into(), lambda) + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::simd::column::BaseColumn; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fri::FriOps; + use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps, SecureEvaluation}; + use crate::core::poly::line::{LineDomain, LineEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::core::secure_column::SecureColumnByCoords; + use crate::qm31; + + #[test] + fn test_fold_line() { + const LOG_SIZE: u32 = 7; + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec(); + let alpha = qm31!(1, 3, 5, 7); + let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); + let cpu_fold = CpuBackend::fold_line( + &LineEvaluation::new(domain, values.iter().copied().collect()), + alpha, + &CpuBackend::precompute_twiddles(domain.coset()), + ); + + let avx_fold = SimdBackend::fold_line( + &LineEvaluation::new(domain, values.iter().copied().collect()), + alpha, + &SimdBackend::precompute_twiddles(domain.coset()), + ); + + assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec()); + } + + #[test] + fn test_fold_circle_into_line() { + const LOG_SIZE: u32 = 7; + let values: Vec = (0..(1 << LOG_SIZE)) + .map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3)) + .collect(); + let alpha = qm31!(1, 3, 5, 7); + let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain(); + let line_domain = LineDomain::new(circle_domain.half_coset); + let mut cpu_fold = LineEvaluation::new( + line_domain, + SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)), + ); + CpuBackend::fold_circle_into_line( + &mut cpu_fold, + &SecureEvaluation::new(circle_domain, values.iter().copied().collect()), + alpha, + &CpuBackend::precompute_twiddles(line_domain.coset()), + ); + + let mut simd_fold = LineEvaluation::new( + line_domain, + SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)), + ); + SimdBackend::fold_circle_into_line( + &mut simd_fold, + &SecureEvaluation::new(circle_domain, values.iter().copied().collect()), + alpha, + &SimdBackend::precompute_twiddles(line_domain.coset()), + ); + + assert_eq!(cpu_fold.values.to_vec(), simd_fold.values.to_vec()); + } + + #[test] + fn decomposition_test() { + const DOMAIN_LOG_SIZE: u32 = 5; + const DOMAIN_LOG_HALF_SIZE: u32 = DOMAIN_LOG_SIZE - 1; + let s = CanonicCoset::new(DOMAIN_LOG_SIZE); + let domain = s.circle_domain(); + let mut coeffs = BaseColumn::zeros(1 << DOMAIN_LOG_SIZE); + // Polynomial is out of FFT space. + coeffs.as_mut_slice()[1 << DOMAIN_LOG_HALF_SIZE] = BaseField::one(); + let poly = CirclePoly::::new(coeffs); + let values = poly.evaluate(domain); + let avx_column = SecureColumnByCoords:: { + columns: [ + values.values.clone(), + values.values.clone(), + values.values.clone(), + values.values.clone(), + ], + }; + let avx_eval = SecureEvaluation::new(domain, avx_column.clone()); + let cpu_eval = + SecureEvaluation::::new(domain, avx_eval.values.to_cpu()); + let (cpu_g, cpu_lambda) = CpuBackend::decompose(&cpu_eval); + let (avx_g, avx_lambda) = SimdBackend::decompose(&avx_eval); + + assert_eq!(avx_lambda, cpu_lambda); + for i in 0..1 << DOMAIN_LOG_SIZE { + assert_eq!(avx_g.values.at(i), cpu_g.values.at(i)); + } + } +} diff --git a/crates/prover/src/core/backend/web/grind.rs b/crates/prover/src/core/backend/web/grind.rs new file mode 100644 index 000000000..4d94e9112 --- /dev/null +++ b/crates/prover/src/core/backend/web/grind.rs @@ -0,0 +1,28 @@ +use super::WebBackend; +use crate::core::backend::simd::SimdBackend; +use crate::core::channel::Blake2sChannel; +#[cfg(not(target_arch = "wasm32"))] +use crate::core::channel::{Channel, Poseidon252Channel}; +use crate::core::proof_of_work::GrindOps; + +impl GrindOps for WebBackend { + fn grind(channel: &Blake2sChannel, pow_bits: u32) -> u64 { + SimdBackend::grind(channel, pow_bits) + } +} + +// TODO(shahars): This is a naive implementation. Optimize it. +#[cfg(not(target_arch = "wasm32"))] +impl GrindOps for WebBackend { + fn grind(channel: &Poseidon252Channel, pow_bits: u32) -> u64 { + let mut nonce = 0; + loop { + let mut channel = channel.clone(); + channel.mix_u64(nonce); + if channel.trailing_zeros() >= pow_bits { + return nonce; + } + nonce += 1; + } + } +} diff --git a/crates/prover/src/core/backend/web/lookups/gkr.rs b/crates/prover/src/core/backend/web/lookups/gkr.rs new file mode 100644 index 000000000..2771a4edc --- /dev/null +++ b/crates/prover/src/core/backend/web/lookups/gkr.rs @@ -0,0 +1,202 @@ +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::web::WebBackend; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, GkrOps, Layer}; +use crate::core::lookups::mle::Mle; +use crate::core::lookups::utils::UnivariatePoly; + +impl GkrOps for WebBackend { + #[allow(clippy::uninit_vec)] + fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { + SimdBackend::gen_eq_evals(y, v).into() + } + + fn next_layer(layer: &Layer) -> Layer { + SimdBackend::next_layer(layer.as_ref()).into() + } + + fn sum_as_poly_in_first_variable( + h: &GkrMultivariatePolyOracle<'_, Self>, + claim: SecureField, + ) -> UnivariatePoly { + SimdBackend::sum_as_poly_in_first_variable(h.as_ref(), claim) + } +} + +#[cfg(test)] +mod tests { + use std::iter::zip; + + use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::WebBackend; + // use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::channel::Channel; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; + use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::lookups::mle::Mle; + use crate::core::lookups::utils::Fraction; + use crate::core::test_utils::test_channel; + + #[test] + fn gen_eq_evals_matches_cpu() { + let two = BaseField::from(2).into(); + let y = [7, 3, 5, 6, 1, 1, 9].map(|v| BaseField::from(v).into()); + let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two); + + let eq_evals_simd = WebBackend::gen_eq_evals(&y, two); + + assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu); + } + + #[test] + fn gen_eq_evals_with_small_assignment_matches_cpu() { + let two = BaseField::from(2).into(); + let y = [7, 3, 5].map(|v| BaseField::from(v).into()); + let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two); + + let eq_evals_simd = WebBackend::gen_eq_evals(&y, two); + + assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu); + } + + #[test] + fn grand_product_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let values = test_channel().draw_secure_felts(N); + let product = values.iter().product(); + let col = Mle::::new(values.into_iter().collect()); + let input_layer = Layer::GrandProduct(col.clone()); + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; + + assert_eq!(proof.output_claims_by_instance, [vec![product]]); + assert_eq!( + claims_to_verify_by_instance, + [vec![col.eval_at_point(&ood_point)]] + ); + Ok(()) + } + + #[test] + fn logup_with_generic_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let numerators = (0..N).map(|_| rng.gen()).collect::>(); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerators, &denominators) + .map(|(&n, &d)| Fraction::new(n, d)) + .sum::>(); + let numerators = Mle::::new(numerators.into_iter().collect()); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpGeneric { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let numerators = (0..N).map(|_| rng.gen()).collect::>(); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerators, &denominators) + .map(|(&n, &d)| Fraction::new(n.into(), d)) + .sum::>(); + let numerators = Mle::::new(numerators.into_iter().collect()); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpMultiplicities { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_singles_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = denominators + .iter() + .map(|&d| Fraction::new(SecureField::one(), d)) + .sum::>(); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpSingles { + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [SecureField::one(), denominators.eval_at_point(&ood_point)] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } +} diff --git a/crates/prover/src/core/backend/web/lookups/mle.rs b/crates/prover/src/core/backend/web/lookups/mle.rs new file mode 100644 index 000000000..fd3e9f230 --- /dev/null +++ b/crates/prover/src/core/backend/web/lookups/mle.rs @@ -0,0 +1,64 @@ +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::web::WebBackend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::mle::{Mle, MleOps}; + +impl MleOps for WebBackend { + fn fix_first_variable( + mle: Mle, + assignment: SecureField, + ) -> Mle { + SimdBackend::fix_first_variable(mle.into(), assignment).into() + } +} + +impl MleOps for WebBackend { + fn fix_first_variable( + mle: Mle, + assignment: SecureField, + ) -> Mle { + SimdBackend::fix_first_variable(mle.into(), assignment).into() + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use crate::core::backend::web::WebBackend; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::channel::Channel; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::mle::Mle; + use crate::core::test_utils::test_channel; + + #[test] + fn fix_first_variable_with_secure_field_mle_matches_cpu() { + const N_VARIABLES: u32 = 8; + let values = test_channel().draw_secure_felts(1 << N_VARIABLES); + let mle_simd = Mle::::new(values.iter().copied().collect()); + let mle_cpu = Mle::::new(values); + let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2); + let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment); + + let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment); + + assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu) + } + + #[test] + fn fix_first_variable_with_base_field_mle_matches_cpu() { + const N_VARIABLES: u32 = 8; + let values = (0..1 << N_VARIABLES).map(BaseField::from).collect_vec(); + let mle_simd = Mle::::new(values.iter().copied().collect()); + let mle_cpu = Mle::::new(values); + let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2); + let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment); + + let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment); + + assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu) + } +} diff --git a/crates/prover/src/core/backend/web/lookups/mod.rs b/crates/prover/src/core/backend/web/lookups/mod.rs new file mode 100644 index 000000000..01b03bd89 --- /dev/null +++ b/crates/prover/src/core/backend/web/lookups/mod.rs @@ -0,0 +1,3 @@ +mod gkr; +mod mle; +mod utils; diff --git a/crates/prover/src/core/backend/web/lookups/utils.rs b/crates/prover/src/core/backend/web/lookups/utils.rs new file mode 100644 index 000000000..6b30ece64 --- /dev/null +++ b/crates/prover/src/core/backend/web/lookups/utils.rs @@ -0,0 +1,69 @@ +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::web::WebBackend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::QM31; +use crate::core::lookups::gkr_prover::{GkrMultivariatePolyOracle, Layer}; +use crate::core::lookups::mle::Mle; + +// WARNING: This works because they are literally the same object layout. +// +// The only difference is the backend methods. +// When we implement all methods for WebGPU, +// we will no longer need this to convert back/forth. +impl AsRef> for Layer { + fn as_ref(&self) -> &Layer { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl<'a> AsRef> + for GkrMultivariatePolyOracle<'a, WebBackend> +{ + fn as_ref(&self) -> &GkrMultivariatePolyOracle<'a, SimdBackend> { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl Into> for Layer { + fn into(self) -> Layer { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl Into> for Mle { + fn into(self) -> Mle { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl Into> for Mle { + fn into(self) -> Mle { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl Into> for Mle { + fn into(self) -> Mle { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl Into> for Mle { + fn into(self) -> Mle { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} diff --git a/crates/prover/src/core/backend/web/mod.rs b/crates/prover/src/core/backend/web/mod.rs new file mode 100644 index 000000000..4094152a4 --- /dev/null +++ b/crates/prover/src/core/backend/web/mod.rs @@ -0,0 +1,37 @@ +use serde::{Deserialize, Serialize}; + +use super::{Backend, BackendForChannel}; +use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; +#[cfg(not(target_arch = "wasm32"))] +use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; + +pub mod accumulation; +pub mod bit_reverse; +pub mod blake2s; +pub mod circle; +// pub mod cm31; +pub mod column; +// pub mod conversion; +// pub mod domain; +// pub mod fft; +pub mod fri; +pub mod grind; +pub mod lookups; +// pub mod m31; +#[cfg(not(target_arch = "wasm32"))] +pub mod poseidon252; +// pub mod prefix_sum; +// pub mod qm31; +pub mod quotients; +pub mod utils; +// pub mod very_packed_m31; +// pub mod prove_poseidon; +pub mod webgpu; + +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +pub struct WebBackend; + +impl Backend for WebBackend {} +impl BackendForChannel for WebBackend {} +#[cfg(not(target_arch = "wasm32"))] +impl BackendForChannel for WebBackend {} diff --git a/crates/prover/src/core/backend/web/poseidon252.rs b/crates/prover/src/core/backend/web/poseidon252.rs new file mode 100644 index 000000000..b2f4c7bca --- /dev/null +++ b/crates/prover/src/core/backend/web/poseidon252.rs @@ -0,0 +1,31 @@ +use starknet_ff::FieldElement as FieldElement252; + +use super::utils::transmute_col_refs; +use super::WebBackend; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, ColumnOps}; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::ops::MerkleOps; +use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher; + +impl ColumnOps for WebBackend { + type Column = Vec; + + fn bit_reverse_column(_column: &mut Self::Column) { + unimplemented!() + } +} + +impl MerkleOps for WebBackend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Col], + ) -> Vec { + >::commit_on_layer( + log_size, + prev_layer, + transmute_col_refs(columns), + ) + } +} diff --git a/crates/prover/src/core/backend/web/quotients.rs b/crates/prover/src/core/backend/web/quotients.rs new file mode 100644 index 000000000..dba2bbd80 --- /dev/null +++ b/crates/prover/src/core/backend/web/quotients.rs @@ -0,0 +1,96 @@ +use super::WebBackend; +use crate::core::backend::simd::SimdBackend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; +use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation}; +use crate::core::poly::BitReversedOrder; + +impl QuotientOps for WebBackend { + fn accumulate_quotients( + domain: CircleDomain, + columns: &[&CircleEvaluation], + random_coeff: SecureField, + sample_batches: &[ColumnSampleBatch], + log_blowup_factor: u32, + ) -> SecureEvaluation { + let columns_simd: Vec<&CircleEvaluation> = + columns.iter().map(|c| (*c).as_ref()).collect(); + + SimdBackend::accumulate_quotients( + domain, + columns_simd.as_slice(), + random_coeff, + sample_batches, + log_blowup_factor, + ) + .into() + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::WebBackend; + use crate::core::backend::simd::column::BaseColumn; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; + use crate::core::fields::m31::BaseField; + use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::qm31; + + #[test] + fn test_accumulate_quotients() { + const LOG_SIZE: u32 = 8; + const LOG_BLOWUP_FACTOR: u32 = 1; + let small_domain = CanonicCoset::new(LOG_SIZE).circle_domain(); + let domain = CanonicCoset::new(LOG_SIZE + LOG_BLOWUP_FACTOR).circle_domain(); + let e0: BaseColumn = (0..small_domain.size()).map(BaseField::from).collect(); + let e1: BaseColumn = (0..small_domain.size()) + .map(|i| BaseField::from(2 * i)) + .collect(); + let polys = [ + CircleEvaluation::::new(small_domain, e0) + .interpolate(), + CircleEvaluation::::new(small_domain, e1) + .interpolate(), + ]; + let columns = [polys[0].evaluate(domain), polys[1].evaluate(domain)]; + let random_coeff = qm31!(1, 2, 3, 4); + let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN); + let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN); + let samples = vec![ColumnSampleBatch { + point: SECURE_FIELD_CIRCLE_GEN, + columns_and_values: vec![(0, a), (1, b)], + }]; + let cpu_columns = columns + .iter() + .map(|c| CircleEvaluation::new(c.domain, c.values.to_cpu())) + .collect_vec(); + let cpu_result = CpuBackend::accumulate_quotients( + domain, + &cpu_columns.iter().collect_vec(), + random_coeff, + &samples, + LOG_BLOWUP_FACTOR, + ) + .values + .to_vec(); + + let res = WebBackend::accumulate_quotients( + domain, + &columns.iter().collect_vec(), + random_coeff, + &samples, + LOG_BLOWUP_FACTOR, + ) + .values + .to_cpu() + .to_vec(); + + assert_eq!(res, cpu_result); + } +} diff --git a/crates/prover/src/core/backend/web/utils.rs b/crates/prover/src/core/backend/web/utils.rs new file mode 100644 index 000000000..d7e11c5da --- /dev/null +++ b/crates/prover/src/core/backend/web/utils.rs @@ -0,0 +1,167 @@ +use super::WebBackend; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{BaseField, Col}; +use crate::core::poly::circle::{CircleEvaluation, CirclePoly, SecureEvaluation}; +use crate::core::poly::line::LineEvaluation; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::BitReversedOrder; +use crate::core::secure_column::SecureColumnByCoords; + +// WARNING: This works because they are literally the same object layout. +// +// The only difference is the backend methods. +// When we implement all methods for WebGPU, +// we will no longer need this to convert back/forth. +pub fn transmute_col_refs<'a>( + input: &'a [&Col], +) -> &'a [&'a Col] { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { + std::mem::transmute::<&'a [&Col], &'a [&Col]>( + input, + ) + } +} + +impl AsRef> for LineEvaluation { + fn as_ref(&self) -> &LineEvaluation { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl Into> for LineEvaluation { + fn into(self) -> LineEvaluation { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl AsMut> for LineEvaluation { + fn as_mut(&mut self) -> &mut LineEvaluation { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl AsRef> + for SecureEvaluation +{ + fn as_ref(&self) -> &SecureEvaluation { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl Into> + for SecureEvaluation +{ + fn into(self) -> SecureEvaluation { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl AsRef> + for CircleEvaluation +{ + fn as_ref(&self) -> &CircleEvaluation { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl AsMut> for SecureColumnByCoords { + fn as_mut(&mut self) -> &mut SecureColumnByCoords { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl AsRef> for SecureColumnByCoords { + fn as_ref(&self) -> &SecureColumnByCoords { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +pub fn convert_web_to_simd_column(col: Col) -> Col { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(col) } +} + +impl Into> + for CircleEvaluation +{ + fn into(self) -> CircleEvaluation { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl Into> + for CircleEvaluation +{ + fn into(self) -> CircleEvaluation { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl AsRef> for TwiddleTree { + fn as_ref(&self) -> &TwiddleTree { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl AsRef> for CirclePoly { + fn as_ref(&self) -> &CirclePoly { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl<'a> AsRef> for &'a CirclePoly { + fn as_ref(&self) -> &CirclePoly { + self + } +} + +impl<'a> AsRef> + for &'a CircleEvaluation +{ + fn as_ref(&self) -> &CircleEvaluation { + self + } +} + +impl Into> for CirclePoly { + fn into(self) -> CirclePoly { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} + +impl Into> for TwiddleTree { + fn into(self) -> TwiddleTree { + assert_eq!(std::mem::size_of::(), 0); + assert_eq!(std::mem::size_of::(), 0); + unsafe { std::mem::transmute(self) } + } +} diff --git a/crates/prover/src/core/backend/web/webgpu/gpu_common.rs b/crates/prover/src/core/backend/web/webgpu/gpu_common.rs new file mode 100644 index 000000000..9e7eb0781 --- /dev/null +++ b/crates/prover/src/core/backend/web/webgpu/gpu_common.rs @@ -0,0 +1,422 @@ +use std::borrow::Cow; + +use wgpu::util::DeviceExt; + +/// Common trait for GPU input/output types +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { std::ptr::read(bytes.as_ptr() as *const Self) } + } +} + +/// Base GPU instance for field computations +pub struct GpuComputeInstance { + pub device: wgpu::Device, + pub queue: wgpu::Queue, + pub input_buffer: wgpu::Buffer, + pub output_buffer: wgpu::Buffer, + pub staging_buffer: wgpu::Buffer, + pub debug_buffer: wgpu::Buffer, + pub staging_buffer_debug: wgpu::Buffer, +} + +impl GpuComputeInstance { + pub async fn new(input_data: &T, output_size: usize) -> Self { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Field Operations Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + // Create input buffer + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Field Input Buffer"), + contents: input_data.as_bytes(), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + // Create output buffer + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Output Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Create staging buffer for reading results + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Staging Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + // Create debug buffer + let debug_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Debug Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Create staging buffer for debug buffer + let staging_buffer_debug = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Staging Buffer Debug"), + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Self { + device, + queue, + input_buffer, + output_buffer, + staging_buffer, + debug_buffer, + staging_buffer_debug, + } + } + + pub fn create_pipeline( + &self, + shader_source: &str, + entry_point: &str, + ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { + let shader = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Field Operations Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + let bind_group_layout = + self.device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Field Operations Bind Group Layout"), + }); + + let pipeline_layout = self + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Field Operations Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Field Operations Pipeline"), + layout: Some(&pipeline_layout), + module: &shader, + entry_point: Some(entry_point), + cache: None, + compilation_options: Default::default(), + }); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: self.input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: self.output_buffer.as_entire_binding(), + }, + ], + label: Some("Field Operations Bind Group"), + }); + + (pipeline, bind_group) + } + + pub fn create_pipeline_debug( + &self, + shader_source: &str, + entry_point: &str, + ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { + let shader = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Field Operations Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + let bind_group_layout = + self.device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Field Operations Bind Group Layout"), + }); + + let pipeline_layout = self + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Field Operations Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Field Operations Pipeline"), + layout: Some(&pipeline_layout), + module: &shader, + entry_point: Some(entry_point), + cache: None, + compilation_options: Default::default(), + }); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: self.input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: self.output_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: self.debug_buffer.as_entire_binding(), + }, + ], + label: Some("Field Operations Bind Group"), + }); + + (pipeline, bind_group) + } + + pub async fn run_computation( + &self, + pipeline: &wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + workgroup_count: (u32, u32, u32), + ) -> T { + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Field Operations Encoder"), + }); + + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Field Operations Compute Pass"), + timestamp_writes: None, + }); + compute_pass.set_pipeline(pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups( + workgroup_count.0, + workgroup_count.1, + workgroup_count.2, + ); + } + + encoder.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + + self.queue.submit(Some(encoder.finish())); + + let buffer_slice = self.staging_buffer.slice(..); + let (tx, rx) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |result| { + tx.send(result).unwrap(); + }); + + self.device.poll(wgpu::Maintain::wait()); + + let result = async { + rx.recv_async().await.unwrap().unwrap(); + let data = buffer_slice.get_mapped_range(); + let result = T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + result + }; + let result = result.await; + + result + } + + pub async fn run_computation_debug( + &self, + pipeline: &wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + workgroup_count: (u32, u32, u32), + ) -> (T, D) { + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Field Operations Encoder"), + }); + + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Field Operations Compute Pass"), + timestamp_writes: None, + }); + compute_pass.set_pipeline(pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups( + workgroup_count.0, + workgroup_count.1, + workgroup_count.2, + ); + } + + encoder.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + + encoder.copy_buffer_to_buffer( + &self.debug_buffer, + 0, + &self.staging_buffer_debug, + 0, + self.staging_buffer_debug.size(), + ); + + self.queue.submit(Some(encoder.finish())); + + let buffer_slice = self.staging_buffer.slice(..); + let (tx, rx) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |result| { + tx.send(result).unwrap(); + }); + + let buffer_slice_debug = self.staging_buffer_debug.slice(..); + let (tx_debug, rx_debug) = flume::bounded(1); + buffer_slice_debug.map_async(wgpu::MapMode::Read, move |result| { + tx_debug.send(result).unwrap(); + }); + + self.device.poll(wgpu::Maintain::wait()); + + let debug_result = async { + rx_debug.recv_async().await.unwrap().unwrap(); + let data = buffer_slice_debug.get_mapped_range(); + let result = D::from_bytes(&data); + drop(data); + self.staging_buffer_debug.unmap(); + result + }; + let debug_result = debug_result.await; + + let result = async { + rx.recv_async().await.unwrap().unwrap(); + let data = buffer_slice.get_mapped_range(); + let result = T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + result + }; + let result = result.await; + + (result, debug_result) + } +} + +/// Trait for GPU operations that require shader source generation +pub trait GpuOperation { + fn shader_source(&self) -> Cow<'static, str>; + + fn entry_point(&self) -> &'static str { + "main" + } +} diff --git a/crates/prover/src/core/backend/web/webgpu/mod.rs b/crates/prover/src/core/backend/web/webgpu/mod.rs new file mode 100644 index 000000000..5b777d931 --- /dev/null +++ b/crates/prover/src/core/backend/web/webgpu/mod.rs @@ -0,0 +1,6 @@ +pub mod gpu_common; +pub mod qm31; +pub mod serialization; +pub mod utils; + +pub use serialization::ByteSerialize; diff --git a/crates/prover/src/core/backend/web/webgpu/qm31.rs b/crates/prover/src/core/backend/web/webgpu/qm31.rs new file mode 100644 index 000000000..189bda5a1 --- /dev/null +++ b/crates/prover/src/core/backend/web/webgpu/qm31.rs @@ -0,0 +1,469 @@ +use bytemuck::{Pod, Zeroable}; + +use super::gpu_common::ByteSerialize; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::M31; +use crate::core::fields::qm31::QM31; + +#[derive(Debug, Clone, Copy, PartialEq, Pod, Zeroable)] +#[repr(C, align(4))] +pub struct GpuM31(pub u32); // alias M31 = u32 + +#[derive(Debug, Clone, Copy, PartialEq, Pod, Zeroable)] +#[repr(C, align(8))] +pub struct GpuCM31(pub [u32; 2]); // alias CM31 = vec2 + +#[derive(Debug, Clone, Copy, PartialEq, Pod, Zeroable)] +#[repr(C, align(16))] +pub struct GpuQM31(pub [u32; 4]); // alias QM31 = vec4 + +impl Default for GpuQM31 { + fn default() -> Self { + Self::zeroed() + } +} + +impl Default for GpuM31 { + fn default() -> Self { + Self::zeroed() + } +} + +impl From for GpuQM31 { + fn from(value: QM31) -> Self { + GpuQM31([ + value.0 .0.into(), + value.0 .1.into(), + value.1 .0.into(), + value.1 .1.into(), + ]) + } +} + +impl From for QM31 { + fn from(value: GpuQM31) -> Self { + QM31( + CM31::from_m31(value.0[0].into(), value.0[1].into()), + CM31::from_m31(value.0[2].into(), value.0[3].into()), + ) + } +} + +impl From for GpuM31 { + fn from(value: M31) -> Self { + GpuM31 { 0: value.into() } + } +} + +impl From for M31 { + fn from(value: GpuM31) -> Self { + M31::from(value.0) + } +} + +#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq, Pod, Zeroable)] +pub struct ComputeInput { + pub first: GpuQM31, + pub second: GpuQM31, +} + +#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq, Pod, Zeroable)] +pub struct ComputeOutput { + pub result: GpuQM31, +} + +impl ByteSerialize for ComputeInput {} +impl ByteSerialize for ComputeOutput {} + +pub enum QM31Operation { + Add, + Subtract, + Multiply, + Negate, + Inverse, + Square, + Pow5, +} + +#[cfg(test)] +mod tests { + use std::borrow::Cow; + + use num_traits::Zero; + + use super::*; + use crate::core::backend::web::webgpu::gpu_common::{GpuComputeInstance, GpuOperation}; + use crate::core::fields::cm31::CM31; + use crate::core::fields::m31::{M31, P}; + use crate::core::fields::qm31::QM31; + use crate::core::fields::FieldExpOps; + use crate::{cm31, qm31}; + + impl GpuOperation for QM31Operation { + fn shader_source(&self) -> Cow<'static, str> { + let base_source = include_str!("qm31.wgsl"); + + let inputs = r#" + struct ComputeInput { + first: QM31, + second: QM31, + } + + @group(0) @binding(0) var input: ComputeInput; + "#; + + let output = r#" + struct ComputeOutput { + result: QM31, + } + + @group(0) @binding(1) var output: ComputeOutput; + "#; + + let operation = match self { + QM31Operation::Add => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_add(input.first, input.second); + } + "# + } + QM31Operation::Multiply => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_mul(input.first, input.second); + } + "# + } + QM31Operation::Subtract => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_sub(input.first, input.second); + } + "# + } + QM31Operation::Negate => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_neg(input.first); + } + "# + } + QM31Operation::Inverse => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_inverse(input.first); + } + "# + } + QM31Operation::Square => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_square(input.first); + } + "# + } + QM31Operation::Pow5 => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_pow5(input.first); + } + "# + } + }; + + format!("{base_source}\n{inputs}\n{output}\n{operation}").into() + } + } + + pub async fn compute_field_operation( + operation: QM31Operation, + first: QM31, + second: QM31, + ) -> QM31 { + let input = ComputeInput { + first: first.into(), + second: second.into(), + }; + + let instance = GpuComputeInstance::new(&input, std::mem::size_of::()).await; + let (pipeline, bind_group) = + instance.create_pipeline(&operation.shader_source(), operation.entry_point()); + + let output = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + .await; + + output.result.into() + } + + #[test] + fn test_gpu_field_values() { + let qm0 = qm31!(1, 2, 3, 4); + let qm1 = qm31!(4, 5, 6, 7); + + // Test round-trip conversion CPU -> GPU -> CPU + let gpu_qm0 = GpuQM31::from(qm0); + let gpu_qm1 = GpuQM31::from(qm1); + + let cpu_qm0 = QM31( + CM31(gpu_qm0.0[0].into(), gpu_qm0.0[1].into()), + CM31(gpu_qm0.0[2].into(), gpu_qm0.0[3].into()), + ); + + let cpu_qm1 = QM31( + CM31(gpu_qm1.0[0].into(), gpu_qm1.0[1].into()), + CM31(gpu_qm1.0[2].into(), gpu_qm1.0[3].into()), + ); + + assert_eq!( + qm0, cpu_qm0, + "Round-trip conversion should preserve values for qm0" + ); + assert_eq!( + qm1, cpu_qm1, + "Round-trip conversion should preserve values for qm1" + ); + } + + #[test] + fn test_gpu_m31_field_arithmetic() { + // Test M31 field operations + let m = M31::from(19u32); + let one = M31::from(1u32); + let zero = M31::zero(); + + // Create QM31 values for GPU computation + let m_qm = QM31(CM31(m, zero), CM31::zero()); + let one_qm = QM31(CM31(one, zero), CM31::zero()); + let zero_qm = QM31(CM31(zero, zero), CM31::zero()); + + // Test addition + let cpu_add = m + one; + let gpu_add = pollster::block_on(compute_field_operation(QM31Operation::Add, m_qm, one_qm)); + assert_eq!(gpu_add.0 .0, cpu_add, "M31 addition failed"); + + // Test subtraction + let cpu_sub = m - one; + let gpu_sub = pollster::block_on(compute_field_operation( + QM31Operation::Subtract, + m_qm, + one_qm, + )); + assert_eq!(gpu_sub.0 .0, cpu_sub, "M31 subtraction failed"); + + // Test multiplication + let cpu_mul = m * one; + let gpu_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + m_qm, + one_qm, + )); + assert_eq!(gpu_mul.0 .0, cpu_mul, "M31 multiplication failed"); + + // Test negation + let cpu_neg = -m; + let gpu_neg = pollster::block_on(compute_field_operation( + QM31Operation::Negate, + m_qm, + zero_qm, + )); + assert_eq!(gpu_neg.0 .0, cpu_neg, "M31 negation failed"); + + // Test inverse + let cpu_inv = m.inverse(); + let gpu_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, + m_qm, + zero_qm, + )); + assert_eq!(gpu_inv.0 .0, cpu_inv, "M31 inverse failed"); + + // Test square + let cpu_square = m.square(); + let gpu_square = pollster::block_on(compute_field_operation( + QM31Operation::Square, + m_qm, + zero_qm, + )); + assert_eq!(gpu_square.0 .0, cpu_square, "M31 square operation failed"); + + // Test pow5 + let cpu_pow5 = m.square().square() * m; + let gpu_pow5 = + pollster::block_on(compute_field_operation(QM31Operation::Pow5, m_qm, zero_qm)); + assert_eq!(gpu_pow5.0 .0, cpu_pow5, "M31 pow5 operation failed"); + + // Test with large numbers (near P) + let large = M31::from(P - 1); + let large_qm = QM31(CM31(large, zero), CM31::zero()); + + // Test large number multiplication + let cpu_large_mul = large * m; + let gpu_large_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + large_qm, + m_qm, + )); + assert_eq!( + gpu_large_mul.0 .0, cpu_large_mul, + "M31 large number multiplication failed" + ); + + // Test large number inverse + let cpu_large_inv = one / large; + let gpu_large_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, + large_qm, + zero_qm, + )); + assert_eq!( + gpu_large_inv.0 .0, cpu_large_inv, + "M31 large number inverse failed" + ); + } + + #[test] + fn test_gpu_cm31_field_arithmetic() { + let cm0 = cm31!(1, 2); + let cm1 = cm31!(4, 5); + let zero = CM31::zero(); + + // Test addition + let cpu_add = cm0 + cm1; + let gpu_add = pollster::block_on(compute_field_operation( + QM31Operation::Add, + QM31(cm0, zero), + QM31(cm1, zero), + )); + assert_eq!(gpu_add.0, cpu_add, "CM31 addition failed"); + + // Test subtraction + let cpu_sub = cm0 - cm1; + let gpu_sub = pollster::block_on(compute_field_operation( + QM31Operation::Subtract, + QM31(cm0, zero), + QM31(cm1, zero), + )); + assert_eq!(gpu_sub.0, cpu_sub, "CM31 subtraction failed"); + + // Test multiplication + let cpu_mul = cm0 * cm1; + let gpu_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + QM31(cm0, zero), + QM31(cm1, zero), + )); + assert_eq!(gpu_mul.0, cpu_mul, "CM31 multiplication failed"); + + // Test negation + let cpu_neg = -cm0; + let gpu_neg = pollster::block_on(compute_field_operation( + QM31Operation::Negate, + QM31(cm0, zero), + QM31(zero, zero), + )); + assert_eq!(gpu_neg.0, cpu_neg, "CM31 negation failed"); + + // Test inverse + let cpu_inv = cm0.inverse(); + let gpu_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, + QM31(cm0, zero), + QM31(zero, zero), + )); + assert_eq!(gpu_inv.0, cpu_inv, "CM31 inverse failed"); + + // Test with large numbers (near P) + let large = cm31!(P - 1, P - 2); + let large_qm = QM31(large, zero); + + // Test large number multiplication + let cpu_large_mul = large * cm1; + let gpu_large_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + large_qm, + QM31(cm1, zero), + )); + assert_eq!( + gpu_large_mul.0, cpu_large_mul, + "CM31 large number multiplication failed" + ); + + // Test large number inverse + let cpu_large_inv = large.inverse(); + let gpu_large_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, + large_qm, + QM31(zero, zero), + )); + assert_eq!( + gpu_large_inv.0, cpu_large_inv, + "CM31 large number inverse failed" + ); + } + + #[test] + fn test_gpu_qm31_field_arithmetic() { + let qm0 = qm31!(1, 2, 3, 4); + let qm1 = qm31!(4, 5, 6, 7); + let zero = QM31::zero(); + + // Test addition + let cpu_add = qm0 + qm1; + let gpu_add = pollster::block_on(compute_field_operation(QM31Operation::Add, qm0, qm1)); + assert_eq!(gpu_add, cpu_add, "QM31 addition failed"); + + // Test subtraction + let cpu_sub = qm0 - qm1; + let gpu_sub = + pollster::block_on(compute_field_operation(QM31Operation::Subtract, qm0, qm1)); + assert_eq!(gpu_sub, cpu_sub, "QM31 subtraction failed"); + + // Test multiplication + let cpu_mul = qm0 * qm1; + let gpu_mul = + pollster::block_on(compute_field_operation(QM31Operation::Multiply, qm0, qm1)); + assert_eq!(gpu_mul, cpu_mul, "QM31 multiplication failed"); + + // Test negation + let cpu_neg = -qm0; + let gpu_neg = pollster::block_on(compute_field_operation(QM31Operation::Negate, qm0, zero)); + assert_eq!(gpu_neg, cpu_neg, "QM31 negation failed"); + + // Test inverse + let cpu_inv = qm0.inverse(); + let gpu_inv = pollster::block_on(compute_field_operation(QM31Operation::Inverse, qm0, qm1)); + assert_eq!(gpu_inv, cpu_inv, "QM31 inverse failed"); + + // Test with large numbers (near P) + let large = qm31!(P - 1, P - 2, P - 3, P - 4); + + // Test large number multiplication + let cpu_large_mul = large * qm1; + let gpu_large_mul = + pollster::block_on(compute_field_operation(QM31Operation::Multiply, large, qm1)); + assert_eq!( + gpu_large_mul, cpu_large_mul, + "QM31 large number multiplication failed" + ); + + // Test large number inverse + let cpu_large_inv = qm1.inverse(); + let gpu_large_inv = + pollster::block_on(compute_field_operation(QM31Operation::Inverse, qm1, zero)); + assert_eq!( + gpu_large_inv, cpu_large_inv, + "QM31 large number inverse failed" + ); + } +} diff --git a/crates/prover/src/core/backend/web/webgpu/qm31.wgsl b/crates/prover/src/core/backend/web/webgpu/qm31.wgsl new file mode 100644 index 000000000..d8ae5a2a4 --- /dev/null +++ b/crates/prover/src/core/backend/web/webgpu/qm31.wgsl @@ -0,0 +1,278 @@ +// This shader contains implementations for QM31/CM31/M31 operations. +// It is stateless, i.e. it does not contain any storage variables, and also it does not include +// any entrypoint functions, which means that it can be used as a library in other shaders. +// Note that the variable names that are used in this shader cannot be used in other shaders. +const P: u32 = 0x7FFFFFFF; // 2^31 - 1 +const MODULUS_BITS: u32 = 31u; +const HALF_BITS: u32 = 16u; + +alias M31 = u32; +alias CM31 = vec2; +alias QM31 = vec4; + +fn m31_add(a: M31, b: M31) -> M31 { + return M31(partial_reduce(a + b)); +} + +fn m31_sub(a: M31, b: M31) -> M31 { + return m31_add(a, m31_neg(b)); +} + +fn m31_mul(a: M31, b: M31) -> M31 { + // Split into 16-bit parts + let a1 = a >> HALF_BITS; + let a0 = a & 0xFFFFu; + let b1 = b >> HALF_BITS; + let b0 = b & 0xFFFFu; + + // Compute partial products + let m0 = partial_reduce(a0 * b0); + let m1 = partial_reduce(a0 * b1); + let m2 = partial_reduce(a1 * b0); + let m3 = partial_reduce(a1 * b1); + + // Combine middle terms with reduction + let mid = partial_reduce(m1 + m2); + + // Combine parts with partial reduction + let shifted_mid = partial_reduce(mid << HALF_BITS); + let low = partial_reduce(m0 + shifted_mid); + + let high_part = partial_reduce(m3 + (mid >> HALF_BITS)); + + // Final combination using Mersenne prime property + let result = partial_reduce( + partial_reduce((high_part << 1u)) + + partial_reduce((low >> MODULUS_BITS)) + + partial_reduce(low & P) + ); + return M31(result); +} + +fn m31_neg(a: M31) -> M31 { + return M31(partial_reduce(P - a)); +} + +fn m31_square(x: M31) -> M31 { + return m31_mul(x, x); +} + +fn m31_pow3(x: M31) -> M31 { + let x2 = m31_square(x); + return m31_mul(x, x2); +} + +fn m31_pow5(x: M31) -> M31 { + let x2 = m31_square(x); + let x4 = m31_square(x2); + return m31_mul(x4, x); +} + +fn m31_pow8(x: M31) -> M31 { + let x2 = m31_square(x); + let x4 = m31_square(x2); + return m31_square(x4); +} + +fn m31_pow128(x: M31) -> M31 { + let x8 = m31_pow8(x); + let x64 = m31_pow8(x8); + return m31_square(x64); +} + +fn m31_pow256(x: M31) -> M31 { + let x8 = m31_pow8(x); + let x64 = m31_pow8(x8); + let x128 = m31_square(x64); + return m31_square(x128); +} + +fn m31_inverse(x: M31) -> M31 { + // Computes x^(2^31-2) using the same sequence as pow2147483645 + // This is equivalent to x^(P-2) where P = 2^31-1 + + // t0 = x^5 + let t0 = m31_pow5(x); + + // t1 = x^15 + let t1 = m31_pow3(t0); + + // t2 = x^125 + let t2 = m31_mul(m31_pow8(t1), t0); + + // t3 = x^255 + let t3 = m31_mul(m31_square(t2), t0); + + // t4 = x^65535 + let t4 = m31_mul(m31_pow256(t3), t3); + + // t5 = x^16777215 + let t5 = m31_mul(m31_pow256(t4), t3); + + // result = x^2147483520 + var result = m31_pow128(t5); + result = m31_mul(result, t2); + + return result; +} + +fn cm31(a0: M31, b0: M31) -> CM31 { + return vec2(a0, b0); +} + +// Complex field operations for CM31 +fn cm31_add(a: CM31, b: CM31) -> CM31 { + return vec2( + m31_add(a.x, b.x), + m31_add(a.y, b.y) + ); +} + +fn cm31_sub(a: CM31, b: CM31) -> CM31 { + return vec2( + m31_sub(a.x, b.x), + m31_sub(a.y, b.y) + ); +} + +fn cm31_mul(a: CM31, b: CM31) -> CM31 { + // (a + bi)(c + di) = (ac - bd) + (ad + bc)i + let ac = m31_mul(a.x, b.x); + let bd = m31_mul(a.y, b.y); + let ad = m31_mul(a.x, b.y); + let bc = m31_mul(a.y, b.x); + + return vec2( + m31_sub(ac, bd), + m31_add(ad, bc) + ); +} + +fn cm31_neg(a: CM31) -> CM31 { + return vec2(m31_neg(a.x), m31_neg(a.y)); +} + +fn cm31_square(x: CM31) -> CM31 { + return cm31_mul(x, x); +} + +fn cm31_pow5(x: CM31) -> CM31 { + return cm31_mul(cm31_square(x), x); +} + +fn cm31_inverse(x: CM31) -> CM31 { + // 1/(a + bi) = (a - bi)/(a² + b²) + let a2 = m31_mul(x.x, x.x); + let b2 = m31_mul(x.y, x.y); + let denom = m31_add(a2, b2); + let denomInv = m31_inverse(denom); + return vec2( + m31_mul(x.x, denomInv), + m31_neg(m31_mul(x.y, denomInv)) + ); +} + +fn qm31(a: CM31, b: CM31) -> QM31 { + return vec4(a.x, a.y, b.x, b.y); +} + +fn qm31_4(a: M31, b: M31, c: M31, d: M31) -> QM31 { + return vec4(a, b, c, d); +} + +// Quadratic extension field operations for QM31 +fn qm31_add(u: QM31, v: QM31) -> QM31 { + let a = cm31_add(vec2(u.x, u.y), vec2(v.x, v.y)); + let b = cm31_add(vec2(u.z, u.w), vec2(v.z, v.w)); + return qm31(a, b); +} + +fn qm31_sub(u: QM31, v: QM31) -> QM31 { + let a = cm31_sub(vec2(u.x, u.y), vec2(v.x, v.y)); + let b = cm31_sub(vec2(u.z, u.w), vec2(v.z, v.w)); + return qm31(a, b); +} + +fn qm31_mul(u: QM31, v: QM31) -> QM31 { + // (a + bu)(c + du) = (ac + rbd) + (ad + bc)u + // where r = 2 + i is the irreducible polynomial coefficient + let ua = vec2(u.x, u.y); + let ub = vec2(u.z, u.w); + let va = vec2(v.x, v.y); + let vb = vec2(v.z, v.w); + + let ac = cm31_mul(ua, va); + let bd = cm31_mul(ub, vb); + let ad = cm31_mul(ua, vb); + let bc = cm31_mul(ub, va); + + // r = 2 + i + let r = vec2(2u, 1u); + let rbd = cm31_mul(r, bd); + + let real = cm31_add(ac, rbd); + let imag = cm31_add(ad, bc); + return qm31(real, imag); +} + +fn qm31_neg(q: QM31) -> QM31 { + return qm31( + cm31_neg(vec2(q.x, q.y)), + cm31_neg(vec2(q.z, q.w)) + ); +} + +fn qm31_square(q: QM31) -> QM31 { + return qm31_mul(q, q); +} + +fn qm31_pow5(q: QM31) -> QM31 { + let q2 = qm31_square(q); + let q4 = qm31_square(q2); + return qm31_mul(q4, q); +} + +fn qm31_inverse(q: QM31) -> QM31 { + let a = vec2(q.x, q.y); + let b = vec2(q.z, q.w); + let b2 = cm31_square(b); + let r = vec2(2u, 1u); // 2 + i + let rb2 = cm31_mul(r, b2); + let a2 = cm31_square(a); + let denomInv = cm31_inverse(cm31_sub(a2, rb2)); + + let neg_b = cm31_neg(b); + return qm31( + cm31_mul(a, denomInv), + cm31_mul(neg_b, denomInv) + ); +} + +// Utility functions +fn partial_reduce(val: u32) -> u32 { + let reduced = val - P; + return select(val, reduced, reduced < val); +} + +const ZERO_FRACTION: Fraction = + Fraction(vec4(0u), vec4(1u, 0u, 0u, 0u)); + +struct Fraction { + numerator : QM31, // vec4 + denominator: QM31, +} + +// Add two fractions: (a/b + c/d) = (ad + bc)/(bd) +fn fraction_add(x: Fraction, y: Fraction) -> Fraction { + let num = qm31_add( + qm31_mul(x.numerator, y.denominator), + qm31_mul(y.numerator, x.denominator) + ); + let den = qm31_mul(x.denominator, y.denominator); + return Fraction(num, den); +} + +fn fraction_eq(x: Fraction, y: Fraction) -> bool { + return all(x.numerator == y.numerator) && + all(x.denominator == y.denominator); +} diff --git a/crates/prover/src/core/backend/web/webgpu/serialization.rs b/crates/prover/src/core/backend/web/webgpu/serialization.rs new file mode 100644 index 000000000..eaec79dd8 --- /dev/null +++ b/crates/prover/src/core/backend/web/webgpu/serialization.rs @@ -0,0 +1,15 @@ +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} diff --git a/crates/prover/src/core/backend/web/webgpu/utils.rs b/crates/prover/src/core/backend/web/webgpu/utils.rs new file mode 100644 index 000000000..1c4840277 --- /dev/null +++ b/crates/prover/src/core/backend/web/webgpu/utils.rs @@ -0,0 +1,334 @@ +use wgpu::util::DeviceExt; + +use super::ByteSerialize; + +/// Input data for the GPU computation +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +pub struct ComputeInput { + pub i: u32, + pub domain_log_size: u32, + pub eval_log_size: u32, + pub offset: i32, +} + +/// Output data from the GPU computation +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +pub struct ComputeOutput { + pub result: u32, +} + +impl From for usize { + fn from(output: ComputeOutput) -> Self { + output.result as usize + } +} + +impl ByteSerialize for ComputeInput {} +impl ByteSerialize for ComputeOutput {} + +/// GPU instance for utility computations +pub struct GpuUtilsInstance { + device: wgpu::Device, + queue: wgpu::Queue, + input_buffer: wgpu::Buffer, + output_buffer: wgpu::Buffer, + staging_buffer: wgpu::Buffer, +} + +impl GpuUtilsInstance { + pub async fn new(input_data: &T, output_size: usize) -> Self { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Field Operations Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + // Create input buffer + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Field Input Buffer"), + contents: input_data.as_bytes(), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + // Create output buffer + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Output Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Create staging buffer for reading results + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Staging Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Self { + device, + queue, + input_buffer, + output_buffer, + staging_buffer, + } + } + + /// Creates a compute pipeline for the operation + pub fn create_pipeline( + &self, + shader_source: &str, + entry_point: &str, + ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { + let shader = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Field Operations Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + let bind_group_layout = + self.device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Field Operations Bind Group Layout"), + }); + + let pipeline_layout = self + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Field Operations Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Field Operations Pipeline"), + layout: Some(&pipeline_layout), + module: &shader, + entry_point: Some(entry_point), + cache: None, + compilation_options: Default::default(), + }); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: self.input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: self.output_buffer.as_entire_binding(), + }, + ], + label: Some("Field Operations Bind Group"), + }); + + (pipeline, bind_group) + } + + /// Runs the computation on the GPU + async fn run_computation( + &self, + pipeline: &wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + workgroup_count: (u32, u32, u32), + ) -> T { + // Create command encoder and compute pass + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + compute_pass.set_pipeline(&pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups( + workgroup_count.0, + workgroup_count.1, + workgroup_count.2, + ); + } + + // Copy results to staging buffer + encoder.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + + // Submit command buffer and wait for results + self.queue.submit(Some(encoder.finish())); + + // Read results from staging buffer + let slice = self.staging_buffer.slice(..); + let (sender, receiver) = flume::bounded(1); + slice.map_async(wgpu::MapMode::Read, move |result| { + sender.send(result).unwrap(); + }); + self.device.poll(wgpu::Maintain::Wait); + + receiver.recv_async().await.unwrap().unwrap(); + let data = slice.get_mapped_range(); + let result = *T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + + result + } +} + +#[derive(Debug)] +pub enum GpuUtilsOperation { + OffsetBitReversedCircleDomainIndex, +} + +impl GpuUtilsOperation { + pub fn shader_source(&self) -> String { + let base_source = include_str!("utils.wgsl"); + let qm31_source = include_str!("qm31.wgsl"); + + let inputs = r#" + struct Inputs { + i: u32, + domain_log_size: u32, + eval_log_size: u32, + offset: i32, + } + + @group(0) @binding(0) var inputs: Inputs; + "#; + + let output = r#" + struct Output { + result: u32, + } + + @group(0) @binding(1) var output: Output; + "#; + + let operation = match self { + GpuUtilsOperation::OffsetBitReversedCircleDomainIndex => { + r#" + @compute @workgroup_size(1) + fn main() {{ + let i = inputs.i; + let domain_log_size = inputs.domain_log_size; + let eval_log_size = inputs.eval_log_size; + let offset = inputs.offset; + + let result = offset_bit_reversed_circle_domain_index(i, domain_log_size, eval_log_size, offset); + output.result = result; + }} + "# + } + }; + + format!("{base_source}\n{qm31_source}\n{inputs}\n{output}\n{operation}") + } +} + +/// Computes the offset bit reversed circle domain index using the GPU +pub async fn compute_offset_bit_reversed_circle_domain_index( + i: usize, + domain_log_size: u32, + eval_log_size: u32, + offset: i32, +) -> usize { + let input = ComputeInput { + i: i as u32, + domain_log_size, + eval_log_size, + offset, + }; + + let instance = GpuUtilsInstance::new(&input, std::mem::size_of::()).await; + + let shader_source = GpuUtilsOperation::OffsetBitReversedCircleDomainIndex.shader_source(); + let (pipeline, bind_group) = instance.create_pipeline(&shader_source, "main"); + + let gpu_result = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + .await; + gpu_result.into() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::utils::offset_bit_reversed_circle_domain_index as cpu_offset_bit_reversed_circle_domain_index; + + #[test] + fn test_offset_bit_reversed_circle_domain_index() { + // Test parameters from the CPU test + let domain_log_size = 3; + let eval_log_size = 6; + let initial_index = 5; + let offset = -2; + + let gpu_result = pollster::block_on(compute_offset_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + offset, + )); + + let cpu_result = cpu_offset_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + offset as isize, + ); + + assert_eq!(gpu_result, cpu_result, "GPU and CPU results should match"); + } +} diff --git a/crates/prover/src/core/backend/web/webgpu/utils.wgsl b/crates/prover/src/core/backend/web/webgpu/utils.wgsl new file mode 100644 index 000000000..27b6fbef7 --- /dev/null +++ b/crates/prover/src/core/backend/web/webgpu/utils.wgsl @@ -0,0 +1,68 @@ +// This shader contains utility functions for bit manipulation and index transformations. +// It is stateless and can be used as a library in other shaders. + +/// Returns the bit reversed index of `i` which is represented by `log_size` bits. +fn bit_reverse_index(i: u32, log_size: u32) -> u32 { + if (log_size == 0u) { + return i; + } + let bits = reverse_bits_u32(i); + return bits >> (32u - log_size); +} + +fn reverse_bits_u32(x: u32) -> u32 { + var x_mut = x; + var result = 0u; + + for (var i = 0u; i < 32u; i = i + 1u) { + result = (result << 1u) | (x_mut & 1u); + x_mut = x_mut >> 1u; + } + + return result; +} + +/// Returns the index of the offset element in a bit reversed circle evaluation +/// of log size `eval_log_size` relative to a smaller domain of size `domain_log_size`. +fn offset_bit_reversed_circle_domain_index( + i: u32, + domain_log_size: u32, + eval_log_size: u32, + offset: i32, +) -> u32 { + var prev_index = bit_reverse_index(i, eval_log_size); + let half_size = 1u << (eval_log_size - 1u); + let step_size = i32(1u << (eval_log_size - domain_log_size - 1u)) * offset; + + if (prev_index < half_size) { + let temp = i32(prev_index) + step_size; + // Implement rem_euclid for positive modulo + let m = i32(half_size); + let rem = temp % m; + prev_index = u32(select(rem + m, rem, rem >= 0)); + } else { + let temp = i32(prev_index - half_size) - step_size; + // Implement rem_euclid for positive modulo + let m = i32(half_size); + let rem = temp % m; + prev_index = u32(select(rem + m, rem, rem >= 0)) + half_size; + } + + return bit_reverse_index(prev_index, eval_log_size); +} + +fn circle_domain_index_to_coset_index(i: u32, n: u32) -> u32 { + if (i < (n / 2u)) { + return 2u * i; + } else { + return 2u * (n - 1u - i) + 1u; + } +} + +fn coset_index_to_circle_domain_index(coset_index: u32, log_domain_size: u32) -> u32 { + if (coset_index % 2u == 0u) { + return coset_index / 2u; + } else { + return ((2u << log_domain_size) - coset_index) / 2u; + } +} \ No newline at end of file diff --git a/crates/prover/src/core/fields/m31.rs b/crates/prover/src/core/fields/m31.rs index 2a50806ac..e364bf833 100644 --- a/crates/prover/src/core/fields/m31.rs +++ b/crates/prover/src/core/fields/m31.rs @@ -69,6 +69,12 @@ impl M31 { } } +impl Into for M31 { + fn into(self) -> u32 { + self.0 + } +} + impl Display for M31 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0)