Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 71 additions & 27 deletions src/uu/tsort/src/tsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
// file that was distributed with this source code.
//spell-checker:ignore TAOCP indegree
use clap::{Arg, Command};
use std::collections::{HashMap, HashSet, VecDeque};
use std::collections::hash_map::Entry;
use std::collections::{HashMap, VecDeque};
use std::ffi::OsString;
use std::path::Path;
use thiserror::Error;
Expand Down Expand Up @@ -34,13 +35,15 @@ enum TsortError {
/// The graph contains a cycle.
#[error("{input}: {message}", input = .0, message = translate!("tsort-error-loop"))]
Loop(String),

/// A particular node in a cycle. (This is mainly used for printing.)
#[error("{0}")]
LoopNode(String),
}

// Auxiliary struct, just for printing loop nodes via show! macro
#[derive(Debug, Error)]
#[error("{0}")]
struct LoopNode<'a>(&'a str);

impl UError for TsortError {}
impl UError for LoopNode<'_> {}

#[uucore::main]
pub fn uumain(args: impl uucore::Args) -> UResult<()> {
Expand Down Expand Up @@ -131,6 +134,12 @@ struct Graph<'input> {
nodes: HashMap<&'input str, Node<'input>>,
}

#[derive(Clone, Copy, PartialEq, Eq)]
enum VisitedState {
Opened,
Closed,
}

impl<'input> Graph<'input> {
fn new(name: String) -> Graph<'input> {
Self {
Expand Down Expand Up @@ -224,8 +233,8 @@ impl<'input> Graph<'input> {
fn find_and_break_cycle(&mut self, frontier: &mut VecDeque<&'input str>) {
let cycle = self.detect_cycle();
show!(TsortError::Loop(self.name.clone()));
for node in &cycle {
show!(TsortError::LoopNode((*node).to_string()));
for &node in &cycle {
show!(LoopNode(node));
}
let u = cycle[0];
let v = cycle[1];
Expand All @@ -240,41 +249,76 @@ impl<'input> Graph<'input> {
let mut nodes: Vec<_> = self.nodes.keys().collect();
nodes.sort_unstable();

let mut visited = HashSet::new();
let mut visited = HashMap::new();
let mut stack = Vec::with_capacity(self.nodes.len());
for node in nodes {
if !visited.contains(node) && self.dfs(node, &mut visited, &mut stack) {
return stack;
if self.dfs(node, &mut visited, &mut stack) {
// last element in the stack appears twice: at the begin
// and at the end of the loop
let (loop_entry, _) = stack.pop().expect("loop is not empty");

// skip the prefix which doesn't belong to the loop
return stack
.into_iter()
.map(|(node, _)| node)
.skip_while(|&node| node != loop_entry)
.collect();
}
}
unreachable!();
unreachable!("detect_cycle is expected to be called only on graphs with cycles");
}

fn dfs(
&self,
fn dfs<'a>(
&'a self,
node: &'input str,
visited: &mut HashSet<&'input str>,
stack: &mut Vec<&'input str>,
visited: &mut HashMap<&'input str, VisitedState>,
stack: &mut Vec<(&'input str, &'a [&'input str])>,
) -> bool {
if stack.contains(&node) {
return true;
}
if visited.contains(&node) {
stack.push((
node,
self.nodes.get(node).map_or(&[], |n| &n.successor_names),
));
let state = *visited.entry(node).or_insert(VisitedState::Opened);

if state == VisitedState::Closed {
return false;
}

visited.insert(node);
stack.push(node);

if let Some(successor_names) = self.nodes.get(node).map(|n| &n.successor_names) {
for &successor in successor_names {
if self.dfs(successor, visited, stack) {
return true;
while let Some((node, pending_successors)) = stack.pop() {
let Some((&next_node, pending)) = pending_successors.split_first() else {
// no more pending successors in the list -> close the node
visited.insert(node, VisitedState::Closed);
continue;
};

// schedule processing for the pending part of successors for this node
stack.push((node, pending));

match visited.entry(next_node) {
Entry::Vacant(v) => {
// It's a first time we enter this node
v.insert(VisitedState::Opened);
stack.push((
next_node,
self.nodes
.get(next_node)
.map_or(&[], |n| &n.successor_names),
));
}
Entry::Occupied(o) => {
if *o.get() == VisitedState::Opened {
// we are entering the same opened node again -> loop found
// stack contains it
//
// But part of the stack may not be belonging to this loop
// push found node to the stack to be able to trace the beginning of the loop
stack.push((next_node, &[]));
return true;
}
}
}
}

stack.pop();
false
}
}
31 changes: 31 additions & 0 deletions tests/by-util/test_tsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,34 @@ fn test_two_cycles() {
.stdout_is("a\nc\nd\nb\n")
.stderr_is("tsort: -: input contains a loop:\ntsort: b\ntsort: c\ntsort: -: input contains a loop:\ntsort: b\ntsort: d\n");
}

#[test]
fn test_long_loop_no_stack_overflow() {
use std::fmt::Write;
const N: usize = 100_000;
let mut input = String::new();
for v in 0..N {
let next = (v + 1) % N;
let _ = write!(input, "{v} {next} ");
}
new_ucmd!()
.pipe_in(input)
.fails_with_code(1)
.stderr_contains("tsort: -: input contains a loop");
}

#[test]
fn test_loop_for_iterative_dfs_correctness() {
let input = r"
A B
B C
C B
C D
D A
";

new_ucmd!()
.pipe_in(input)
.fails_with_code(1)
.stderr_contains("tsort: -: input contains a loop:\ntsort: B\ntsort: C");
}
Loading