arb_stylus/
middleware.rs

1use std::{collections::HashMap, sync::RwLock};
2
3use wasmer_compiler::{FunctionMiddleware, MiddlewareReaderState, ModuleMiddleware};
4use wasmer_types::{
5    FunctionIndex, GlobalIndex, GlobalInit, ImportIndex, LocalFunctionIndex, MiddlewareError,
6    ModuleInfo, Type,
7};
8use wasmparser::{BlockType, Operator};
9
10use crate::meter::{STYLUS_INK_LEFT, STYLUS_INK_STATUS, STYLUS_STACK_LEFT};
11
12const SCRATCH_GLOBAL: &str = "stylus_scratch_global";
13
14fn mw_err(msg: impl Into<String>) -> MiddlewareError {
15    MiddlewareError::new("stylus", msg.into())
16}
17
18// ── InkMeter ────────────────────────────────────────────────────────
19
20#[derive(Debug)]
21pub struct InkMeter {
22    header_cost: u64,
23    globals: RwLock<Option<[GlobalIndex; 2]>>,
24    sigs: RwLock<HashMap<u32, usize>>,
25}
26
27impl InkMeter {
28    pub fn new(header_cost: u64) -> Self {
29        Self {
30            header_cost,
31            globals: RwLock::new(None),
32            sigs: RwLock::new(HashMap::new()),
33        }
34    }
35
36    fn globals(&self) -> [GlobalIndex; 2] {
37        self.globals
38            .read()
39            .expect("ink globals lock poisoned")
40            .expect("missing ink globals")
41    }
42}
43
44impl ModuleMiddleware for InkMeter {
45    fn transform_module_info(&self, info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
46        let ink_ty = wasmer_types::GlobalType::new(Type::I64, wasmer_types::Mutability::Var);
47        let status_ty = wasmer_types::GlobalType::new(Type::I32, wasmer_types::Mutability::Var);
48
49        let ink_idx = info.globals.push(ink_ty);
50        let status_idx = info.globals.push(status_ty);
51        info.global_initializers.push(GlobalInit::I64Const(0));
52        info.global_initializers.push(GlobalInit::I32Const(0));
53
54        info.exports.insert(
55            STYLUS_INK_LEFT.to_string(),
56            wasmer_types::ExportIndex::Global(ink_idx),
57        );
58        info.exports.insert(
59            STYLUS_INK_STATUS.to_string(),
60            wasmer_types::ExportIndex::Global(status_idx),
61        );
62
63        let mut sig_map = self.sigs.write().expect("ink sigs lock poisoned");
64        for (sig_idx, sig) in info.signatures.iter() {
65            sig_map.insert(sig_idx.as_u32(), sig.params().len());
66        }
67
68        *self.globals.write().expect("ink globals lock poisoned") = Some([ink_idx, status_idx]);
69        Ok(())
70    }
71
72    fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
73        let [ink, status] = self.globals();
74        let sigs = self.sigs.read().expect("ink sigs lock poisoned").clone();
75        Box::new(InkMeterFn {
76            ink_global: ink,
77            status_global: status,
78            block: vec![],
79            block_cost: 0,
80            header_cost: self.header_cost,
81            sigs,
82        })
83    }
84}
85
86#[derive(Debug)]
87struct InkMeterFn {
88    ink_global: GlobalIndex,
89    status_global: GlobalIndex,
90    block: Vec<Operator<'static>>,
91    block_cost: u64,
92    header_cost: u64,
93    sigs: HashMap<u32, usize>,
94}
95
96fn ends_basic_block(op: &Operator) -> bool {
97    use Operator::*;
98    matches!(
99        op,
100        End | Else
101            | Return
102            | Loop { .. }
103            | Br { .. }
104            | BrTable { .. }
105            | BrIf { .. }
106            | If { .. }
107            | Call { .. }
108            | CallIndirect { .. }
109    )
110}
111
112impl FunctionMiddleware for InkMeterFn {
113    fn feed<'a>(
114        &mut self,
115        op: Operator<'a>,
116        state: &mut MiddlewareReaderState<'a>,
117    ) -> Result<(), MiddlewareError> {
118        let end = ends_basic_block(&op);
119        let op_cost = opcode_ink_cost(&op, &self.sigs);
120        let mut cost = self.block_cost.saturating_add(op_cost);
121        self.block_cost = cost;
122
123        // SAFETY: Operator variants we support contain no borrowed data.
124        // We buffer them as 'static and transmute back when draining.
125        let op_static = unsafe { std::mem::transmute::<Operator<'a>, Operator<'static>>(op) };
126        self.block.push(op_static);
127
128        if end {
129            let ink = self.ink_global.as_u32();
130            let status = self.status_global.as_u32();
131            cost = cost.saturating_add(self.header_cost);
132
133            state.push_operator(Operator::GlobalGet { global_index: ink });
134            state.push_operator(Operator::I64Const { value: cost as i64 });
135            state.push_operator(Operator::I64LtU);
136            state.push_operator(Operator::If {
137                blockty: BlockType::Empty,
138            });
139            state.push_operator(Operator::I32Const { value: 1 });
140            state.push_operator(Operator::GlobalSet {
141                global_index: status,
142            });
143            state.push_operator(Operator::Unreachable);
144            state.push_operator(Operator::End);
145
146            state.push_operator(Operator::GlobalGet { global_index: ink });
147            state.push_operator(Operator::I64Const { value: cost as i64 });
148            state.push_operator(Operator::I64Sub);
149            state.push_operator(Operator::GlobalSet { global_index: ink });
150
151            for buffered in self.block.drain(..) {
152                let op_a =
153                    unsafe { std::mem::transmute::<Operator<'static>, Operator<'a>>(buffered) };
154                state.push_operator(op_a);
155            }
156            self.block_cost = 0;
157        }
158        Ok(())
159    }
160}
161
162// ── DynamicMeter ────────────────────────────────────────────────────
163
164#[derive(Debug)]
165pub struct DynamicMeter {
166    memory_fill_ink: u64,
167    memory_copy_ink: u64,
168    globals: RwLock<Option<[GlobalIndex; 3]>>,
169}
170
171impl DynamicMeter {
172    pub fn new(memory_fill_ink: u64, memory_copy_ink: u64) -> Self {
173        Self {
174            memory_fill_ink,
175            memory_copy_ink,
176            globals: RwLock::new(None),
177        }
178    }
179}
180
181impl ModuleMiddleware for DynamicMeter {
182    fn transform_module_info(&self, info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
183        let ink_idx = info
184            .exports
185            .get(STYLUS_INK_LEFT)
186            .and_then(|e| match e {
187                wasmer_types::ExportIndex::Global(g) => Some(*g),
188                _ => None,
189            })
190            .ok_or_else(|| mw_err("ink global not found"))?;
191
192        let status_idx = info
193            .exports
194            .get(STYLUS_INK_STATUS)
195            .and_then(|e| match e {
196                wasmer_types::ExportIndex::Global(g) => Some(*g),
197                _ => None,
198            })
199            .ok_or_else(|| mw_err("ink status global not found"))?;
200
201        let scratch_ty = wasmer_types::GlobalType::new(Type::I32, wasmer_types::Mutability::Var);
202        let scratch_idx = info.globals.push(scratch_ty);
203        info.global_initializers.push(GlobalInit::I32Const(0));
204        info.exports.insert(
205            SCRATCH_GLOBAL.to_string(),
206            wasmer_types::ExportIndex::Global(scratch_idx),
207        );
208
209        *self.globals.write().expect("dynamic meter lock poisoned") =
210            Some([ink_idx, status_idx, scratch_idx]);
211        Ok(())
212    }
213
214    fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
215        let globals = self
216            .globals
217            .read()
218            .expect("dynamic meter lock poisoned")
219            .expect("missing dynamic globals");
220        Box::new(DynamicMeterFn {
221            memory_fill_ink: self.memory_fill_ink,
222            memory_copy_ink: self.memory_copy_ink,
223            globals,
224        })
225    }
226}
227
228#[derive(Debug)]
229struct DynamicMeterFn {
230    memory_fill_ink: u64,
231    memory_copy_ink: u64,
232    globals: [GlobalIndex; 3],
233}
234
235impl FunctionMiddleware for DynamicMeterFn {
236    fn feed<'a>(
237        &mut self,
238        op: Operator<'a>,
239        state: &mut MiddlewareReaderState<'a>,
240    ) -> Result<(), MiddlewareError> {
241        use Operator::*;
242
243        let [ink, status, scratch] = self.globals.map(|x| x.as_u32());
244        let blockty = BlockType::Empty;
245
246        let coefficient = match &op {
247            MemoryFill { .. } => Some(self.memory_fill_ink as i64),
248            MemoryCopy { .. } => Some(self.memory_copy_ink as i64),
249            _ => None,
250        };
251
252        if let Some(coeff) = coefficient {
253            // Stack has [dest, val/src, size]. Save size to scratch, compute cost,
254            // subtract from ink with overflow check, restore size.
255            state.extend([
256                GlobalSet {
257                    global_index: scratch,
258                },
259                GlobalGet { global_index: ink },
260                GlobalGet { global_index: ink },
261                GlobalGet {
262                    global_index: scratch,
263                },
264                I64ExtendI32U,
265                I64Const { value: coeff },
266                I64Mul,
267                I64Sub,
268                GlobalSet { global_index: ink },
269                GlobalGet { global_index: ink },
270                I64LtU,
271                If { blockty },
272                I32Const { value: 1 },
273                GlobalSet {
274                    global_index: status,
275                },
276                Unreachable,
277                End,
278                GlobalGet {
279                    global_index: scratch,
280                },
281            ]);
282        }
283
284        state.push_operator(op);
285        Ok(())
286    }
287}
288
289// ── DepthChecker ────────────────────────────────────────────────────
290
291#[derive(Debug)]
292pub struct DepthChecker {
293    max_depth: u32,
294    global: RwLock<Option<GlobalIndex>>,
295}
296
297impl DepthChecker {
298    pub fn new(max_depth: u32) -> Self {
299        Self {
300            max_depth,
301            global: RwLock::new(None),
302        }
303    }
304}
305
306impl ModuleMiddleware for DepthChecker {
307    fn transform_module_info(&self, info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
308        let ty = wasmer_types::GlobalType::new(Type::I32, wasmer_types::Mutability::Var);
309        let idx = info.globals.push(ty);
310        info.global_initializers
311            .push(GlobalInit::I32Const(self.max_depth as i32));
312        info.exports.insert(
313            STYLUS_STACK_LEFT.to_string(),
314            wasmer_types::ExportIndex::Global(idx),
315        );
316        *self.global.write().expect("depth checker lock poisoned") = Some(idx);
317        Ok(())
318    }
319
320    fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
321        let g = self
322            .global
323            .read()
324            .expect("depth checker lock poisoned")
325            .expect("missing depth global");
326        Box::new(DepthCheckerFn {
327            global: g,
328            frame_cost: 1,
329            emitted_entry: false,
330        })
331    }
332}
333
334#[derive(Debug)]
335struct DepthCheckerFn {
336    global: GlobalIndex,
337    frame_cost: u32,
338    emitted_entry: bool,
339}
340
341impl FunctionMiddleware for DepthCheckerFn {
342    fn feed<'a>(
343        &mut self,
344        op: Operator<'a>,
345        state: &mut MiddlewareReaderState<'a>,
346    ) -> Result<(), MiddlewareError> {
347        if !self.emitted_entry {
348            self.emitted_entry = true;
349            let g = self.global.as_u32();
350            let cost = self.frame_cost as i32;
351
352            state.extend([
353                Operator::GlobalGet { global_index: g },
354                Operator::I32Const { value: cost },
355                Operator::I32LeU,
356                Operator::If {
357                    blockty: BlockType::Empty,
358                },
359                Operator::Unreachable,
360                Operator::End,
361                Operator::GlobalGet { global_index: g },
362                Operator::I32Const { value: cost },
363                Operator::I32Sub,
364                Operator::GlobalSet { global_index: g },
365            ]);
366        }
367
368        if matches!(op, Operator::Return) {
369            let g = self.global.as_u32();
370            state.extend([
371                Operator::GlobalGet { global_index: g },
372                Operator::I32Const {
373                    value: self.frame_cost as i32,
374                },
375                Operator::I32Add,
376                Operator::GlobalSet { global_index: g },
377            ]);
378        }
379
380        state.push_operator(op);
381        Ok(())
382    }
383}
384
385// ── HeapBound ───────────────────────────────────────────────────────
386
387#[derive(Debug)]
388pub struct HeapBound {
389    globals: RwLock<Option<(GlobalIndex, Option<FunctionIndex>)>>,
390}
391
392impl HeapBound {
393    pub fn new() -> Self {
394        Self {
395            globals: RwLock::new(None),
396        }
397    }
398}
399
400impl ModuleMiddleware for HeapBound {
401    fn transform_module_info(&self, info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
402        let scratch_idx = info
403            .exports
404            .get(SCRATCH_GLOBAL)
405            .and_then(|e| match e {
406                wasmer_types::ExportIndex::Global(g) => Some(*g),
407                _ => None,
408            })
409            .ok_or_else(|| mw_err("scratch global not found"))?;
410
411        let pay_func = info.imports.iter().find_map(|(key, idx)| {
412            if key.field == "pay_for_memory_grow" {
413                if let ImportIndex::Function(f) = idx {
414                    return Some(*f);
415                }
416            }
417            None
418        });
419
420        *self.globals.write().expect("heap bound lock poisoned") = Some((scratch_idx, pay_func));
421        Ok(())
422    }
423
424    fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
425        let (scratch, pay_func) = self
426            .globals
427            .read()
428            .expect("heap bound lock poisoned")
429            .expect("missing heap globals");
430        Box::new(HeapBoundFn { scratch, pay_func })
431    }
432}
433
434#[derive(Debug)]
435struct HeapBoundFn {
436    scratch: GlobalIndex,
437    pay_func: Option<FunctionIndex>,
438}
439
440impl FunctionMiddleware for HeapBoundFn {
441    fn feed<'a>(
442        &mut self,
443        op: Operator<'a>,
444        state: &mut MiddlewareReaderState<'a>,
445    ) -> Result<(), MiddlewareError> {
446        if let (Operator::MemoryGrow { .. }, Some(pay)) = (&op, self.pay_func) {
447            let g = self.scratch.as_u32();
448            let f = pay.as_u32();
449            state.extend([
450                Operator::GlobalSet { global_index: g },
451                Operator::GlobalGet { global_index: g },
452                Operator::GlobalGet { global_index: g },
453                Operator::Call { function_index: f },
454            ]);
455        }
456        state.push_operator(op);
457        Ok(())
458    }
459}
460
461// ── Opcode ink costs (matches Nitro pricing_v1) ─────────────────────
462
463#[rustfmt::skip]
464fn opcode_ink_cost(op: &Operator, sigs: &HashMap<u32, usize>) -> u64 {
465    use Operator::*;
466
467    macro_rules! op {
468        ($first:ident $(,$opcode:ident)*) => { $first $(| $opcode)* };
469    }
470    macro_rules! dot {
471        ($first:ident $(,$opcode:ident)*) => { $first { .. } $(| $opcode { .. })* };
472    }
473
474    match op {
475        op!(Unreachable, Return) => 1,
476        op!(Nop) | dot!(I32Const, I64Const) => 1,
477        op!(Drop) => 9,
478
479        dot!(Block, Loop) | op!(Else, End) => 1,
480        dot!(Br, BrIf, If) => 765,
481        dot!(Select) => 1250,
482        dot!(Call) => 3800,
483        dot!(LocalGet, LocalTee) => 75,
484        dot!(LocalSet) => 210,
485        dot!(GlobalGet) => 225,
486        dot!(GlobalSet) => 575,
487        dot!(I32Load, I32Load8S, I32Load8U, I32Load16S, I32Load16U) => 670,
488        dot!(I64Load, I64Load8S, I64Load8U, I64Load16S, I64Load16U, I64Load32S, I64Load32U) => 680,
489        dot!(I32Store, I32Store8, I32Store16) => 825,
490        dot!(I64Store, I64Store8, I64Store16, I64Store32) => 950,
491        dot!(MemorySize) => 3000,
492        dot!(MemoryGrow) => 8050,
493
494        op!(I32Eqz, I32Eq, I32Ne, I32LtS, I32LtU, I32GtS, I32GtU, I32LeS, I32LeU, I32GeS, I32GeU) => 170,
495        op!(I64Eqz, I64Eq, I64Ne, I64LtS, I64LtU, I64GtS, I64GtU, I64LeS, I64LeU, I64GeS, I64GeU) => 225,
496
497        op!(I32Clz, I32Ctz) => 210,
498        op!(I32Add, I32Sub) => 70,
499        op!(I32Mul) => 160,
500        op!(I32DivS, I32DivU, I32RemS, I32RemU) => 1120,
501        op!(I32And, I32Or, I32Xor, I32Shl, I32ShrS, I32ShrU, I32Rotl, I32Rotr) => 70,
502
503        op!(I64Clz, I64Ctz) => 210,
504        op!(I64Add, I64Sub) => 100,
505        op!(I64Mul) => 160,
506        op!(I64DivS, I64DivU, I64RemS, I64RemU) => 1270,
507        op!(I64And, I64Or, I64Xor, I64Shl, I64ShrS, I64ShrU, I64Rotl, I64Rotr) => 100,
508
509        op!(I32Popcnt) => 2650,
510        op!(I64Popcnt) => 6000,
511
512        op!(I32WrapI64, I64ExtendI32S, I64ExtendI32U) => 100,
513        op!(I32Extend8S, I32Extend16S, I64Extend8S, I64Extend16S, I64Extend32S) => 100,
514        dot!(MemoryCopy) => 950,
515        dot!(MemoryFill) => 950,
516
517        BrTable { targets } => 2400 + 325 * targets.len() as u64,
518        CallIndirect { type_index, .. } => {
519            let params = sigs.get(type_index).copied().unwrap_or(0);
520            13610 + 650 * params as u64
521        },
522
523        _ => u64::MAX,
524    }
525}