• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::collections::HashSet;
2 use std::hash::Hash;
3 
4 #[derive(Debug, thiserror::Error)]
5 #[error("Cycle detected")]
6 pub struct TopoSortCycle;
7 
toposort<K, I>( input: impl IntoIterator<Item = K>, deps: impl Fn(&K) -> I, ) -> Result<Vec<K>, TopoSortCycle> where K: Eq + Hash + Clone, I: Iterator<Item = K>,8 pub fn toposort<K, I>(
9     input: impl IntoIterator<Item = K>,
10     deps: impl Fn(&K) -> I,
11 ) -> Result<Vec<K>, TopoSortCycle>
12 where
13     K: Eq + Hash + Clone,
14     I: Iterator<Item = K>,
15 {
16     struct Ts<K, D, I>
17     where
18         K: Eq + Hash + Clone,
19         I: Iterator<Item = K>,
20         D: Fn(&K) -> I,
21     {
22         result_set: HashSet<K>,
23         result: Vec<K>,
24         deps: D,
25         stack: HashSet<K>,
26     }
27 
28     impl<K, D, I> Ts<K, D, I>
29     where
30         K: Eq + Hash + Clone,
31         I: Iterator<Item = K>,
32         D: Fn(&K) -> I,
33     {
34         fn visit(&mut self, i: &K) -> Result<(), TopoSortCycle> {
35             if self.result_set.contains(i) {
36                 return Ok(());
37             }
38 
39             if !self.stack.insert(i.clone()) {
40                 return Err(TopoSortCycle);
41             }
42             for dep in (self.deps)(i) {
43                 self.visit(&dep)?;
44             }
45 
46             let removed = self.stack.remove(i);
47             assert!(removed);
48 
49             self.result.push(i.clone());
50             self.result_set.insert(i.clone());
51 
52             Ok(())
53         }
54     }
55 
56     let mut ts = Ts {
57         result: Vec::new(),
58         result_set: HashSet::new(),
59         deps,
60         stack: HashSet::new(),
61     };
62 
63     for i in input {
64         ts.visit(&i)?;
65     }
66 
67     Ok(ts.result)
68 }
69 
70 #[cfg(test)]
71 mod tests {
72     use std::collections::HashMap;
73 
74     use crate::toposort::toposort;
75     use crate::toposort::TopoSortCycle;
76 
test_toposort(input: &str) -> Result<Vec<&str>, TopoSortCycle>77     fn test_toposort(input: &str) -> Result<Vec<&str>, TopoSortCycle> {
78         let mut keys: Vec<&str> = Vec::new();
79         let mut edges: HashMap<&str, Vec<&str>> = HashMap::new();
80         for part in input.split(" ") {
81             match part.split_once("->") {
82                 Some((k, vs)) => {
83                     keys.push(k);
84                     edges.insert(k, vs.split(",").collect());
85                 }
86                 None => keys.push(part),
87             };
88         }
89 
90         toposort(keys, |k| {
91             edges
92                 .get(k)
93                 .map(|v| v.as_slice())
94                 .unwrap_or_default()
95                 .into_iter()
96                 .copied()
97         })
98     }
99 
test_toposort_check(input: &str, expected: &str)100     fn test_toposort_check(input: &str, expected: &str) {
101         let sorted = test_toposort(input).unwrap();
102         let expected = expected.split(" ").collect::<Vec<_>>();
103         assert_eq!(expected, sorted);
104     }
105 
106     #[test]
test()107     fn test() {
108         test_toposort_check("1 2 3", "1 2 3");
109         test_toposort_check("1->2 2->3 3", "3 2 1");
110         test_toposort_check("1 2->1 3->2", "1 2 3");
111         test_toposort_check("1->2,3 2->3 3", "3 2 1");
112     }
113 
114     #[test]
cycle()115     fn cycle() {
116         assert!(test_toposort("1->1").is_err());
117         assert!(test_toposort("1->2 2->1").is_err());
118     }
119 }
120