1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18 DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19 null_write_iterator_fn,
20};
21use super::{
22 CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23 GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24 Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41 nodes: SlotMap<GraphNodeId, GraphNode>,
43
44 #[serde(skip)]
47 operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48 operator_tag: SecondaryMap<GraphNodeId, String>,
50 graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52 ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55 node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57 loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59 loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61 root_loops: Vec<GraphLoopId>,
63 loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66 node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69 subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71 subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74 node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76 node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79 subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85impl DfirGraph {
87 pub fn new() -> Self {
89 Default::default()
90 }
91}
92
93impl DfirGraph {
95 pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97 self.nodes.get(node_id).expect("Node not found.")
98 }
99
100 pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105 self.operator_instances.get(node_id)
106 }
107
108 pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
110 self.node_varnames.get(node_id)
111 }
112
113 pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115 self.node_subgraph.get(node_id).copied()
116 }
117
118 pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120 self.graph.degree_in(node_id)
121 }
122
123 pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125 self.graph.degree_out(node_id)
126 }
127
128 pub fn node_successors(
130 &self,
131 src: GraphNodeId,
132 ) -> impl '_
133 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134 + ExactSizeIterator
135 + FusedIterator
136 + Clone
137 + Debug {
138 self.graph.successors(src)
139 }
140
141 pub fn node_predecessors(
143 &self,
144 dst: GraphNodeId,
145 ) -> impl '_
146 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147 + ExactSizeIterator
148 + FusedIterator
149 + Clone
150 + Debug {
151 self.graph.predecessors(dst)
152 }
153
154 pub fn node_successor_edges(
156 &self,
157 src: GraphNodeId,
158 ) -> impl '_
159 + DoubleEndedIterator<Item = GraphEdgeId>
160 + ExactSizeIterator
161 + FusedIterator
162 + Clone
163 + Debug {
164 self.graph.successor_edges(src)
165 }
166
167 pub fn node_predecessor_edges(
169 &self,
170 dst: GraphNodeId,
171 ) -> impl '_
172 + DoubleEndedIterator<Item = GraphEdgeId>
173 + ExactSizeIterator
174 + FusedIterator
175 + Clone
176 + Debug {
177 self.graph.predecessor_edges(dst)
178 }
179
180 pub fn node_successor_nodes(
182 &self,
183 src: GraphNodeId,
184 ) -> impl '_
185 + DoubleEndedIterator<Item = GraphNodeId>
186 + ExactSizeIterator
187 + FusedIterator
188 + Clone
189 + Debug {
190 self.graph.successor_vertices(src)
191 }
192
193 pub fn node_predecessor_nodes(
195 &self,
196 dst: GraphNodeId,
197 ) -> impl '_
198 + DoubleEndedIterator<Item = GraphNodeId>
199 + ExactSizeIterator
200 + FusedIterator
201 + Clone
202 + Debug {
203 self.graph.predecessor_vertices(dst)
204 }
205
206 pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208 self.nodes.keys()
209 }
210
211 pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213 self.nodes.iter()
214 }
215
216 pub fn insert_node(
218 &mut self,
219 node: GraphNode,
220 varname_opt: Option<Ident>,
221 loop_opt: Option<GraphLoopId>,
222 ) -> GraphNodeId {
223 let node_id = self.nodes.insert(node);
224 if let Some(varname) = varname_opt {
225 self.node_varnames.insert(node_id, Varname(varname));
226 }
227 if let Some(loop_id) = loop_opt {
228 self.node_loops.insert(node_id, loop_id);
229 self.loop_nodes[loop_id].push(node_id);
230 }
231 node_id
232 }
233
234 pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236 assert!(matches!(
237 self.nodes.get(node_id),
238 Some(GraphNode::Operator(_))
239 ));
240 let old_inst = self.operator_instances.insert(node_id, op_inst);
241 assert!(old_inst.is_none());
242 }
243
244 pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Vec<Diagnostic>) {
246 let mut op_insts = Vec::new();
247 for (node_id, node) in self.nodes() {
248 let GraphNode::Operator(operator) = node else {
249 continue;
250 };
251 if self.node_op_inst(node_id).is_some() {
252 continue;
253 };
254
255 let Some(op_constraints) = find_op_op_constraints(operator) else {
257 diagnostics.push(Diagnostic::spanned(
258 operator.path.span(),
259 Level::Error,
260 format!("Unknown operator `{}`", operator.name_string()),
261 ));
262 continue;
263 };
264
265 let (input_ports, output_ports) = {
267 let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268 .node_predecessors(node_id)
269 .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270 .collect();
271 input_edges.sort();
273 let input_ports: Vec<PortIndexValue> = input_edges
274 .into_iter()
275 .map(|(port, _pred)| port)
276 .cloned()
277 .collect();
278
279 let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281 .node_successors(node_id)
282 .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283 .collect();
284 output_edges.sort();
286 let output_ports: Vec<PortIndexValue> = output_edges
287 .into_iter()
288 .map(|(port, _succ)| port)
289 .cloned()
290 .collect();
291
292 (input_ports, output_ports)
293 };
294
295 let generics = get_operator_generics(diagnostics, operator);
297 {
299 let generics_span = generics
301 .generic_args
302 .as_ref()
303 .map(Spanned::span)
304 .unwrap_or_else(|| operator.path.span());
305
306 if !op_constraints
307 .persistence_args
308 .contains(&generics.persistence_args.len())
309 {
310 diagnostics.push(Diagnostic::spanned(
311 generics_span,
312 Level::Error,
313 format!(
314 "`{}` should have {} persistence lifetime arguments, actually has {}.",
315 op_constraints.name,
316 op_constraints.persistence_args.human_string(),
317 generics.persistence_args.len()
318 ),
319 ));
320 }
321 if !op_constraints.type_args.contains(&generics.type_args.len()) {
322 diagnostics.push(Diagnostic::spanned(
323 generics_span,
324 Level::Error,
325 format!(
326 "`{}` should have {} generic type arguments, actually has {}.",
327 op_constraints.name,
328 op_constraints.type_args.human_string(),
329 generics.type_args.len()
330 ),
331 ));
332 }
333 }
334
335 op_insts.push((
336 node_id,
337 OperatorInstance {
338 op_constraints,
339 input_ports,
340 output_ports,
341 singletons_referenced: operator.singletons_referenced.clone(),
342 generics,
343 arguments_pre: operator.args.clone(),
344 arguments_raw: operator.args_raw.clone(),
345 },
346 ));
347 }
348
349 for (node_id, op_inst) in op_insts {
350 self.insert_node_op_inst(node_id, op_inst);
351 }
352 }
353
354 pub fn insert_intermediate_node(
366 &mut self,
367 edge_id: GraphEdgeId,
368 new_node: GraphNode,
369 ) -> (GraphNodeId, GraphEdgeId) {
370 let span = Some(new_node.span());
371
372 let op_inst_opt = 'oc: {
374 let GraphNode::Operator(operator) = &new_node else {
375 break 'oc None;
376 };
377 let Some(op_constraints) = find_op_op_constraints(operator) else {
378 break 'oc None;
379 };
380 let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381 let generics = get_operator_generics(
382 &mut Vec::new(), operator,
384 );
385 Some(OperatorInstance {
386 op_constraints,
387 input_ports: vec![input_port],
388 output_ports: vec![output_port],
389 singletons_referenced: operator.singletons_referenced.clone(),
390 generics,
391 arguments_pre: operator.args.clone(),
392 arguments_raw: operator.args_raw.clone(),
393 })
394 };
395
396 let node_id = self.nodes.insert(new_node);
398 if let Some(op_inst) = op_inst_opt {
400 self.operator_instances.insert(node_id, op_inst);
401 }
402 let (e0, e1) = self
404 .graph
405 .insert_intermediate_vertex(node_id, edge_id)
406 .unwrap();
407
408 let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
410 self.ports
411 .insert(e0, (src_idx, PortIndexValue::Elided(span)));
412 self.ports
413 .insert(e1, (PortIndexValue::Elided(span), dst_idx));
414
415 (node_id, e1)
416 }
417
418 pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
421 assert_eq!(
422 1,
423 self.node_degree_in(node_id),
424 "Removed intermediate node must have one predecessor"
425 );
426 assert_eq!(
427 1,
428 self.node_degree_out(node_id),
429 "Removed intermediate node must have one successor"
430 );
431 assert!(
432 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
433 "Should not remove intermediate node after subgraph partitioning"
434 );
435
436 assert!(self.nodes.remove(node_id).is_some());
437 let (new_edge_id, (pred_edge_id, succ_edge_id)) =
438 self.graph.remove_intermediate_vertex(node_id).unwrap();
439 self.operator_instances.remove(node_id);
440 self.node_varnames.remove(node_id);
441
442 let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
443 let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
444 self.ports.insert(new_edge_id, (src_port, dst_port));
445 }
446
447 pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
453 if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
454 return Some(Color::Hoff);
455 }
456 let inn_degree = self.node_predecessor_nodes(node_id).count();
458 let out_degree = self.node_successor_nodes(node_id).count();
460
461 match (inn_degree, out_degree) {
462 (0, 0) => None, (0, 1) => Some(Color::Pull),
464 (1, 0) => Some(Color::Push),
465 (1, 1) => None, (_many, 0 | 1) => Some(Color::Pull),
467 (0 | 1, _many) => Some(Color::Push),
468 (_many, _to_many) => Some(Color::Comp),
469 }
470 }
471
472 pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
474 self.operator_tag.insert(node_id, tag.to_owned());
475 }
476}
477
478impl DfirGraph {
480 pub fn set_node_singleton_references(
483 &mut self,
484 node_id: GraphNodeId,
485 singletons_referenced: Vec<Option<GraphNodeId>>,
486 ) -> Option<Vec<Option<GraphNodeId>>> {
487 self.node_singleton_references
488 .insert(node_id, singletons_referenced)
489 }
490
491 pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
494 self.node_singleton_references
495 .get(node_id)
496 .map(std::ops::Deref::deref)
497 .unwrap_or_default()
498 }
499}
500
501impl DfirGraph {
503 pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
511 let mod_bound_nodes = self
512 .nodes()
513 .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
514 .map(|(nid, _node)| nid)
515 .collect::<Vec<_>>();
516
517 for mod_bound_node in mod_bound_nodes {
518 self.remove_module_boundary(mod_bound_node)?;
519 }
520
521 Ok(())
522 }
523
524 fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
528 assert!(
529 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
530 "Should not remove intermediate node after subgraph partitioning"
531 );
532
533 let mut mod_pred_ports = BTreeMap::new();
534 let mut mod_succ_ports = BTreeMap::new();
535
536 for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
537 let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
538 mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
539 }
540
541 for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
542 let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
543 mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
544 }
545
546 if mod_pred_ports.keys().collect::<BTreeSet<_>>()
547 != mod_succ_ports.keys().collect::<BTreeSet<_>>()
548 {
549 let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
551 panic!();
552 };
553
554 if *input {
555 return Err(Diagnostic {
556 span: *import_expr,
557 level: Level::Error,
558 message: format!(
559 "The ports into the module did not match. input: {:?}, expected: {:?}",
560 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
561 mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
562 ),
563 });
564 } else {
565 return Err(Diagnostic {
566 span: *import_expr,
567 level: Level::Error,
568 message: format!(
569 "The ports out of the module did not match. output: {:?}, expected: {:?}",
570 mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
571 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
572 ),
573 });
574 }
575 }
576
577 for (port, (pred_edge, pred_port)) in mod_pred_ports {
578 let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
579
580 let (src, _) = self.edge(pred_edge);
581 let (_, dst) = self.edge(succ_edge);
582 self.remove_edge(pred_edge);
583 self.remove_edge(succ_edge);
584
585 let new_edge_id = self.graph.insert_edge(src, dst);
586 self.ports.insert(new_edge_id, (pred_port, succ_port));
587 }
588
589 self.graph.remove_vertex(mod_bound_node);
590 self.nodes.remove(mod_bound_node);
591
592 Ok(())
593 }
594}
595
596impl DfirGraph {
598 pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
600 let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
601 (src, dst)
602 }
603
604 pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
606 let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
607 (src_port, dst_port)
608 }
609
610 pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
612 self.graph.edge_ids()
613 }
614
615 pub fn edges(
617 &self,
618 ) -> impl '_
619 + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
620 + FusedIterator
621 + Clone
622 + Debug {
623 self.graph.edges()
624 }
625
626 pub fn insert_edge(
628 &mut self,
629 src: GraphNodeId,
630 src_port: PortIndexValue,
631 dst: GraphNodeId,
632 dst_port: PortIndexValue,
633 ) -> GraphEdgeId {
634 let edge_id = self.graph.insert_edge(src, dst);
635 self.ports.insert(edge_id, (src_port, dst_port));
636 edge_id
637 }
638
639 pub fn remove_edge(&mut self, edge: GraphEdgeId) {
641 let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
642 let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
643 }
644}
645
646impl DfirGraph {
648 pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
650 self.subgraph_nodes
651 .get(subgraph_id)
652 .expect("Subgraph not found.")
653 }
654
655 pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
657 self.subgraph_nodes.keys()
658 }
659
660 pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
662 self.subgraph_nodes.iter()
663 }
664
665 pub fn insert_subgraph(
667 &mut self,
668 node_ids: Vec<GraphNodeId>,
669 ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
670 for &node_id in node_ids.iter() {
672 if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
673 return Err((node_id, old_sg_id));
674 }
675 }
676 let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
677 for &node_id in node_ids.iter() {
678 self.node_subgraph.insert(node_id, sg_id);
679 }
680 node_ids
681 });
682
683 Ok(subgraph_id)
684 }
685
686 pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
688 if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
689 self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
690 true
691 } else {
692 false
693 }
694 }
695
696 pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
698 self.subgraph_stratum.get(sg_id).copied()
699 }
700
701 pub fn set_subgraph_stratum(
703 &mut self,
704 sg_id: GraphSubgraphId,
705 stratum: usize,
706 ) -> Option<usize> {
707 self.subgraph_stratum.insert(sg_id, stratum)
708 }
709
710 fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
712 self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
713 }
714
715 pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
717 self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
718 }
719
720 pub fn max_stratum(&self) -> Option<usize> {
722 self.subgraph_stratum.values().copied().max()
723 }
724
725 fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
727 subgraph_nodes
728 .iter()
729 .position(|&node_id| {
730 self.node_color(node_id)
731 .is_some_and(|color| Color::Pull != color)
732 })
733 .unwrap_or(subgraph_nodes.len())
734 }
735}
736
737impl DfirGraph {
739 fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
741 let name = match &self.nodes[node_id] {
742 GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
743 GraphNode::Handoff { .. } => format!(
744 "hoff_{:?}_{}",
745 node_id.data(),
746 if is_pred { "recv" } else { "send" }
747 ),
748 GraphNode::ModuleBoundary { .. } => panic!(),
749 };
750 let span = match (is_pred, &self.nodes[node_id]) {
751 (_, GraphNode::Operator(operator)) => operator.span(),
752 (true, &GraphNode::Handoff { src_span, .. }) => src_span,
753 (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
754 (_, GraphNode::ModuleBoundary { .. }) => panic!(),
755 };
756 Ident::new(&name, span)
757 }
758
759 fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
761 Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
762 }
763
764 fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
766 self.node_singleton_references(node_id)
767 .iter()
768 .map(|singleton_node_id| {
769 self.node_as_singleton_ident(
771 singleton_node_id
772 .expect("Expected singleton to be resolved but was not, this is a bug."),
773 span,
774 )
775 })
776 .collect::<Vec<_>>()
777 }
778
779 fn helper_collect_subgraph_handoffs(
782 &self,
783 ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
784 let mut subgraph_handoffs: SecondaryMap<
786 GraphSubgraphId,
787 (Vec<GraphNodeId>, Vec<GraphNodeId>),
788 > = self
789 .subgraph_nodes
790 .keys()
791 .map(|k| (k, Default::default()))
792 .collect();
793
794 for (hoff_id, node) in self.nodes() {
796 if !matches!(node, GraphNode::Handoff { .. }) {
797 continue;
798 }
799 for (_edge, succ_id) in self.node_successors(hoff_id) {
801 let succ_sg = self.node_subgraph(succ_id).unwrap();
802 subgraph_handoffs[succ_sg].0.push(hoff_id);
803 }
804 for (_edge, pred_id) in self.node_predecessors(hoff_id) {
806 let pred_sg = self.node_subgraph(pred_id).unwrap();
807 subgraph_handoffs[pred_sg].1.push(hoff_id);
808 }
809 }
810
811 subgraph_handoffs
812 }
813
814 fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
816 let mut out = TokenStream::new();
818 let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
819 while let Some(loop_id) = queue.pop_front() {
820 let parent_opt = self
821 .loop_parent(loop_id)
822 .map(|loop_id| loop_id.as_ident(Span::call_site()))
823 .map(|ident| quote! { Some(#ident) })
824 .unwrap_or_else(|| quote! { None });
825 let loop_name = loop_id.as_ident(Span::call_site());
826 out.append_all(quote! {
827 let #loop_name = #df.add_loop(#parent_opt);
828 });
829 queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
830 }
831 out
832 }
833
834 pub fn as_code(
836 &self,
837 root: &TokenStream,
838 include_type_guards: bool,
839 prefix: TokenStream,
840 diagnostics: &mut Vec<Diagnostic>,
841 ) -> TokenStream {
842 let df = Ident::new(GRAPH, Span::call_site());
843 let context = Ident::new(CONTEXT, Span::call_site());
844
845 let handoff_code = self
847 .nodes
848 .iter()
849 .filter_map(|(node_id, node)| match node {
850 GraphNode::Operator(_) => None,
851 &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
852 GraphNode::ModuleBoundary { .. } => panic!(),
853 })
854 .map(|(node_id, (src_span, dst_span))| {
855 let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
856 let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
857 let span = src_span.join(dst_span).unwrap_or(src_span);
858 let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
859 hoff_name.set_span(span);
860 let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
861 quote_spanned! {span=>
862 let (#ident_send, #ident_recv) =
863 #df.make_edge::<_, #hoff_type>(#hoff_name);
864 }
865 });
866
867 let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
868
869 let (subgraphs_without_preds, subgraphs_with_preds) = self
871 .subgraph_nodes
872 .iter()
873 .partition::<Vec<_>, _>(|(_, nodes)| {
874 nodes
875 .iter()
876 .any(|&node_id| self.node_degree_in(node_id) == 0)
877 });
878
879 let mut op_prologue_code = Vec::new();
880 let mut op_prologue_after_code = Vec::new();
881 let mut subgraphs = Vec::new();
882 {
883 for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
884 .iter()
885 .chain(subgraphs_with_preds.iter())
886 {
887 let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
888 let recv_ports: Vec<Ident> = recv_hoffs
889 .iter()
890 .map(|&hoff_id| self.node_as_ident(hoff_id, true))
891 .collect();
892 let send_ports: Vec<Ident> = send_hoffs
893 .iter()
894 .map(|&hoff_id| self.node_as_ident(hoff_id, false))
895 .collect();
896
897 let recv_port_code = recv_ports.iter().map(|ident| {
898 quote_spanned! {ident.span()=>
899 let mut #ident = #ident.borrow_mut_swap();
900 let #ident = #ident.drain(..);
901 }
902 });
903 let send_port_code = send_ports.iter().map(|ident| {
904 quote_spanned! {ident.span()=>
905 let #ident = #root::sinktools::for_each(|v| {
906 #ident.give(Some(v));
907 });
908 }
909 });
910
911 let loop_id = self
912 .node_loop(subgraph_nodes[0]);
914
915 let mut subgraph_op_iter_code = Vec::new();
916 let mut subgraph_op_iter_after_code = Vec::new();
917 {
918 let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
919
920 let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
921 let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
922
923 for (idx, &node_id) in nodes_iter.enumerate() {
924 let node = &self.nodes[node_id];
925 assert!(
926 matches!(node, GraphNode::Operator(_)),
927 "Handoffs are not part of subgraphs."
928 );
929 let op_inst = &self.operator_instances[node_id];
930
931 let op_span = node.span();
932 let op_name = op_inst.op_constraints.name;
933 let root = change_spans(root.clone(), op_span);
935 let op_constraints = OPERATORS
937 .iter()
938 .find(|op| op_name == op.name)
939 .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
940
941 let ident = self.node_as_ident(node_id, false);
942
943 {
944 let mut input_edges = self
947 .graph
948 .predecessor_edges(node_id)
949 .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
950 .collect::<Vec<_>>();
951 input_edges.sort();
953
954 let inputs = input_edges
955 .iter()
956 .map(|&(_port, edge_id)| {
957 let (pred, _) = self.edge(edge_id);
958 self.node_as_ident(pred, true)
959 })
960 .collect::<Vec<_>>();
961
962 let mut output_edges = self
964 .graph
965 .successor_edges(node_id)
966 .map(|edge_id| (&self.ports[edge_id].0, edge_id))
967 .collect::<Vec<_>>();
968 output_edges.sort();
970
971 let outputs = output_edges
972 .iter()
973 .map(|&(_port, edge_id)| {
974 let (_, succ) = self.edge(edge_id);
975 self.node_as_ident(succ, false)
976 })
977 .collect::<Vec<_>>();
978
979 let is_pull = idx < pull_to_push_idx;
980
981 let singleton_output_ident = &if op_constraints.has_singleton_output {
982 self.node_as_singleton_ident(node_id, op_span)
983 } else {
984 Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
986 };
987
988 let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
997 let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
998
999 let singletons_resolved =
1000 self.helper_resolve_singletons(node_id, op_span);
1001 let arguments = &process_singletons::postprocess_singletons(
1002 op_inst.arguments_raw.clone(),
1003 singletons_resolved.clone(),
1004 context,
1005 );
1006 let arguments_handles =
1007 &process_singletons::postprocess_singletons_handles(
1008 op_inst.arguments_raw.clone(),
1009 singletons_resolved.clone(),
1010 );
1011
1012 let source_tag = 'a: {
1013 if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1014 break 'a tag;
1015 }
1016
1017 #[cfg(nightly)]
1018 if proc_macro::is_available() {
1019 let op_span = op_span.unwrap();
1020 break 'a format!(
1021 "loc_{}_{}_{}_{}_{}",
1022 crate::pretty_span::make_source_path_relative(
1023 &op_span.file()
1024 )
1025 .display()
1026 .to_string()
1027 .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1028 op_span.start().line(),
1029 op_span.start().column(),
1030 op_span.end().line(),
1031 op_span.end().column(),
1032 );
1033 }
1034
1035 format!(
1036 "loc_nopath_{}_{}_{}_{}",
1037 op_span.start().line,
1038 op_span.start().column,
1039 op_span.end().line,
1040 op_span.end().column
1041 )
1042 };
1043
1044 let work_fn = format_ident!(
1045 "{}__{}__{}",
1046 ident,
1047 op_name,
1048 source_tag,
1049 span = op_span
1050 );
1051
1052 let context_args = WriteContextArgs {
1053 root: &root,
1054 df_ident: df_local,
1055 context,
1056 subgraph_id,
1057 node_id,
1058 loop_id,
1059 op_span,
1060 op_tag: self.operator_tag.get(node_id).cloned(),
1061 work_fn: &work_fn,
1062 ident: &ident,
1063 is_pull,
1064 inputs: &inputs,
1065 outputs: &outputs,
1066 singleton_output_ident,
1067 op_name,
1068 op_inst,
1069 arguments,
1070 arguments_handles,
1071 };
1072
1073 let write_result =
1074 (op_constraints.write_fn)(&context_args, diagnostics);
1075 let OperatorWriteOutput {
1076 write_prologue,
1077 write_prologue_after,
1078 write_iterator,
1079 write_iterator_after,
1080 } = write_result.unwrap_or_else(|()| {
1081 assert!(
1082 diagnostics.iter().any(Diagnostic::is_error),
1083 "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1084 op_name,
1085 );
1086 OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1087 });
1088
1089 op_prologue_code.push(syn::parse_quote! {
1090 #[allow(non_snake_case)]
1091 #[inline(always)]
1092 fn #work_fn<T>(thunk: impl FnOnce() -> T) -> T {
1093 thunk()
1094 }
1095 });
1096 op_prologue_code.push(write_prologue);
1097 op_prologue_after_code.push(write_prologue_after);
1098 subgraph_op_iter_code.push(write_iterator);
1099
1100 if include_type_guards {
1101 let type_guard = if is_pull {
1102 quote_spanned! {op_span=>
1103 let #ident = {
1104 #[allow(non_snake_case)]
1105 #[inline(always)]
1106 pub fn #work_fn<Item, Input: ::std::iter::Iterator<Item = Item>>(input: Input) -> impl ::std::iter::Iterator<Item = Item> {
1107 #[repr(transparent)]
1108 struct Pull<Item, Input: ::std::iter::Iterator<Item = Item>> {
1109 inner: Input
1110 }
1111
1112 impl<Item, Input: ::std::iter::Iterator<Item = Item>> Iterator for Pull<Item, Input> {
1113 type Item = Item;
1114
1115 #[inline(always)]
1116 fn next(&mut self) -> Option<Self::Item> {
1117 self.inner.next()
1118 }
1119
1120 #[inline(always)]
1121 fn size_hint(&self) -> (usize, Option<usize>) {
1122 self.inner.size_hint()
1123 }
1124 }
1125
1126 Pull {
1127 inner: input
1128 }
1129 }
1130 #work_fn( #ident )
1131 };
1132 }
1133 } else {
1134 quote_spanned! {op_span=>
1135 let #ident = {
1136 #[allow(non_snake_case)]
1137 #[inline(always)]
1138 pub fn #work_fn<Item, Si>(si: Si) -> impl #root::futures::sink::Sink<Item, Error = #root::Never>
1139 where
1140 Si: #root::futures::sink::Sink<Item, Error = #root::Never>
1141 {
1142 #root::pin_project_lite::pin_project! {
1143 #[repr(transparent)]
1144 struct Push<Si> {
1145 #[pin]
1146 si: Si,
1147 }
1148 }
1149 impl<Item, Si> #root::futures::sink::Sink<Item> for Push<Si>
1150 where
1151 Si: #root::futures::sink::Sink<Item>,
1152 {
1153 type Error = Si::Error;
1154
1155 fn poll_ready(
1156 self: ::std::pin::Pin<&mut Self>,
1157 cx: &mut ::std::task::Context<'_>,
1158 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1159 self.project().si.poll_ready(cx)
1160 }
1161
1162 fn start_send(
1163 self: ::std::pin::Pin<&mut Self>,
1164 item: Item,
1165 ) -> ::std::result::Result<(), Self::Error> {
1166 self.project().si.start_send(item)
1167 }
1168
1169 fn poll_flush(
1170 self: ::std::pin::Pin<&mut Self>,
1171 cx: &mut ::std::task::Context<'_>,
1172 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1173 self.project().si.poll_flush(cx)
1174 }
1175
1176 fn poll_close(
1177 self: ::std::pin::Pin<&mut Self>,
1178 cx: &mut ::std::task::Context<'_>,
1179 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1180 self.project().si.poll_close(cx)
1181 }
1182 }
1183
1184 Push {
1185 si
1186 }
1187 }
1188 #work_fn( #ident )
1189 };
1190 }
1191 };
1192 subgraph_op_iter_code.push(type_guard);
1193 }
1194 subgraph_op_iter_after_code.push(write_iterator_after);
1195 }
1196 }
1197
1198 {
1199 let pull_ident = if 0 < pull_to_push_idx {
1201 self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1202 } else {
1203 recv_ports[0].clone()
1205 };
1206
1207 #[rustfmt::skip]
1208 let push_ident = if let Some(&node_id) =
1209 subgraph_nodes.get(pull_to_push_idx)
1210 {
1211 self.node_as_ident(node_id, false)
1212 } else if 1 == send_ports.len() {
1213 send_ports[0].clone()
1215 } else {
1216 diagnostics.push(Diagnostic::spanned(
1217 pull_ident.span(),
1218 Level::Error,
1219 "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1220 ));
1221 continue;
1222 };
1223
1224 let pivot_span = pull_ident
1226 .span()
1227 .join(push_ident.span())
1228 .unwrap_or_else(|| push_ident.span());
1229 let pivot_fn_ident =
1230 Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1231 let root = change_spans(root.clone(), pivot_span);
1232 subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1233 #[inline(always)]
1234 fn #pivot_fn_ident<Pull, Push, Item>(pull: Pull, push: Push)
1235 -> impl ::std::future::Future<Output = ::std::result::Result<(), #root::Never>>
1236 where
1237 Pull: ::std::iter::Iterator<Item = Item>,
1238 Push: #root::futures::sink::Sink<Item, Error = #root::Never>,
1239 {
1240 #root::sinktools::send_iter(pull, push)
1241 }
1242 (#pivot_fn_ident)(#pull_ident, #push_ident).await.unwrap();
1243 });
1244 }
1245 };
1246
1247 let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1248 let stratum = Literal::usize_unsuffixed(
1249 self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1250 );
1251 let laziness = self.subgraph_laziness(subgraph_id);
1252
1253 let loop_id_opt = loop_id
1255 .map(|loop_id| loop_id.as_ident(Span::call_site()))
1256 .map(|ident| quote! { Some(#ident) })
1257 .unwrap_or_else(|| quote! { None });
1258
1259 let sg_ident = subgraph_id.as_ident(Span::call_site());
1260
1261 subgraphs.push(quote! {
1262 let #sg_ident = #df.add_subgraph_full(
1263 #subgraph_name,
1264 #stratum,
1265 var_expr!( #( #recv_ports ),* ),
1266 var_expr!( #( #send_ports ),* ),
1267 #laziness,
1268 #loop_id_opt,
1269 async move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1270 #( #recv_port_code )*
1271 #( #send_port_code )*
1272 #( #subgraph_op_iter_code )*
1273 #( #subgraph_op_iter_after_code )*
1274 },
1275 );
1276 });
1277 }
1278 }
1279
1280 let loop_code = self.codegen_nested_loops(&df);
1281
1282 let code = quote! {
1287 #( #handoff_code )*
1288 #loop_code
1289 #( #op_prologue_code )*
1290 #( #subgraphs )*
1291 #( #op_prologue_after_code )*
1292 };
1293
1294 let meta_graph_json = serde_json::to_string(&self).unwrap();
1295 let meta_graph_json = Literal::string(&meta_graph_json);
1296
1297 let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1298 let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1299 let diagnostics_json = Literal::string(&diagnostics_json);
1300
1301 quote! {
1302 {
1303 #[allow(unused_qualifications, clippy::await_holding_refcell_ref)]
1304 {
1305 #prefix
1306
1307 use #root::{var_expr, var_args};
1308
1309 let mut #df = #root::scheduled::graph::Dfir::new();
1310 #df.__assign_meta_graph(#meta_graph_json);
1311 #df.__assign_diagnostics(#diagnostics_json);
1312
1313 #code
1314
1315 #df
1316 }
1317 }
1318 }
1319 }
1320
1321 pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1324 let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1325 .node_ids()
1326 .filter_map(|node_id| {
1327 let op_color = self.node_color(node_id)?;
1328 Some((node_id, op_color))
1329 })
1330 .collect();
1331
1332 for sg_nodes in self.subgraph_nodes.values() {
1334 let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1335
1336 for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1337 let is_pull = idx < pull_to_push_idx;
1338 node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1339 }
1340 }
1341
1342 node_color_map
1343 }
1344
1345 pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1347 let mut output = String::new();
1348 self.write_mermaid(&mut output, write_config).unwrap();
1349 output
1350 }
1351
1352 pub fn write_mermaid(
1354 &self,
1355 output: impl std::fmt::Write,
1356 write_config: &WriteConfig,
1357 ) -> std::fmt::Result {
1358 let mut graph_write = Mermaid::new(output);
1359 self.write_graph(&mut graph_write, write_config)
1360 }
1361
1362 pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1364 let mut output = String::new();
1365 let mut graph_write = Dot::new(&mut output);
1366 self.write_graph(&mut graph_write, write_config).unwrap();
1367 output
1368 }
1369
1370 pub fn write_dot(
1372 &self,
1373 output: impl std::fmt::Write,
1374 write_config: &WriteConfig,
1375 ) -> std::fmt::Result {
1376 let mut graph_write = Dot::new(output);
1377 self.write_graph(&mut graph_write, write_config)
1378 }
1379
1380 pub(crate) fn write_graph<W>(
1382 &self,
1383 mut graph_write: W,
1384 write_config: &WriteConfig,
1385 ) -> Result<(), W::Err>
1386 where
1387 W: GraphWrite,
1388 {
1389 fn helper_edge_label(
1390 src_port: &PortIndexValue,
1391 dst_port: &PortIndexValue,
1392 ) -> Option<String> {
1393 let src_label = match src_port {
1394 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1395 PortIndexValue::Int(index) => Some(index.value.to_string()),
1396 _ => None,
1397 };
1398 let dst_label = match dst_port {
1399 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1400 PortIndexValue::Int(index) => Some(index.value.to_string()),
1401 _ => None,
1402 };
1403 let label = match (src_label, dst_label) {
1404 (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1405 (Some(l1), None) => Some(l1),
1406 (None, Some(l2)) => Some(l2),
1407 (None, None) => None,
1408 };
1409 label
1410 }
1411
1412 let node_color_map = self.node_color_map();
1414
1415 graph_write.write_prologue()?;
1417
1418 let mut skipped_handoffs = BTreeSet::new();
1420 let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1421 for (node_id, node) in self.nodes() {
1422 if matches!(node, GraphNode::Handoff { .. }) {
1423 if write_config.no_handoffs {
1424 skipped_handoffs.insert(node_id);
1425 continue;
1426 } else {
1427 let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1428 let pred_sg = self.node_subgraph(pred_node);
1429 let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1430 let succ_sg = self.node_subgraph(succ_node);
1431 if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1432 && pred_sg == succ_sg
1433 {
1434 subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1435 }
1436 }
1437 }
1438 graph_write.write_node_definition(
1439 node_id,
1440 &if write_config.op_short_text {
1441 node.to_name_string()
1442 } else if write_config.op_text_no_imports {
1443 let full_text = node.to_pretty_string();
1445 let mut output = String::new();
1446 for sentence in full_text.split('\n') {
1447 if sentence.trim().starts_with("use") {
1448 continue;
1449 }
1450 output.push('\n');
1451 output.push_str(sentence);
1452 }
1453 output.into()
1454 } else {
1455 node.to_pretty_string()
1456 },
1457 if write_config.no_pull_push {
1458 None
1459 } else {
1460 node_color_map.get(node_id).copied()
1461 },
1462 )?;
1463 }
1464
1465 for (edge_id, (src_id, mut dst_id)) in self.edges() {
1467 if skipped_handoffs.contains(&src_id) {
1469 continue;
1470 }
1471
1472 let (src_port, mut dst_port) = self.edge_ports(edge_id);
1473 if skipped_handoffs.contains(&dst_id) {
1474 let mut handoff_succs = self.node_successors(dst_id);
1475 assert_eq!(1, handoff_succs.len());
1476 let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1477 dst_id = succ_node;
1478 dst_port = self.edge_ports(succ_edge).1;
1479 }
1480
1481 let label = helper_edge_label(src_port, dst_port);
1482 let delay_type = self
1483 .node_op_inst(dst_id)
1484 .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1485 graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1486 }
1487
1488 if !write_config.no_references {
1490 for dst_id in self.node_ids() {
1491 for src_ref_id in self
1492 .node_singleton_references(dst_id)
1493 .iter()
1494 .copied()
1495 .flatten()
1496 {
1497 let delay_type = Some(DelayType::Stratum);
1498 let label = None;
1499 graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1500 }
1501 }
1502 }
1503
1504 let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1515 let loop_id = if write_config.no_loops {
1516 None
1517 } else {
1518 self.subgraph_loop(sg_id)
1519 };
1520 (loop_id, sg_id)
1521 });
1522 let loop_subgraphs = into_group_map(loop_subgraphs);
1523 for (loop_id, subgraph_ids) in loop_subgraphs {
1524 if let Some(loop_id) = loop_id {
1525 graph_write.write_loop_start(loop_id)?;
1526 }
1527
1528 let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1530 self.subgraph(sg_id).iter().copied().map(move |node_id| {
1531 let opt_sg_id = if write_config.no_subgraphs {
1532 None
1533 } else {
1534 Some(sg_id)
1535 };
1536 (opt_sg_id, (self.node_varname(node_id), node_id))
1537 })
1538 });
1539 let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1540 for (sg_id, varnames) in subgraph_varnames_nodes {
1541 if let Some(sg_id) = sg_id {
1542 let stratum = self.subgraph_stratum(sg_id).unwrap();
1543 graph_write.write_subgraph_start(sg_id, stratum)?;
1544 }
1545
1546 let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1548 let varname = if write_config.no_varnames {
1549 None
1550 } else {
1551 varname
1552 };
1553 (varname, node)
1554 });
1555 let varname_nodes = into_group_map(varname_nodes);
1556 for (varname, node_ids) in varname_nodes {
1557 if let Some(varname) = varname {
1558 graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1559 }
1560
1561 for node_id in node_ids {
1563 graph_write.write_node(node_id)?;
1564 }
1565
1566 if varname.is_some() {
1567 graph_write.write_varname_end()?;
1568 }
1569 }
1570
1571 if sg_id.is_some() {
1572 graph_write.write_subgraph_end()?;
1573 }
1574 }
1575
1576 if loop_id.is_some() {
1577 graph_write.write_loop_end()?;
1578 }
1579 }
1580
1581 graph_write.write_epilogue()?;
1583
1584 Ok(())
1585 }
1586
1587 pub fn surface_syntax_string(&self) -> String {
1589 let mut string = String::new();
1590 self.write_surface_syntax(&mut string).unwrap();
1591 string
1592 }
1593
1594 pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1596 for (key, node) in self.nodes.iter() {
1597 match node {
1598 GraphNode::Operator(op) => {
1599 writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1600 }
1601 GraphNode::Handoff { .. } => {
1602 writeln!(write, "// {:?} = <handoff>;", key.data())?;
1603 }
1604 GraphNode::ModuleBoundary { .. } => panic!(),
1605 }
1606 }
1607 writeln!(write)?;
1608 for (_e, (src_key, dst_key)) in self.graph.edges() {
1609 writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1610 }
1611 Ok(())
1612 }
1613
1614 pub fn mermaid_string_flat(&self) -> String {
1616 let mut string = String::new();
1617 self.write_mermaid_flat(&mut string).unwrap();
1618 string
1619 }
1620
1621 pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1623 writeln!(write, "flowchart TB")?;
1624 for (key, node) in self.nodes.iter() {
1625 match node {
1626 GraphNode::Operator(operator) => writeln!(
1627 write,
1628 " %% {span}\n {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1629 span = PrettySpan(node.span()),
1630 id = key.data(),
1631 row_col = PrettyRowCol(node.span()),
1632 code = operator
1633 .to_token_stream()
1634 .to_string()
1635 .replace('&', "&")
1636 .replace('<', "<")
1637 .replace('>', ">")
1638 .replace('"', """)
1639 .replace('\n', "<br>"),
1640 ),
1641 GraphNode::Handoff { .. } => {
1642 writeln!(write, r#" {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1643 }
1644 GraphNode::ModuleBoundary { .. } => {
1645 writeln!(
1646 write,
1647 r#" {:?}{{"{}"}}"#,
1648 key.data(),
1649 MODULE_BOUNDARY_NODE_STR
1650 )
1651 }
1652 }?;
1653 }
1654 writeln!(write)?;
1655 for (_e, (src_key, dst_key)) in self.graph.edges() {
1656 writeln!(write, " {:?}-->{:?}", src_key.data(), dst_key.data())?;
1657 }
1658 Ok(())
1659 }
1660}
1661
1662impl DfirGraph {
1664 pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1666 self.loop_nodes.keys()
1667 }
1668
1669 pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1671 self.loop_nodes.iter()
1672 }
1673
1674 pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1676 let loop_id = self.loop_nodes.insert(Vec::new());
1677 self.loop_children.insert(loop_id, Vec::new());
1678 if let Some(parent_loop) = parent_loop {
1679 self.loop_parent.insert(loop_id, parent_loop);
1680 self.loop_children
1681 .get_mut(parent_loop)
1682 .unwrap()
1683 .push(loop_id);
1684 } else {
1685 self.root_loops.push(loop_id);
1686 }
1687 loop_id
1688 }
1689
1690 pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1692 self.node_loops.get(node_id).copied()
1693 }
1694
1695 pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1697 let &node_id = self.subgraph(subgraph_id).first().unwrap();
1698 let out = self.node_loop(node_id);
1699 debug_assert!(
1700 self.subgraph(subgraph_id)
1701 .iter()
1702 .all(|&node_id| self.node_loop(node_id) == out),
1703 "Subgraph nodes should all have the same loop context."
1704 );
1705 out
1706 }
1707
1708 pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1710 self.loop_parent.get(loop_id).copied()
1711 }
1712
1713 pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1715 self.loop_children.get(loop_id).unwrap()
1716 }
1717}
1718
1719#[derive(Clone, Debug, Default)]
1721#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1722pub struct WriteConfig {
1723 #[cfg_attr(feature = "clap-derive", arg(long))]
1725 pub no_subgraphs: bool,
1726 #[cfg_attr(feature = "clap-derive", arg(long))]
1728 pub no_varnames: bool,
1729 #[cfg_attr(feature = "clap-derive", arg(long))]
1731 pub no_pull_push: bool,
1732 #[cfg_attr(feature = "clap-derive", arg(long))]
1734 pub no_handoffs: bool,
1735 #[cfg_attr(feature = "clap-derive", arg(long))]
1737 pub no_references: bool,
1738 #[cfg_attr(feature = "clap-derive", arg(long))]
1740 pub no_loops: bool,
1741
1742 #[cfg_attr(feature = "clap-derive", arg(long))]
1744 pub op_short_text: bool,
1745 #[cfg_attr(feature = "clap-derive", arg(long))]
1747 pub op_text_no_imports: bool,
1748}
1749
1750#[derive(Copy, Clone, Debug)]
1752#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1753pub enum WriteGraphType {
1754 Mermaid,
1756 Dot,
1758}
1759
1760fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
1762where
1763 K: Ord,
1764{
1765 let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
1766 for (k, v) in iter {
1767 out.entry(k).or_default().push(v);
1768 }
1769 out
1770}