From 8102ddfd54b5d9598be2d87bf101ca47b43a4185 Mon Sep 17 00:00:00 2001 From: Adrien Dufraux Date: Wed, 21 Apr 2021 10:24:29 -0700 Subject: [PATCH] Add minimization algorithm for Acyclic FST --- gtn/functions.cpp | 117 ++++++++++++++++++++++++++++++++++++++++++++++ gtn/functions.h | 3 ++ 2 files changed, 120 insertions(+) diff --git a/gtn/functions.cpp b/gtn/functions.cpp index 361d313..62a447c 100644 --- a/gtn/functions.cpp +++ b/gtn/functions.cpp @@ -329,4 +329,121 @@ Graph viterbiPath(const Graph& g) { return detail::shortestPath(g); } +Graph minimizeAcyclicFST(const Graph& g){ + Graph graph; + std::vector oldToNew(g.numNodes(), -1); // a map between the nodes of g and the minimized graph. + std::vector oldProcessed; // store which nodes has been processed in the g graph + std::set predecessors; // a subset will become candidates to explore + + auto addPredecessors = [&predecessors, &g](int node) { + for (auto in_arc : g.in(node)){ + predecessors.insert(g.srcNode(in_arc)); + } + }; + + auto isToMerge = [&g, &oldToNew] (int node1, int node2){ + if (g.isStart(node1) == g.isStart(node2) && + g.isAccept(node1) == g.isAccept(node2) && + g.numOut(node1) == g.numOut(node2)){ + + //find out if there is a 1:1 mapping between the out arcs of node1 and node2 + if ( std::equal(g.out(node1).begin(), g.out(node1).end(), g.out(node2).begin(), [&g, &oldToNew](int a1, int a2){ + return (g.ilabel(a1) == g.ilabel(a2) && + g.olabel(a1) == g.olabel(a2) && + oldToNew[g.dstNode(a1)] == oldToNew[g.dstNode(a2)]);}) + ){ + return true; + } + } + return false; + }; + + + //Initialization + //a. Find all states with no outgoing arcs. (Since we are dealing with an acyclic FST, it is always possible.) + //b. Split the resulting set into 4 sets according to their START and ACCEPT status. + + int nodeStartAccept = -1 , nodeStartNoAccept = -1, nodeNoStartAccept = -1, nodeNoStartNoAccept = -1; + for (auto n = 0; n < g.numNodes(); ++n) { + if (g.numOut(n) == 0){ + + if (g.isStart(n) && g.isAccept(n)){ + if (nodeStartAccept < 0){ + nodeStartAccept = graph.addNode(true, true); + } + oldToNew[n] = nodeStartAccept; + + } else if (g.isStart(n) && !g.isAccept(n)){ + if (nodeStartNoAccept < 0){ + nodeStartNoAccept = graph.addNode(true, false); + } + oldToNew[n] = nodeStartNoAccept; + + } else if (!g.isStart(n) && g.isAccept(n)){ + if (nodeNoStartAccept < 0){ + nodeNoStartAccept = graph.addNode(false, true); + } + oldToNew[n] = nodeNoStartAccept; + + } else if (!g.isStart(n) && !g.isAccept(n)){ + if (nodeNoStartNoAccept < 0) { + nodeNoStartNoAccept = graph.addNode(false, false); + } + oldToNew[n] = nodeNoStartNoAccept; + } + + addPredecessors(n); // fill predecessors accordingly + oldProcessed.push_back(n); + } + } + + std::vector> candidateSets; + while (!predecessors.empty()) { + candidateSets.clear(); + // find candidates in predecessors and separate them in subsets with same: + // - start state + // - final state + // - out arcs (same ilabel, same olable, same destNode). + for (auto predNode : predecessors){ + //verfiy if this node lead to only processed nodes + if (std::all_of(g.out(predNode).begin(), g.out(predNode).end(), + [&g, &oldProcessed](int a) {return std::count(oldProcessed.begin(), oldProcessed.end(), g.dstNode(a)) > 0;})){ + // place this candidate in an exiting subset of candidateSets if possible + auto it = std::find_if(candidateSets.begin(), candidateSets.end(), [&g, &predNode, &isToMerge](std::vector subset){ + return isToMerge(subset[0], predNode); + }); + + //if subset not found + if (it == candidateSets.end()){ + candidateSets.push_back({predNode}); + } else{ + it->push_back(predNode); + } + } + } + + predecessors.clear(); + + for (auto subset : candidateSets) { + int mergedNode = graph.addNode(g.isStart(subset[0]), g.isAccept(subset[0])); + for (auto n : subset){ + addPredecessors(n); + oldProcessed.push_back(n); + oldToNew[n] = mergedNode; + } + //reattaching arcs as appropriate + for (auto a : g.out(subset[0])){ + graph.addArc( + mergedNode, + oldToNew[g.dstNode(a)], + g.ilabel(a), + g.olabel(a), + g.weight(a)); // should be change to support weighted graphs + } + } + } + + return graph; +} + } // namespace gtn diff --git a/gtn/functions.h b/gtn/functions.h index efaec63..2cdbce9 100644 --- a/gtn/functions.h +++ b/gtn/functions.h @@ -151,5 +151,8 @@ Graph viterbiScore(const Graph& g); */ Graph viterbiPath(const Graph& g); +/** Minimize an Acyclic FST */ +Graph minimizeAcyclicFST(const Graph& g); + /** @} */ } // namespace gtn