Skip to content

Commit 98b3671

Browse files
committed
use an uninitialized global buffer for CUDA
1 parent 401f1a3 commit 98b3671

File tree

2 files changed

+59
-13
lines changed

2 files changed

+59
-13
lines changed

prover/src/gpu/cuda/mod.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! This module contains GPU acceleration logic for Nvidia CUDA devices.
22
3-
use std::marker::PhantomData;
3+
use std::{cell::RefCell, marker::PhantomData, mem::MaybeUninit};
44

55
use air::{AuxRandElements, PartitionOptions};
66
use miden_gpu::{
@@ -32,33 +32,40 @@ const DIGEST_SIZE: usize = Rpo256::DIGEST_RANGE.end - Rpo256::DIGEST_RANGE.start
3232
// ================================================================================================
3333

3434
/// Wraps an [ExecutionProver] and provides GPU acceleration for building trace commitments.
35-
pub(crate) struct CudaExecutionProver<H, D, R>
35+
pub(crate) struct CudaExecutionProver<'g, H, D, R>
3636
where
3737
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
3838
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
3939
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
4040
{
41+
main: RefCell<&'g mut [MaybeUninit<Felt>]>,
42+
aux: RefCell<&'g mut [MaybeUninit<Felt>]>,
43+
ce: RefCell<&'g mut [MaybeUninit<Felt>]>,
44+
4145
pub execution_prover: ExecutionProver<H, R>,
4246
pub hash_fn: HashFn,
4347
phantom_data: PhantomData<D>,
4448
}
4549

46-
impl<H, D, R> CudaExecutionProver<H, D, R>
50+
impl<'g, H, D, R> CudaExecutionProver<'g, H, D, R>
4751
where
4852
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
4953
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
5054
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
5155
{
52-
pub fn new(execution_prover: ExecutionProver<H, R>, hash_fn: HashFn) -> Self {
56+
pub fn new(execution_prover: ExecutionProver<H, R>, hash_fn: HashFn, main: &'g mut [MaybeUninit<Felt>], aux: &'g mut [MaybeUninit<Felt>], ce: &'g mut [MaybeUninit<Felt>]) -> Self {
5357
CudaExecutionProver {
58+
main: RefCell::new(main),
59+
aux: RefCell::new(aux),
60+
ce: RefCell::new(ce),
5461
execution_prover,
5562
hash_fn,
5663
phantom_data: PhantomData,
5764
}
5865
}
5966
}
6067

61-
impl<H, D, R> Prover for CudaExecutionProver<H, D, R>
68+
impl<'g, H, D, R> Prover for CudaExecutionProver<'g, H, D, R>
6269
where
6370
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField> + Send + Sync,
6471
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
@@ -67,11 +74,11 @@ where
6774
type BaseField = Felt;
6875
type Air = ProcessorAir;
6976
type Trace = ExecutionTrace;
70-
type VC = MerkleTree<Self::HashFn>;
77+
type VC = MerkleTree<'g, Self::HashFn>;
7178
type HashFn = H;
7279
type RandomCoin = R;
73-
type TraceLde<E: FieldElement<BaseField = Felt>> = CudaTraceLde<E, H>;
74-
type ConstraintCommitment<E: FieldElement<BaseField = Felt>> = CudaConstraintCommitment<E, H>;
80+
type TraceLde<E: FieldElement<BaseField = Felt>> = CudaTraceLde<'g, E, H>;
81+
type ConstraintCommitment<E: FieldElement<BaseField = Felt>> = CudaConstraintCommitment<'g, E, H>;
7582
type ConstraintEvaluator<'a, E: FieldElement<BaseField = Felt>> =
7683
DefaultConstraintEvaluator<'a, ProcessorAir, E>;
7784

@@ -90,7 +97,7 @@ where
9097
domain: &StarkDomain<Felt>,
9198
partition_options: PartitionOptions,
9299
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
93-
CudaTraceLde::new(trace_info, main_trace, domain, partition_options, self.hash_fn)
100+
CudaTraceLde::new(self.main.take(), self.aux.take(), trace_info, main_trace, domain, partition_options, self.hash_fn)
94101
}
95102

96103
fn new_evaluator<'a, E: FieldElement<BaseField = Felt>>(
@@ -125,6 +132,7 @@ where
125132
E: FieldElement<BaseField = Self::BaseField>,
126133
{
127134
CudaConstraintCommitment::new(
135+
self.ce.take(),
128136
composition_poly_trace,
129137
num_constraint_composition_columns,
130138
domain,

prover/src/lib.rs

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ extern crate std;
88

99
use core::marker::PhantomData;
1010

11-
use air::{AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs};
11+
use air::{trace::{AUX_TRACE_WIDTH, TRACE_WIDTH}, AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs};
12+
#[cfg(all(target_arch = "x86_64", feature = "cuda"))]
13+
use miden_gpu::cuda::util::{struct_size, CudaStorageOwned};
1214
#[cfg(any(
1315
all(feature = "metal", target_arch = "aarch64", target_os = "macos"),
1416
all(feature = "cuda", target_arch = "x86_64")
@@ -20,7 +22,7 @@ use processor::{
2022
RpxRandomCoin, WinterRandomCoin,
2123
},
2224
math::{Felt, FieldElement},
23-
ExecutionTrace, Program,
25+
ExecutionTrace, Program, QuadExtension,
2426
};
2527
use tracing::instrument;
2628
use winter_maybe_async::{maybe_async, maybe_await};
@@ -48,6 +50,37 @@ pub use winter_prover::{crypto::MerkleTree as MerkleTreeVC, Proof};
4850
// PROVER
4951
// ================================================================================================
5052

53+
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
54+
#[instrument("allocate_memory", skip_all)]
55+
fn allocate_memory(trace: &ExecutionTrace, options: &ProvingOptions) -> CudaStorageOwned {
56+
use winter_prover::{math::fields::CubeExtension, Air};
57+
58+
let main_columns = TRACE_WIDTH;
59+
let aux_columns = AUX_TRACE_WIDTH;
60+
let rows = trace.get_trace_len();
61+
let options: WinterProofOptions = options.clone().into();
62+
let extension = options.field_extension();
63+
let blowup = options.blowup_factor();
64+
let partitions = options.partition_options();
65+
66+
let main = struct_size::<Felt>(main_columns, rows, blowup, partitions);
67+
let aux = match extension {
68+
FieldExtension::None => struct_size::<Felt>(aux_columns, rows, blowup, partitions),
69+
FieldExtension::Quadratic => struct_size::<QuadExtension<Felt>>(aux_columns, rows, blowup, partitions),
70+
FieldExtension::Cubic => struct_size::<CubeExtension<Felt>>(aux_columns, rows, blowup, partitions),
71+
};
72+
73+
let air = ProcessorAir::new(trace.info().clone(), PublicInputs::new(Default::default(), Default::default(), Default::default()), options);
74+
let ce_columns = air.context().num_constraint_composition_columns();
75+
let ce = match extension {
76+
FieldExtension::None => struct_size::<Felt>(ce_columns, rows, blowup, partitions),
77+
FieldExtension::Quadratic => struct_size::<QuadExtension<Felt>>(ce_columns, rows, blowup, partitions),
78+
FieldExtension::Cubic => struct_size::<CubeExtension<Felt>>(ce_columns, rows, blowup, partitions),
79+
};
80+
81+
CudaStorageOwned::new(main, aux, ce)
82+
}
83+
5184
/// Executes and proves the specified `program` and returns the result together with a STARK-based
5285
/// proof of the program's execution.
5386
///
@@ -84,6 +117,11 @@ pub fn prove(
84117
let stack_outputs = trace.stack_outputs().clone();
85118
let hash_fn = options.hash_fn();
86119

120+
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
121+
let mut storage = allocate_memory(&trace, &options);
122+
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
123+
let (main, aux, ce) = storage.borrow_mut();
124+
87125
// generate STARK proof
88126
let proof = match hash_fn {
89127
HashFunction::Blake3_192 => {
@@ -111,7 +149,7 @@ pub fn prove(
111149
#[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))]
112150
let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpo256);
113151
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
114-
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256);
152+
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256, main, aux, ce);
115153
maybe_await!(prover.prove(trace))
116154
},
117155
HashFunction::Rpx256 => {
@@ -123,7 +161,7 @@ pub fn prove(
123161
#[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))]
124162
let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpx256);
125163
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
126-
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256);
164+
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256, main, aux, ce);
127165
maybe_await!(prover.prove(trace))
128166
},
129167
}

0 commit comments

Comments
 (0)