@@ -8,7 +8,9 @@ extern crate std;
8
8
9
9
use core:: marker:: PhantomData ;
10
10
11
- use air:: { AuxRandElements , PartitionOptions , ProcessorAir , PublicInputs } ;
11
+ use air:: { trace:: { AUX_TRACE_WIDTH , TRACE_WIDTH } , AuxRandElements , PartitionOptions , ProcessorAir , PublicInputs } ;
12
+ #[ cfg( all( target_arch = "x86_64" , feature = "cuda" ) ) ]
13
+ use miden_gpu:: cuda:: util:: { struct_size, CudaStorageOwned } ;
12
14
#[ cfg( any(
13
15
all( feature = "metal" , target_arch = "aarch64" , target_os = "macos" ) ,
14
16
all( feature = "cuda" , target_arch = "x86_64" )
@@ -20,7 +22,7 @@ use processor::{
20
22
RpxRandomCoin , WinterRandomCoin ,
21
23
} ,
22
24
math:: { Felt , FieldElement } ,
23
- ExecutionTrace , Program ,
25
+ ExecutionTrace , Program , QuadExtension ,
24
26
} ;
25
27
use tracing:: instrument;
26
28
use winter_maybe_async:: { maybe_async, maybe_await} ;
@@ -48,6 +50,37 @@ pub use winter_prover::{crypto::MerkleTree as MerkleTreeVC, Proof};
48
50
// PROVER
49
51
// ================================================================================================
50
52
53
+ #[ cfg( all( feature = "cuda" , target_arch = "x86_64" ) ) ]
54
+ #[ instrument( "allocate_memory" , skip_all) ]
55
+ fn allocate_memory ( trace : & ExecutionTrace , options : & ProvingOptions ) -> CudaStorageOwned {
56
+ use winter_prover:: { math:: fields:: CubeExtension , Air } ;
57
+
58
+ let main_columns = TRACE_WIDTH ;
59
+ let aux_columns = AUX_TRACE_WIDTH ;
60
+ let rows = trace. get_trace_len ( ) ;
61
+ let options: WinterProofOptions = options. clone ( ) . into ( ) ;
62
+ let extension = options. field_extension ( ) ;
63
+ let blowup = options. blowup_factor ( ) ;
64
+ let partitions = options. partition_options ( ) ;
65
+
66
+ let main = struct_size :: < Felt > ( main_columns, rows, blowup, partitions) ;
67
+ let aux = match extension {
68
+ FieldExtension :: None => struct_size :: < Felt > ( aux_columns, rows, blowup, partitions) ,
69
+ FieldExtension :: Quadratic => struct_size :: < QuadExtension < Felt > > ( aux_columns, rows, blowup, partitions) ,
70
+ FieldExtension :: Cubic => struct_size :: < CubeExtension < Felt > > ( aux_columns, rows, blowup, partitions) ,
71
+ } ;
72
+
73
+ let air = ProcessorAir :: new ( trace. info ( ) . clone ( ) , PublicInputs :: new ( Default :: default ( ) , Default :: default ( ) , Default :: default ( ) ) , options) ;
74
+ let ce_columns = air. context ( ) . num_constraint_composition_columns ( ) ;
75
+ let ce = match extension {
76
+ FieldExtension :: None => struct_size :: < Felt > ( ce_columns, rows, blowup, partitions) ,
77
+ FieldExtension :: Quadratic => struct_size :: < QuadExtension < Felt > > ( ce_columns, rows, blowup, partitions) ,
78
+ FieldExtension :: Cubic => struct_size :: < CubeExtension < Felt > > ( ce_columns, rows, blowup, partitions) ,
79
+ } ;
80
+
81
+ CudaStorageOwned :: new ( main, aux, ce)
82
+ }
83
+
51
84
/// Executes and proves the specified `program` and returns the result together with a STARK-based
52
85
/// proof of the program's execution.
53
86
///
@@ -84,6 +117,11 @@ pub fn prove(
84
117
let stack_outputs = trace. stack_outputs ( ) . clone ( ) ;
85
118
let hash_fn = options. hash_fn ( ) ;
86
119
120
+ #[ cfg( all( feature = "cuda" , target_arch = "x86_64" ) ) ]
121
+ let mut storage = allocate_memory ( & trace, & options) ;
122
+ #[ cfg( all( feature = "cuda" , target_arch = "x86_64" ) ) ]
123
+ let ( main, aux, ce) = storage. borrow_mut ( ) ;
124
+
87
125
// generate STARK proof
88
126
let proof = match hash_fn {
89
127
HashFunction :: Blake3_192 => {
@@ -111,7 +149,7 @@ pub fn prove(
111
149
#[ cfg( all( feature = "metal" , target_arch = "aarch64" , target_os = "macos" ) ) ]
112
150
let prover = gpu:: metal:: MetalExecutionProver :: new ( prover, HashFn :: Rpo256 ) ;
113
151
#[ cfg( all( feature = "cuda" , target_arch = "x86_64" ) ) ]
114
- let prover = gpu:: cuda:: CudaExecutionProver :: new ( prover, HashFn :: Rpo256 ) ;
152
+ let prover = gpu:: cuda:: CudaExecutionProver :: new ( prover, HashFn :: Rpo256 , main , aux , ce ) ;
115
153
maybe_await ! ( prover. prove( trace) )
116
154
} ,
117
155
HashFunction :: Rpx256 => {
@@ -123,7 +161,7 @@ pub fn prove(
123
161
#[ cfg( all( feature = "metal" , target_arch = "aarch64" , target_os = "macos" ) ) ]
124
162
let prover = gpu:: metal:: MetalExecutionProver :: new ( prover, HashFn :: Rpx256 ) ;
125
163
#[ cfg( all( feature = "cuda" , target_arch = "x86_64" ) ) ]
126
- let prover = gpu:: cuda:: CudaExecutionProver :: new ( prover, HashFn :: Rpx256 ) ;
164
+ let prover = gpu:: cuda:: CudaExecutionProver :: new ( prover, HashFn :: Rpx256 , main , aux , ce ) ;
127
165
maybe_await ! ( prover. prove( trace) )
128
166
} ,
129
167
}
0 commit comments