1use alloy_primitives::{Address, B256, U256};
2use revm::Database;
3
4use arb_storage::{Storage, StorageBackedAddress, StorageBackedUint64};
5
6pub struct AddressSet<D> {
11 backing_storage: Storage<D>,
12 size: StorageBackedUint64<D>,
13 by_address: Storage<D>,
14}
15
16pub fn initialize_address_set<D: Database>(sto: &Storage<D>) -> Result<(), ()> {
17 sto.set_by_uint64(0, B256::ZERO)
18}
19
20pub fn open_address_set<D: Database>(sto: Storage<D>) -> AddressSet<D> {
21 let size = StorageBackedUint64::new(sto.state_ptr(), sto.base_key(), 0);
22 let by_address = sto.open_sub_storage(&[0u8]);
23 AddressSet {
24 backing_storage: sto,
25 size,
26 by_address,
27 }
28}
29
30impl<D: Database> AddressSet<D> {
31 pub fn size(&self) -> Result<u64, ()> {
32 self.size.get()
33 }
34
35 pub fn is_member(&self, addr: Address) -> Result<bool, ()> {
36 let addr_hash = address_to_hash(addr);
37 let value = self.by_address.get(addr_hash)?;
38 Ok(value != B256::ZERO)
39 }
40
41 pub fn get_any_member(&self) -> Result<Option<Address>, ()> {
42 let size = self.size.get()?;
43 if size == 0 {
44 return Ok(None);
45 }
46 let sba = StorageBackedAddress::new(
47 self.backing_storage.state_ptr(),
48 self.backing_storage.base_key(),
49 1,
50 );
51 sba.get().map(Some)
52 }
53
54 pub fn clear(&self) -> Result<(), ()> {
55 let size = self.size.get()?;
56 if size == 0 {
57 return Ok(());
58 }
59 for i in 1..=size {
60 let contents = self.backing_storage.get_by_uint64(i)?;
61 self.backing_storage.set_by_uint64(i, B256::ZERO)?;
62 self.by_address.set(contents, B256::ZERO)?;
63 }
64 self.size.set(0)
65 }
66
67 pub fn all_members(&self, max_num: u64) -> Result<Vec<Address>, ()> {
68 let mut size = self.size.get()?;
69 if size > max_num {
70 size = max_num;
71 }
72 let mut ret = Vec::with_capacity(size as usize);
73 for i in 0..size {
74 let sba = StorageBackedAddress::new(
75 self.backing_storage.state_ptr(),
76 self.backing_storage.base_key(),
77 i + 1,
78 );
79 ret.push(sba.get()?);
80 }
81 Ok(ret)
82 }
83
84 pub fn clear_list(&self) -> Result<(), ()> {
85 let size = self.size.get()?;
86 if size == 0 {
87 return Ok(());
88 }
89 for i in 1..=size {
90 self.backing_storage.set_by_uint64(i, B256::ZERO)?;
91 }
92 self.size.set(0)
93 }
94
95 pub fn rectify_mapping(&self, addr: Address) -> Result<(), ()> {
96 let is_owner = self.is_member(addr)?;
97 if !is_owner {
98 return Err(());
99 }
100
101 let addr_as_hash = address_to_hash(addr);
102 let slot = hash_to_uint64(self.by_address.get(addr_as_hash)?);
103 let at_slot = self.backing_storage.get_by_uint64(slot)?;
104 let size = self.size.get()?;
105
106 if at_slot == addr_as_hash && slot <= size {
107 return Err(());
108 }
109
110 self.by_address.set(addr_as_hash, B256::ZERO)?;
111 self.add(addr)
112 }
113
114 pub fn add(&self, addr: Address) -> Result<(), ()> {
115 let present = self.is_member(addr)?;
116 if present {
117 return Ok(());
118 }
119
120 let size = self.size.get()?;
121 let slot = uint_to_hash(1 + size);
122 let addr_as_hash = address_to_hash(addr);
123
124 self.by_address.set(addr_as_hash, slot)?;
125
126 let sba = StorageBackedAddress::new(
127 self.backing_storage.state_ptr(),
128 self.backing_storage.base_key(),
129 1 + size,
130 );
131 sba.set(addr)?;
132
133 self.size.set(size + 1)
134 }
135
136 pub fn remove(&self, addr: Address, arbos_version: u64) -> Result<(), ()> {
137 let addr_as_hash = address_to_hash(addr);
138 let slot_hash = self.by_address.get(addr_as_hash)?;
139 let slot = hash_to_uint64(slot_hash);
140
141 if slot == 0 {
142 return Ok(());
143 }
144
145 self.by_address.set(addr_as_hash, B256::ZERO)?;
146
147 let size = self.size.get()?;
148 if slot < size {
149 let at_size = self.backing_storage.get_by_uint64(size)?;
150 self.backing_storage.set_by_uint64(slot, at_size)?;
151
152 if arbos_version >= 11 {
153 self.by_address.set(at_size, uint_to_hash(slot))?;
154 }
155 }
156
157 self.backing_storage.set_by_uint64(size, B256::ZERO)?;
158 self.size.set(size - 1)
159 }
160}
161
162fn address_to_hash(addr: Address) -> B256 {
163 let mut bytes = [0u8; 32];
164 bytes[12..32].copy_from_slice(addr.as_slice());
165 B256::from(bytes)
166}
167
168fn uint_to_hash(val: u64) -> B256 {
169 B256::from(U256::from(val))
170}
171
172fn hash_to_uint64(hash: B256) -> u64 {
173 U256::from_be_bytes(hash.0).to::<u64>()
174}