use std::collections::HashSet; use std::hash::Hash; #[derive(Debug, thiserror::Error)] #[error("Cycle detected")] pub struct TopoSortCycle; pub fn toposort( input: impl IntoIterator, deps: impl Fn(&K) -> I, ) -> Result, TopoSortCycle> where K: Eq + Hash + Clone, I: Iterator, { struct Ts where K: Eq + Hash + Clone, I: Iterator, D: Fn(&K) -> I, { result_set: HashSet, result: Vec, deps: D, stack: HashSet, } impl Ts where K: Eq + Hash + Clone, I: Iterator, D: Fn(&K) -> I, { fn visit(&mut self, i: &K) -> Result<(), TopoSortCycle> { if self.result_set.contains(i) { return Ok(()); } if !self.stack.insert(i.clone()) { return Err(TopoSortCycle); } for dep in (self.deps)(i) { self.visit(&dep)?; } let removed = self.stack.remove(i); assert!(removed); self.result.push(i.clone()); self.result_set.insert(i.clone()); Ok(()) } } let mut ts = Ts { result: Vec::new(), result_set: HashSet::new(), deps, stack: HashSet::new(), }; for i in input { ts.visit(&i)?; } Ok(ts.result) } #[cfg(test)] mod tests { use std::collections::HashMap; use crate::toposort::toposort; use crate::toposort::TopoSortCycle; fn test_toposort(input: &str) -> Result, TopoSortCycle> { let mut keys: Vec<&str> = Vec::new(); let mut edges: HashMap<&str, Vec<&str>> = HashMap::new(); for part in input.split(" ") { match part.split_once("->") { Some((k, vs)) => { keys.push(k); edges.insert(k, vs.split(",").collect()); } None => keys.push(part), }; } toposort(keys, |k| { edges .get(k) .map(|v| v.as_slice()) .unwrap_or_default() .into_iter() .copied() }) } fn test_toposort_check(input: &str, expected: &str) { let sorted = test_toposort(input).unwrap(); let expected = expected.split(" ").collect::>(); assert_eq!(expected, sorted); } #[test] fn test() { test_toposort_check("1 2 3", "1 2 3"); test_toposort_check("1->2 2->3 3", "3 2 1"); test_toposort_check("1 2->1 3->2", "1 2 3"); test_toposort_check("1->2,3 2->3 3", "3 2 1"); } #[test] fn cycle() { assert!(test_toposort("1->1").is_err()); assert!(test_toposort("1->2 2->1").is_err()); } }