diff --git a/src/uu/tsort/src/tsort.rs b/src/uu/tsort/src/tsort.rs index 85380bf403b..0ae78e15306 100644 --- a/src/uu/tsort/src/tsort.rs +++ b/src/uu/tsort/src/tsort.rs @@ -4,7 +4,8 @@ // file that was distributed with this source code. //spell-checker:ignore TAOCP indegree use clap::{Arg, Command}; -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::hash_map::Entry; +use std::collections::{HashMap, VecDeque}; use std::ffi::OsString; use std::path::Path; use thiserror::Error; @@ -34,13 +35,15 @@ enum TsortError { /// The graph contains a cycle. #[error("{input}: {message}", input = .0, message = translate!("tsort-error-loop"))] Loop(String), - - /// A particular node in a cycle. (This is mainly used for printing.) - #[error("{0}")] - LoopNode(String), } +// Auxiliary struct, just for printing loop nodes via show! macro +#[derive(Debug, Error)] +#[error("{0}")] +struct LoopNode<'a>(&'a str); + impl UError for TsortError {} +impl UError for LoopNode<'_> {} #[uucore::main] pub fn uumain(args: impl uucore::Args) -> UResult<()> { @@ -131,6 +134,12 @@ struct Graph<'input> { nodes: HashMap<&'input str, Node<'input>>, } +#[derive(Clone, Copy, PartialEq, Eq)] +enum VisitedState { + Opened, + Closed, +} + impl<'input> Graph<'input> { fn new(name: String) -> Graph<'input> { Self { @@ -224,8 +233,8 @@ impl<'input> Graph<'input> { fn find_and_break_cycle(&mut self, frontier: &mut VecDeque<&'input str>) { let cycle = self.detect_cycle(); show!(TsortError::Loop(self.name.clone())); - for node in &cycle { - show!(TsortError::LoopNode((*node).to_string())); + for &node in &cycle { + show!(LoopNode(node)); } let u = cycle[0]; let v = cycle[1]; @@ -240,41 +249,76 @@ impl<'input> Graph<'input> { let mut nodes: Vec<_> = self.nodes.keys().collect(); nodes.sort_unstable(); - let mut visited = HashSet::new(); + let mut visited = HashMap::new(); let mut stack = Vec::with_capacity(self.nodes.len()); for node in nodes { - if !visited.contains(node) && self.dfs(node, &mut visited, &mut stack) { - return stack; + if self.dfs(node, &mut visited, &mut stack) { + // last element in the stack appears twice: at the begin + // and at the end of the loop + let (loop_entry, _) = stack.pop().expect("loop is not empty"); + + // skip the prefix which doesn't belong to the loop + return stack + .into_iter() + .map(|(node, _)| node) + .skip_while(|&node| node != loop_entry) + .collect(); } } - unreachable!(); + unreachable!("detect_cycle is expected to be called only on graphs with cycles"); } - fn dfs( - &self, + fn dfs<'a>( + &'a self, node: &'input str, - visited: &mut HashSet<&'input str>, - stack: &mut Vec<&'input str>, + visited: &mut HashMap<&'input str, VisitedState>, + stack: &mut Vec<(&'input str, &'a [&'input str])>, ) -> bool { - if stack.contains(&node) { - return true; - } - if visited.contains(&node) { + stack.push(( + node, + self.nodes.get(node).map_or(&[], |n| &n.successor_names), + )); + let state = *visited.entry(node).or_insert(VisitedState::Opened); + + if state == VisitedState::Closed { return false; } - visited.insert(node); - stack.push(node); - - if let Some(successor_names) = self.nodes.get(node).map(|n| &n.successor_names) { - for &successor in successor_names { - if self.dfs(successor, visited, stack) { - return true; + while let Some((node, pending_successors)) = stack.pop() { + let Some((&next_node, pending)) = pending_successors.split_first() else { + // no more pending successors in the list -> close the node + visited.insert(node, VisitedState::Closed); + continue; + }; + + // schedule processing for the pending part of successors for this node + stack.push((node, pending)); + + match visited.entry(next_node) { + Entry::Vacant(v) => { + // It's a first time we enter this node + v.insert(VisitedState::Opened); + stack.push(( + next_node, + self.nodes + .get(next_node) + .map_or(&[], |n| &n.successor_names), + )); + } + Entry::Occupied(o) => { + if *o.get() == VisitedState::Opened { + // we are entering the same opened node again -> loop found + // stack contains it + // + // But part of the stack may not be belonging to this loop + // push found node to the stack to be able to trace the beginning of the loop + stack.push((next_node, &[])); + return true; + } } } } - stack.pop(); false } } diff --git a/tests/by-util/test_tsort.rs b/tests/by-util/test_tsort.rs index eb1a8630d31..9033ea4a7ef 100644 --- a/tests/by-util/test_tsort.rs +++ b/tests/by-util/test_tsort.rs @@ -122,3 +122,34 @@ fn test_two_cycles() { .stdout_is("a\nc\nd\nb\n") .stderr_is("tsort: -: input contains a loop:\ntsort: b\ntsort: c\ntsort: -: input contains a loop:\ntsort: b\ntsort: d\n"); } + +#[test] +fn test_long_loop_no_stack_overflow() { + use std::fmt::Write; + const N: usize = 100_000; + let mut input = String::new(); + for v in 0..N { + let next = (v + 1) % N; + let _ = write!(input, "{v} {next} "); + } + new_ucmd!() + .pipe_in(input) + .fails_with_code(1) + .stderr_contains("tsort: -: input contains a loop"); +} + +#[test] +fn test_loop_for_iterative_dfs_correctness() { + let input = r" + A B + B C + C B + C D + D A + "; + + new_ucmd!() + .pipe_in(input) + .fails_with_code(1) + .stderr_contains("tsort: -: input contains a loop:\ntsort: B\ntsort: C"); +}