Skip to main content

dfir_lang/graph/
meta_graph.rs

1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18    DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19    null_write_iterator_fn,
20};
21use super::{
22    CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23    GraphSubgraphId, HANDOFF_NODE_STR, HandoffKind, MODULE_BOUNDARY_NODE_STR, OperatorInstance,
24    PortIndexValue, SINGLETON_SLOT_NODE_STR, Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30/// A resolved singleton reference: the target node ID plus mutability and access group info.
31#[derive(Clone, Debug, Serialize, Deserialize)]
32pub struct ResolvedSingletonRef {
33    /// The resolved target node ID (`None` if unresolved/error).
34    pub node_id: Option<GraphNodeId>,
35    /// Whether this is a mutable reference (`#mut var`).
36    pub is_mut: bool,
37    /// Optional access group for ordering (`#{N} var`).
38    pub access_group: Option<u32>,
39}
40
41/// An abstract "meta graph" representation of a DFIR graph.
42///
43/// Can be with or without subgraph partitioning, stratification, and handoff insertion. This is
44/// the meta graph used for generating Rust source code in macros from DFIR sytnax.
45///
46/// This struct has a lot of methods for manipulating the graph, vaguely grouped together in
47/// separate `impl` blocks. You might notice a few particularly specific arbitray-seeming methods
48/// in here--those are just what was needed for the compilation algorithms. If you need another
49/// method then add it.
50#[derive(Default, Debug, Serialize, Deserialize)]
51pub struct DfirGraph {
52    /// Each node type (operator or handoff).
53    nodes: SlotMap<GraphNodeId, GraphNode>,
54
55    /// Instance data corresponding to each operator node.
56    /// This field will be empty after deserialization.
57    #[serde(skip)]
58    operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
59    /// Debugging/tracing tag for each operator node.
60    operator_tag: SecondaryMap<GraphNodeId, String>,
61    /// Graph data structure (two-way adjacency list).
62    graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
63    /// Input and output port for each edge.
64    ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
65
66    /// Which loop a node belongs to (or none for top-level).
67    node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
68    /// Which nodes belong to each loop.
69    loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
70    /// For the loop, what is its parent (`None` for top-level).
71    loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
72    /// What loops are at the root.
73    root_loops: Vec<GraphLoopId>,
74    /// For the loop, what are its child loops.
75    loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
76
77    /// Which subgraph each node belongs to.
78    node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
79
80    /// Which nodes belong to each subgraph.
81    subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
82
83    /// Resolved singletons varnames references, per node.
84    node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<ResolvedSingletonRef>>,
85    /// What variable name each graph node belongs to (if any). For debugging (graph writing) purposes only.
86    node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
87
88    /// Delay type for handoff nodes that represent tick-boundary back-edges.
89    /// Set by `order_subgraphs` for `defer_tick` / `defer_tick_lazy`, either on handoff nodes
90    /// it injects or on existing handoff nodes that it marks as tick-boundary back-edges.
91    handoff_delay_type: SparseSecondaryMap<GraphNodeId, DelayType>,
92}
93
94/// Basic methods.
95impl DfirGraph {
96    /// Create a new empty graph.
97    pub fn new() -> Self {
98        Default::default()
99    }
100}
101
102/// Node methods.
103impl DfirGraph {
104    /// Get a node with its operator instance (if applicable).
105    pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
106        self.nodes.get(node_id).expect("Node not found.")
107    }
108
109    /// Get the `OperatorInstance` for a given node. Node must be an operator and have an
110    /// `OperatorInstance` present, otherwise will return `None`.
111    ///
112    /// Note that no operator instances will be persent after deserialization.
113    pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
114        self.operator_instances.get(node_id)
115    }
116
117    /// Get the debug variable name attached to a graph node.
118    pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
119        self.node_varnames.get(node_id)
120    }
121
122    /// Get subgraph for node.
123    pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
124        self.node_subgraph.get(node_id).copied()
125    }
126
127    /// Degree into a node, i.e. the number of predecessors.
128    pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
129        self.graph.degree_in(node_id)
130    }
131
132    /// Degree out of a node, i.e. the number of successors.
133    pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
134        self.graph.degree_out(node_id)
135    }
136
137    /// Successors, iterator of `(GraphEdgeId, GraphNodeId)` of outgoing edges.
138    pub fn node_successors(
139        &self,
140        src: GraphNodeId,
141    ) -> impl '_
142    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
143    + ExactSizeIterator
144    + FusedIterator
145    + Clone
146    + Debug {
147        self.graph.successors(src)
148    }
149
150    /// Predecessors, iterator of `(GraphEdgeId, GraphNodeId)` of incoming edges.
151    pub fn node_predecessors(
152        &self,
153        dst: GraphNodeId,
154    ) -> impl '_
155    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
156    + ExactSizeIterator
157    + FusedIterator
158    + Clone
159    + Debug {
160        self.graph.predecessors(dst)
161    }
162
163    /// Successor edges, iterator of `GraphEdgeId` of outgoing edges.
164    pub fn node_successor_edges(
165        &self,
166        src: GraphNodeId,
167    ) -> impl '_
168    + DoubleEndedIterator<Item = GraphEdgeId>
169    + ExactSizeIterator
170    + FusedIterator
171    + Clone
172    + Debug {
173        self.graph.successor_edges(src)
174    }
175
176    /// Predecessor edges, iterator of `GraphEdgeId` of incoming edges.
177    pub fn node_predecessor_edges(
178        &self,
179        dst: GraphNodeId,
180    ) -> impl '_
181    + DoubleEndedIterator<Item = GraphEdgeId>
182    + ExactSizeIterator
183    + FusedIterator
184    + Clone
185    + Debug {
186        self.graph.predecessor_edges(dst)
187    }
188
189    /// Successor nodes, iterator of `GraphNodeId`.
190    pub fn node_successor_nodes(
191        &self,
192        src: GraphNodeId,
193    ) -> impl '_
194    + DoubleEndedIterator<Item = GraphNodeId>
195    + ExactSizeIterator
196    + FusedIterator
197    + Clone
198    + Debug {
199        self.graph.successor_vertices(src)
200    }
201
202    /// Predecessor nodes, iterator of `GraphNodeId`.
203    pub fn node_predecessor_nodes(
204        &self,
205        dst: GraphNodeId,
206    ) -> impl '_
207    + DoubleEndedIterator<Item = GraphNodeId>
208    + ExactSizeIterator
209    + FusedIterator
210    + Clone
211    + Debug {
212        self.graph.predecessor_vertices(dst)
213    }
214
215    /// Iterator of node IDs `GraphNodeId`.
216    pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
217        self.nodes.keys()
218    }
219
220    /// Iterator over `(GraphNodeId, &Node)` pairs.
221    pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
222        self.nodes.iter()
223    }
224
225    /// Insert a node, assigning the given varname.
226    pub fn insert_node(
227        &mut self,
228        node: GraphNode,
229        varname_opt: Option<Ident>,
230        loop_opt: Option<GraphLoopId>,
231    ) -> GraphNodeId {
232        let node_id = self.nodes.insert(node);
233        if let Some(varname) = varname_opt {
234            self.node_varnames.insert(node_id, Varname(varname));
235        }
236        if let Some(loop_id) = loop_opt {
237            self.node_loops.insert(node_id, loop_id);
238            self.loop_nodes[loop_id].push(node_id);
239        }
240        node_id
241    }
242
243    /// Insert an operator instance for the given node. Panics if already set.
244    pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
245        assert!(matches!(
246            self.nodes.get(node_id),
247            Some(GraphNode::Operator(_))
248        ));
249        let old_inst = self.operator_instances.insert(node_id, op_inst);
250        assert!(old_inst.is_none());
251    }
252
253    /// Assign all operator instances if not set. Write diagnostic messages/errors into `diagnostics`.
254    pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
255        // Handle all nodes in two phases, since the helper methods take total ownership of `&self`.
256        // Possible to do in one phase, but would require accessing fields directly for partial mutable ownership.
257
258        // Collect operator instances, then assign.
259        let mut op_insts = Vec::new();
260        // Collect nodes that should be lowered to handoffs (the `handoff()`/`singleton()` pseudo-operators).
261        let mut handoff_nodes: Vec<(GraphNodeId, HandoffKind, Span)> = Vec::new();
262
263        for (node_id, node) in self.nodes() {
264            let GraphNode::Operator(operator) = node else {
265                continue;
266            };
267            if self.node_op_inst(node_id).is_some() {
268                continue;
269            };
270
271            // Recognize `handoff()`/`singleton()` pseudo-operators and lower to GraphNode::Handoff.
272            let handoff_kind = match &*operator.name_string() {
273                "handoff" => Some(HandoffKind::Vec),
274                "singleton" => Some(HandoffKind::Singleton),
275                "optional" => Some(HandoffKind::Optional),
276                _ => None,
277            };
278            if let Some(kind) = handoff_kind {
279                if !operator.args.is_empty() {
280                    diagnostics.push(Diagnostic::spanned(
281                        operator.path.span(),
282                        Level::Error,
283                        format!("`{}` takes no arguments.", operator.name_string()),
284                    ));
285                }
286                if operator.type_arguments().is_some() {
287                    diagnostics.push(Diagnostic::spanned(
288                        operator.path.span(),
289                        Level::Error,
290                        format!("`{}` takes no generic arguments.", operator.name_string()),
291                    ));
292                }
293                handoff_nodes.push((node_id, kind, operator.path.span()));
294                continue;
295            }
296
297            // Op constraints.
298            let Some(op_constraints) = find_op_op_constraints(operator) else {
299                diagnostics.push(Diagnostic::spanned(
300                    operator.path.span(),
301                    Level::Error,
302                    format!("Unknown operator `{}`", operator.name_string()),
303                ));
304                continue;
305            };
306
307            // Input and output ports.
308            let (input_ports, output_ports) = {
309                let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
310                    .node_predecessors(node_id)
311                    .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
312                    .collect();
313                // Ensure sorted by port index.
314                input_edges.sort();
315                let input_ports: Vec<PortIndexValue> = input_edges
316                    .into_iter()
317                    .map(|(port, _pred)| port)
318                    .cloned()
319                    .collect();
320
321                // Collect output arguments (successors).
322                let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
323                    .node_successors(node_id)
324                    .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
325                    .collect();
326                // Ensure sorted by port index.
327                output_edges.sort();
328                let output_ports: Vec<PortIndexValue> = output_edges
329                    .into_iter()
330                    .map(|(port, _succ)| port)
331                    .cloned()
332                    .collect();
333
334                (input_ports, output_ports)
335            };
336
337            // Generic arguments.
338            let generics = get_operator_generics(diagnostics, operator);
339            // Generic argument errors.
340            {
341                // Span of `generic_args` (if it exists), otherwise span of the operator name.
342                let generics_span = generics
343                    .generic_args
344                    .as_ref()
345                    .map(Spanned::span)
346                    .unwrap_or_else(|| operator.path.span());
347
348                if !op_constraints
349                    .persistence_args
350                    .contains(&generics.persistence_args.len())
351                {
352                    diagnostics.push(Diagnostic::spanned(
353                        generics.persistence_args_span().unwrap_or(generics_span),
354                        Level::Error,
355                        format!(
356                            "`{}` should have {} persistence lifetime arguments, actually has {}.",
357                            op_constraints.name,
358                            op_constraints.persistence_args.human_string(),
359                            generics.persistence_args.len()
360                        ),
361                    ));
362                }
363                if !op_constraints.type_args.contains(&generics.type_args.len()) {
364                    diagnostics.push(Diagnostic::spanned(
365                        generics.type_args_span().unwrap_or(generics_span),
366                        Level::Error,
367                        format!(
368                            "`{}` should have {} generic type arguments, actually has {}.",
369                            op_constraints.name,
370                            op_constraints.type_args.human_string(),
371                            generics.type_args.len()
372                        ),
373                    ));
374                }
375            }
376
377            op_insts.push((
378                node_id,
379                OperatorInstance {
380                    op_constraints,
381                    input_ports,
382                    output_ports,
383                    singletons_referenced: operator.singletons_referenced.clone(),
384                    generics,
385                    arguments_pre: operator.args.clone(),
386                    arguments_raw: operator.args_raw.clone(),
387                },
388            ));
389        }
390
391        for (node_id, op_inst) in op_insts {
392            self.insert_node_op_inst(node_id, op_inst);
393        }
394
395        // Replace pseudo-operator nodes with GraphNode::Handoff.
396        for (node_id, kind, span) in handoff_nodes {
397            self.nodes[node_id] = GraphNode::Handoff {
398                kind,
399                src_span: span,
400                dst_span: span,
401            };
402        }
403    }
404
405    /// Inserts a node between two existing nodes connected by the given `edge_id`.
406    ///
407    /// `edge`: (src, dst, dst_idx)
408    ///
409    /// Before: A (src) ------------> B (dst)
410    /// After:  A (src) -> X (new) -> B (dst)
411    ///
412    /// Returns the ID of X & ID of edge OUT of X.
413    ///
414    /// Note that both the edges will be new and `edge_id` will be removed. Both new edges will
415    /// get the edge type of the original edge.
416    pub fn insert_intermediate_node(
417        &mut self,
418        edge_id: GraphEdgeId,
419        new_node: GraphNode,
420    ) -> (GraphNodeId, GraphEdgeId) {
421        let span = Some(new_node.span());
422
423        // Make corresponding operator instance (if `node` is an operator).
424        let op_inst_opt = 'oc: {
425            let GraphNode::Operator(operator) = &new_node else {
426                break 'oc None;
427            };
428            let Some(op_constraints) = find_op_op_constraints(operator) else {
429                break 'oc None;
430            };
431            let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
432
433            let mut dummy_diagnostics = Diagnostics::new();
434            let generics = get_operator_generics(&mut dummy_diagnostics, operator);
435            assert!(dummy_diagnostics.is_empty());
436
437            Some(OperatorInstance {
438                op_constraints,
439                input_ports: vec![input_port],
440                output_ports: vec![output_port],
441                singletons_referenced: operator.singletons_referenced.clone(),
442                generics,
443                arguments_pre: operator.args.clone(),
444                arguments_raw: operator.args_raw.clone(),
445            })
446        };
447
448        // Insert new `node`.
449        let node_id = self.nodes.insert(new_node);
450        // Insert corresponding `OperatorInstance` if applicable.
451        if let Some(op_inst) = op_inst_opt {
452            self.operator_instances.insert(node_id, op_inst);
453        }
454        // Update edges to insert node within `edge_id`.
455        let (e0, e1) = self
456            .graph
457            .insert_intermediate_vertex(node_id, edge_id)
458            .unwrap();
459
460        // Update corresponding ports.
461        let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
462        self.ports
463            .insert(e0, (src_idx, PortIndexValue::Elided(span)));
464        self.ports
465            .insert(e1, (PortIndexValue::Elided(span), dst_idx));
466
467        (node_id, e1)
468    }
469
470    /// Remove the node `node_id` but preserves and connects the single predecessor and single successor.
471    /// Panics if the node does not have exactly one predecessor and one successor, or is not in the graph.
472    pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
473        assert_eq!(
474            1,
475            self.node_degree_in(node_id),
476            "Removed intermediate node must have one predecessor"
477        );
478        assert_eq!(
479            1,
480            self.node_degree_out(node_id),
481            "Removed intermediate node must have one successor"
482        );
483        assert!(
484            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
485            "Should not remove intermediate node after subgraph partitioning"
486        );
487
488        assert!(self.nodes.remove(node_id).is_some());
489        let (new_edge_id, (pred_edge_id, succ_edge_id)) =
490            self.graph.remove_intermediate_vertex(node_id).unwrap();
491        self.operator_instances.remove(node_id);
492        self.node_varnames.remove(node_id);
493
494        let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
495        let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
496        self.ports.insert(new_edge_id, (src_port, dst_port));
497    }
498
499    /// Helper method: determine the "color" (pull vs push) of a node based on its in and out degree,
500    /// excluding reference edges. If linear (1 in, 1 out), color is `None`, indicating it can be
501    /// either push or pull.
502    ///
503    /// Note that this does NOT consider `DelayType` barriers (which generally implies `Pull`).
504    pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
505        if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
506            return Some(Color::Hoff);
507        }
508
509        // TODO(shadaj): this is a horrible hack
510        if let GraphNode::Operator(op) = self.node(node_id)
511            && (op.name_string() == "resolve_futures_blocking"
512                || op.name_string() == "resolve_futures_blocking_ordered")
513        {
514            return Some(Color::Push);
515        }
516
517        // In-degree, excluding ref-edges.
518        let inn_degree = self.node_predecessor_nodes(node_id).len();
519        // Out-degree excluding ref-edges.
520        let out_degree = self.node_successor_nodes(node_id).len();
521
522        match (inn_degree, out_degree) {
523            (0, 0) => None, // Generally should not happen, "Degenerate subgraph detected".
524            (0, 1) => Some(Color::Pull),
525            (1, 0) => Some(Color::Push),
526            (1, 1) => None, // Linear, can be either push or pull.
527            (_many, 0 | 1) => Some(Color::Pull),
528            (0 | 1, _many) => Some(Color::Push),
529            (_many, _to_many) => Some(Color::Comp),
530        }
531    }
532
533    /// Set the operator tag (for debugging/tracing).
534    pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
535        self.operator_tag.insert(node_id, tag);
536    }
537}
538
539/// Singleton references.
540impl DfirGraph {
541    /// Set the singletons referenced for the `node_id` operator. Each reference corresponds to the
542    /// same index in the [`crate::parse::Operator::singletons_referenced`] vec.
543    pub fn set_node_singleton_references(
544        &mut self,
545        node_id: GraphNodeId,
546        singletons_referenced: Vec<ResolvedSingletonRef>,
547    ) -> Option<Vec<ResolvedSingletonRef>> {
548        self.node_singleton_references
549            .insert(node_id, singletons_referenced)
550    }
551
552    /// Gets the singletons referenced by a node. Returns an empty slice for non-operators and
553    /// operators that do not reference singletons.
554    pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[ResolvedSingletonRef] {
555        self.node_singleton_references
556            .get(node_id)
557            .map(std::ops::Deref::deref)
558            .unwrap_or_default()
559    }
560
561    /// Collect all refs, grouped by the singleton they're pointing at, then by the access group idx `Option<u32>`.
562    pub fn node_singleton_reference_groups(&self) -> NodeSingletonReferenceGroups<'_> {
563        let mut singleton_references = NodeSingletonReferenceGroups::new();
564        for node_id in self.node_ids() {
565            if let GraphNode::Operator(operator) = self.node(node_id) {
566                let resolved = self.node_singleton_references(node_id);
567                for (resolved_ref, ref_token) in
568                    resolved.iter().zip(operator.singletons_referenced.iter())
569                {
570                    if let Some(target_nid) = resolved_ref.node_id {
571                        singleton_references
572                            .entry(target_nid)
573                            .or_default()
574                            .entry(resolved_ref.access_group)
575                            .or_default()
576                            .push((node_id, resolved_ref, ref_token.span()));
577                    }
578                }
579            }
580        }
581        singleton_references
582    }
583}
584
585/// Per-node singleton references, in turn grouped by access group.
586/// Map: singleton_node_id -> access_group -> (source `GraphNodeId`, `ResolvedSingletonRef`, `#ref` span)
587pub type NodeSingletonReferenceGroups<'a> = BTreeMap<
588    GraphNodeId,
589    BTreeMap<Option<u32>, Vec<(GraphNodeId, &'a ResolvedSingletonRef, Span)>>,
590>;
591
592/// Module methods.
593impl DfirGraph {
594    /// When modules are imported into a flat graph, they come with an input and output ModuleBoundary node.
595    /// The partitioner doesn't understand these nodes and will panic if it encounters them.
596    /// merge_modules removes them from the graph, stitching the input and ouput sides of the ModuleBondaries based on their ports
597    /// For example:
598    ///     source_iter([]) -> \[myport\]ModuleBoundary(input)\[my_port\] -> map(|x| x) -> ModuleBoundary(output) -> null();
599    /// in the above eaxmple, the \[myport\] port will be used to connect the source_iter with the map that is inside of the module.
600    /// The output module boundary has elided ports, this is also used to match up the input/output across the module boundary.
601    pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
602        let mod_bound_nodes = self
603            .nodes()
604            .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
605            .map(|(nid, _node)| nid)
606            .collect::<Vec<_>>();
607
608        for mod_bound_node in mod_bound_nodes {
609            self.remove_module_boundary(mod_bound_node)?;
610        }
611
612        Ok(())
613    }
614
615    /// see `merge_modules`
616    /// This function removes a singular module boundary from the graph and performs the necessary stitching to fix the graph afterward.
617    /// `merge_modules` calls this function for each module boundary in the graph.
618    fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
619        assert!(
620            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
621            "Should not remove intermediate node after subgraph partitioning"
622        );
623
624        let mut mod_pred_ports = BTreeMap::new();
625        let mut mod_succ_ports = BTreeMap::new();
626
627        for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
628            let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
629            mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
630        }
631
632        for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
633            let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
634            mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
635        }
636
637        if mod_pred_ports.keys().collect::<BTreeSet<_>>()
638            != mod_succ_ports.keys().collect::<BTreeSet<_>>()
639        {
640            // get module boundary node
641            let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
642                panic!();
643            };
644
645            if *input {
646                return Err(Diagnostic {
647                    span: *import_expr,
648                    level: Level::Error,
649                    message: format!(
650                        "The ports into the module did not match. input: {:?}, expected: {:?}",
651                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
652                        mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
653                    ),
654                });
655            } else {
656                return Err(Diagnostic {
657                    span: *import_expr,
658                    level: Level::Error,
659                    message: format!(
660                        "The ports out of the module did not match. output: {:?}, expected: {:?}",
661                        mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
662                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
663                    ),
664                });
665            }
666        }
667
668        for (port, (pred_edge, pred_port)) in mod_pred_ports {
669            let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
670
671            let (src, _) = self.edge(pred_edge);
672            let (_, dst) = self.edge(succ_edge);
673            self.remove_edge(pred_edge);
674            self.remove_edge(succ_edge);
675
676            let new_edge_id = self.graph.insert_edge(src, dst);
677            self.ports.insert(new_edge_id, (pred_port, succ_port));
678        }
679
680        self.graph.remove_vertex(mod_bound_node);
681        self.nodes.remove(mod_bound_node);
682
683        Ok(())
684    }
685}
686
687/// Edge methods.
688impl DfirGraph {
689    /// Get the `src` and `dst` for an edge: `(src GraphNodeId, dst GraphNodeId)`.
690    pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
691        let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
692        (src, dst)
693    }
694
695    /// Get the source and destination ports for an edge: `(src &PortIndexValue, dst &PortIndexValue)`.
696    pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
697        let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
698        (src_port, dst_port)
699    }
700
701    /// Iterator of all edge IDs `GraphEdgeId`.
702    pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
703        self.graph.edge_ids()
704    }
705
706    /// Iterator over all edges: `(GraphEdgeId, (src GraphNodeId, dst GraphNodeId))`.
707    pub fn edges(
708        &self,
709    ) -> impl '_
710    + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
711    + FusedIterator
712    + Clone
713    + Debug {
714        self.graph.edges()
715    }
716
717    /// Insert an edge between nodes thru the given ports.
718    pub fn insert_edge(
719        &mut self,
720        src: GraphNodeId,
721        src_port: PortIndexValue,
722        dst: GraphNodeId,
723        dst_port: PortIndexValue,
724    ) -> GraphEdgeId {
725        let edge_id = self.graph.insert_edge(src, dst);
726        self.ports.insert(edge_id, (src_port, dst_port));
727        edge_id
728    }
729
730    /// Removes an edge and its corresponding ports and edge type info.
731    pub fn remove_edge(&mut self, edge: GraphEdgeId) {
732        let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
733        let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
734    }
735}
736
737/// Subgraph methods.
738impl DfirGraph {
739    /// Nodes belonging to the given subgraph.
740    pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
741        self.subgraph_nodes
742            .get(subgraph_id)
743            .expect("Subgraph not found.")
744    }
745
746    /// Iterator over all subgraph IDs.
747    pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
748        self.subgraph_nodes.keys()
749    }
750
751    /// Iterator over all subgraphs, ID and members: `(GraphSubgraphId, Vec<GraphNodeId>)`.
752    pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
753        self.subgraph_nodes.iter()
754    }
755
756    /// Create a subgraph consisting of `node_ids`. Returns an error if any of the nodes are already in a subgraph.
757    pub fn insert_subgraph(
758        &mut self,
759        node_ids: Vec<GraphNodeId>,
760    ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
761        // Check none are already in subgraphs
762        for &node_id in node_ids.iter() {
763            if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
764                return Err((node_id, old_sg_id));
765            }
766        }
767        let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
768            for &node_id in node_ids.iter() {
769                self.node_subgraph.insert(node_id, sg_id);
770            }
771            node_ids
772        });
773
774        Ok(subgraph_id)
775    }
776
777    /// Removes a node from its subgraph. Returns true if the node was in a subgraph.
778    pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
779        if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
780            self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
781            true
782        } else {
783            false
784        }
785    }
786
787    /// Gets the delay type for a handoff node, if set.
788    pub fn handoff_delay_type(&self, node_id: GraphNodeId) -> Option<DelayType> {
789        self.handoff_delay_type.get(node_id).copied()
790    }
791
792    /// Sets the delay type for a handoff node.
793    pub fn set_handoff_delay_type(&mut self, node_id: GraphNodeId, delay_type: DelayType) {
794        self.handoff_delay_type.insert(node_id, delay_type);
795    }
796
797    /// Helper: finds the first index in `subgraph_nodes` where it transitions from pull to push.
798    fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
799        subgraph_nodes
800            .iter()
801            .position(|&node_id| {
802                self.node_color(node_id)
803                    .is_some_and(|color| Color::Pull != color)
804            })
805            .unwrap_or(subgraph_nodes.len())
806    }
807}
808
809/// Display/output methods.
810impl DfirGraph {
811    /// Helper to generate a deterministic `Ident` for the given node.
812    fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
813        let name = match &self.nodes[node_id] {
814            GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
815            GraphNode::Handoff {
816                kind: HandoffKind::Vec,
817                ..
818            } => format!(
819                "hoff_{:?}_{}",
820                node_id.data(),
821                if is_pred { "recv" } else { "send" }
822            ),
823            GraphNode::Handoff {
824                kind: HandoffKind::Singleton | HandoffKind::Optional,
825                ..
826            } => format!(
827                "singleton_{:?}_{}",
828                node_id.data(),
829                if is_pred { "recv" } else { "send" }
830            ),
831            GraphNode::ModuleBoundary { .. } => panic!(),
832        };
833        let span = match (is_pred, &self.nodes[node_id]) {
834            (_, GraphNode::Operator(operator)) => operator.span(),
835            (true, &GraphNode::Handoff { src_span, .. }) => src_span,
836            (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
837            (_, GraphNode::ModuleBoundary { .. }) => panic!(),
838        };
839        Ident::new(&name, span)
840    }
841
842    /// Helper to generate the main buffer `Ident` for a handoff node.
843    fn hoff_buf_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
844        Ident::new(&format!("hoff_{:?}_buf", hoff_id.data()), span)
845    }
846
847    /// Helper to generate the back (double-buffer) `Ident` for a handoff node.
848    fn hoff_back_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
849        Ident::new(&format!("hoff_{:?}_back", hoff_id.data()), span)
850    }
851
852    /// Resolve the singletons via [`Self::node_singleton_references`] for the given `node_id`.
853    /// Returns token streams for each reference:
854    /// - For HandoffKind::Singleton: `buf.as_ref().unwrap()` (shared, `&T`) or
855    ///   `buf.as_mut().unwrap()` (mutable, `&mut T`)
856    /// - For HandoffKind::Optional: `&buf` (shared, `&Option<T>`) or
857    ///   `&mut buf` (mutable, `&mut Option<T>`)
858    /// - For HandoffKind::Vec: `&buf` (shared, `&Vec<T>`) or
859    ///   `&mut buf` (mutable, `&mut Vec<T>`)
860    fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<TokenStream> {
861        self.node_singleton_references(node_id)
862            .iter()
863            .map(|resolved_ref| {
864                // TODO(mingwei): this `expect` should be caught in error checking
865                let ref_node_id = resolved_ref
866                    .node_id
867                    .expect("Expected singleton to be resolved but was not, this is a bug.");
868                let is_mut = resolved_ref.is_mut;
869                match self.node(ref_node_id) {
870                    GraphNode::Handoff {
871                        kind: HandoffKind::Singleton,
872                        ..
873                    } => {
874                        let buf_ident = self.hoff_buf_ident(ref_node_id, span);
875                        if is_mut {
876                            quote_spanned! {span=> #buf_ident.as_mut().unwrap() }
877                        } else {
878                            quote_spanned! {span=> #buf_ident.as_ref().unwrap() }
879                        }
880                    }
881                    GraphNode::Handoff {
882                        kind: HandoffKind::Optional | HandoffKind::Vec,
883                        ..
884                    } => {
885                        let buf_ident = self.hoff_buf_ident(ref_node_id, span);
886                        if is_mut {
887                            quote_spanned! {span=> &mut #buf_ident }
888                        } else {
889                            quote_spanned! {span=> &#buf_ident }
890                        }
891                    }
892                    _ => unreachable!(
893                        "Only handoff nodes should be reachable as singleton references"
894                    ),
895                }
896            })
897            .collect::<Vec<_>>()
898    }
899
900    /// Returns each subgraph's receive and send handoffs.
901    /// `Map<GraphSubgraphId, (recv handoffs, send handoffs)>`
902    fn helper_collect_subgraph_handoffs(
903        &self,
904    ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
905        // Get data on handoff src and dst subgraphs.
906        let mut subgraph_handoffs: SecondaryMap<
907            GraphSubgraphId,
908            (Vec<GraphNodeId>, Vec<GraphNodeId>),
909        > = self
910            .subgraph_nodes
911            .keys()
912            .map(|k| (k, Default::default()))
913            .collect();
914
915        // For each handoff/singleton node, add it to the `send`/`recv` lists for the corresponding subgraphs.
916        for (hoff_id, hoff) in self.nodes() {
917            if !matches!(hoff, GraphNode::Handoff { .. }) {
918                continue;
919            }
920            // Receivers from the handoff. (Should really only be one).
921            for (_edge, succ_id) in self.node_successors(hoff_id) {
922                let succ_sg = self.node_subgraph(succ_id).unwrap();
923                subgraph_handoffs[succ_sg].0.push(hoff_id);
924            }
925            // Senders into the handoff. (Should really only be one).
926            for (_edge, pred_id) in self.node_predecessors(hoff_id) {
927                let pred_sg = self.node_subgraph(pred_id).unwrap();
928                subgraph_handoffs[pred_sg].1.push(hoff_id);
929            }
930        }
931
932        subgraph_handoffs
933    }
934
935    /// Emit this graph as runnable Rust source code tokens that execute inline.
936    /// Generates a flat `async move |df: &mut Context|` closure where subgraph
937    /// blocks are inlined in topological order, using local `Vec<T>` buffers
938    /// instead of runtime handoffs. Each call to the closure runs one tick.
939    ///
940    /// The generated code block evaluates to a `Dfir` instance wrapping the
941    /// closure. Operator prologues run at construction time on the `Context`
942    /// before it is moved into `Dfir::new`. `Dfir` provides the `Context`
943    /// to the closure on each tick run.
944    ///
945    /// # Errors
946    ///
947    /// Returns all diagnostics as `Err(diagnostics)` if any are errors
948    /// (leaving `&mut diagnostics` empty).
949    pub fn as_code(
950        &self,
951        root: &TokenStream,
952        include_type_guards: bool,
953        prefix: TokenStream,
954        diagnostics: &mut Diagnostics,
955    ) -> Result<TokenStream, Diagnostics> {
956        self.as_code_with_options(root, include_type_guards, true, prefix, diagnostics)
957    }
958
959    /// Like [`Self::as_code`], but with `include_meta` controlling whether
960    /// the runtime meta graph + diagnostics JSON blobs are baked into the
961    /// generated `Dfir::new(...)` call.
962    ///
963    /// The simulator calls Dfir::new() on each iteration, and as a part of that
964    /// it does parsing of the metagraph and diganostics blob. One of them causes spans to get allocated,
965    /// each time a span is allocated, some threadlocal u32 is being incremented, and, on a long simulator run,
966    /// the u32 overflows and panics.
967    pub fn as_code_with_options(
968        &self,
969        root: &TokenStream,
970        include_type_guards: bool,
971        include_meta: bool,
972        prefix: TokenStream,
973        diagnostics: &mut Diagnostics,
974    ) -> Result<TokenStream, Diagnostics> {
975        let df = Ident::new(GRAPH, Span::call_site());
976        let context = Ident::new(CONTEXT, Span::call_site());
977        // Tick-local bump-allocated Vec handoff declarations (inside the tick closure).
978        let bump_ident = Ident::new("__dfir_bump", Span::call_site());
979
980        // 1. Collect all handoff nodes.
981        let handoff_nodes = self
982            .nodes
983            .iter()
984            .filter_map(|(node_id, node)| match node {
985                &GraphNode::Handoff {
986                    kind,
987                    src_span,
988                    dst_span,
989                } => Some((node_id, kind, (src_span, dst_span))),
990                GraphNode::Operator(_) => None,
991                GraphNode::ModuleBoundary { .. } => panic!(),
992            })
993            .collect::<Vec<_>>();
994
995        // Determine which handoff nodes are tick-boundary (defer_tick) back-edges.
996        // These must remain as captured Vec<T> since they persist across ticks.
997        // All other Vec handoffs will be bump-allocated (tick-local).
998        let back_edge_hoffs_and_lazyness = handoff_nodes
999            .iter()
1000            .map(|&(node_id, _, _)| node_id)
1001            .filter_map(|node_id| {
1002                if let Some(delay_type) = self.handoff_delay_type(node_id) {
1003                    assert!(
1004                        matches!(delay_type, DelayType::Tick | DelayType::TickLazy),
1005                        "Handoff `DelayType` must be either `Tick` or `TickLazy` (or unset)."
1006                    );
1007                    Some((node_id, matches!(delay_type, DelayType::TickLazy)))
1008                } else {
1009                    None
1010                }
1011            })
1012            .collect::<SparseSecondaryMap<_, _>>();
1013
1014        // Back buffer idents, buf idents, and if they are lazy.
1015        let back_buffer_idents_laziness = handoff_nodes
1016            .iter()
1017            .filter_map(|&(hoff_id, _kind, (src_span, dst_span))| {
1018                back_edge_hoffs_and_lazyness.get(hoff_id).map(|&is_lazy| {
1019                    let span = src_span.join(dst_span).unwrap_or(src_span);
1020                    let back_ident = self.hoff_back_ident(hoff_id, span);
1021                    let buf_ident = self.hoff_buf_ident(hoff_id, span);
1022                    (back_ident, buf_ident, is_lazy)
1023                })
1024            })
1025            .collect::<Vec<_>>();
1026
1027        // Generate swap code for tick-boundary (defer_tick / defer_tick_lazy) handoffs.
1028        // At the end of each tick, swap the regular buffer and back buffer so the
1029        // consumer reads last tick's data from the back buffer.
1030        let back_edge_swap_code = handoff_nodes
1031            .iter()
1032            .filter(|&&(node_id, _kind, _)| back_edge_hoffs_and_lazyness.contains_key(node_id))
1033            .map(|&(hoff_id, _kind, _)| {
1034                let span = self.nodes[hoff_id].span();
1035                let buf_ident = self.hoff_buf_ident(hoff_id, span);
1036                let back_ident = self.hoff_back_ident(hoff_id, span);
1037                quote_spanned! {span=>
1038                    ::std::mem::swap(&mut #buf_ident, &mut #back_ident);
1039                }
1040            })
1041            .collect::<Vec<_>>();
1042
1043        // 2. Collect per-subgraph recv & send handoffs.
1044        let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
1045
1046        // 3. Sort subgraphs topologically and collect non-lazy defer_tick buffer idents.
1047        //
1048        // Handoffs marked with a `DelayType` (Tick/TickLazy) are tick-boundary back-edges.
1049        // These are excluded from the topo sort (no ordering constraint). Double-buffering
1050        // ensures data written by the producer in tick N is only visible to the consumer
1051        // in tick N+1, regardless of execution order.
1052        //
1053        // While iterating handoffs, we also collect buffer idents for non-lazy tick-boundary
1054        // edges (defer_tick). When these buffers are non-empty at end of tick, we set
1055        // can_start_tick so that run_available continues ticking.
1056        //
1057        // TODO(mingwei): right now we topo sort more than once in the build process, we should keep a single order.
1058        let all_subgraphs = {
1059            // Build predecessor map for subgraphs.
1060            let mut sg_preds: SecondaryMap<GraphSubgraphId, Vec<GraphSubgraphId>> =
1061                SecondaryMap::<_, Vec<_>>::with_capacity(self.subgraph_nodes.len());
1062            for (hoff_id, hoff) in self.nodes() {
1063                if !matches!(hoff, GraphNode::Handoff { .. }) {
1064                    // Not a handoff; skip.
1065                    continue;
1066                }
1067                if 0 == self.node_successors(hoff_id).len() {
1068                    // Is a handoff only used by reference, not consumed.
1069                    continue;
1070                }
1071                assert_eq!(1, self.node_successors(hoff_id).len());
1072                assert_eq!(1, self.node_predecessors(hoff_id).len());
1073                let (_edge_id, pred) = self.node_predecessors(hoff_id).next().unwrap();
1074                let (_edge_id, succ) = self.node_successors(hoff_id).next().unwrap();
1075                let pred_sg = self.node_subgraph(pred).unwrap();
1076                let succ_sg = self.node_subgraph(succ).unwrap();
1077                if pred_sg == succ_sg {
1078                    panic!("bug: unexpected subgraph self-handoff cycle");
1079                }
1080                // Only consider non-back-edges.
1081                if !back_edge_hoffs_and_lazyness.contains_key(hoff_id) {
1082                    sg_preds.entry(succ_sg).unwrap().or_default().push(pred_sg);
1083                }
1084            }
1085
1086            // Include singleton reference edges: if node A references the
1087            // singleton output of node B, then A's subgraph must run after B's.
1088            for dst_id in self.node_ids() {
1089                for src_ref_id in self
1090                    .node_singleton_references(dst_id)
1091                    .iter()
1092                    .filter_map(|r| r.node_id)
1093                {
1094                    // For handoff nodes (no subgraph), use the predecessor's subgraph.
1095                    let src_sg = if let Some(sg) = self.node_subgraph(src_ref_id) {
1096                        sg
1097                    } else {
1098                        let (_edge, pred) = self
1099                            .node_predecessors(src_ref_id)
1100                            .next()
1101                            .expect("handoff must have a predecessor");
1102                        self.node_subgraph(pred).unwrap()
1103                    };
1104                    let dst_sg = self
1105                        .node_subgraph(dst_id)
1106                        .expect("bug: singleton ref consumer must belong to a subgraph");
1107                    if src_sg != dst_sg {
1108                        sg_preds.entry(dst_sg).unwrap().or_default().push(src_sg);
1109                    }
1110
1111                    // Ensure the borrower runs before the pipe consumer
1112                    // (which takes/drains the value).
1113                    // All handoffs should have at most one successor.
1114                    if self.node_subgraph(src_ref_id).is_none() {
1115                        assert!(
1116                            self.node_degree_out(src_ref_id) <= 1,
1117                            "handoff should have at most one successor"
1118                        );
1119                        if let Some((_edge, succ_id)) = self.node_successors(src_ref_id).next()
1120                            && let Some(consumer_sg) = self.node_subgraph(succ_id)
1121                            && consumer_sg != dst_sg
1122                        {
1123                            sg_preds
1124                                .entry(consumer_sg)
1125                                .unwrap()
1126                                .or_default()
1127                                .push(dst_sg);
1128                        }
1129                    }
1130                }
1131            }
1132
1133            let topo_sort = super::graph_algorithms::topo_sort(self.subgraph_ids(), |sg_id| {
1134                sg_preds.get(sg_id).into_iter().flatten().copied()
1135            })
1136            .expect("bug: unexpected cycle between subgraphs within the tick");
1137
1138            topo_sort
1139                .into_iter()
1140                .map(|sg_id| (sg_id, self.subgraph(sg_id)))
1141                .collect::<Vec<_>>()
1142        };
1143
1144        // TODO(mingwei): If a handoff has no pipe consumers we should drop it as soon as possible, after all reference
1145        // consumers. Right now we just let these handoffs die at the end of the tick.
1146
1147        let mut op_prologue_code = Vec::new();
1148        let mut op_tick_end_code = Vec::new();
1149        let mut subgraph_blocks = Vec::new();
1150        {
1151            for &(subgraph_id, subgraph_nodes) in all_subgraphs.iter() {
1152                let sg_metrics_ffi = subgraph_id.data().as_ffi();
1153                let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
1154
1155                // Generate buffer ident helpers for this subgraph's handoffs.
1156                let recv_port_idents: Vec<Ident> = recv_hoffs
1157                    .iter()
1158                    .map(|&hoff_id| self.node_as_ident(hoff_id, true))
1159                    .collect();
1160                let send_port_idents: Vec<Ident> = send_hoffs
1161                    .iter()
1162                    .map(|&hoff_id| self.node_as_ident(hoff_id, false))
1163                    .collect();
1164
1165                // Map handoff node IDs to buffer idents.
1166                let recv_buf_idents: Vec<Ident> = recv_hoffs
1167                    .iter()
1168                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1169                    .collect();
1170                let send_buf_idents: Vec<Ident> = send_hoffs
1171                    .iter()
1172                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1173                    .collect();
1174
1175                // Handoff kinds
1176                let recv_kinds = recv_hoffs
1177                    .iter()
1178                    .map(|&hoff_id| {
1179                        let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1180                            panic!()
1181                        };
1182                        *kind
1183                    })
1184                    .collect::<Vec<_>>();
1185                let send_kinds = send_hoffs
1186                    .iter()
1187                    .map(|&hoff_id| {
1188                        let GraphNode::Handoff { kind, .. } = self.node(hoff_id) else {
1189                            panic!()
1190                        };
1191                        *kind
1192                    })
1193                    .collect::<Vec<_>>();
1194
1195                // Recv port code: drain from buffer into iterator, tracking if non-empty.
1196                // For back-edge (defer_tick) handoffs, drain from the back buffer instead.
1197                // Also update handoff metrics (measured at recv, not send — see graph.rs).
1198                let recv_port_code: Vec<TokenStream> = recv_port_idents
1199                    .iter()
1200                    .zip(recv_buf_idents.iter())
1201                    .zip(recv_kinds.iter())
1202                    .zip(recv_hoffs.iter())
1203                    .map(|(((port_ident, buf_ident), &kind), &hoff_id)| {
1204                        let hoff_ffi = hoff_id.data().as_ffi();
1205                        // Use call_site span for internal identifiers to avoid
1206                        // hygiene issues when invoked through declarative macros
1207                        // (e.g. dfir_expect_warnings!). TODO(#2781): define these once.
1208                        let work_done = Ident::new("__dfir_work_done", Span::call_site());
1209                        let metrics = Ident::new("__dfir_metrics", Span::call_site());
1210
1211                        // Compute len and drain expressions based on handoff kind.
1212                        let (len_expr, drain_expr) = match kind {
1213                            HandoffKind::Singleton | HandoffKind::Optional => (
1214                                quote! { if #buf_ident.is_some() { 1usize } else { 0usize } },
1215                                quote! { #root::dfir_pipes::pull::iter(#buf_ident.take().into_iter()) },
1216                            ),
1217                            HandoffKind::Vec => {
1218                                // Special asymmetric handling for defer tick handoffs, which are double-buffered.
1219                                // The producer writes to the regular buffer; at end-of-tick the buffers are swapped,
1220                                // so the consumer drains from the back buffer (here).
1221                                let drain_ident = if back_edge_hoffs_and_lazyness.contains_key(hoff_id) {
1222                                    &self.hoff_back_ident(hoff_id, buf_ident.span())
1223                                } else {
1224                                    buf_ident
1225                                };
1226                                (
1227                                    quote! { #drain_ident.len() },
1228                                    quote! { #root::dfir_pipes::pull::iter(#drain_ident.drain(..)) },
1229                                )
1230                            }
1231                        };
1232
1233                        quote_spanned! {port_ident.span()=>
1234                            {
1235                                let hoff_len = #len_expr;
1236                                if hoff_len > 0 {
1237                                    #work_done = true;
1238                                }
1239                                let hoff_metrics = &#metrics.handoffs[
1240                                    #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1241                                ];
1242                                hoff_metrics.total_items_count.update(|x| x + hoff_len);
1243                                hoff_metrics.curr_items_count.set(hoff_len);
1244                            }
1245                            let #port_ident = #drain_expr;
1246                        }
1247                    })
1248                    .collect();
1249
1250                // Send port code: push into buffer.
1251                let send_port_code: Vec<TokenStream> = send_port_idents
1252                    .iter()
1253                    .zip(send_buf_idents.iter())
1254                    .zip(send_kinds.iter())
1255                    .map(|((port_ident, buf_ident), &kind)| {
1256                        match kind {
1257                            HandoffKind::Singleton => {
1258                                // Singleton slot: store exactly one item, panic on duplicate.
1259                                quote_spanned! {port_ident.span()=>
1260                                    let #port_ident = #root::dfir_pipes::push::for_each(|__item| {
1261                                        if #buf_ident.replace(__item).is_some() {
1262                                            panic!("singleton() received more than one item");
1263                                        }
1264                                    });
1265                                }
1266                            }
1267                            HandoffKind::Optional => {
1268                                // Optional slot: store at most one item, panic on duplicate.
1269                                quote_spanned! {port_ident.span()=>
1270                                    let #port_ident = #root::dfir_pipes::push::for_each(|__item| {
1271                                        if #buf_ident.replace(__item).is_some() {
1272                                            panic!("optional() received more than one item");
1273                                        }
1274                                    });
1275                                }
1276                            }
1277                            HandoffKind::Vec => {
1278                                quote_spanned! {port_ident.span()=>
1279                                    // TODO(mingwei): use `#root::dfir_pipes::push::vec_push`?
1280                                    let #port_ident = #root::dfir_pipes::push::for_each(|item| { #buf_ident.push(item); });
1281                                }
1282                            }
1283                        }
1284                    })
1285                    .collect();
1286
1287                // All nodes in a subgraph should be in the same loop.
1288                let loop_id = self.node_loop(subgraph_nodes[0]);
1289
1290                let mut subgraph_op_iter_code = Vec::new();
1291                let mut subgraph_op_iter_after_code = Vec::new();
1292                {
1293                    let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
1294
1295                    let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
1296                    let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
1297
1298                    for (idx, &node_id) in nodes_iter.enumerate() {
1299                        let node = &self.nodes[node_id];
1300                        assert!(
1301                            matches!(node, GraphNode::Operator(_)),
1302                            "Handoffs are not part of subgraphs."
1303                        );
1304                        let op_inst = &self.operator_instances[node_id];
1305
1306                        let op_span = node.span();
1307                        let op_name = op_inst.op_constraints.name;
1308                        // Use op's span for root. #root is expected to be correct, any errors should span back to the op gen.
1309                        let root = change_spans(root.clone(), op_span);
1310                        let op_constraints = OPERATORS
1311                            .iter()
1312                            .find(|op| op_name == op.name)
1313                            .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
1314
1315                        let ident = self.node_as_ident(node_id, false);
1316
1317                        {
1318                            // TODO clean this up.
1319                            // Collect input arguments (predecessors).
1320                            let mut input_edges = self
1321                                .graph
1322                                .predecessor_edges(node_id)
1323                                .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
1324                                .collect::<Vec<_>>();
1325                            // Ensure sorted by port index.
1326                            input_edges.sort();
1327
1328                            let inputs = input_edges
1329                                .iter()
1330                                .map(|&(_port, edge_id)| {
1331                                    let (pred, _) = self.edge(edge_id);
1332                                    self.node_as_ident(pred, true)
1333                                })
1334                                .collect::<Vec<_>>();
1335
1336                            // Collect output arguments (successors).
1337                            let mut output_edges = self
1338                                .graph
1339                                .successor_edges(node_id)
1340                                .map(|edge_id| (&self.ports[edge_id].0, edge_id))
1341                                .collect::<Vec<_>>();
1342                            // Ensure sorted by port index.
1343                            output_edges.sort();
1344
1345                            let outputs = output_edges
1346                                .iter()
1347                                .map(|&(_port, edge_id)| {
1348                                    let (_, succ) = self.edge(edge_id);
1349                                    self.node_as_ident(succ, false)
1350                                })
1351                                .collect::<Vec<_>>();
1352
1353                            let is_pull = idx < pull_to_push_idx;
1354
1355                            // There's a bit of dark magic hidden in `Span`s... you'd think it's just a `file:line:column`,
1356                            // but it has one extra bit of info for _name resolution_, used for `Ident`s. `Span::call_site()`
1357                            // has the (unhygienic) resolution we want, an ident is just solely determined by its string name,
1358                            // which is what you'd expect out of unhygienic proc macros like this. Meanwhile, declarative macros
1359                            // use `Span::mixed_site()` which is weird and I don't understand it. It turns out that if you call
1360                            // the dfir syntax proc macro from _within_ a declarative macro then `op_span` will have the
1361                            // bad `Span::mixed_site()` name resolution and cause "Cannot find value `df/context`" errors. So
1362                            // we call `.resolved_at()` to fix resolution back to `Span::call_site()`. -Mingwei
1363                            let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1364                            let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1365
1366                            let singletons_resolved =
1367                                self.helper_resolve_singletons(node_id, op_span);
1368
1369                            let arguments = &process_singletons::postprocess_singletons(
1370                                op_inst.arguments_raw.clone(),
1371                                singletons_resolved,
1372                            );
1373
1374                            let source_tag = 'a: {
1375                                if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1376                                    break 'a tag;
1377                                }
1378
1379                                if proc_macro::is_available() {
1380                                    let op_span = op_span.unwrap();
1381                                    break 'a format!(
1382                                        "loc_{}_{}_{}_{}_{}",
1383                                        crate::pretty_span::make_source_path_relative(
1384                                            &op_span.file()
1385                                        )
1386                                        .display()
1387                                        .to_string()
1388                                        .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1389                                        op_span.start().line(),
1390                                        op_span.start().column(),
1391                                        op_span.end().line(),
1392                                        op_span.end().column(),
1393                                    );
1394                                }
1395
1396                                format!(
1397                                    "loc_nopath_{}_{}_{}_{}",
1398                                    op_span.start().line,
1399                                    op_span.start().column,
1400                                    op_span.end().line,
1401                                    op_span.end().column
1402                                )
1403                            };
1404
1405                            let work_fn = format_ident!(
1406                                "{}__{}__{}",
1407                                ident,
1408                                op_name,
1409                                source_tag,
1410                                span = op_span
1411                            );
1412                            let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1413
1414                            let context_args = WriteContextArgs {
1415                                root: &root,
1416                                df_ident: df_local,
1417                                context,
1418                                subgraph_id,
1419                                node_id,
1420                                loop_id,
1421                                op_span,
1422                                op_tag: self.operator_tag.get(node_id).cloned(),
1423                                work_fn: &work_fn,
1424                                work_fn_async: &work_fn_async,
1425                                ident: &ident,
1426                                is_pull,
1427                                inputs: &inputs,
1428                                outputs: &outputs,
1429                                op_name,
1430                                op_inst,
1431                                arguments,
1432                            };
1433
1434                            let write_result =
1435                                (op_constraints.write_fn)(&context_args, diagnostics);
1436                            let OperatorWriteOutput {
1437                                write_prologue,
1438                                write_iterator,
1439                                write_iterator_after,
1440                                write_tick_end,
1441                            } = write_result.unwrap_or_else(|()| {
1442                                assert!(
1443                                    diagnostics.has_error(),
1444                                    "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1445                                    op_name,
1446                                );
1447                                OperatorWriteOutput {
1448                                    write_iterator: null_write_iterator_fn(&context_args),
1449                                    ..Default::default()
1450                                }
1451                            });
1452
1453                            op_prologue_code.push(syn::parse_quote! {
1454                                #[allow(non_snake_case)]
1455                                #[inline(always)]
1456                                fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1457                                    thunk()
1458                                }
1459
1460                                #[allow(non_snake_case)]
1461                                #[inline(always)]
1462                                async fn #work_fn_async<T>(
1463                                    thunk: impl ::std::future::Future<Output = T>,
1464                                ) -> T {
1465                                    thunk.await
1466                                }
1467                            });
1468                            op_prologue_code.push(write_prologue);
1469                            op_tick_end_code.push(write_tick_end);
1470                            subgraph_op_iter_code.push(write_iterator);
1471
1472                            if include_type_guards {
1473                                let type_guard = if is_pull {
1474                                    quote_spanned! {op_span=>
1475                                        let #ident = {
1476                                            #[allow(non_snake_case)]
1477                                            #[inline(always)]
1478                                            pub fn #work_fn<Item, Input>(input: Input)
1479                                                -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1480                                            where
1481                                                Input: #root::dfir_pipes::pull::Pull<Item = Item, Meta = ()>,
1482                                            {
1483                                                #root::pin_project_lite::pin_project! {
1484                                                    #[repr(transparent)]
1485                                                    struct Pull<Item, Input: #root::dfir_pipes::pull::Pull<Item = Item>> {
1486                                                        #[pin]
1487                                                        inner: Input
1488                                                    }
1489                                                }
1490
1491                                                impl<Item, Input> #root::dfir_pipes::pull::Pull for Pull<Item, Input>
1492                                                where
1493                                                    Input: #root::dfir_pipes::pull::Pull<Item = Item>,
1494                                                {
1495                                                    type Ctx<'ctx> = Input::Ctx<'ctx>;
1496
1497                                                    type Item = Item;
1498                                                    type Meta = Input::Meta;
1499                                                    type CanPend = Input::CanPend;
1500                                                    type CanEnd = Input::CanEnd;
1501
1502                                                    #[inline(always)]
1503                                                    fn pull(
1504                                                        self: ::std::pin::Pin<&mut Self>,
1505                                                        ctx: &mut Self::Ctx<'_>,
1506                                                    ) -> #root::dfir_pipes::pull::PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1507                                                        #root::dfir_pipes::pull::Pull::pull(self.project().inner, ctx)
1508                                                    }
1509
1510                                                    #[inline(always)]
1511                                                    fn size_hint(&self) -> (usize, Option<usize>) {
1512                                                        #root::dfir_pipes::pull::Pull::size_hint(&self.inner)
1513                                                    }
1514                                                }
1515
1516                                                Pull {
1517                                                    inner: input
1518                                                }
1519                                            }
1520                                            #work_fn::<_, _>( #ident )
1521                                        };
1522                                    }
1523                                } else {
1524                                    quote_spanned! {op_span=>
1525                                        let #ident = {
1526                                            #[allow(non_snake_case)]
1527                                            #[inline(always)]
1528                                            pub fn #work_fn<Item, Psh>(psh: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
1529                                            where
1530                                                Psh: #root::dfir_pipes::push::Push<Item, ()>
1531                                            {
1532                                                #root::pin_project_lite::pin_project! {
1533                                                    #[repr(transparent)]
1534                                                    struct PushGuard<Psh> {
1535                                                        #[pin]
1536                                                        inner: Psh,
1537                                                    }
1538                                                }
1539
1540                                                impl<Item, Psh> #root::dfir_pipes::push::Push<Item, ()> for PushGuard<Psh>
1541                                                where
1542                                                    Psh: #root::dfir_pipes::push::Push<Item, ()>,
1543                                                {
1544                                                    type Ctx<'ctx> = Psh::Ctx<'ctx>;
1545
1546                                                    type CanPend = Psh::CanPend;
1547
1548                                                    #[inline(always)]
1549                                                    fn poll_ready(
1550                                                        self: ::std::pin::Pin<&mut Self>,
1551                                                        ctx: &mut Self::Ctx<'_>,
1552                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1553                                                        #root::dfir_pipes::push::Push::poll_ready(self.project().inner, ctx)
1554                                                    }
1555
1556                                                    #[inline(always)]
1557                                                    fn start_send(
1558                                                        self: ::std::pin::Pin<&mut Self>,
1559                                                        item: Item,
1560                                                        meta: (),
1561                                                    ) {
1562                                                        #root::dfir_pipes::push::Push::start_send(self.project().inner, item, meta)
1563                                                    }
1564
1565                                                    #[inline(always)]
1566                                                    fn poll_finalize(
1567                                                        self: ::std::pin::Pin<&mut Self>,
1568                                                        ctx: &mut Self::Ctx<'_>,
1569                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1570                                                        #root::dfir_pipes::push::Push::poll_finalize(self.project().inner, ctx)
1571                                                    }
1572
1573                                                    #[inline(always)]
1574                                                    fn size_hint(
1575                                                        self: ::std::pin::Pin<&mut Self>,
1576                                                        hint: (usize, Option<usize>),
1577                                                    ) {
1578                                                        #root::dfir_pipes::push::Push::size_hint(self.project().inner, hint)
1579                                                    }
1580                                                }
1581
1582                                                PushGuard {
1583                                                    inner: psh
1584                                                }
1585                                            }
1586                                            #work_fn( #ident )
1587                                        };
1588                                    }
1589                                };
1590                                subgraph_op_iter_code.push(type_guard);
1591                            }
1592                            subgraph_op_iter_after_code.push(write_iterator_after);
1593                        }
1594                    }
1595
1596                    {
1597                        // Determine pull and push halves of the `Pivot`.
1598                        let pull_ident = if 0 < pull_to_push_idx {
1599                            self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1600                        } else {
1601                            // Entire subgraph is push (with a single recv/pull handoff input).
1602                            recv_port_idents[0].clone()
1603                        };
1604
1605                        #[rustfmt::skip]
1606                        let push_ident = if let Some(&node_id) =
1607                            subgraph_nodes.get(pull_to_push_idx)
1608                        {
1609                            self.node_as_ident(node_id, false)
1610                        } else if 1 == send_port_idents.len() {
1611                            // Entire subgraph is pull (with a single send/push handoff output).
1612                            send_port_idents[0].clone()
1613                        } else {
1614                            diagnostics.push(Diagnostic::spanned(
1615                                pull_ident.span(),
1616                                Level::Error,
1617                                "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1618                            ));
1619                            continue;
1620                        };
1621
1622                        // Pivot span is combination of pull and push spans (or if not possible, just take the push).
1623                        let pivot_span = pull_ident
1624                            .span()
1625                            .join(push_ident.span())
1626                            .unwrap_or_else(|| push_ident.span());
1627                        let pivot_fn_ident = Ident::new(
1628                            &format!("pivot_run_sg_{:?}", subgraph_id.data()),
1629                            pivot_span,
1630                        );
1631                        let root = change_spans(root.clone(), pivot_span);
1632                        subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1633                            #[inline(always)]
1634                            fn #pivot_fn_ident<Pul, Psh, Item>(pull: Pul, push: Psh)
1635                                -> impl ::std::future::Future<Output = ()>
1636                            where
1637                                Pul: #root::dfir_pipes::pull::Pull<Item = Item>,
1638                                Psh: #root::dfir_pipes::push::Push<Item, Pul::Meta>,
1639                            {
1640                                #root::dfir_pipes::pull::Pull::send_push(pull, push)
1641                            }
1642                            (#pivot_fn_ident)(#pull_ident, #push_ident).await;
1643                        });
1644                    }
1645                };
1646
1647                // Each subgraph block is an async block so it can be individually instrumented.
1648                // Note: this ident is for the subgraph future, not a runtime SubgraphId binding
1649                // (unlike the scheduled path's `sg_ident`).
1650                let sg_fut_ident = subgraph_id.as_ident(Span::call_site());
1651
1652                // Generate send-side curr_items_count updates (after subgraph runs).
1653                let send_metrics_code = send_hoffs
1654                    .iter()
1655                    .zip(send_buf_idents.iter())
1656                    .zip(send_kinds.iter())
1657                    .map(|((&hoff_id, buf_ident), &kind)| {
1658                        let hoff_ffi = hoff_id.data().as_ffi();
1659                        let len_expr = match kind {
1660                            HandoffKind::Singleton | HandoffKind::Optional => {
1661                                quote! { if #buf_ident.is_some() { 1 } else { 0 } }
1662                            }
1663                            HandoffKind::Vec => {
1664                                quote! { #buf_ident.len() }
1665                            }
1666                        };
1667                        quote! {
1668                            __dfir_metrics.handoffs[
1669                                #root::slotmap::KeyData::from_ffi(#hoff_ffi).into()
1670                            ].curr_items_count.set(#len_expr);
1671                        }
1672                    })
1673                    .collect::<Vec<_>>();
1674
1675                // Create the handoffs we are about to push to (send).
1676                let send_hoff_make_code = send_buf_idents.iter()
1677                    .zip(send_kinds.iter())
1678                    .zip(send_hoffs.iter())
1679                    .map(|((buf_ident, &kind), &hoff_id)| {
1680                        let span = buf_ident.span();
1681                        if back_edge_hoffs_and_lazyness.contains_key(hoff_id) {
1682                            // Defer_tick send buffers are declared outside the tick closure
1683                            // as std::vec::Vec for O(1) swap. Just clear here.
1684                            quote_spanned! {span=>
1685                                #buf_ident.clear();
1686                            }
1687                        } else {
1688                            match kind {
1689                                HandoffKind::Vec => quote_spanned! {span=>
1690                                    let mut #buf_ident = #root::bumpalo::collections::Vec::new_in(&#bump_ident);
1691                                },
1692                                HandoffKind::Singleton | HandoffKind::Optional => quote_spanned! {span=>
1693                                    let mut #buf_ident = ::std::option::Option::None;
1694                                },
1695                            }
1696                        }
1697                    });
1698                // Drop the handoffs we just drained (recv).
1699                // TODO(mingwei): we could use `.into_iter()` instead of `.drain(..)` to consume the handoffs directly.
1700                // This only works for handoffs within the tick, though, not `defer_tick` handoffs.
1701                let recv_hoff_drop_code = recv_buf_idents
1702                    .iter()
1703                    .zip(recv_hoffs.iter())
1704                    .filter(|&(_, &hoff_id)| !back_edge_hoffs_and_lazyness.contains_key(hoff_id))
1705                    .map(|(buf_ident, _)| {
1706                        let span = buf_ident.span();
1707                        quote_spanned! {span=>
1708                            let _ = #buf_ident;
1709                        }
1710                    });
1711
1712                subgraph_blocks.push(quote! {
1713                    // Create the handoffs we are about to push to (send).
1714                    #( #send_hoff_make_code )*
1715
1716                    let #sg_fut_ident = async {
1717                        let #context = &#df;
1718                        #( #recv_port_code )*
1719                        #( #send_port_code )*
1720                        #( #subgraph_op_iter_code )*
1721                        #( #subgraph_op_iter_after_code )*
1722                    };
1723                    {
1724                        // Instrument w/ the subgraph metrics.
1725                        let sg_metrics = &__dfir_metrics.subgraphs[
1726                            #root::slotmap::KeyData::from_ffi(#sg_metrics_ffi).into()
1727                        ];
1728                        #root::scheduled::metrics::InstrumentSubgraph::new(
1729                            #sg_fut_ident, sg_metrics
1730                        ).await;
1731                        sg_metrics.total_run_count.update(|x| x + 1);
1732
1733                        // Update send (output) handoff metrics.
1734                        #( #send_metrics_code )*
1735
1736                        // Drop the handoffs we just drained (recv).
1737                        #( #recv_hoff_drop_code )*
1738                    }
1739                });
1740
1741                // Collect per-subgraph prologues into the main prologue lists.
1742                // (They are already pushed above in the operator loop.)
1743            }
1744        }
1745
1746        if diagnostics.has_error() {
1747            return Err(std::mem::take(diagnostics));
1748        }
1749        let _ = diagnostics; // Ensure no more diagnostics may be added after checking for errors.
1750
1751        let (meta_graph_arg, diagnostics_arg) = if include_meta {
1752            let meta_graph_json = serde_json::to_string(&self).unwrap();
1753            let meta_graph_json = Literal::string(&meta_graph_json);
1754
1755            let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1756            let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1757            let diagnostics_json = Literal::string(&diagnostics_json);
1758
1759            (
1760                quote! { Some(#meta_graph_json) },
1761                quote! { Some(#diagnostics_json) },
1762            )
1763        } else {
1764            (quote! { None }, quote! { None })
1765        };
1766
1767        // Generate metrics initialization: one entry per handoff and per subgraph.
1768        let metrics_init_code = {
1769            let handoff_inits = handoff_nodes.iter().map(|&(node_id, _, _)| {
1770                let ffi = node_id.data().as_ffi();
1771                quote! {
1772                    dfir_metrics.handoffs.insert(
1773                        #root::slotmap::KeyData::from_ffi(#ffi).into(),
1774                        ::std::default::Default::default(),
1775                    );
1776                }
1777            });
1778            let subgraph_inits = all_subgraphs.iter().map(|&(sg_id, _)| {
1779                let ffi = sg_id.data().as_ffi();
1780                quote! {
1781                    dfir_metrics.subgraphs.insert(
1782                        #root::slotmap::KeyData::from_ffi(#ffi).into(),
1783                        ::std::default::Default::default(),
1784                    );
1785                }
1786            });
1787            handoff_inits.chain(subgraph_inits).collect::<Vec<_>>()
1788        };
1789
1790        // For creating back-buffer handoff vecs.
1791        let back_buffer_idents = back_buffer_idents_laziness
1792            .iter()
1793            .map(|(back_ident, _, _)| back_ident);
1794        // For creating the send-side buffer for defer_tick handoffs (also outside the closure).
1795        let defer_tick_buf_idents = back_buffer_idents_laziness
1796            .iter()
1797            .map(|(_, buf_ident, _)| buf_ident);
1798        // For checking if we should start the next tick (`schedule_subgraph`):
1799        // check the regular (send) buffer for non-lazy defer_tick handoffs, since
1800        // that's where the producer writes during this tick.
1801        let non_lazy_buf_idents = back_buffer_idents_laziness
1802            .iter()
1803            .filter_map(|(_, buf_ident, is_lazy)| (!is_lazy).then_some(buf_ident));
1804
1805        // Prologues and buffer declarations persist across ticks (outside the closure).
1806        // Subgraph blocks run each tick (inside the closure).
1807        Ok(quote! {
1808            {
1809                #prefix
1810
1811                use #root::{var_expr, var_args};
1812
1813                let __dfir_wake_state = ::std::sync::Arc::new(
1814                    #root::scheduled::context::WakeState::default()
1815                );
1816
1817                let __dfir_metrics = {
1818                    let mut dfir_metrics = #root::scheduled::metrics::DfirMetrics::default();
1819                    #( #metrics_init_code )*
1820                    ::std::rc::Rc::new(dfir_metrics)
1821                };
1822
1823                #[allow(unused_mut)]
1824                let mut #df = #root::scheduled::context::Context::new(
1825                    ::std::clone::Clone::clone(&__dfir_wake_state),
1826                    __dfir_metrics,
1827                );
1828
1829                #( #op_prologue_code )*
1830
1831                // For tick-boundary handoffs (`defer_tick` / `defer_tick_lazy`), declare both the
1832                // send buffer and the "back" buffer as std::vec::Vec outside the tick closure.
1833                // This enables O(1) mem::swap at end of tick for double-buffering.
1834                #( let mut #back_buffer_idents = ::std::vec::Vec::new(); )*
1835                #( let mut #defer_tick_buf_idents = ::std::vec::Vec::new(); )*
1836
1837                // Bump allocator for handoffs (except for back-edge handoffs, above).
1838                let mut #bump_ident = #root::bumpalo::Bump::new();
1839
1840                // Pre-set to true so the first tick always returns true
1841                // (matching Dfir pre-scheduling behavior). Subsequent ticks
1842                // start false (from take()) and are set true by recv port code
1843                // if any handoff buffer has data.
1844                let mut __dfir_work_done = true;
1845                #[allow(unused_qualifications, unused_mut, unused_variables, clippy::await_holding_refcell_ref, clippy::deref_addrof)]
1846                let __dfir_inline_tick = async move |#df: &mut #root::scheduled::context::Context| {
1847                    // Reset arena between ticks (start-of-tick)
1848                    #bump_ident.reset();
1849
1850                    {
1851                        let __dfir_metrics = #df.metrics();
1852
1853                        #( #subgraph_blocks )*
1854
1855                        // For non-lazy defer_tick: if any deferred buffer has data,
1856                        // signal that another tick should run.
1857                        if false #( || !#non_lazy_buf_idents.is_empty() )* {
1858                            #df.schedule_subgraph(true);
1859                        }
1860
1861                        // Double-buffer swap for defer_tick handoffs: move last tick's producer output (regular buffer)
1862                        // into the back buffer for the consumer to drain.
1863                        #( #back_edge_swap_code )*
1864                    }
1865
1866                    // End-of-tick per-operator state handling (i.e. 'tick persistence).
1867                    #( #op_tick_end_code )*
1868
1869                    #df.__end_tick();
1870
1871                    ::std::mem::take(&mut __dfir_work_done)
1872                };
1873                #root::scheduled::context::Dfir::new(
1874                    __dfir_inline_tick,
1875                    #df,
1876                    #meta_graph_arg,
1877                    #diagnostics_arg,
1878                )
1879            }
1880        })
1881    }
1882
1883    /// Color mode (pull vs. push, handoff vs. comp) for nodes. Some nodes can be push *OR* pull;
1884    /// those nodes will not be set in the returned map.
1885    pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1886        let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1887            .node_ids()
1888            .filter_map(|node_id| {
1889                let op_color = self.node_color(node_id)?;
1890                Some((node_id, op_color))
1891            })
1892            .collect();
1893
1894        // Fill in rest via subgraphs.
1895        for sg_nodes in self.subgraph_nodes.values() {
1896            let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1897
1898            for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1899                let is_pull = idx < pull_to_push_idx;
1900                node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1901            }
1902        }
1903
1904        node_color_map
1905    }
1906
1907    /// Writes this graph as mermaid into a string.
1908    pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1909        let mut output = String::new();
1910        self.write_mermaid(&mut output, write_config).unwrap();
1911        output
1912    }
1913
1914    /// Writes this graph as mermaid into the given `Write`.
1915    pub fn write_mermaid(
1916        &self,
1917        output: impl std::fmt::Write,
1918        write_config: &WriteConfig,
1919    ) -> std::fmt::Result {
1920        let mut graph_write = Mermaid::new(output);
1921        self.write_graph(&mut graph_write, write_config)
1922    }
1923
1924    /// Writes this graph as DOT (graphviz) into a string.
1925    pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1926        let mut output = String::new();
1927        let mut graph_write = Dot::new(&mut output);
1928        self.write_graph(&mut graph_write, write_config).unwrap();
1929        output
1930    }
1931
1932    /// Writes this graph as DOT (graphviz) into the given `Write`.
1933    pub fn write_dot(
1934        &self,
1935        output: impl std::fmt::Write,
1936        write_config: &WriteConfig,
1937    ) -> std::fmt::Result {
1938        let mut graph_write = Dot::new(output);
1939        self.write_graph(&mut graph_write, write_config)
1940    }
1941
1942    /// Write out this graph using the given `GraphWrite`. E.g. `Mermaid` or `Dot.
1943    pub(crate) fn write_graph<W>(
1944        &self,
1945        mut graph_write: W,
1946        write_config: &WriteConfig,
1947    ) -> Result<(), W::Err>
1948    where
1949        W: GraphWrite,
1950    {
1951        fn helper_edge_label(
1952            src_port: &PortIndexValue,
1953            dst_port: &PortIndexValue,
1954        ) -> Option<String> {
1955            let src_label = match src_port {
1956                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1957                PortIndexValue::Int(index) => Some(index.value.to_string()),
1958                _ => None,
1959            };
1960            let dst_label = match dst_port {
1961                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1962                PortIndexValue::Int(index) => Some(index.value.to_string()),
1963                _ => None,
1964            };
1965            let label = match (src_label, dst_label) {
1966                (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1967                (Some(l1), None) => Some(l1),
1968                (None, Some(l2)) => Some(l2),
1969                (None, None) => None,
1970            };
1971            label
1972        }
1973
1974        // Make node color map one time.
1975        let node_color_map = self.node_color_map();
1976
1977        // Write prologue.
1978        graph_write.write_prologue()?;
1979
1980        // Define nodes.
1981        let mut skipped_handoffs = BTreeSet::new();
1982        let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1983        for (node_id, node) in self.nodes() {
1984            if matches!(node, GraphNode::Handoff { .. }) {
1985                if write_config.no_handoffs {
1986                    skipped_handoffs.insert(node_id);
1987                    continue;
1988                } else {
1989                    let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1990                    let pred_sg = self.node_subgraph(pred_node);
1991                    let succ_node = self.node_successor_nodes(node_id).next();
1992                    let succ_sg = succ_node.and_then(|n| self.node_subgraph(n));
1993                    if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1994                        && pred_sg == succ_sg
1995                    {
1996                        subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1997                    }
1998                }
1999            }
2000            graph_write.write_node_definition(
2001                node_id,
2002                &if write_config.op_short_text {
2003                    node.to_name_string()
2004                } else if write_config.op_text_no_imports {
2005                    // Remove any lines that start with "use" (imports)
2006                    let full_text = node.to_pretty_string();
2007                    let mut output = String::new();
2008                    for sentence in full_text.split('\n') {
2009                        if sentence.trim().starts_with("use") {
2010                            continue;
2011                        }
2012                        output.push('\n');
2013                        output.push_str(sentence);
2014                    }
2015                    output.into()
2016                } else {
2017                    node.to_pretty_string()
2018                },
2019                if write_config.no_pull_push {
2020                    None
2021                } else {
2022                    node_color_map.get(node_id).copied()
2023                },
2024            )?;
2025        }
2026
2027        // Write edges.
2028        for (edge_id, (src_id, mut dst_id)) in self.edges() {
2029            // Handling for if `write_config.no_handoffs` true.
2030            if skipped_handoffs.contains(&src_id) {
2031                continue;
2032            }
2033
2034            let (src_port, mut dst_port) = self.edge_ports(edge_id);
2035            if skipped_handoffs.contains(&dst_id) {
2036                let mut handoff_succs = self.node_successors(dst_id);
2037                assert_eq!(1, handoff_succs.len());
2038                let (succ_edge, succ_node) = handoff_succs.next().unwrap();
2039                dst_id = succ_node;
2040                dst_port = self.edge_ports(succ_edge).1;
2041            }
2042
2043            let label = helper_edge_label(src_port, dst_port);
2044            let delay_type = self
2045                .node_op_inst(dst_id)
2046                .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
2047            graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
2048        }
2049
2050        // Write reference edges.
2051        if !write_config.no_references {
2052            for dst_id in self.node_ids() {
2053                for src_ref_id in self
2054                    .node_singleton_references(dst_id)
2055                    .iter()
2056                    .filter_map(|r| r.node_id)
2057                {
2058                    let delay_type = Some(DelayType::Stratum);
2059                    let label = None;
2060                    graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
2061                }
2062            }
2063        }
2064
2065        // The following code is a little bit tricky. Generally, the graph has the hierarchy:
2066        // `loop -> subgraph -> varname -> node`. However, each of these can be disabled via the `write_config`. To
2067        // handle both the enabled and disabled case, this code is structured as a series of nested loops. If the layer
2068        // is disabled, then the HashMap<Option<KEY>, Vec<VALUE>> will only have a single key (`None`) with a
2069        // corresponding `Vec` value containing everything. This way no special handling is needed for the next layer.
2070
2071        // Loop -> Subgraphs
2072        let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
2073            let loop_id = if write_config.no_loops {
2074                None
2075            } else {
2076                self.subgraph_loop(sg_id)
2077            };
2078            (loop_id, sg_id)
2079        });
2080        let loop_subgraphs = into_group_map(loop_subgraphs);
2081        for (loop_id, subgraph_ids) in loop_subgraphs {
2082            if let Some(loop_id) = loop_id {
2083                graph_write.write_loop_start(loop_id)?;
2084            }
2085
2086            // Subgraph -> Varnames.
2087            let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
2088                self.subgraph(sg_id).iter().copied().map(move |node_id| {
2089                    let opt_sg_id = if write_config.no_subgraphs {
2090                        None
2091                    } else {
2092                        Some(sg_id)
2093                    };
2094                    (opt_sg_id, (self.node_varname(node_id), node_id))
2095                })
2096            });
2097            let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
2098            for (sg_id, varnames) in subgraph_varnames_nodes {
2099                if let Some(sg_id) = sg_id {
2100                    graph_write.write_subgraph_start(sg_id)?;
2101                }
2102
2103                // Varnames -> Nodes.
2104                let varname_nodes = varnames.into_iter().map(|(varname, node)| {
2105                    let varname = if write_config.no_varnames {
2106                        None
2107                    } else {
2108                        varname
2109                    };
2110                    (varname, node)
2111                });
2112                let varname_nodes = into_group_map(varname_nodes);
2113                for (varname, node_ids) in varname_nodes {
2114                    if let Some(varname) = varname {
2115                        graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
2116                    }
2117
2118                    // Write all nodes.
2119                    for node_id in node_ids {
2120                        graph_write.write_node(node_id)?;
2121                    }
2122
2123                    if varname.is_some() {
2124                        graph_write.write_varname_end()?;
2125                    }
2126                }
2127
2128                if sg_id.is_some() {
2129                    graph_write.write_subgraph_end()?;
2130                }
2131            }
2132
2133            if loop_id.is_some() {
2134                graph_write.write_loop_end()?;
2135            }
2136        }
2137
2138        // Write epilogue.
2139        graph_write.write_epilogue()?;
2140
2141        Ok(())
2142    }
2143
2144    /// Convert back into surface syntax.
2145    pub fn surface_syntax_string(&self) -> String {
2146        let mut string = String::new();
2147        self.write_surface_syntax(&mut string).unwrap();
2148        string
2149    }
2150
2151    /// Convert back into surface syntax.
2152    pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
2153        for (key, node) in self.nodes.iter() {
2154            match node {
2155                GraphNode::Operator(op) => {
2156                    writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
2157                }
2158                GraphNode::Handoff {
2159                    kind: HandoffKind::Vec,
2160                    ..
2161                } => {
2162                    writeln!(write, "{:?} = handoff();", key.data())?;
2163                }
2164                GraphNode::Handoff {
2165                    kind: HandoffKind::Singleton,
2166                    ..
2167                } => {
2168                    writeln!(write, "{:?} = singleton();", key.data())?;
2169                }
2170                GraphNode::Handoff {
2171                    kind: HandoffKind::Optional,
2172                    ..
2173                } => {
2174                    writeln!(write, "{:?} = optional();", key.data())?;
2175                }
2176                GraphNode::ModuleBoundary { .. } => panic!(),
2177            }
2178        }
2179        writeln!(write)?;
2180        for (_e, (src_key, dst_key)) in self.graph.edges() {
2181            writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
2182        }
2183        Ok(())
2184    }
2185
2186    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
2187    pub fn mermaid_string_flat(&self) -> String {
2188        let mut string = String::new();
2189        self.write_mermaid_flat(&mut string).unwrap();
2190        string
2191    }
2192
2193    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
2194    pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
2195        writeln!(write, "flowchart TB")?;
2196        for (key, node) in self.nodes.iter() {
2197            match node {
2198                GraphNode::Operator(operator) => writeln!(
2199                    write,
2200                    "    %% {span}\n    {id:?}[\"{row_col} <tt>{code}</tt>\"]",
2201                    span = PrettySpan(node.span()),
2202                    id = key.data(),
2203                    row_col = PrettyRowCol(node.span()),
2204                    code = operator
2205                        .to_token_stream()
2206                        .to_string()
2207                        .replace('&', "&amp;")
2208                        .replace('<', "&lt;")
2209                        .replace('>', "&gt;")
2210                        .replace('"', "&quot;")
2211                        .replace('\n', "<br>"),
2212                ),
2213                GraphNode::Handoff {
2214                    kind: HandoffKind::Vec,
2215                    ..
2216                } => {
2217                    writeln!(write, r#"    {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
2218                }
2219                GraphNode::Handoff {
2220                    kind: HandoffKind::Singleton | HandoffKind::Optional,
2221                    ..
2222                } => {
2223                    writeln!(
2224                        write,
2225                        r#"    {:?}{{"{}"}}"#,
2226                        key.data(),
2227                        SINGLETON_SLOT_NODE_STR
2228                    )
2229                }
2230                GraphNode::ModuleBoundary { .. } => {
2231                    writeln!(
2232                        write,
2233                        r#"    {:?}{{"{}"}}"#,
2234                        key.data(),
2235                        MODULE_BOUNDARY_NODE_STR
2236                    )
2237                }
2238            }?;
2239        }
2240        writeln!(write)?;
2241        for (_e, (src_key, dst_key)) in self.graph.edges() {
2242            writeln!(write, "    {:?}-->{:?}", src_key.data(), dst_key.data())?;
2243        }
2244        Ok(())
2245    }
2246}
2247
2248/// Loops
2249impl DfirGraph {
2250    /// Iterator over all loop IDs.
2251    pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
2252        self.loop_nodes.keys()
2253    }
2254
2255    /// Iterator over all loops, ID and members: `(GraphLoopId, Vec<GraphNodeId>)`.
2256    pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
2257        self.loop_nodes.iter()
2258    }
2259
2260    /// Create a new loop context, with the given parent loop (or `None`).
2261    pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
2262        let loop_id = self.loop_nodes.insert(Vec::new());
2263        self.loop_children.insert(loop_id, Vec::new());
2264        if let Some(parent_loop) = parent_loop {
2265            self.loop_parent.insert(loop_id, parent_loop);
2266            self.loop_children
2267                .get_mut(parent_loop)
2268                .unwrap()
2269                .push(loop_id);
2270        } else {
2271            self.root_loops.push(loop_id);
2272        }
2273        loop_id
2274    }
2275
2276    /// Get a node's loop context (or `None` for root).
2277    pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
2278        self.node_loops.get(node_id).copied()
2279    }
2280
2281    /// Get a subgraph's loop context (or `None` for root).
2282    pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
2283        let &node_id = self.subgraph(subgraph_id).first().unwrap();
2284        let out = self.node_loop(node_id);
2285        debug_assert!(
2286            self.subgraph(subgraph_id)
2287                .iter()
2288                .all(|&node_id| self.node_loop(node_id) == out),
2289            "Subgraph nodes should all have the same loop context."
2290        );
2291        out
2292    }
2293
2294    /// Get a loop context's parent loop context (or `None` for root).
2295    pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
2296        self.loop_parent.get(loop_id).copied()
2297    }
2298
2299    /// Get a loop context's child loops.
2300    pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
2301        self.loop_children.get(loop_id).unwrap()
2302    }
2303}
2304
2305/// Configuration for writing graphs.
2306#[derive(Clone, Debug, Default)]
2307#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
2308pub struct WriteConfig {
2309    /// Subgraphs will not be rendered if set.
2310    #[cfg_attr(feature = "clap-derive", arg(long))]
2311    pub no_subgraphs: bool,
2312    /// Variable names will not be rendered if set.
2313    #[cfg_attr(feature = "clap-derive", arg(long))]
2314    pub no_varnames: bool,
2315    /// Will not render pull/push shapes if set.
2316    #[cfg_attr(feature = "clap-derive", arg(long))]
2317    pub no_pull_push: bool,
2318    /// Will not render handoffs if set.
2319    #[cfg_attr(feature = "clap-derive", arg(long))]
2320    pub no_handoffs: bool,
2321    /// Will not render singleton references if set.
2322    #[cfg_attr(feature = "clap-derive", arg(long))]
2323    pub no_references: bool,
2324    /// Will not render loops if set.
2325    #[cfg_attr(feature = "clap-derive", arg(long))]
2326    pub no_loops: bool,
2327
2328    /// Op text will only be their name instead of the whole source.
2329    #[cfg_attr(feature = "clap-derive", arg(long))]
2330    pub op_short_text: bool,
2331    /// Op text will exclude any line that starts with "use".
2332    #[cfg_attr(feature = "clap-derive", arg(long))]
2333    pub op_text_no_imports: bool,
2334}
2335
2336/// Enum for choosing between mermaid and dot graph writing.
2337#[derive(Copy, Clone, Debug)]
2338#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
2339pub enum WriteGraphType {
2340    /// Mermaid graphs.
2341    Mermaid,
2342    /// Dot (Graphviz) graphs.
2343    Dot,
2344}
2345
2346/// [`itertools::Itertools::into_group_map`], but for `BTreeMap`.
2347fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
2348where
2349    K: Ord,
2350{
2351    let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
2352    for (k, v) in iter {
2353        out.entry(k).or_default().push(v);
2354    }
2355    out
2356}