Skip to main content

dfir_lang/graph/
mod.rs

1//! Graph representation stages for DFIR graphs.
2
3use std::borrow::Cow;
4use std::hash::Hash;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use syn::punctuated::Punctuated;
10use syn::spanned::Spanned;
11use syn::{Expr, ExprPath, GenericArgument, Token, Type};
12
13use self::ops::{OperatorConstraints, Persistence};
14use crate::diagnostic::{Diagnostic, Diagnostics, Level};
15use crate::parse::{DfirCode, IndexInt, Operator, PortIndex, Ported, SingletonRef};
16use crate::pretty_span::PrettySpan;
17
18mod di_mul_graph;
19mod eliminate_extra_unions_tees;
20mod flat_graph_builder;
21mod flat_to_partitioned;
22mod graph_write;
23mod meta_graph;
24mod meta_graph_debugging;
25
26use std::fmt::Display;
27
28pub use di_mul_graph::DiMulGraph;
29pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
30pub use flat_graph_builder::{FlatGraphBuilder, FlatGraphBuilderOutput};
31pub use flat_to_partitioned::partition_graph;
32pub use meta_graph::{DfirGraph, WriteConfig, WriteGraphType};
33
34pub use crate::graph_ids::{GraphEdgeId, GraphLoopId, GraphNodeId, GraphSubgraphId};
35
36pub mod graph_algorithms;
37pub mod ops;
38
39impl GraphSubgraphId {
40    /// Generate a deterministic `Ident` for the given subgraph ID.
41    pub fn as_ident(self, span: Span) -> Ident {
42        use slotmap::Key;
43        Ident::new(&format!("sgid_{:?}", self.data()), span)
44    }
45}
46
47impl GraphLoopId {
48    /// Generate a deterministic `Ident` for the given loop ID.
49    pub fn as_ident(self, span: Span) -> Ident {
50        use slotmap::Key;
51        Ident::new(&format!("loop_{:?}", self.data()), span)
52    }
53}
54
55/// Context identifier as a string.
56const CONTEXT: &str = "context";
57/// Runnable DFIR graph object identifier as a string.
58const GRAPH: &str = "df";
59
60const HANDOFF_NODE_STR: &str = "handoff";
61const SINGLETON_SLOT_NODE_STR: &str = "singleton";
62const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
63
64mod serde_syn {
65    use serde::{Deserialize, Deserializer, Serializer};
66
67    pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
68    where
69        S: Serializer,
70        T: quote::ToTokens,
71    {
72        serializer.serialize_str(&value.to_token_stream().to_string())
73    }
74
75    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
76    where
77        D: Deserializer<'de>,
78        T: syn::parse::Parse,
79    {
80        let s = String::deserialize(deserializer)?;
81        syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
82    }
83}
84
85/// A variable name assigned to a pipeline in DFIR syntax.
86///
87/// Fundamentally a serializable/deserializable wrapper around [`syn::Ident`].
88#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq, Hash)]
89pub struct Varname(#[serde(with = "serde_syn")] pub Ident);
90
91/// The kind of inter-subgraph handoff.
92#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
93pub enum HandoffKind {
94    /// A `Vec<T>` buffer for streams (zero or more items).
95    Vec,
96    /// An `Option<T>` slot for singletons (exactly one item expected).
97    /// `#varname` gives `&T` (panics if empty).
98    Singleton,
99    /// An `Option<T>` slot for optionals (zero or one item).
100    /// `#varname` gives `&Option<T>`.
101    Optional,
102}
103
104/// A node, corresponding to an operator or a handoff.
105#[derive(Clone, Serialize, Deserialize)]
106pub enum GraphNode {
107    /// An operator.
108    Operator(#[serde(with = "serde_syn")] Operator),
109    /// An inter-subgraph handoff point for buffering data between subgraphs.
110    Handoff {
111        /// What kind of storage this handoff uses.
112        kind: HandoffKind,
113        /// The span of the input into the handoff.
114        #[serde(skip, default = "Span::call_site")]
115        src_span: Span,
116        /// The span of the output out of the handoff.
117        #[serde(skip, default = "Span::call_site")]
118        dst_span: Span,
119    },
120
121    /// Module Boundary, used for importing modules. Only exists prior to partitioning.
122    ModuleBoundary {
123        /// If this module is an input or output boundary.
124        input: bool,
125
126        /// The span of the import!() expression that imported this module.
127        /// The value of this span when the ModuleBoundary node is still inside the module is Span::call_site()
128        /// TODO: This could one day reference into the module file itself?
129        #[serde(skip, default = "Span::call_site")]
130        import_expr: Span,
131    },
132}
133impl GraphNode {
134    /// Return the node as a human-readable string.
135    pub fn to_pretty_string(&self) -> Cow<'static, str> {
136        match self {
137            GraphNode::Operator(op) => op.to_pretty_string().into(),
138            GraphNode::Handoff {
139                kind: HandoffKind::Vec,
140                ..
141            } => HANDOFF_NODE_STR.into(),
142            GraphNode::Handoff {
143                kind: HandoffKind::Singleton | HandoffKind::Optional,
144                ..
145            } => SINGLETON_SLOT_NODE_STR.into(),
146            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
147        }
148    }
149
150    /// Return the name of the node as a string, excluding parenthesis and op source code.
151    pub fn to_name_string(&self) -> Cow<'static, str> {
152        match self {
153            GraphNode::Operator(op) => op.name_string().into(),
154            GraphNode::Handoff {
155                kind: HandoffKind::Vec,
156                ..
157            } => HANDOFF_NODE_STR.into(),
158            GraphNode::Handoff {
159                kind: HandoffKind::Singleton | HandoffKind::Optional,
160                ..
161            } => SINGLETON_SLOT_NODE_STR.into(),
162            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
163        }
164    }
165
166    /// Return the source code span of the node.
167    pub fn span(&self) -> Span {
168        match self {
169            Self::Operator(op) => op.span(),
170            &Self::Handoff {
171                src_span, dst_span, ..
172            } => src_span.join(dst_span).unwrap_or(src_span),
173            Self::ModuleBoundary { import_expr, .. } => *import_expr,
174        }
175    }
176}
177impl std::fmt::Debug for GraphNode {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        match self {
180            Self::Operator(operator) => {
181                write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
182            }
183            Self::Handoff { kind, .. } => write!(f, "Node::Handoff({kind:?})"),
184            Self::ModuleBoundary { input, .. } => {
185                write!(f, "Node::ModuleBoundary{{input: {}}}", input)
186            }
187        }
188    }
189}
190
191/// Meta-data relating to operators which may be useful throughout the compilation process.
192///
193/// This data can be generated from the graph, but it is useful to have it readily available
194/// pre-computed as many algorithms use the same info. Stuff like port names, arguments, and the
195/// [`OperatorConstraints`] for the operator.
196///
197/// Because it is derived from the graph itself, there can be "cache invalidation"-esque issues
198/// if this data is not kept in sync with the graph.
199#[derive(Clone, Debug)]
200pub struct OperatorInstance {
201    /// Name of the operator (will match [`OperatorConstraints::name`]).
202    pub op_constraints: &'static OperatorConstraints,
203    /// Port values used as this operator's input.
204    pub input_ports: Vec<PortIndexValue>,
205    /// Port values used as this operator's output.
206    pub output_ports: Vec<PortIndexValue>,
207    /// Singleton references within the operator arguments.
208    pub singletons_referenced: Vec<SingletonRef>,
209
210    /// Generic arguments.
211    pub generics: OpInstGenerics,
212    /// Arguments provided by the user into the operator as arguments.
213    /// I.e. the `a, b, c` in `-> my_op(a, b, c) -> `.
214    ///
215    /// These arguments do not include singleton postprocessing codegen. Instead use
216    /// [`ops::WriteContextArgs::arguments`].
217    pub arguments_pre: Punctuated<Expr, Token![,]>,
218    /// Unparsed arguments, for singleton parsing.
219    pub arguments_raw: TokenStream,
220}
221
222/// Operator generic arguments, split into specific categories.
223#[derive(Clone, Debug)]
224pub struct OpInstGenerics {
225    /// Operator generic (type or lifetime) arguments.
226    pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
227    /// Lifetime persistence arguments. Corresponds to a prefix of [`Self::generic_args`].
228    pub persistence_args: Vec<Persistence>,
229    /// Type persistence arguments. Corersponds to a (suffix) of [`Self::generic_args`].
230    pub type_args: Vec<Type>,
231}
232
233impl OpInstGenerics {
234    /// Helper to join a sequence of spans into a single span, if possible.
235    ///
236    /// Returns `None` if there are no spans or if any `Span::join` call fails
237    /// (for example, when spans are not contiguous).
238    fn join_spans<I>(mut spans: I) -> Option<Span>
239    where
240        I: Iterator<Item = Span>,
241    {
242        let mut span = spans.next()?;
243        for s in spans {
244            span = span.join(s)?;
245        }
246        Some(span)
247    }
248
249    /// Returns a [`Span`] containing all persistence (lifetime) args if possible.
250    pub fn persistence_args_span(&self) -> Option<Span> {
251        self.generic_args.as_ref().and_then(|args| {
252            Self::join_spans(
253                args.iter()
254                    .filter(|a| matches!(a, GenericArgument::Lifetime(_)))
255                    .map(|a| a.span()),
256            )
257        })
258    }
259
260    /// Returns a [`Span`] containing all type args if possible.
261    pub fn type_args_span(&self) -> Option<Span> {
262        self.generic_args.as_ref().and_then(|args| {
263            Self::join_spans(
264                args.iter()
265                    .filter(|a| matches!(a, GenericArgument::Type(_)))
266                    .map(|a| a.span()),
267            )
268        })
269    }
270}
271
272/// Gets the generic arguments for the operator.
273///
274/// This helper method is useful due to the special handling of persistence lifetimes (`'static`,
275/// `'tick`, `'mutable`) which must come before other generic parameters.
276pub fn get_operator_generics(diagnostics: &mut Diagnostics, operator: &Operator) -> OpInstGenerics {
277    // Generic arguments.
278    let generic_args = operator.type_arguments().cloned();
279    let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
280            GenericArgument::Lifetime(lifetime) => {
281                match &*lifetime.ident.to_string() {
282                    "none" => Some(Persistence::None),
283                    "loop" => Some(Persistence::Loop),
284                    "tick" => Some(Persistence::Tick),
285                    "static" => Some(Persistence::Static),
286                    "mutable" => Some(Persistence::Mutable),
287                    _ => {
288                        diagnostics.push(Diagnostic::spanned(
289                            generic_arg.span(),
290                            Level::Error,
291                            format!("Unknown lifetime generic argument `'{}`, expected `'none`, `'loop`, `'tick`, `'static`, or `'mutable`.", lifetime.ident),
292                        ));
293                        // TODO(mingwei): should really keep going and not short circuit?
294                        None
295                    }
296                }
297            },
298            _ => None,
299        }).collect::<Vec<_>>();
300    let type_args = generic_args
301        .iter()
302        .flatten()
303        .skip(persistence_args.len())
304        .map_while(|generic_arg| match generic_arg {
305            GenericArgument::Type(typ) => Some(typ),
306            _ => None,
307        })
308        .cloned()
309        .collect::<Vec<_>>();
310
311    OpInstGenerics {
312        generic_args,
313        persistence_args,
314        type_args,
315    }
316}
317
318/// Push, Pull, Comp, or Hoff polarity.
319#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
320pub enum Color {
321    /// Pull (green)
322    Pull,
323    /// Push (blue)
324    Push,
325    /// Computation (yellow)
326    Comp,
327    /// Handoff (grey) -- not a color for operators, inserted between subgraphs.
328    Hoff,
329}
330
331/// Helper struct for [`PortIndex`] which keeps span information for elided ports.
332#[derive(Clone, Debug, Serialize, Deserialize)]
333pub enum PortIndexValue {
334    /// An integer value: `[0]`, `[1]`, etc. Can be negative although we don't use that (2023-08-16).
335    Int(#[serde(with = "serde_syn")] IndexInt),
336    /// A name or path. `[pos]`, `[neg]`, etc. Can use `::` separators but we don't use that (2023-08-16).
337    Path(#[serde(with = "serde_syn")] ExprPath),
338    /// Elided, unspecified port. We have this variant, rather than wrapping in `Option`, in order
339    /// to preserve the `Span` information.
340    Elided(#[serde(skip)] Option<Span>),
341}
342impl PortIndexValue {
343    /// For a [`Ported`] value like `[port_in]name[port_out]`, get the `port_in` and `port_out` as
344    /// [`PortIndexValue`]s.
345    pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
346    where
347        Inner: Spanned,
348    {
349        let ported_span = Some(ported.inner.span());
350        let port_inn = ported
351            .inn
352            .map(|idx| idx.index.into())
353            .unwrap_or_else(|| Self::Elided(ported_span));
354        let inner = ported.inner;
355        let port_out = ported
356            .out
357            .map(|idx| idx.index.into())
358            .unwrap_or_else(|| Self::Elided(ported_span));
359        (port_inn, inner, port_out)
360    }
361
362    /// Returns `true` if `self` is not [`PortIndexValue::Elided`].
363    pub fn is_specified(&self) -> bool {
364        !matches!(self, Self::Elided(_))
365    }
366
367    /// Returns whichever of the two ports are specified.
368    /// If both are [`Self::Elided`], returns [`Self::Elided`].
369    /// If both are specified, returns `Err(self)`.
370    #[allow(clippy::allow_attributes, reason = "Only triggered on nightly.")]
371    #[allow(
372        clippy::result_large_err,
373        reason = "variants are same size, error isn't to be propagated."
374    )]
375    pub fn combine(self, other: Self) -> Result<Self, Self> {
376        match (self.is_specified(), other.is_specified()) {
377            (false, _other) => Ok(other),
378            (true, false) => Ok(self),
379            (true, true) => Err(self),
380        }
381    }
382
383    /// Formats self as a human-readable string for error messages.
384    pub fn as_error_message_string(&self) -> String {
385        match self {
386            PortIndexValue::Int(n) => format!("`{}`", n.value),
387            PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
388            PortIndexValue::Elided(_) => "<elided>".to_owned(),
389        }
390    }
391
392    /// Returns the span of this port value.
393    pub fn span(&self) -> Span {
394        match self {
395            PortIndexValue::Int(x) => x.span(),
396            PortIndexValue::Path(x) => x.span(),
397            PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
398        }
399    }
400}
401impl From<PortIndex> for PortIndexValue {
402    fn from(value: PortIndex) -> Self {
403        match value {
404            PortIndex::Int(x) => Self::Int(x),
405            PortIndex::Path(x) => Self::Path(x),
406        }
407    }
408}
409impl PartialEq for PortIndexValue {
410    fn eq(&self, other: &Self) -> bool {
411        match (self, other) {
412            (Self::Int(l0), Self::Int(r0)) => l0 == r0,
413            (Self::Path(l0), Self::Path(r0)) => l0 == r0,
414            (Self::Elided(_), Self::Elided(_)) => true,
415            _else => false,
416        }
417    }
418}
419impl Eq for PortIndexValue {}
420impl PartialOrd for PortIndexValue {
421    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
422        Some(self.cmp(other))
423    }
424}
425impl Ord for PortIndexValue {
426    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
427        match (self, other) {
428            (Self::Int(s), Self::Int(o)) => s.cmp(o),
429            (Self::Path(s), Self::Path(o)) => s
430                .to_token_stream()
431                .to_string()
432                .cmp(&o.to_token_stream().to_string()),
433            (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
434            (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
435            (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
436            (_, Self::Elided(_)) => std::cmp::Ordering::Less,
437            (Self::Elided(_), _) => std::cmp::Ordering::Greater,
438        }
439    }
440}
441
442impl Display for PortIndexValue {
443    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444        match self {
445            PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
446            PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
447            PortIndexValue::Elided(_) => write!(f, "[]"),
448        }
449    }
450}
451
452/// Output of [`build_dfir_code`].
453pub struct BuildDfirCodeOutput {
454    /// The now-partitioned graph.
455    pub partitioned_graph: DfirGraph,
456    /// The Rust source code tokens for the DFIR.
457    pub code: TokenStream,
458    /// Any (non-error) diagnostics emitted.
459    pub diagnostics: Diagnostics,
460}
461
462/// Compiles a [`DfirCode`] AST into inline source code that runs the dataflow.
463pub fn build_dfir_code(
464    dfir_code: DfirCode,
465    root: &TokenStream,
466) -> Result<BuildDfirCodeOutput, Diagnostics> {
467    let flat_graph_builder = FlatGraphBuilder::from_dfir(dfir_code);
468
469    let FlatGraphBuilderOutput {
470        mut flat_graph,
471        uses,
472        mut diagnostics,
473    } = flat_graph_builder.build()?;
474
475    let () = match flat_graph.merge_modules() {
476        Ok(()) => (),
477        Err(d) => {
478            diagnostics.push(d);
479            return Err(diagnostics);
480        }
481    };
482
483    eliminate_extra_unions_tees(&mut flat_graph);
484
485    // Reject `loop { }` blocks (not yet supported in inline codegen).
486    // TODO(cleanup): find a better home for this check — ideally inside `partition_graph` once
487    // it supports returning multiple diagnostics.
488    for (_loop_id, nodes) in flat_graph.loops() {
489        let span = nodes
490            .first()
491            .map_or_else(Span::call_site, |&n| flat_graph.node(n).span());
492        diagnostics.push(Diagnostic::spanned(
493            span,
494            Level::Error,
495            "`loop { }` blocks are not (yet) supported in `dfir_syntax!`.",
496        ));
497    }
498    if diagnostics.has_error() {
499        return Err(diagnostics);
500    }
501
502    let partitioned_graph = match partition_graph(flat_graph) {
503        Ok(partitioned_graph) => partitioned_graph,
504        Err(d) => {
505            diagnostics.push(d);
506            return Err(diagnostics);
507        }
508    };
509
510    let code =
511        partitioned_graph.as_code(root, true, quote::quote! { #( #uses )* }, &mut diagnostics)?;
512
513    Ok(BuildDfirCodeOutput {
514        partitioned_graph,
515        code,
516        diagnostics,
517    })
518}
519
520/// Changes all of token's spans to `span`, recursing into groups.
521fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
522    use proc_macro2::{Group, TokenTree};
523    tokens
524        .into_iter()
525        .map(|token| match token {
526            TokenTree::Group(mut group) => {
527                group.set_span(span);
528                TokenTree::Group(Group::new(
529                    group.delimiter(),
530                    change_spans(group.stream(), span),
531                ))
532            }
533            TokenTree::Ident(mut ident) => {
534                ident.set_span(span.resolved_at(ident.span()));
535                TokenTree::Ident(ident)
536            }
537            TokenTree::Punct(mut punct) => {
538                punct.set_span(span);
539                TokenTree::Punct(punct)
540            }
541            TokenTree::Literal(mut literal) => {
542                literal.set_span(span);
543                TokenTree::Literal(literal)
544            }
545        })
546        .collect()
547}