arb_stylus/
middleware.rs

1use std::{collections::HashMap, sync::RwLock};
2
3use wasmer_compiler::{FunctionMiddleware, MiddlewareReaderState, ModuleMiddleware};
4use wasmer_types::{
5    ExportIndex, FunctionIndex, FunctionType, GlobalIndex, GlobalInit, ImportIndex,
6    LocalFunctionIndex, MiddlewareError, ModuleInfo, SignatureIndex, Type,
7};
8use wasmparser::{BlockType, Operator, ValType};
9
10use crate::meter::{STYLUS_ENTRY_POINT, STYLUS_INK_LEFT, STYLUS_INK_STATUS, STYLUS_STACK_LEFT};
11
12const SCRATCH_GLOBAL: &str = "stylus_scratch_global";
13
14fn mw_err(msg: impl Into<String>) -> MiddlewareError {
15    MiddlewareError::new("stylus", msg.into())
16}
17
18// ── StartMover ──────────────────────────────────────────────────────
19//
20// Renames the WASM start function to "stylus_start" so it doesn't run at
21// module instantiation, then drops all exports except the allowed whitelist.
22// Must run before the metering middleware.
23
24const STYLUS_START: &str = "stylus_start";
25
26#[derive(Debug)]
27pub struct StartMover {
28    debug: bool,
29}
30
31impl StartMover {
32    pub fn new(debug: bool) -> Self {
33        Self { debug }
34    }
35}
36
37impl ModuleMiddleware for StartMover {
38    fn transform_module_info(&self, info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
39        let exports_before = info.exports.len();
40
41        let had_start = if let Some(start) = info.start_function.take() {
42            if info.exports.contains_key(STYLUS_START) {
43                return Err(mw_err(format!("function {STYLUS_START} already exists")));
44            }
45            info.exports
46                .insert(STYLUS_START.to_owned(), ExportIndex::Function(start));
47            info.function_names.insert(start, STYLUS_START.to_owned());
48            true
49        } else {
50            false
51        };
52
53        if had_start && !self.debug {
54            return Err(mw_err("start functions not allowed"));
55        }
56
57        if !self.debug {
58            // Drop all exports except the whitelist (entry point, start, memory).
59            info.exports.retain(|name, export| match name.as_str() {
60                STYLUS_ENTRY_POINT => matches!(export, ExportIndex::Function(_)),
61                STYLUS_START => matches!(export, ExportIndex::Function(_)),
62                "memory" => matches!(export, ExportIndex::Memory(_)),
63                _ => false,
64            });
65            info.function_names.clear();
66        }
67        tracing::debug!(target: "stylus",
68            had_start, exports_before, exports_after = info.exports.len(),
69            "StartMover applied");
70        Ok(())
71    }
72
73    fn generate_function_middleware<'a>(
74        &self,
75        _: LocalFunctionIndex,
76    ) -> Box<dyn FunctionMiddleware<'a> + 'a> {
77        Box::new(NoopFunctionMiddleware)
78    }
79}
80
81#[derive(Debug)]
82struct NoopFunctionMiddleware;
83
84impl<'a> FunctionMiddleware<'a> for NoopFunctionMiddleware {
85    fn feed(
86        &mut self,
87        op: Operator<'a>,
88        state: &mut MiddlewareReaderState<'a>,
89    ) -> Result<(), MiddlewareError> {
90        // SAFETY: Operator variants we encounter contain no borrowed data we keep.
91        let op_static = unsafe { std::mem::transmute::<Operator<'a>, Operator<'static>>(op) };
92        state.push_operator(op_static);
93        Ok(())
94    }
95}
96
97// ── InkMeter ────────────────────────────────────────────────────────
98
99#[derive(Debug)]
100pub struct InkMeter {
101    header_cost: u64,
102    globals: RwLock<Option<[GlobalIndex; 2]>>,
103    sigs: RwLock<HashMap<u32, usize>>,
104}
105
106impl InkMeter {
107    pub fn new(header_cost: u64) -> Self {
108        Self {
109            header_cost,
110            globals: RwLock::new(None),
111            sigs: RwLock::new(HashMap::new()),
112        }
113    }
114
115    fn globals(&self) -> [GlobalIndex; 2] {
116        self.globals
117            .read()
118            .expect("ink globals lock poisoned")
119            .expect("missing ink globals")
120    }
121}
122
123impl ModuleMiddleware for InkMeter {
124    fn transform_module_info(&self, info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
125        let ink_ty = wasmer_types::GlobalType::new(Type::I64, wasmer_types::Mutability::Var);
126        let status_ty = wasmer_types::GlobalType::new(Type::I32, wasmer_types::Mutability::Var);
127
128        let ink_idx = info.globals.push(ink_ty);
129        let status_idx = info.globals.push(status_ty);
130        info.global_initializers.push(GlobalInit::I64Const(0));
131        info.global_initializers.push(GlobalInit::I32Const(0));
132
133        info.exports.insert(
134            STYLUS_INK_LEFT.to_string(),
135            wasmer_types::ExportIndex::Global(ink_idx),
136        );
137        info.exports.insert(
138            STYLUS_INK_STATUS.to_string(),
139            wasmer_types::ExportIndex::Global(status_idx),
140        );
141
142        let mut sig_map = self.sigs.write().expect("ink sigs lock poisoned");
143        for (sig_idx, sig) in info.signatures.iter() {
144            sig_map.insert(sig_idx.as_u32(), sig.params().len());
145        }
146
147        *self.globals.write().expect("ink globals lock poisoned") = Some([ink_idx, status_idx]);
148        Ok(())
149    }
150
151    fn generate_function_middleware<'a>(
152        &self,
153        _: LocalFunctionIndex,
154    ) -> Box<dyn FunctionMiddleware<'a> + 'a> {
155        let [ink, status] = self.globals();
156        let sigs = self.sigs.read().expect("ink sigs lock poisoned").clone();
157        Box::new(InkMeterFn {
158            ink_global: ink,
159            status_global: status,
160            block: vec![],
161            block_cost: 0,
162            header_cost: self.header_cost,
163            sigs,
164        })
165    }
166}
167
168#[derive(Debug)]
169struct InkMeterFn {
170    ink_global: GlobalIndex,
171    status_global: GlobalIndex,
172    block: Vec<Operator<'static>>,
173    block_cost: u64,
174    header_cost: u64,
175    sigs: HashMap<u32, usize>,
176}
177
178fn ends_basic_block(op: &Operator) -> bool {
179    use Operator::*;
180    matches!(
181        op,
182        End | Else
183            | Return
184            | Loop { .. }
185            | Br { .. }
186            | BrTable { .. }
187            | BrIf { .. }
188            | If { .. }
189            | Call { .. }
190            | CallIndirect { .. }
191    )
192}
193
194impl<'a> FunctionMiddleware<'a> for InkMeterFn {
195    fn feed(
196        &mut self,
197        op: Operator<'a>,
198        state: &mut MiddlewareReaderState<'a>,
199    ) -> Result<(), MiddlewareError> {
200        let end = ends_basic_block(&op);
201        let op_cost = opcode_ink_cost(&op, &self.sigs);
202        let mut cost = self.block_cost.saturating_add(op_cost);
203        self.block_cost = cost;
204
205        // SAFETY: Operator variants we support contain no borrowed data.
206        // We buffer them as 'static and transmute back when draining.
207        let op_static = unsafe { std::mem::transmute::<Operator<'a>, Operator<'static>>(op) };
208        self.block.push(op_static);
209
210        if end {
211            let ink = self.ink_global.as_u32();
212            let status = self.status_global.as_u32();
213            cost = cost.saturating_add(self.header_cost);
214
215            state.push_operator(Operator::GlobalGet { global_index: ink });
216            state.push_operator(Operator::I64Const { value: cost as i64 });
217            state.push_operator(Operator::I64LtU);
218            state.push_operator(Operator::If {
219                blockty: BlockType::Empty,
220            });
221            state.push_operator(Operator::I32Const { value: 1 });
222            state.push_operator(Operator::GlobalSet {
223                global_index: status,
224            });
225            state.push_operator(Operator::Unreachable);
226            state.push_operator(Operator::End);
227
228            state.push_operator(Operator::GlobalGet { global_index: ink });
229            state.push_operator(Operator::I64Const { value: cost as i64 });
230            state.push_operator(Operator::I64Sub);
231            state.push_operator(Operator::GlobalSet { global_index: ink });
232
233            for buffered in self.block.drain(..) {
234                let op_a =
235                    unsafe { std::mem::transmute::<Operator<'static>, Operator<'a>>(buffered) };
236                state.push_operator(op_a);
237            }
238            self.block_cost = 0;
239        }
240        Ok(())
241    }
242}
243
244// ── DynamicMeter ────────────────────────────────────────────────────
245
246#[derive(Debug)]
247pub struct DynamicMeter {
248    memory_fill_ink: u64,
249    memory_copy_ink: u64,
250    globals: RwLock<Option<[GlobalIndex; 3]>>,
251}
252
253impl DynamicMeter {
254    pub fn new(memory_fill_ink: u64, memory_copy_ink: u64) -> Self {
255        Self {
256            memory_fill_ink,
257            memory_copy_ink,
258            globals: RwLock::new(None),
259        }
260    }
261}
262
263impl ModuleMiddleware for DynamicMeter {
264    fn transform_module_info(&self, info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
265        let ink_idx = info
266            .exports
267            .get(STYLUS_INK_LEFT)
268            .and_then(|e| match e {
269                wasmer_types::ExportIndex::Global(g) => Some(*g),
270                _ => None,
271            })
272            .ok_or_else(|| mw_err("ink global not found"))?;
273
274        let status_idx = info
275            .exports
276            .get(STYLUS_INK_STATUS)
277            .and_then(|e| match e {
278                wasmer_types::ExportIndex::Global(g) => Some(*g),
279                _ => None,
280            })
281            .ok_or_else(|| mw_err("ink status global not found"))?;
282
283        let scratch_ty = wasmer_types::GlobalType::new(Type::I32, wasmer_types::Mutability::Var);
284        let scratch_idx = info.globals.push(scratch_ty);
285        info.global_initializers.push(GlobalInit::I32Const(0));
286        info.exports.insert(
287            SCRATCH_GLOBAL.to_string(),
288            wasmer_types::ExportIndex::Global(scratch_idx),
289        );
290
291        *self.globals.write().expect("dynamic meter lock poisoned") =
292            Some([ink_idx, status_idx, scratch_idx]);
293        Ok(())
294    }
295
296    fn generate_function_middleware<'a>(
297        &self,
298        _: LocalFunctionIndex,
299    ) -> Box<dyn FunctionMiddleware<'a> + 'a> {
300        let globals = self
301            .globals
302            .read()
303            .expect("dynamic meter lock poisoned")
304            .expect("missing dynamic globals");
305        Box::new(DynamicMeterFn {
306            memory_fill_ink: self.memory_fill_ink,
307            memory_copy_ink: self.memory_copy_ink,
308            globals,
309        })
310    }
311}
312
313#[derive(Debug)]
314struct DynamicMeterFn {
315    memory_fill_ink: u64,
316    memory_copy_ink: u64,
317    globals: [GlobalIndex; 3],
318}
319
320impl<'a> FunctionMiddleware<'a> for DynamicMeterFn {
321    fn feed(
322        &mut self,
323        op: Operator<'a>,
324        state: &mut MiddlewareReaderState<'a>,
325    ) -> Result<(), MiddlewareError> {
326        use Operator::*;
327
328        let [ink, status, scratch] = self.globals.map(|x| x.as_u32());
329        let blockty = BlockType::Empty;
330
331        let coefficient = match &op {
332            MemoryFill { .. } => Some(self.memory_fill_ink as i64),
333            MemoryCopy { .. } => Some(self.memory_copy_ink as i64),
334            _ => None,
335        };
336
337        if let Some(coeff) = coefficient {
338            // Stack has [dest, val/src, size]. Save size to scratch, compute cost,
339            // subtract from ink with overflow check, restore size.
340            state.extend([
341                GlobalSet {
342                    global_index: scratch,
343                },
344                GlobalGet { global_index: ink },
345                GlobalGet { global_index: ink },
346                GlobalGet {
347                    global_index: scratch,
348                },
349                I64ExtendI32U,
350                I64Const { value: coeff },
351                I64Mul,
352                I64Sub,
353                GlobalSet { global_index: ink },
354                GlobalGet { global_index: ink },
355                I64LtU,
356                If { blockty },
357                I32Const { value: 1 },
358                GlobalSet {
359                    global_index: status,
360                },
361                Unreachable,
362                End,
363                GlobalGet {
364                    global_index: scratch,
365                },
366            ]);
367        }
368
369        state.push_operator(op);
370        Ok(())
371    }
372}
373
374// ── DepthChecker ────────────────────────────────────────────────────
375
376type FuncMap = HashMap<FunctionIndex, FunctionType>;
377type SigMap = HashMap<SignatureIndex, FunctionType>;
378
379#[derive(Debug)]
380pub struct DepthChecker {
381    frame_limit: u32,
382    frame_contention: u16,
383    global: RwLock<Option<GlobalIndex>>,
384    funcs: RwLock<Option<FuncMap>>,
385    sigs: RwLock<Option<SigMap>>,
386}
387
388impl DepthChecker {
389    pub fn new(frame_limit: u32, frame_contention: u16) -> Self {
390        Self {
391            frame_limit,
392            frame_contention,
393            global: RwLock::new(None),
394            funcs: RwLock::new(None),
395            sigs: RwLock::new(None),
396        }
397    }
398}
399
400impl ModuleMiddleware for DepthChecker {
401    fn transform_module_info(&self, info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
402        let ty = wasmer_types::GlobalType::new(Type::I32, wasmer_types::Mutability::Var);
403        let idx = info.globals.push(ty);
404        info.global_initializers.push(GlobalInit::I32Const(0));
405        info.exports.insert(
406            STYLUS_STACK_LEFT.to_string(),
407            wasmer_types::ExportIndex::Global(idx),
408        );
409
410        let mut funcs = HashMap::new();
411        for (func_idx, sig_idx) in info.functions.iter() {
412            if let Some(sig) = info.signatures.get(*sig_idx) {
413                funcs.insert(func_idx, sig.clone());
414            }
415        }
416        let mut sigs = HashMap::new();
417        for (sig_idx, sig) in info.signatures.iter() {
418            sigs.insert(sig_idx, sig.clone());
419        }
420
421        *self.global.write().expect("depth checker lock poisoned") = Some(idx);
422        *self.funcs.write().expect("depth checker lock poisoned") = Some(funcs);
423        *self.sigs.write().expect("depth checker lock poisoned") = Some(sigs);
424        Ok(())
425    }
426
427    fn generate_function_middleware<'a>(
428        &self,
429        _: LocalFunctionIndex,
430    ) -> Box<dyn FunctionMiddleware<'a> + 'a> {
431        let g = self
432            .global
433            .read()
434            .expect("depth checker lock poisoned")
435            .expect("missing depth global");
436        let funcs = self
437            .funcs
438            .read()
439            .expect("depth checker lock poisoned")
440            .clone()
441            .expect("missing funcs");
442        let sigs = self
443            .sigs
444            .read()
445            .expect("depth checker lock poisoned")
446            .clone()
447            .expect("missing sigs");
448        Box::new(DepthCheckerFn {
449            global: g,
450            funcs,
451            sigs,
452            locals: None,
453            frame_limit: self.frame_limit,
454            frame_contention: self.frame_contention,
455            scopes: 1,
456            code: vec![],
457            done: false,
458        })
459    }
460}
461
462#[derive(Debug)]
463struct DepthCheckerFn {
464    global: GlobalIndex,
465    funcs: FuncMap,
466    sigs: SigMap,
467    locals: Option<usize>,
468    frame_limit: u32,
469    frame_contention: u16,
470    scopes: isize,
471    code: Vec<Operator<'static>>,
472    done: bool,
473}
474
475impl DepthCheckerFn {
476    #[rustfmt::skip]
477    fn worst_case_depth(&self) -> Result<u32, MiddlewareError> {
478        use Operator::*;
479
480        let mut worst: u32 = 0;
481        let mut stack: u32 = 0;
482
483        macro_rules! push {
484            ($count:expr) => {{ stack += $count; worst = worst.max(stack); }};
485            () => { push!(1) };
486        }
487        macro_rules! pop {
488            ($count:expr) => {{ stack = stack.saturating_sub($count); }};
489            () => { pop!(1) };
490        }
491        macro_rules! ins_and_outs {
492            ($ty:expr) => {{
493                let ins = $ty.params().len() as u32;
494                let outs = $ty.results().len() as u32;
495                push!(outs);
496                pop!(ins);
497            }};
498        }
499        macro_rules! op {
500            ($first:ident $(,$opcode:ident)* $(,)?) => {
501                $first $(| $opcode)*
502            };
503        }
504        macro_rules! dot {
505            ($first:ident $(,$opcode:ident)* $(,)?) => {
506                $first { .. } $(| $opcode { .. })*
507            };
508        }
509        macro_rules! block_type {
510            ($ty:expr) => {{
511                match $ty {
512                    BlockType::Empty => {}
513                    BlockType::Type(_) => push!(1),
514                    BlockType::FuncType(id) => {
515                        let index = SignatureIndex::from_u32(*id);
516                        let Some(ty) = self.sigs.get(&index) else {
517                            return Err(mw_err(format!("missing type for func {id}")));
518                        };
519                        ins_and_outs!(ty);
520                    }
521                }
522            }};
523        }
524
525        let mut scopes = vec![stack];
526
527        for op in &self.code {
528            match op {
529                Block { blockty } => {
530                    block_type!(blockty);
531                    scopes.push(stack);
532                }
533                Loop { blockty } => {
534                    block_type!(blockty);
535                    scopes.push(stack);
536                }
537                If { blockty } => {
538                    pop!();
539                    block_type!(blockty);
540                    scopes.push(stack);
541                }
542                Else => {
543                    stack = match scopes.last() {
544                        Some(scope) => *scope,
545                        None => return Err(mw_err("malformed if-else scope")),
546                    };
547                }
548                End => {
549                    stack = match scopes.pop() {
550                        Some(stack) => stack,
551                        None => return Err(mw_err("malformed scoping at end of block")),
552                    };
553                }
554
555                Call { function_index } => {
556                    let index = FunctionIndex::from_u32(*function_index);
557                    let Some(ty) = self.funcs.get(&index) else {
558                        return Err(mw_err(format!("missing type for func {function_index}")));
559                    };
560                    ins_and_outs!(ty)
561                }
562                CallIndirect { type_index, .. } => {
563                    let index = SignatureIndex::from_u32(*type_index);
564                    let Some(ty) = self.sigs.get(&index) else {
565                        return Err(mw_err(format!("missing type for signature {type_index}")));
566                    };
567                    ins_and_outs!(ty);
568                    pop!() // table index
569                }
570
571                MemoryFill { .. } | MemoryCopy { .. } => pop!(3), // 3 args, 0 returns
572
573                op!(
574                    Nop, Unreachable,
575                    I32Eqz, I64Eqz, I32Clz, I32Ctz, I32Popcnt, I64Clz, I64Ctz, I64Popcnt,
576                )
577                | dot!(
578                    Br, Return,
579                    LocalTee, MemoryGrow,
580                    I32Load, I64Load, F32Load, F64Load,
581                    I32Load8S, I32Load8U, I32Load16S, I32Load16U, I64Load8S, I64Load8U,
582                    I64Load16S, I64Load16U, I64Load32S, I64Load32U,
583                    I32WrapI64, I64ExtendI32S, I64ExtendI32U,
584                    I32Extend8S, I32Extend16S, I64Extend8S, I64Extend16S, I64Extend32S,
585                    F32Abs, F32Neg, F32Ceil, F32Floor, F32Trunc, F32Nearest, F32Sqrt,
586                    F64Abs, F64Neg, F64Ceil, F64Floor, F64Trunc, F64Nearest, F64Sqrt,
587                    I32TruncF32S, I32TruncF32U, I32TruncF64S, I32TruncF64U,
588                    I64TruncF32S, I64TruncF32U, I64TruncF64S, I64TruncF64U,
589                    F32ConvertI32S, F32ConvertI32U, F32ConvertI64S, F32ConvertI64U, F32DemoteF64,
590                    F64ConvertI32S, F64ConvertI32U, F64ConvertI64S, F64ConvertI64U, F64PromoteF32,
591                    I32ReinterpretF32, I64ReinterpretF64, F32ReinterpretI32, F64ReinterpretI64,
592                    I32TruncSatF32S, I32TruncSatF32U, I32TruncSatF64S, I32TruncSatF64U,
593                    I64TruncSatF32S, I64TruncSatF32U, I64TruncSatF64S, I64TruncSatF64U,
594                ) => {}
595
596                dot!(
597                    LocalGet, GlobalGet, MemorySize,
598                    I32Const, I64Const, F32Const, F64Const,
599                ) => push!(),
600
601                op!(
602                    Drop,
603                    I32Eq, I32Ne, I32LtS, I32LtU, I32GtS, I32GtU, I32LeS, I32LeU, I32GeS, I32GeU,
604                    I64Eq, I64Ne, I64LtS, I64LtU, I64GtS, I64GtU, I64LeS, I64LeU, I64GeS, I64GeU,
605                    F32Eq, F32Ne, F32Lt, F32Gt, F32Le, F32Ge,
606                    F64Eq, F64Ne, F64Lt, F64Gt, F64Le, F64Ge,
607                    I32Add, I32Sub, I32Mul, I32DivS, I32DivU, I32RemS, I32RemU,
608                    I64Add, I64Sub, I64Mul, I64DivS, I64DivU, I64RemS, I64RemU,
609                    I32And, I32Or, I32Xor, I32Shl, I32ShrS, I32ShrU, I32Rotl, I32Rotr,
610                    I64And, I64Or, I64Xor, I64Shl, I64ShrS, I64ShrU, I64Rotl, I64Rotr,
611                    F32Add, F32Sub, F32Mul, F32Div, F32Min, F32Max, F32Copysign,
612                    F64Add, F64Sub, F64Mul, F64Div, F64Min, F64Max, F64Copysign,
613                )
614                | dot!(BrIf, BrTable, LocalSet, GlobalSet) => pop!(),
615
616                dot!(
617                    Select,
618                    I32Store, I64Store, F32Store, F64Store,
619                    I32Store8, I32Store16, I64Store8, I64Store16, I64Store32,
620                ) => pop!(2),
621
622                unsupported @ dot!(Try, Catch, Throw, Rethrow, ThrowRef, TryTable) => {
623                    return Err(mw_err(format!("exception-handling not supported {unsupported:?}")));
624                }
625                unsupported @ dot!(ReturnCall, ReturnCallIndirect) => {
626                    return Err(mw_err(format!("tail-call not supported {unsupported:?}")));
627                }
628                unsupported @ dot!(CallRef, ReturnCallRef) => {
629                    return Err(mw_err(format!("typed function references not supported {unsupported:?}")));
630                }
631                unsupported @ (dot!(Delegate) | op!(CatchAll)) => {
632                    return Err(mw_err(format!("exception-handling not supported {unsupported:?}")));
633                }
634                unsupported @ (op!(RefIsNull) | dot!(TypedSelect, RefNull, RefFunc, RefEq)) => {
635                    return Err(mw_err(format!("reference-types not supported {unsupported:?}")));
636                }
637                unsupported @ dot!(RefAsNonNull, BrOnNull, BrOnNonNull) => {
638                    return Err(mw_err(format!("typed function references not supported {unsupported:?}")));
639                }
640                unsupported @ dot!(
641                    MemoryInit, DataDrop, TableInit, ElemDrop,
642                    TableCopy, TableFill, TableGet, TableSet, TableGrow, TableSize
643                ) => {
644                    return Err(mw_err(format!("bulk-memory not fully supported {unsupported:?}")));
645                }
646                unsupported @ dot!(MemoryDiscard) => {
647                    return Err(mw_err(format!("memory discard not supported {unsupported:?}")));
648                }
649                unsupported @ dot!(
650                    StructNew, StructNewDefault, StructGet, StructGetS, StructGetU, StructSet,
651                    ArrayNew, ArrayNewDefault, ArrayNewFixed, ArrayNewData, ArrayNewElem,
652                    ArrayGet, ArrayGetS, ArrayGetU, ArraySet, ArrayLen, ArrayFill, ArrayCopy,
653                    ArrayInitData, ArrayInitElem,
654                    RefTestNonNull, RefTestNullable, RefCastNonNull, RefCastNullable,
655                    BrOnCast, BrOnCastFail, AnyConvertExtern, ExternConvertAny,
656                    RefI31, I31GetS, I31GetU
657                ) => {
658                    return Err(mw_err(format!("GC extension not supported {unsupported:?}")));
659                }
660                unsupported @ dot!(
661                    MemoryAtomicNotify, MemoryAtomicWait32, MemoryAtomicWait64, AtomicFence,
662                    I32AtomicLoad, I64AtomicLoad, I32AtomicLoad8U, I32AtomicLoad16U,
663                    I64AtomicLoad8U, I64AtomicLoad16U, I64AtomicLoad32U,
664                    I32AtomicStore, I64AtomicStore, I32AtomicStore8, I32AtomicStore16,
665                    I64AtomicStore8, I64AtomicStore16, I64AtomicStore32,
666                    I32AtomicRmwAdd, I64AtomicRmwAdd, I32AtomicRmw8AddU, I32AtomicRmw16AddU,
667                    I64AtomicRmw8AddU, I64AtomicRmw16AddU, I64AtomicRmw32AddU,
668                    I32AtomicRmwSub, I64AtomicRmwSub, I32AtomicRmw8SubU, I32AtomicRmw16SubU,
669                    I64AtomicRmw8SubU, I64AtomicRmw16SubU, I64AtomicRmw32SubU,
670                    I32AtomicRmwAnd, I64AtomicRmwAnd, I32AtomicRmw8AndU, I32AtomicRmw16AndU,
671                    I64AtomicRmw8AndU, I64AtomicRmw16AndU, I64AtomicRmw32AndU,
672                    I32AtomicRmwOr, I64AtomicRmwOr, I32AtomicRmw8OrU, I32AtomicRmw16OrU,
673                    I64AtomicRmw8OrU, I64AtomicRmw16OrU, I64AtomicRmw32OrU,
674                    I32AtomicRmwXor, I64AtomicRmwXor, I32AtomicRmw8XorU, I32AtomicRmw16XorU,
675                    I64AtomicRmw8XorU, I64AtomicRmw16XorU, I64AtomicRmw32XorU,
676                    I32AtomicRmwXchg, I64AtomicRmwXchg, I32AtomicRmw8XchgU, I32AtomicRmw16XchgU,
677                    I64AtomicRmw8XchgU, I64AtomicRmw16XchgU, I64AtomicRmw32XchgU,
678                    I32AtomicRmwCmpxchg, I64AtomicRmwCmpxchg, I32AtomicRmw8CmpxchgU,
679                    I32AtomicRmw16CmpxchgU, I64AtomicRmw8CmpxchgU, I64AtomicRmw16CmpxchgU,
680                    I64AtomicRmw32CmpxchgU
681                ) => {
682                    return Err(mw_err(format!("threads extension not supported {unsupported:?}")));
683                }
684                unsupported @ dot!(
685                    V128Load, V128Load8x8S, V128Load8x8U, V128Load16x4S, V128Load16x4U,
686                    V128Load32x2S, V128Load8Splat, V128Load16Splat, V128Load32Splat,
687                    V128Load64Splat, V128Load32Zero, V128Load64Zero, V128Load32x2U,
688                    V128Store, V128Load8Lane, V128Load16Lane, V128Load32Lane, V128Load64Lane,
689                    V128Store8Lane, V128Store16Lane, V128Store32Lane, V128Store64Lane, V128Const,
690                    I8x16Shuffle, I8x16ExtractLaneS, I8x16ExtractLaneU, I8x16ReplaceLane,
691                    I16x8ExtractLaneS, I16x8ExtractLaneU, I16x8ReplaceLane,
692                    I32x4ExtractLane, I32x4ReplaceLane, I64x2ExtractLane, I64x2ReplaceLane,
693                    F32x4ExtractLane, F32x4ReplaceLane, F64x2ExtractLane, F64x2ReplaceLane,
694                    I8x16Swizzle, I8x16Splat, I16x8Splat, I32x4Splat, I64x2Splat,
695                    F32x4Splat, F64x2Splat,
696                    I8x16Eq, I8x16Ne, I8x16LtS, I8x16LtU, I8x16GtS, I8x16GtU,
697                    I8x16LeS, I8x16LeU, I8x16GeS, I8x16GeU,
698                    I16x8Eq, I16x8Ne, I16x8LtS, I16x8LtU, I16x8GtS, I16x8GtU,
699                    I16x8LeS, I16x8LeU, I16x8GeS, I16x8GeU,
700                    I32x4Eq, I32x4Ne, I32x4LtS, I32x4LtU, I32x4GtS, I32x4GtU,
701                    I32x4LeS, I32x4LeU, I32x4GeS, I32x4GeU,
702                    I64x2Eq, I64x2Ne, I64x2LtS, I64x2GtS, I64x2LeS, I64x2GeS,
703                    F32x4Eq, F32x4Ne, F32x4Lt, F32x4Gt, F32x4Le, F32x4Ge,
704                    F64x2Eq, F64x2Ne, F64x2Lt, F64x2Gt, F64x2Le, F64x2Ge,
705                    V128Not, V128And, V128AndNot, V128Or, V128Xor, V128Bitselect, V128AnyTrue,
706                    I8x16Abs, I8x16Neg, I8x16Popcnt, I8x16AllTrue, I8x16Bitmask,
707                    I8x16NarrowI16x8S, I8x16NarrowI16x8U,
708                    I8x16Shl, I8x16ShrS, I8x16ShrU, I8x16Add, I8x16AddSatS, I8x16AddSatU,
709                    I8x16Sub, I8x16SubSatS, I8x16SubSatU, I8x16MinS, I8x16MinU,
710                    I8x16MaxS, I8x16MaxU, I8x16AvgrU,
711                    I16x8ExtAddPairwiseI8x16S, I16x8ExtAddPairwiseI8x16U, I16x8Abs, I16x8Neg,
712                    I16x8Q15MulrSatS, I16x8AllTrue, I16x8Bitmask,
713                    I16x8NarrowI32x4S, I16x8NarrowI32x4U,
714                    I16x8ExtendLowI8x16S, I16x8ExtendHighI8x16S,
715                    I16x8ExtendLowI8x16U, I16x8ExtendHighI8x16U,
716                    I16x8Shl, I16x8ShrS, I16x8ShrU, I16x8Add, I16x8AddSatS, I16x8AddSatU,
717                    I16x8Sub, I16x8SubSatS, I16x8SubSatU, I16x8Mul,
718                    I16x8MinS, I16x8MinU, I16x8MaxS, I16x8MaxU, I16x8AvgrU,
719                    I16x8ExtMulLowI8x16S, I16x8ExtMulHighI8x16S,
720                    I16x8ExtMulLowI8x16U, I16x8ExtMulHighI8x16U,
721                    I32x4ExtAddPairwiseI16x8U, I32x4Abs, I32x4Neg, I32x4AllTrue, I32x4Bitmask,
722                    I32x4ExtAddPairwiseI16x8S,
723                    I32x4ExtendLowI16x8S, I32x4ExtendHighI16x8S,
724                    I32x4ExtendLowI16x8U, I32x4ExtendHighI16x8U,
725                    I32x4Shl, I32x4ShrS, I32x4ShrU, I32x4Add, I32x4Sub, I32x4Mul,
726                    I32x4MinS, I32x4MinU, I32x4MaxS, I32x4MaxU, I32x4DotI16x8S,
727                    I32x4ExtMulLowI16x8S, I32x4ExtMulHighI16x8S,
728                    I32x4ExtMulLowI16x8U, I32x4ExtMulHighI16x8U,
729                    I64x2Abs, I64x2Neg, I64x2AllTrue, I64x2Bitmask,
730                    I64x2ExtendLowI32x4S, I64x2ExtendHighI32x4S,
731                    I64x2ExtendLowI32x4U, I64x2ExtendHighI32x4U,
732                    I64x2Shl, I64x2ShrS, I64x2ShrU, I64x2Add, I64x2Sub, I64x2Mul,
733                    I64x2ExtMulLowI32x4S, I64x2ExtMulHighI32x4S,
734                    I64x2ExtMulLowI32x4U, I64x2ExtMulHighI32x4U,
735                    F32x4Ceil, F32x4Floor, F32x4Trunc, F32x4Nearest,
736                    F32x4Abs, F32x4Neg, F32x4Sqrt, F32x4Add, F32x4Sub, F32x4Mul, F32x4Div,
737                    F32x4Min, F32x4Max, F32x4PMin, F32x4PMax,
738                    F64x2Ceil, F64x2Floor, F64x2Trunc, F64x2Nearest,
739                    F64x2Abs, F64x2Neg, F64x2Sqrt, F64x2Add, F64x2Sub, F64x2Mul, F64x2Div,
740                    F64x2Min, F64x2Max, F64x2PMin, F64x2PMax,
741                    I32x4TruncSatF32x4S, I32x4TruncSatF32x4U,
742                    F32x4ConvertI32x4S, F32x4ConvertI32x4U,
743                    I32x4TruncSatF64x2SZero, I32x4TruncSatF64x2UZero,
744                    F64x2ConvertLowI32x4S, F64x2ConvertLowI32x4U,
745                    F32x4DemoteF64x2Zero, F64x2PromoteLowF32x4,
746                    I8x16RelaxedSwizzle,
747                    I32x4RelaxedTruncF32x4S, I32x4RelaxedTruncF32x4U,
748                    I32x4RelaxedTruncF64x2SZero, I32x4RelaxedTruncF64x2UZero,
749                    F32x4RelaxedMadd, F32x4RelaxedNmadd, F64x2RelaxedMadd, F64x2RelaxedNmadd,
750                    I8x16RelaxedLaneselect, I16x8RelaxedLaneselect,
751                    I32x4RelaxedLaneselect, I64x2RelaxedLaneselect,
752                    F32x4RelaxedMin, F32x4RelaxedMax, F64x2RelaxedMin, F64x2RelaxedMax,
753                    I16x8RelaxedQ15mulrS, I16x8RelaxedDotI8x16I7x16S,
754                    I32x4RelaxedDotI8x16I7x16AddS
755                ) => {
756                    return Err(mw_err(format!("SIMD extension not supported {unsupported:?}")));
757                }
758            };
759        }
760
761        if self.locals.is_none() {
762            return Err(mw_err("missing locals info"));
763        }
764
765        let contention = worst;
766        if contention > self.frame_contention.into() {
767            return Err(mw_err(format!(
768                "too many values on the stack at once: {contention} > {}",
769                self.frame_contention
770            )));
771        }
772
773        let locals = self.locals.unwrap_or_default();
774        Ok(worst + locals as u32 + 4)
775    }
776}
777
778impl<'a> FunctionMiddleware<'a> for DepthCheckerFn {
779    fn locals_info(&mut self, locals: &[ValType]) {
780        self.locals = Some(locals.len());
781    }
782
783    fn feed(
784        &mut self,
785        op: Operator<'a>,
786        state: &mut MiddlewareReaderState<'a>,
787    ) -> Result<(), MiddlewareError> {
788        if self.done {
789            return Err(mw_err("depth checker: feed called after finalization"));
790        }
791
792        match op {
793            Operator::Block { .. } | Operator::Loop { .. } | Operator::If { .. } => {
794                self.scopes += 1;
795            }
796            Operator::End => {
797                self.scopes -= 1;
798            }
799            _ => {}
800        }
801        if self.scopes < 0 {
802            return Err(mw_err("malformed scoping detected"));
803        }
804
805        let last = self.scopes == 0 && matches!(op, Operator::End);
806
807        // SAFETY: Operator variants we support contain no borrowed data.
808        let op_static = unsafe { std::mem::transmute::<Operator<'a>, Operator<'static>>(op) };
809        self.code.push(op_static);
810
811        if !last {
812            return Ok(());
813        }
814
815        let size = self.worst_case_depth()?;
816        let g = self.global.as_u32();
817
818        if size > self.frame_limit {
819            return Err(mw_err(format!(
820                "frame too large: {size} > {}-word limit",
821                self.frame_limit
822            )));
823        }
824
825        // Prologue: check and deduct depth budget
826        state.extend([
827            Operator::GlobalGet { global_index: g },
828            Operator::I32Const { value: size as i32 },
829            Operator::I32LeU,
830            Operator::If {
831                blockty: BlockType::Empty,
832            },
833            Operator::I32Const { value: 0 },
834            Operator::GlobalSet { global_index: g },
835            Operator::Unreachable,
836            Operator::End,
837            Operator::GlobalGet { global_index: g },
838            Operator::I32Const { value: size as i32 },
839            Operator::I32Sub,
840            Operator::GlobalSet { global_index: g },
841        ]);
842
843        // Insert an extraneous Return before the final End to match Arbitrator.
844        let mut code = std::mem::take(&mut self.code);
845        let final_end = code.pop().unwrap();
846        code.push(Operator::Return);
847        code.push(final_end);
848
849        for op_s in code {
850            let is_return = matches!(op_s, Operator::Return);
851            if is_return {
852                state.extend([
853                    Operator::GlobalGet { global_index: g },
854                    Operator::I32Const { value: size as i32 },
855                    Operator::I32Add,
856                    Operator::GlobalSet { global_index: g },
857                ]);
858            }
859            let op_a = unsafe { std::mem::transmute::<Operator<'static>, Operator<'a>>(op_s) };
860            state.push_operator(op_a);
861        }
862
863        self.done = true;
864        Ok(())
865    }
866}
867
868// ── HeapBound ───────────────────────────────────────────────────────
869
870#[derive(Debug)]
871pub struct HeapBound {
872    globals: RwLock<Option<(GlobalIndex, Option<FunctionIndex>)>>,
873}
874
875impl HeapBound {
876    pub fn new() -> Self {
877        Self {
878            globals: RwLock::new(None),
879        }
880    }
881}
882
883impl ModuleMiddleware for HeapBound {
884    fn transform_module_info(&self, info: &mut ModuleInfo) -> Result<(), MiddlewareError> {
885        let scratch_idx = info
886            .exports
887            .get(SCRATCH_GLOBAL)
888            .and_then(|e| match e {
889                wasmer_types::ExportIndex::Global(g) => Some(*g),
890                _ => None,
891            })
892            .ok_or_else(|| mw_err("scratch global not found"))?;
893
894        let pay_func = info.imports.iter().find_map(|(key, idx)| {
895            if key.field == "pay_for_memory_grow" {
896                if let ImportIndex::Function(f) = idx {
897                    return Some(*f);
898                }
899            }
900            None
901        });
902
903        *self.globals.write().expect("heap bound lock poisoned") = Some((scratch_idx, pay_func));
904        Ok(())
905    }
906
907    fn generate_function_middleware<'a>(
908        &self,
909        _: LocalFunctionIndex,
910    ) -> Box<dyn FunctionMiddleware<'a> + 'a> {
911        let (scratch, pay_func) = self
912            .globals
913            .read()
914            .expect("heap bound lock poisoned")
915            .expect("missing heap globals");
916        Box::new(HeapBoundFn { scratch, pay_func })
917    }
918}
919
920#[derive(Debug)]
921struct HeapBoundFn {
922    scratch: GlobalIndex,
923    pay_func: Option<FunctionIndex>,
924}
925
926impl<'a> FunctionMiddleware<'a> for HeapBoundFn {
927    fn feed(
928        &mut self,
929        op: Operator<'a>,
930        state: &mut MiddlewareReaderState<'a>,
931    ) -> Result<(), MiddlewareError> {
932        if let (Operator::MemoryGrow { .. }, Some(pay)) = (&op, self.pay_func) {
933            let g = self.scratch.as_u32();
934            let f = pay.as_u32();
935            state.extend([
936                Operator::GlobalSet { global_index: g },
937                Operator::GlobalGet { global_index: g },
938                Operator::GlobalGet { global_index: g },
939                Operator::Call { function_index: f },
940            ]);
941        }
942        state.push_operator(op);
943        Ok(())
944    }
945}
946
947// ── Opcode ink costs ────────────────────────────────────────────────
948
949/// Per-opcode ink cost used by the ink meter middleware.
950#[rustfmt::skip]
951pub fn opcode_ink_cost(op: &Operator, sigs: &HashMap<u32, usize>) -> u64 {
952    use Operator::*;
953
954    macro_rules! op {
955        ($first:ident $(,$opcode:ident)*) => { $first $(| $opcode)* };
956    }
957    macro_rules! dot {
958        ($first:ident $(,$opcode:ident)*) => { $first { .. } $(| $opcode { .. })* };
959    }
960
961    match op {
962        op!(Unreachable, Return) => 1,
963        op!(Nop) | dot!(I32Const, I64Const) => 1,
964        op!(Drop) => 9,
965
966        dot!(Block, Loop) | op!(Else, End) => 1,
967        dot!(Br, BrIf, If) => 765,
968        dot!(Select) => 1250,
969        dot!(Call) => 3800,
970        dot!(LocalGet, LocalTee) => 75,
971        dot!(LocalSet) => 210,
972        dot!(GlobalGet) => 225,
973        dot!(GlobalSet) => 575,
974        dot!(I32Load, I32Load8S, I32Load8U, I32Load16S, I32Load16U) => 670,
975        dot!(I64Load, I64Load8S, I64Load8U, I64Load16S, I64Load16U, I64Load32S, I64Load32U) => 680,
976        dot!(I32Store, I32Store8, I32Store16) => 825,
977        dot!(I64Store, I64Store8, I64Store16, I64Store32) => 950,
978        dot!(MemorySize) => 3000,
979        dot!(MemoryGrow) => 8050,
980
981        op!(I32Eqz, I32Eq, I32Ne, I32LtS, I32LtU, I32GtS, I32GtU, I32LeS, I32LeU, I32GeS, I32GeU) => 170,
982        op!(I64Eqz, I64Eq, I64Ne, I64LtS, I64LtU, I64GtS, I64GtU, I64LeS, I64LeU, I64GeS, I64GeU) => 225,
983
984        op!(I32Clz, I32Ctz) => 210,
985        op!(I32Add, I32Sub) => 70,
986        op!(I32Mul) => 160,
987        op!(I32DivS, I32DivU, I32RemS, I32RemU) => 1120,
988        op!(I32And, I32Or, I32Xor, I32Shl, I32ShrS, I32ShrU, I32Rotl, I32Rotr) => 70,
989
990        op!(I64Clz, I64Ctz) => 210,
991        op!(I64Add, I64Sub) => 100,
992        op!(I64Mul) => 160,
993        op!(I64DivS, I64DivU, I64RemS, I64RemU) => 1270,
994        op!(I64And, I64Or, I64Xor, I64Shl, I64ShrS, I64ShrU, I64Rotl, I64Rotr) => 100,
995
996        op!(I32Popcnt) => 2650,
997        op!(I64Popcnt) => 6000,
998
999        op!(I32WrapI64, I64ExtendI32S, I64ExtendI32U) => 100,
1000        op!(I32Extend8S, I32Extend16S, I64Extend8S, I64Extend16S, I64Extend32S) => 100,
1001        dot!(MemoryCopy) => 950,
1002        dot!(MemoryFill) => 950,
1003
1004        BrTable { targets } => 2400 + 325 * targets.len() as u64,
1005        CallIndirect { type_index, .. } => {
1006            let params = sigs.get(type_index).copied().unwrap_or(0);
1007            13610 + 650 * params as u64
1008        },
1009
1010        _ => u64::MAX,
1011    }
1012}