|  | 
|  | 1 | +use crate::gas::stats::GasStats; | 
|  | 2 | +use cheatnet::trace_data::{CallTrace, CallTraceNode}; | 
|  | 3 | +use debugging::ContractName as DebuggingContractName; | 
|  | 4 | +use debugging::ContractsDataStore; | 
|  | 5 | +use starknet_api::core::{ClassHash, EntryPointSelector}; | 
|  | 6 | +use starknet_api::execution_resources::GasVector; | 
|  | 7 | +use std::collections::BTreeMap; | 
|  | 8 | + | 
|  | 9 | +type ContractName = String; | 
|  | 10 | +type Selector = String; | 
|  | 11 | + | 
|  | 12 | +#[derive(Debug, Clone)] | 
|  | 13 | +pub struct GasSingleTestInfo { | 
|  | 14 | +    pub gas_used: GasVector, | 
|  | 15 | +    pub report_data: BTreeMap<ContractName, ContractInfo>, | 
|  | 16 | +} | 
|  | 17 | + | 
|  | 18 | +#[derive(Debug, Clone, Default)] | 
|  | 19 | +pub struct ContractInfo { | 
|  | 20 | +    pub gas_used: GasVector, | 
|  | 21 | +    pub functions: BTreeMap<Selector, SelectorReportData>, | 
|  | 22 | +} | 
|  | 23 | + | 
|  | 24 | +#[derive(Debug, Clone, Default)] | 
|  | 25 | +pub struct SelectorReportData { | 
|  | 26 | +    pub gas_stats: GasStats, | 
|  | 27 | +    pub n_calls: u64, | 
|  | 28 | +    pub records: Vec<u64>, | 
|  | 29 | +} | 
|  | 30 | + | 
|  | 31 | +impl GasSingleTestInfo { | 
|  | 32 | +    #[must_use] | 
|  | 33 | +    pub fn new(gas_used: GasVector) -> Self { | 
|  | 34 | +        Self { | 
|  | 35 | +            gas_used, | 
|  | 36 | +            report_data: BTreeMap::new(), | 
|  | 37 | +        } | 
|  | 38 | +    } | 
|  | 39 | + | 
|  | 40 | +    #[must_use] | 
|  | 41 | +    pub fn new_with_report( | 
|  | 42 | +        gas_used: GasVector, | 
|  | 43 | +        call_trace: &CallTrace, | 
|  | 44 | +        contracts_data: &ContractsDataStore, | 
|  | 45 | +    ) -> Self { | 
|  | 46 | +        Self { | 
|  | 47 | +            gas_used, | 
|  | 48 | +            report_data: BTreeMap::new(), | 
|  | 49 | +        } | 
|  | 50 | +        .collect_gas_data(call_trace, contracts_data) | 
|  | 51 | +    } | 
|  | 52 | + | 
|  | 53 | +    fn collect_gas_data(mut self, trace: &CallTrace, contracts_data: &ContractsDataStore) -> Self { | 
|  | 54 | +        let mut stack = trace.nested_calls.clone(); | 
|  | 55 | + | 
|  | 56 | +        while let Some(call_trace_node) = stack.pop() { | 
|  | 57 | +            if let CallTraceNode::EntryPointCall(call) = call_trace_node { | 
|  | 58 | +                let call = call.borrow(); | 
|  | 59 | +                let class_hash = call.entry_point.class_hash.expect( | 
|  | 60 | +                    "class_hash should be set in `fn execute_call_entry_point` in cheatnet", | 
|  | 61 | +                ); | 
|  | 62 | + | 
|  | 63 | +                let contract_name = get_contract_name(contracts_data, class_hash); | 
|  | 64 | +                let selector = get_selector(contracts_data, call.entry_point.entry_point_selector); | 
|  | 65 | +                let gas = call | 
|  | 66 | +                    .gas_report_data | 
|  | 67 | +                    .as_ref() | 
|  | 68 | +                    .expect("Gas report data must be updated after test execution") | 
|  | 69 | +                    .get_gas(); | 
|  | 70 | + | 
|  | 71 | +                self.update_entry(contract_name, selector, gas); | 
|  | 72 | +                stack.extend(call.nested_calls.clone()); | 
|  | 73 | +            } | 
|  | 74 | +        } | 
|  | 75 | +        self.finalize(); | 
|  | 76 | +        self | 
|  | 77 | +    } | 
|  | 78 | + | 
|  | 79 | +    fn update_entry( | 
|  | 80 | +        &mut self, | 
|  | 81 | +        contract_name: ContractName, | 
|  | 82 | +        selector: Selector, | 
|  | 83 | +        gas_used: GasVector, | 
|  | 84 | +    ) { | 
|  | 85 | +        let contract_info = self.report_data.entry(contract_name).or_default(); | 
|  | 86 | + | 
|  | 87 | +        if let Some(gas) = contract_info.gas_used.checked_add(gas_used) { | 
|  | 88 | +            contract_info.gas_used = gas; | 
|  | 89 | +        } | 
|  | 90 | + | 
|  | 91 | +        let entry = contract_info.functions.entry(selector).or_default(); | 
|  | 92 | +        entry.records.push(gas_used.l2_gas.0); | 
|  | 93 | +        entry.n_calls += 1; | 
|  | 94 | +    } | 
|  | 95 | + | 
|  | 96 | +    fn finalize(&mut self) { | 
|  | 97 | +        for contract_info in self.report_data.values_mut() { | 
|  | 98 | +            for gas_info in contract_info.functions.values_mut() { | 
|  | 99 | +                gas_info.gas_stats = GasStats::new(&gas_info.records); | 
|  | 100 | +            } | 
|  | 101 | +        } | 
|  | 102 | +    } | 
|  | 103 | +} | 
|  | 104 | + | 
|  | 105 | +fn get_contract_name(contracts_data: &ContractsDataStore, class_hash: ClassHash) -> ContractName { | 
|  | 106 | +    contracts_data | 
|  | 107 | +        .get_contract_name(&class_hash) | 
|  | 108 | +        .cloned() | 
|  | 109 | +        .unwrap_or_else(|| DebuggingContractName("forked contract".to_string())) | 
|  | 110 | +        .0 | 
|  | 111 | +        .clone() | 
|  | 112 | +} | 
|  | 113 | + | 
|  | 114 | +fn get_selector(contracts_data: &ContractsDataStore, selector: EntryPointSelector) -> Selector { | 
|  | 115 | +    contracts_data | 
|  | 116 | +        .get_selector(&selector) | 
|  | 117 | +        .expect("`Selector` should be present") | 
|  | 118 | +        .0 | 
|  | 119 | +        .clone() | 
|  | 120 | +} | 
0 commit comments