Skip to main content

cranelift_isle/
recursion.rs

1//! Recursion checking for ISLE terms.
2
3use std::collections::{HashMap, HashSet};
4
5use crate::{
6    error::{Error, Span},
7    sema::{TermEnv, TermId},
8    trie_again::{Binding, RuleSet},
9};
10
11/// Check for recursive terms.
12pub fn check(terms: &[(TermId, RuleSet)], termenv: &TermEnv) -> Result<(), Vec<Error>> {
13    // Search for cycles in the term dependency graph.
14    let cyclic_terms = terms_in_cycles(terms);
15
16    // Cyclic terms should be explicitly permitted with the `rec` attribute.
17    let mut errors = Vec::new();
18    for term_id in cyclic_terms {
19        // Error if term is not explicitly marked recursive.
20        let term = &termenv.terms[term_id.index()];
21        if !term.is_recursive() {
22            errors.push(Error::RecursionError {
23                msg: "Term is recursive but does not have the `rec` attribute".to_string(),
24                span: Span::new_single(term.decl_pos),
25            });
26        }
27    }
28
29    if errors.is_empty() {
30        Ok(())
31    } else {
32        Err(errors)
33    }
34}
35
36// Find terms that are in cycles in the term dependency graph.
37fn terms_in_cycles(terms: &[(TermId, RuleSet)]) -> HashSet<TermId> {
38    // Construct term dependency graph.
39    let edges: HashMap<TermId, HashSet<TermId>> = terms
40        .iter()
41        .map(|(term_id, rule_set)| (*term_id, terms_in_rule_set(rule_set)))
42        .collect();
43
44    // Depth-first search with a stack.
45    enum Event {
46        Enter(TermId),
47        Exit(TermId),
48    }
49    let mut stack = Vec::from_iter(edges.keys().copied().map(Event::Enter));
50
51    // State of each term.
52    enum State {
53        Visiting,
54        Visited,
55    }
56    let mut states = HashMap::new();
57
58    // Maintain current path.
59    let mut path = Vec::new();
60
61    // Collect terms that are in cycles.
62    let mut in_cycle = HashSet::new();
63
64    // Process DFS stack.
65    while let Some(event) = stack.pop() {
66        match event {
67            Event::Enter(term_id) => match states.get(&term_id) {
68                None => {
69                    states.insert(term_id, State::Visiting);
70                    path.push(term_id);
71                    stack.push(Event::Exit(term_id));
72                    if let Some(deps) = edges.get(&term_id) {
73                        for dep in deps {
74                            stack.push(Event::Enter(*dep));
75                        }
76                    }
77                }
78                Some(State::Visiting) => {
79                    // Cycle detected. Reconstruct the cycle from path.
80                    let begin = path
81                        .iter()
82                        .rposition(|&t| t == term_id)
83                        .expect("cycle origin should be in path");
84                    in_cycle.extend(&path[begin..]);
85                }
86                Some(State::Visited) => {}
87            },
88            Event::Exit(term_id) => {
89                states.insert(term_id, State::Visited);
90                let last = path.pop().expect("exit with empty path");
91                debug_assert_eq!(last, term_id, "exit term does not match last path term");
92            }
93        }
94    }
95
96    debug_assert!(path.is_empty(), "search finished with non-empty path");
97
98    in_cycle
99}
100
101fn terms_in_rule_set(rule_set: &RuleSet) -> HashSet<TermId> {
102    rule_set
103        .bindings
104        .iter()
105        .filter_map(binding_used_term)
106        .collect()
107}
108
109fn binding_used_term(binding: &Binding) -> Option<TermId> {
110    match binding {
111        Binding::Constructor { term, .. } | Binding::Extractor { term, .. } => Some(*term),
112        _ => None,
113    }
114}