1use std::sync::{Arc, OnceLock};
7
8use alloy_consensus::BlockHeader;
9use alloy_primitives::B256;
10use alloy_rpc_types_eth::BlockNumberOrTag;
11use base64::{
12 alphabet,
13 engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig},
14 Engine as _,
15};
16use jsonrpsee::core::RpcResult;
17use parking_lot::RwLock;
18use reth_provider::{BlockNumReader, BlockReaderIdExt, HeaderProvider};
19use tracing::{debug, info};
20
21use crate::{
22 block_producer::{BlockProducer, BlockProductionInput},
23 nitro_execution::{
24 NitroExecutionApiServer, RpcConsensusSyncData, RpcFinalityData, RpcMaintenanceStatus,
25 RpcMessageResult, RpcMessageWithMetadata, RpcMessageWithMetadataAndBlockInfo,
26 },
27};
28
29#[derive(Debug, Default)]
31pub struct NitroExecutionState {
32 pub synced: bool,
34 pub max_message_count: u64,
36}
37
38pub struct NitroExecutionHandler<Provider, BP> {
43 provider: Provider,
44 block_producer: Arc<BP>,
45 state: Arc<RwLock<NitroExecutionState>>,
46 genesis_block_num: u64,
48}
49
50impl<Provider, BP> NitroExecutionHandler<Provider, BP> {
51 pub fn new(provider: Provider, block_producer: Arc<BP>, genesis_block_num: u64) -> Self {
53 Self {
54 provider,
55 block_producer,
56 state: Arc::new(RwLock::new(NitroExecutionState::default())),
57 genesis_block_num,
58 }
59 }
60
61 fn message_index_to_block_number(&self, msg_idx: u64) -> u64 {
63 self.genesis_block_num + msg_idx
64 }
65
66 fn block_number_to_message_index(&self, block_num: u64) -> Option<u64> {
68 if block_num < self.genesis_block_num {
69 return None;
70 }
71 Some(block_num - self.genesis_block_num)
72 }
73}
74
75impl<Provider, BP> NitroExecutionHandler<Provider, BP>
76where
77 Provider: BlockReaderIdExt + HeaderProvider,
78{
79 fn get_header(
81 &self,
82 block_num: u64,
83 ) -> Result<
84 Option<reth_primitives_traits::SealedHeader<<Provider as HeaderProvider>::Header>>,
85 String,
86 > {
87 self.provider
88 .sealed_header_by_number_or_tag(BlockNumberOrTag::Number(block_num))
89 .map_err(|e| e.to_string())
90 }
91
92 fn send_root_from_header(header: &impl BlockHeader) -> B256 {
94 let extra = header.extra_data();
95 if extra.len() >= 32 {
96 B256::from_slice(&extra[..32])
97 } else {
98 B256::ZERO
99 }
100 }
101}
102
103fn internal_error(msg: impl Into<String>) -> jsonrpsee::types::ErrorObjectOwned {
104 jsonrpsee::types::ErrorObject::owned(
105 jsonrpsee::types::error::INTERNAL_ERROR_CODE,
106 msg.into(),
107 None::<()>,
108 )
109}
110
111fn decode_l2_msg(l2_msg: &Option<String>) -> Result<Vec<u8>, String> {
116 match l2_msg {
117 Some(s) if !s.is_empty() => base64_decode(s).map_err(|e| format!("base64 decode: {e}")),
118 _ => Ok(vec![]),
119 }
120}
121
122const STANDARD_ALPHABET: &[u8; 64] =
123 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
124
125fn base64_engine() -> &'static GeneralPurpose {
126 static ENGINE: OnceLock<GeneralPurpose> = OnceLock::new();
127 ENGINE.get_or_init(|| {
128 let cfg = GeneralPurposeConfig::new()
129 .with_decode_padding_mode(DecodePaddingMode::Indifferent)
130 .with_decode_allow_trailing_bits(true);
131 GeneralPurpose::new(&alphabet::STANDARD, cfg)
132 })
133}
134
135fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
136 let stripped = input.trim_end_matches('=');
137 let body_len = stripped.len() & !3;
140 let tail = &stripped.as_bytes()[body_len..];
141 let body = if tail.len() == 1 {
142 let b = tail[0];
143 if !STANDARD_ALPHABET.contains(&b) {
144 return Err(format!("invalid base64 character: {}", b as char));
145 }
146 &stripped[..body_len]
147 } else {
148 stripped
149 };
150 base64_engine()
151 .decode(body)
152 .map_err(|e| format!("invalid base64: {e}"))
153}
154
155#[async_trait::async_trait]
156impl<Provider, BP> NitroExecutionApiServer for NitroExecutionHandler<Provider, BP>
157where
158 Provider: BlockNumReader + BlockReaderIdExt + HeaderProvider + 'static,
159 BP: BlockProducer,
160{
161 async fn digest_message(
162 &self,
163 msg_idx: u64,
164 message: RpcMessageWithMetadata,
165 _message_for_prefetch: Option<RpcMessageWithMetadata>,
166 ) -> RpcResult<RpcMessageResult> {
167 let block_num = self.message_index_to_block_number(msg_idx);
168 let kind = message.message.header.kind;
169 info!(target: "nitroexecution", msg_idx, block_num, kind, "digestMessage called");
170
171 if kind == 11 {
175 let l2_msg = decode_l2_msg(&message.message.l2_msg).map_err(internal_error)?;
176 self.block_producer
177 .cache_init_message(&l2_msg)
178 .map_err(|e| internal_error(e.to_string()))?;
179
180 let genesis_header = self
182 .get_header(self.genesis_block_num)
183 .map_err(internal_error)?
184 .ok_or_else(|| internal_error("Genesis block not found for Init message"))?;
185 let send_root = Self::send_root_from_header(genesis_header.header());
186 info!(target: "nitroexecution", "Init message cached, returning genesis block");
187 return Ok(RpcMessageResult {
188 block_hash: genesis_header.hash(),
189 send_root,
190 });
191 }
192
193 if let Some(header) = self.get_header(block_num).map_err(internal_error)? {
195 let send_root = Self::send_root_from_header(header.header());
196 debug!(target: "nitroexecution", block_num, ?send_root, "Block already exists");
197 return Ok(RpcMessageResult {
198 block_hash: header.hash(),
199 send_root,
200 });
201 }
202
203 let l2_msg = decode_l2_msg(&message.message.l2_msg).map_err(internal_error)?;
205
206 let batch_data_stats = message
208 .message
209 .batch_data_tokens
210 .as_ref()
211 .map(|s| (s.length, s.nonzeros));
212
213 let input = BlockProductionInput {
215 kind,
216 sender: message.message.header.sender,
217 l1_block_number: message.message.header.block_number,
218 l1_timestamp: message.message.header.timestamp,
219 request_id: message.message.header.request_id,
220 l1_base_fee: message.message.header.base_fee_l1,
221 l2_msg,
222 delayed_messages_read: message.delayed_messages_read,
223 batch_gas_cost: message.message.batch_gas_cost,
224 batch_data_stats,
225 };
226
227 let result = self
229 .block_producer
230 .produce_block(msg_idx, input)
231 .await
232 .map_err(|e| internal_error(e.to_string()))?;
233
234 Ok(RpcMessageResult {
235 block_hash: result.block_hash,
236 send_root: result.send_root,
237 })
238 }
239
240 async fn reorg(
241 &self,
242 msg_idx_of_first_msg_to_add: u64,
243 new_messages: Vec<RpcMessageWithMetadataAndBlockInfo>,
244 _old_messages: Vec<RpcMessageWithMetadata>,
245 ) -> RpcResult<Vec<RpcMessageResult>> {
246 info!(
247 target: "nitroexecution",
248 msg_idx_of_first_msg_to_add,
249 new_msgs = new_messages.len(),
250 "reorg"
251 );
252
253 let target_block = msg_idx_of_first_msg_to_add
258 .saturating_sub(1)
259 .saturating_add(self.genesis_block_num);
260
261 self.block_producer
262 .reset_to_block(target_block)
263 .await
264 .map_err(|e| internal_error(format!("reset_to_block: {e}")))?;
265
266 let mut results = Vec::with_capacity(new_messages.len());
268 for (i, wrapped) in new_messages.into_iter().enumerate() {
269 let msg_idx = msg_idx_of_first_msg_to_add + i as u64;
270 let meta = wrapped.message;
271 let l2_msg = decode_l2_msg(&meta.message.l2_msg).map_err(internal_error)?;
272 let batch_data_stats = meta
273 .message
274 .batch_data_tokens
275 .as_ref()
276 .map(|s| (s.length, s.nonzeros));
277 let input = BlockProductionInput {
278 kind: meta.message.header.kind,
279 sender: meta.message.header.sender,
280 l1_block_number: meta.message.header.block_number,
281 l1_timestamp: meta.message.header.timestamp,
282 request_id: meta.message.header.request_id,
283 l1_base_fee: meta.message.header.base_fee_l1,
284 l2_msg,
285 delayed_messages_read: meta.delayed_messages_read,
286 batch_gas_cost: meta.message.batch_gas_cost,
287 batch_data_stats,
288 };
289 let produced = self
290 .block_producer
291 .produce_block(msg_idx, input)
292 .await
293 .map_err(|e| internal_error(format!("reorg replay msg {msg_idx}: {e}")))?;
294 results.push(RpcMessageResult {
295 block_hash: produced.block_hash,
296 send_root: produced.send_root,
297 });
298 }
299 Ok(results)
300 }
301
302 async fn head_message_index(&self) -> RpcResult<u64> {
303 let best = self
304 .provider
305 .best_block_number()
306 .map_err(|e| internal_error(e.to_string()))?;
307
308 let msg_idx = self.block_number_to_message_index(best).unwrap_or(0);
309 debug!(target: "nitroexecution", best, msg_idx, "headMessageIndex");
310 Ok(msg_idx)
311 }
312
313 async fn result_at_message_index(&self, msg_idx: u64) -> RpcResult<RpcMessageResult> {
314 let block_num = self.message_index_to_block_number(msg_idx);
315
316 let header = self
317 .get_header(block_num)
318 .map_err(internal_error)?
319 .ok_or_else(|| internal_error(format!("Block {block_num} not found")))?;
320
321 let send_root = Self::send_root_from_header(header.header());
322
323 Ok(RpcMessageResult {
324 block_hash: header.hash(),
325 send_root,
326 })
327 }
328
329 fn set_finality_data(
330 &self,
331 safe: Option<RpcFinalityData>,
332 finalized: Option<RpcFinalityData>,
333 validated: Option<RpcFinalityData>,
334 ) -> RpcResult<()> {
335 debug!(target: "nitroexecution", ?safe, ?finalized, ?validated, "setFinalityData");
336 self.block_producer
337 .set_finality(
338 safe.map(|f| f.block_hash),
339 finalized.map(|f| f.block_hash),
340 validated.map(|f| f.block_hash),
341 )
342 .map_err(|e| internal_error(format!("set_finality: {e}")))?;
343 Ok(())
344 }
345
346 fn set_consensus_sync_data(&self, sync_data: RpcConsensusSyncData) -> RpcResult<()> {
347 let mut state = self.state.write();
348 state.synced = sync_data.synced;
349 state.max_message_count = sync_data.max_message_count;
350 debug!(target: "nitroexecution", synced = sync_data.synced, max = sync_data.max_message_count, "setConsensusSyncData");
351 Ok(())
352 }
353
354 fn mark_feed_start(&self, to: u64) -> RpcResult<()> {
355 debug!(target: "nitroexecution", to, "markFeedStart");
356 Ok(())
357 }
358
359 async fn trigger_maintenance(&self) -> RpcResult<()> {
360 Ok(())
361 }
362
363 async fn should_trigger_maintenance(&self) -> RpcResult<bool> {
364 Ok(false)
365 }
366
367 async fn maintenance_status(&self) -> RpcResult<RpcMaintenanceStatus> {
368 Ok(RpcMaintenanceStatus { is_running: false })
369 }
370
371 async fn arbos_version_for_message_index(&self, msg_idx: u64) -> RpcResult<u64> {
372 let block_num = self.message_index_to_block_number(msg_idx);
373
374 let header = self
375 .get_header(block_num)
376 .map_err(internal_error)?
377 .ok_or_else(|| internal_error(format!("Block {block_num} not found")))?;
378
379 let mix = header.header().mix_hash().unwrap_or_default();
380 let arbos_version = u64::from_be_bytes(mix.0[16..24].try_into().unwrap_or_default());
381
382 Ok(arbos_version)
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use base64::engine::general_purpose::STANDARD as B64;
390
391 #[test]
392 fn decode_empty_option_is_ok() {
393 assert_eq!(decode_l2_msg(&None).unwrap(), Vec::<u8>::new());
394 assert_eq!(
395 decode_l2_msg(&Some(String::new())).unwrap(),
396 Vec::<u8>::new()
397 );
398 }
399
400 #[test]
401 fn decode_standard_padded() {
402 let encoded = B64.encode(b"Hello, world!");
403 let out = base64_decode(&encoded).unwrap();
404 assert_eq!(out, b"Hello, world!");
405 }
406
407 #[test]
408 fn decode_accepts_unpadded() {
409 let encoded = B64.encode(b"Hello");
410 let stripped = encoded.trim_end_matches('=').to_string();
411 assert_eq!(base64_decode(&stripped).unwrap(), b"Hello");
412 }
413
414 #[test]
415 fn decode_accepts_extra_padding() {
416 assert_eq!(base64_decode("SGVsbG8==").unwrap(), b"Hello");
417 assert_eq!(base64_decode("SGVsbG8====").unwrap(), b"Hello");
418 }
419
420 #[test]
421 fn decode_rejects_invalid_character() {
422 assert!(base64_decode("SG!X").is_err());
423 assert!(base64_decode("a b").is_err());
424 assert!(base64_decode("hello world").is_err());
425 }
426
427 #[test]
428 fn decode_rejects_padding_in_body() {
429 assert!(base64_decode("=SGVs").is_err());
430 assert!(base64_decode("SGVs=bG8").is_err());
431 }
432
433 #[test]
434 fn decode_large_payload_matches_roundtrip() {
435 let bytes: Vec<u8> = (0..32 * 1024).map(|i| (i * 7 + 3) as u8).collect();
436 let encoded = B64.encode(&bytes);
437 assert_eq!(base64_decode(&encoded).unwrap(), bytes);
438 }
439
440 #[test]
441 fn decode_preserves_lenient_padding_tail() {
442 assert_eq!(base64_decode("S").unwrap(), Vec::<u8>::new());
443 assert_eq!(base64_decode("SG").unwrap(), vec![b'H']);
444 assert_eq!(base64_decode("SGV").unwrap(), vec![b'H', b'e']);
445 }
446
447 #[test]
448 fn decode_length_one_tail_validates_alphabet() {
449 assert!(base64_decode("ABCD!").is_err());
450 }
451
452 #[test]
453 fn decode_all_alphabet_characters() {
454 let out = base64_decode("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/")
455 .unwrap();
456 assert_eq!(out.len(), 48);
457 }
458}