Skip to main content

dfir_lang/graph/
flat_to_partitioned.rs

1//! Subgraph partioning algorithm
2
3use std::collections::{BTreeMap, BTreeSet};
4
5use itertools::Itertools;
6use proc_macro2::Span;
7use slotmap::{SecondaryMap, SparseSecondaryMap};
8
9use super::meta_graph::DfirGraph;
10use super::ops::{DelayType, FloType};
11use super::{
12    Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, HandoffKind, graph_algorithms,
13};
14use crate::diagnostic::{Diagnostic, Level};
15use crate::union_find::UnionFind;
16
17/// Helper struct for tracking barrier crossers, see [`find_barrier_crossers`].
18struct BarrierCrossers {
19    /// Edge barrier crossers, including what type.
20    pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
21    /// Singleton reference barrier crossers, considered to be [`DelayType::Stratum`].
22    pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
23}
24impl BarrierCrossers {
25    /// Iterate pairs of nodes that are across a barrier. Excludes `DelayType::NextIteration` pairs.
26    fn iter_node_pairs<'a>(
27        &'a self,
28        partitioned_graph: &'a DfirGraph,
29    ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
30        let edge_pairs_iter = self
31            .edge_barrier_crossers
32            .iter()
33            .map(|(edge_id, &delay_type)| {
34                let src_dst = partitioned_graph.edge(edge_id);
35                (src_dst, delay_type)
36            });
37        let singleton_pairs_iter = self
38            .singleton_barrier_crossers
39            .iter()
40            .map(|&src_dst| (src_dst, DelayType::Stratum));
41        edge_pairs_iter.chain(singleton_pairs_iter)
42    }
43
44    /// Insert/replace edge.
45    fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
46        if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
47            self.edge_barrier_crossers.insert(new_edge_id, delay_type);
48        }
49    }
50}
51
52/// Find all the barrier crossers.
53fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
54    let edge_barrier_crossers = partitioned_graph
55        .edges()
56        .filter(|&(_, (_src, dst))| {
57            // Ignore barriers within `loop {` blocks.
58            partitioned_graph.node_loop(dst).is_none()
59        })
60        .filter_map(|(edge_id, (_src, dst))| {
61            let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
62            let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
63            let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
64            Some((edge_id, input_barrier))
65        })
66        .collect();
67
68    // Basic singleton barriers: producer → consumer.
69    let mut singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)> = partitioned_graph
70        .node_ids()
71        .flat_map(|dst| {
72            partitioned_graph
73                .node_singleton_references(dst)
74                .iter()
75                .filter_map(|r| r.node_id)
76                .map(move |src_ref| (src_ref, dst))
77        })
78        .collect();
79
80    // Access group ordering barriers: for the same singleton target, operators in
81    // lower access groups must run before operators in higher access groups.
82    // Also: ungrouped mutable refs must run after ungrouped shared refs.
83    let refs_by_target = partitioned_graph.node_singleton_reference_groups();
84    // For each singleton target...
85    for (_singleton, groups) in refs_by_target {
86        // For sequential access groups...
87        for (group_a, group_b) in groups.values().tuple_windows() {
88            // Add ordering barriers so every node in the lower group must run before every node in the higher group.
89            for &(node_a, _, _) in group_a {
90                for &(node_b, _, _) in group_b {
91                    // TODO(mingwei): handle with diagnostics.
92                    assert_ne!(
93                        node_a, node_b,
94                        "encounted conflicted or cyclical singleton references\n{:?}\n{:?}",
95                        group_a, group_b,
96                    );
97                    singleton_barrier_crossers.push((node_a, node_b));
98                }
99            }
100        }
101    }
102
103    BarrierCrossers {
104        edge_barrier_crossers,
105        singleton_barrier_crossers,
106    }
107}
108
109fn find_subgraph_unionfind(
110    partitioned_graph: &DfirGraph,
111    barrier_crossers: &BarrierCrossers,
112) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
113    // Modality (color) of nodes, push or pull.
114    // TODO(mingwei)? This does NOT consider `DelayType` barriers (which generally imply `Pull`),
115    // which makes it inconsistant with the final output in `as_code()`. But this doesn't create
116    // any bugs since we exclude `DelayType` edges from joining subgraphs anyway.
117    let mut node_color = partitioned_graph
118        .node_ids()
119        .filter_map(|node_id| {
120            let op_color = partitioned_graph.node_color(node_id)?;
121            Some((node_id, op_color))
122        })
123        .collect::<SparseSecondaryMap<_, _>>();
124
125    let mut subgraph_unionfind: UnionFind<GraphNodeId> =
126        UnionFind::with_capacity(partitioned_graph.nodes().len());
127
128    // Will contain all edges which are handoffs. Starts out with all edges and
129    // we remove from this set as we combine nodes into subgraphs.
130    let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
131    // Would sort edges here for priority (for now, no sort/priority).
132
133    // Each edge gets looked at in order. However we may not know if a linear
134    // chain of operators is PUSH vs PULL until we look at the ends. A fancier
135    // algorithm would know to handle linear chains from the outside inward.
136    // But instead we just run through the edges in a loop until no more
137    // progress is made. Could have some sort of O(N^2) pathological worst
138    // case.
139    let mut progress = true;
140    while progress {
141        progress = false;
142        // TODO(mingwei): Could this iterate `handoff_edges` instead? (Modulo ownership). Then no case (1) below.
143        for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
144            // Ignore (1) already added edges as well as (2) new self-cycles. (Unless reference edge).
145            if subgraph_unionfind.same_set(src, dst) {
146                // Note that the _edge_ `edge_id` might not be in the subgraph even when both `src` and `dst` are. This prevents case 2.
147                // Handoffs will be inserted later for this self-loop.
148                continue;
149            }
150
151            // Do not connect stratum crossers (next edges).
152            if barrier_crossers
153                .iter_node_pairs(partitioned_graph)
154                .any(|((x_src, x_dst), _)| {
155                    (subgraph_unionfind.same_set(x_src, src)
156                        && subgraph_unionfind.same_set(x_dst, dst))
157                        || (subgraph_unionfind.same_set(x_src, dst)
158                            && subgraph_unionfind.same_set(x_dst, src))
159                })
160            {
161                continue;
162            }
163
164            // Do not connect across loop contexts.
165            if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
166                continue;
167            }
168            // Do not connect `next_iteration()`.
169            if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
170                Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
171            }) {
172                continue;
173            }
174
175            if can_connect_colorize(&mut node_color, src, dst) {
176                // At this point we have selected this edge and its src & dst to be
177                // within a single subgraph.
178                subgraph_unionfind.union(src, dst);
179                assert!(handoff_edges.remove(&edge_id));
180                progress = true;
181            }
182        }
183    }
184
185    (subgraph_unionfind, handoff_edges)
186}
187
188/// Builds the datastructures for checking which subgraph each node belongs to
189/// after handoffs have already been inserted to partition subgraphs.
190/// This list of nodes in each subgraph are returned in topological sort order.
191fn make_subgraph_collect(
192    partitioned_graph: &DfirGraph,
193    mut subgraph_unionfind: UnionFind<GraphNodeId>,
194) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
195    // We want the nodes of each subgraph to be listed in topo-sort order.
196    // We could do this on each subgraph, or we could do it all at once on the
197    // whole node graph by ignoring handoffs, which is what we do here:
198    let topo_sort = graph_algorithms::topo_sort(
199        partitioned_graph
200            .nodes()
201            .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
202            .map(|(node_id, _)| node_id),
203        |v| {
204            partitioned_graph
205                .node_predecessor_nodes(v)
206                .filter(|&pred_id| {
207                    let pred = partitioned_graph.node(pred_id);
208                    !matches!(pred, GraphNode::Handoff { .. })
209                })
210        },
211    )
212    .expect("Subgraphs are in-out trees.");
213
214    let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
215    for node_id in topo_sort {
216        let repr_node = subgraph_unionfind.find(node_id);
217        if !grouped_nodes.contains_key(repr_node) {
218            grouped_nodes.insert(repr_node, Default::default());
219        }
220        grouped_nodes[repr_node].push(node_id);
221    }
222    grouped_nodes
223}
224
225/// Find subgraph and insert handoffs.
226/// Modifies barrier_crossers so that the edge OUT of an inserted handoff has
227/// the DelayType data.
228fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
229    // Algorithm:
230    // 1. Each node begins as its own subgraph.
231    // 2. Collect edges. (Future optimization: sort so edges which should not be split across a handoff come first).
232    // 3. For each edge, try to join `(to, from)` into the same subgraph.
233
234    // TODO(mingwei):
235    // self.partitioned_graph.assert_valid();
236
237    let (subgraph_unionfind, handoff_edges) =
238        find_subgraph_unionfind(partitioned_graph, barrier_crossers);
239
240    // Insert handoffs between subgraphs (or on subgraph self-loop edges)
241    for edge_id in handoff_edges {
242        let (src_id, dst_id) = partitioned_graph.edge(edge_id);
243
244        // Already has a handoff, no need to insert one.
245        let src_node = partitioned_graph.node(src_id);
246        let dst_node = partitioned_graph.node(dst_id);
247        if matches!(src_node, GraphNode::Handoff { .. })
248            || matches!(dst_node, GraphNode::Handoff { .. })
249        {
250            continue;
251        }
252
253        let hoff = GraphNode::Handoff {
254            kind: HandoffKind::Vec,
255            src_span: src_node.span(),
256            dst_span: dst_node.span(),
257        };
258        let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
259
260        // Update barrier_crossers for inserted node.
261        barrier_crossers.replace_edge(edge_id, out_edge_id);
262    }
263
264    // Determine node's subgraph and subgraph's nodes.
265    // This list of nodes in each subgraph are to be in topological sort order.
266    // Eventually returned directly in the [`DfirGraph`].
267    let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
268    for (_repr_node, member_nodes) in grouped_nodes {
269        partitioned_graph.insert_subgraph(member_nodes).unwrap();
270    }
271}
272
273/// Set `src` or `dst` color if `None` based on the other (if possible):
274/// `None` indicates an op could be pull or push i.e. unary-in & unary-out.
275/// So in that case we color `src` or `dst` based on its newfound neighbor (the other one).
276///
277/// Returns if `src` and `dst` can be in the same subgraph.
278fn can_connect_colorize(
279    node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
280    src: GraphNodeId,
281    dst: GraphNodeId,
282) -> bool {
283    // Pull -> Pull
284    // Push -> Push
285    // Pull -> [Computation] -> Push
286    // Push -> [Handoff] -> Pull
287    let can_connect = match (node_color.get(src), node_color.get(dst)) {
288        // Linear chain, can't connect because it may cause future conflicts.
289        // But if it doesn't in the _future_ we can connect it (once either/both ends are determined).
290        (None, None) => false,
291
292        // Infer left side.
293        (None, Some(Color::Pull | Color::Comp)) => {
294            node_color.insert(src, Color::Pull);
295            true
296        }
297        (None, Some(Color::Push | Color::Hoff)) => {
298            node_color.insert(src, Color::Push);
299            true
300        }
301
302        // Infer right side.
303        (Some(Color::Pull | Color::Hoff), None) => {
304            node_color.insert(dst, Color::Pull);
305            true
306        }
307        (Some(Color::Comp | Color::Push), None) => {
308            node_color.insert(dst, Color::Push);
309            true
310        }
311
312        // Both sides already specified.
313        (Some(Color::Pull), Some(Color::Pull)) => true,
314        (Some(Color::Pull), Some(Color::Comp)) => true,
315        (Some(Color::Pull), Some(Color::Push)) => true,
316
317        (Some(Color::Comp), Some(Color::Pull)) => false,
318        (Some(Color::Comp), Some(Color::Comp)) => false,
319        (Some(Color::Comp), Some(Color::Push)) => true,
320
321        (Some(Color::Push), Some(Color::Pull)) => false,
322        (Some(Color::Push), Some(Color::Comp)) => false,
323        (Some(Color::Push), Some(Color::Push)) => true,
324
325        // Handoffs are not part of subgraphs.
326        (Some(Color::Hoff), Some(_)) => false,
327        (Some(_), Some(Color::Hoff)) => false,
328    };
329    can_connect
330}
331
332/// Topologically sorts subgraphs and marks tick-boundary (`defer_tick` / `defer_tick_lazy`)
333/// handoffs with their delay type for double-buffered codegen in `as_code`.
334///
335/// Returns an error if there is an intra-tick cycle (i.e. the subgraph DAG has a cycle when
336/// tick-boundary edges are excluded).
337fn order_subgraphs(
338    partitioned_graph: &mut DfirGraph,
339    barrier_crossers: &BarrierCrossers,
340) -> Result<(), Diagnostic> {
341    // Build a subgraph-level directed graph, excluding tick-boundary edges.
342    let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
343
344    // Track which handoff edges are tick-boundary, keyed by (src_sg, dst_sg).
345    let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
346
347    // Iterate handoffs between subgraphs.
348    for (hoff_id, hoff) in partitioned_graph.nodes() {
349        if !matches!(hoff, GraphNode::Handoff { .. }) {
350            continue;
351        }
352
353        // Handoffs may have 0 successors if only used by reference. Skip ordering those.
354        if partitioned_graph.node_degree_out(hoff_id) == 0 {
355            continue;
356        }
357        assert_eq!(1, partitioned_graph.node_degree_out(hoff_id));
358
359        let (succ_edge, succ) = partitioned_graph.node_successors(hoff_id).next().unwrap();
360
361        let succ_edge_delaytype = barrier_crossers
362            .edge_barrier_crossers
363            .get(succ_edge)
364            .copied();
365        // Tick edges are excluded from the topo sort — they are cross-tick by design.
366        if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
367            tick_edges.push((succ_edge, delay_type));
368            continue;
369        }
370
371        assert_eq!(1, partitioned_graph.node_degree_in(hoff_id));
372        let (_edge_id, pred) = partitioned_graph.node_predecessors(hoff_id).next().unwrap();
373
374        let pred_sg = partitioned_graph
375            .node_subgraph(pred)
376            .expect("Handoff pred not in subgraph, may be a doubled/adjacent handoff");
377        let succ_sg = partitioned_graph
378            .node_subgraph(succ)
379            .expect("Handoff succ not in subgraph, may be a doubled/adjacent handoff");
380
381        sg_preds.entry(succ_sg).or_default().push(pred_sg);
382    }
383    // Include singleton reference edges.
384    for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
385        assert_ne!(pred, succ);
386        // For handoff nodes (which have no subgraph), use the predecessor's subgraph.
387        let pred_sg = if let Some(sg) = partitioned_graph.node_subgraph(pred) {
388            sg
389        } else {
390            // pred is a handoff node — find its predecessor operator's subgraph.
391            let (_edge, pred_pred) = partitioned_graph
392                .node_predecessors(pred)
393                .next()
394                .expect("handoff must have a predecessor");
395            partitioned_graph.node_subgraph(pred_pred).unwrap()
396        };
397        let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
398        if pred_sg == succ_sg {
399            continue;
400        }
401        sg_preds.entry(succ_sg).or_default().push(pred_sg);
402
403        // For handoff nodes: borrower must run before pipe consumer.
404        // All handoffs should have at most one successor.
405        if matches!(partitioned_graph.node(pred), GraphNode::Handoff { .. }) {
406            assert!(
407                partitioned_graph.node_degree_out(pred) <= 1,
408                "handoff should have at most one successor"
409            );
410            if let Some((_edge, consumer)) = partitioned_graph.node_successors(pred).next() {
411                let consumer_sg = partitioned_graph.node_subgraph(consumer).unwrap();
412                if consumer_sg != succ_sg {
413                    sg_preds.entry(consumer_sg).or_default().push(succ_sg);
414                }
415            }
416        }
417    }
418
419    // Topological sort — rejects intra-tick cycles.
420    if let Err(cycle) = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
421        sg_preds.get(&v).into_iter().flatten().copied()
422    }) {
423        let span = cycle
424            .first()
425            .and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
426            .map(|n| partitioned_graph.node(n).span())
427            .unwrap_or_else(Span::call_site);
428        return Err(Diagnostic::spanned(
429            span,
430            Level::Error,
431            "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.",
432        ));
433    }
434
435    // Mark tick-boundary handoffs with their delay type.
436    // These handoffs are excluded from the intra-tick topo ordering in
437    // `as_code`; instead, their double-buffered handoff semantics defer data
438    // across the tick boundary to the next tick.
439    for (edge_id, delay_type) in tick_edges {
440        let (hoff, _dst) = partitioned_graph.edge(edge_id);
441        assert!(matches!(
442            partitioned_graph.node(hoff),
443            GraphNode::Handoff {
444                kind: HandoffKind::Vec,
445                ..
446            }
447        ));
448        partitioned_graph.set_handoff_delay_type(hoff, delay_type);
449    }
450    Ok(())
451}
452
453/// Main method for this module. Partitions a flat [`DfirGraph`] into one with subgraphs.
454///
455/// Returns an error if an intra-tick cycle exists in the graph.
456pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
457    // Pre-find barrier crossers (input edges with a `DelayType`).
458    let mut barrier_crossers = find_barrier_crossers(&flat_graph);
459    let mut partitioned_graph = flat_graph;
460
461    // Partition into subgraphs.
462    make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
463
464    // Topologically order subgraphs and mark tick-boundary handoffs for double-buffering.
465    order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
466
467    Ok(partitioned_graph)
468}