arbos/merkle_accumulator/
mod.rs

1use alloy_primitives::{keccak256, B256};
2use revm::Database;
3
4use arb_storage::{Storage, StorageBackedUint64};
5
6/// Event emitted when a Merkle tree node is updated during append.
7#[derive(Debug, Clone)]
8pub struct MerkleTreeNodeEvent {
9    pub level: u64,
10    pub num_leaves: u64,
11    pub hash: B256,
12}
13
14/// Storage-backed Merkle accumulator.
15pub struct MerkleAccumulator<D> {
16    backing_storage: Storage<D>,
17    size: StorageBackedUint64<D>,
18}
19
20pub fn initialize_merkle_accumulator<D: Database>(_sto: &Storage<D>) {
21    // no-op
22}
23
24pub fn open_merkle_accumulator<D: Database>(sto: Storage<D>) -> MerkleAccumulator<D> {
25    let size = StorageBackedUint64::new(sto.state_ptr(), sto.base_key(), 0);
26    MerkleAccumulator {
27        backing_storage: sto,
28        size,
29    }
30}
31
32/// Returns the number of partial tree hashes needed for a given size.
33/// This is the bit-length of `size` (i.e. floor(log2(size)) + 1).
34pub fn calc_num_partials(size: u64) -> u64 {
35    if size == 0 {
36        return 0;
37    }
38    64 - size.leading_zeros() as u64
39}
40
41impl<D: Database> MerkleAccumulator<D> {
42    fn get_partial(&self, level: u64) -> Result<B256, ()> {
43        self.backing_storage.get_by_uint64(2 + level)
44    }
45
46    fn set_partial(&self, level: u64, val: B256) -> Result<(), ()> {
47        self.backing_storage.set_by_uint64(2 + level, val)
48    }
49
50    pub fn append(&self, item_hash: B256) -> Result<Vec<MerkleTreeNodeEvent>, ()> {
51        let current_size = self.size.get()?;
52        let new_size = current_size + 1;
53        self.size.set(new_size)?;
54
55        let mut events = Vec::new();
56        let mut level = 0u64;
57        let mut so_far = keccak256(item_hash.as_slice());
58
59        loop {
60            if level == calc_num_partials(current_size) {
61                self.set_partial(level, so_far)?;
62                return Ok(events);
63            }
64
65            let this_level = self.get_partial(level)?;
66            if this_level == B256::ZERO {
67                self.set_partial(level, so_far)?;
68                return Ok(events);
69            }
70
71            let mut combined = Vec::with_capacity(64);
72            combined.extend_from_slice(this_level.as_slice());
73            combined.extend_from_slice(so_far.as_slice());
74            so_far = keccak256(&combined);
75
76            self.set_partial(level, B256::ZERO)?;
77
78            level += 1;
79            events.push(MerkleTreeNodeEvent {
80                level,
81                num_leaves: new_size - 1,
82                hash: so_far,
83            });
84        }
85    }
86
87    pub fn size(&self) -> Result<u64, ()> {
88        self.size.get()
89    }
90
91    pub fn root(&self) -> Result<B256, ()> {
92        let size = self.size.get()?;
93        if size == 0 {
94            return Ok(B256::ZERO);
95        }
96
97        let mut hash_so_far: Option<B256> = None;
98        let mut capacity_in_hash = 0u64;
99        let mut capacity = 1u64;
100
101        for level in 0..calc_num_partials(size) {
102            let partial = self.get_partial(level)?;
103            if partial != B256::ZERO {
104                if let Some(ref mut current) = hash_so_far {
105                    while capacity_in_hash < capacity {
106                        let mut combined = Vec::with_capacity(64);
107                        combined.extend_from_slice(current.as_slice());
108                        combined.extend_from_slice(&[0u8; 32]);
109                        *current = keccak256(&combined);
110                        capacity_in_hash *= 2;
111                    }
112
113                    let mut combined = Vec::with_capacity(64);
114                    combined.extend_from_slice(partial.as_slice());
115                    combined.extend_from_slice(current.as_slice());
116                    *current = keccak256(&combined);
117                    capacity_in_hash = 2 * capacity;
118                } else {
119                    hash_so_far = Some(partial);
120                    capacity_in_hash = capacity;
121                }
122            }
123            capacity *= 2;
124        }
125
126        Ok(hash_so_far.unwrap_or(B256::ZERO))
127    }
128
129    pub fn get_partials(&self) -> Result<Vec<B256>, ()> {
130        let size = self.size.get()?;
131        let num = calc_num_partials(size);
132        let mut partials = Vec::with_capacity(num as usize);
133        for i in 0..num {
134            partials.push(self.get_partial(i)?);
135        }
136        Ok(partials)
137    }
138
139    pub fn state_for_export(&self) -> Result<(u64, B256, Vec<B256>), ()> {
140        let root = self.root()?;
141        let size = self.size.get()?;
142        let partials = self.get_partials()?;
143        Ok((size, root, partials))
144    }
145}
146
147/// In-memory (non-persistent) Merkle accumulator for export/import and testing.
148pub struct InMemoryMerkleAccumulator {
149    size: u64,
150    partials: Vec<B256>,
151}
152
153impl InMemoryMerkleAccumulator {
154    pub fn new() -> Self {
155        Self {
156            size: 0,
157            partials: Vec::new(),
158        }
159    }
160
161    pub fn from_partials(partials: Vec<B256>) -> Self {
162        let mut size = 0u64;
163        let mut level_size = 1u64;
164        for p in &partials {
165            if *p != B256::ZERO {
166                size += level_size;
167            }
168            level_size *= 2;
169        }
170        Self { size, partials }
171    }
172
173    pub fn size(&self) -> u64 {
174        self.size
175    }
176
177    fn get_partial(&self, level: u64) -> B256 {
178        self.partials
179            .get(level as usize)
180            .copied()
181            .unwrap_or(B256::ZERO)
182    }
183
184    fn set_partial(&mut self, level: u64, val: B256) {
185        let idx = level as usize;
186        if idx >= self.partials.len() {
187            self.partials.resize(idx + 1, B256::ZERO);
188        }
189        self.partials[idx] = val;
190    }
191
192    pub fn append(&mut self, item_hash: B256) -> Vec<MerkleTreeNodeEvent> {
193        let current_size = self.size;
194        self.size += 1;
195        let new_size = self.size;
196
197        let mut events = Vec::new();
198        let mut level = 0u64;
199        let mut so_far = keccak256(item_hash.as_slice());
200
201        loop {
202            if level == calc_num_partials(current_size) {
203                self.set_partial(level, so_far);
204                return events;
205            }
206
207            let this_level = self.get_partial(level);
208            if this_level == B256::ZERO {
209                self.set_partial(level, so_far);
210                return events;
211            }
212
213            let mut combined = Vec::with_capacity(64);
214            combined.extend_from_slice(this_level.as_slice());
215            combined.extend_from_slice(so_far.as_slice());
216            so_far = keccak256(&combined);
217
218            self.set_partial(level, B256::ZERO);
219
220            level += 1;
221            events.push(MerkleTreeNodeEvent {
222                level,
223                num_leaves: new_size - 1,
224                hash: so_far,
225            });
226        }
227    }
228
229    pub fn root(&self) -> B256 {
230        if self.size == 0 {
231            return B256::ZERO;
232        }
233
234        let mut hash_so_far: Option<B256> = None;
235        let mut capacity_in_hash = 0u64;
236        let mut capacity = 1u64;
237
238        for level in 0..calc_num_partials(self.size) {
239            let partial = self.get_partial(level);
240            if partial != B256::ZERO {
241                if let Some(ref mut current) = hash_so_far {
242                    while capacity_in_hash < capacity {
243                        let mut combined = Vec::with_capacity(64);
244                        combined.extend_from_slice(current.as_slice());
245                        combined.extend_from_slice(&[0u8; 32]);
246                        *current = keccak256(&combined);
247                        capacity_in_hash *= 2;
248                    }
249
250                    let mut combined = Vec::with_capacity(64);
251                    combined.extend_from_slice(partial.as_slice());
252                    combined.extend_from_slice(current.as_slice());
253                    *current = keccak256(&combined);
254                    capacity_in_hash = 2 * capacity;
255                } else {
256                    hash_so_far = Some(partial);
257                    capacity_in_hash = capacity;
258                }
259            }
260            capacity *= 2;
261        }
262
263        hash_so_far.unwrap_or(B256::ZERO)
264    }
265
266    pub fn partials(&self) -> &[B256] {
267        &self.partials
268    }
269}
270
271impl Default for InMemoryMerkleAccumulator {
272    fn default() -> Self {
273        Self::new()
274    }
275}