• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1__all__ = ["TopologicalSorter", "CycleError"]
2
3_NODE_OUT = -1
4_NODE_DONE = -2
5
6
7class _NodeInfo:
8    __slots__ = "node", "npredecessors", "successors"
9
10    def __init__(self, node):
11        # The node this class is augmenting.
12        self.node = node
13
14        # Number of predecessors, generally >= 0. When this value falls to 0,
15        # and is returned by get_ready(), this is set to _NODE_OUT and when the
16        # node is marked done by a call to done(), set to _NODE_DONE.
17        self.npredecessors = 0
18
19        # List of successor nodes. The list can contain duplicated elements as
20        # long as they're all reflected in the successor's npredecessors attribute.
21        self.successors = []
22
23
24class CycleError(ValueError):
25    """Subclass of ValueError raised by TopologicalSorter.prepare if cycles
26    exist in the working graph.
27
28    If multiple cycles exist, only one undefined choice among them will be reported
29    and included in the exception. The detected cycle can be accessed via the second
30    element in the *args* attribute of the exception instance and consists in a list
31    of nodes, such that each node is, in the graph, an immediate predecessor of the
32    next node in the list. In the reported list, the first and the last node will be
33    the same, to make it clear that it is cyclic.
34    """
35
36    pass
37
38
39class TopologicalSorter:
40    """Provides functionality to topologically sort a graph of hashable nodes"""
41
42    def __init__(self, graph=None):
43        self._node2info = {}
44        self._ready_nodes = None
45        self._npassedout = 0
46        self._nfinished = 0
47
48        if graph is not None:
49            for node, predecessors in graph.items():
50                self.add(node, *predecessors)
51
52    def _get_nodeinfo(self, node):
53        if (result := self._node2info.get(node)) is None:
54            self._node2info[node] = result = _NodeInfo(node)
55        return result
56
57    def add(self, node, *predecessors):
58        """Add a new node and its predecessors to the graph.
59
60        Both the *node* and all elements in *predecessors* must be hashable.
61
62        If called multiple times with the same node argument, the set of dependencies
63        will be the union of all dependencies passed in.
64
65        It is possible to add a node with no dependencies (*predecessors* is not provided)
66        as well as provide a dependency twice. If a node that has not been provided before
67        is included among *predecessors* it will be automatically added to the graph with
68        no predecessors of its own.
69
70        Raises ValueError if called after "prepare".
71        """
72        if self._ready_nodes is not None:
73            raise ValueError("Nodes cannot be added after a call to prepare()")
74
75        # Create the node -> predecessor edges
76        nodeinfo = self._get_nodeinfo(node)
77        nodeinfo.npredecessors += len(predecessors)
78
79        # Create the predecessor -> node edges
80        for pred in predecessors:
81            pred_info = self._get_nodeinfo(pred)
82            pred_info.successors.append(node)
83
84    def prepare(self):
85        """Mark the graph as finished and check for cycles in the graph.
86
87        If any cycle is detected, "CycleError" will be raised, but "get_ready" can
88        still be used to obtain as many nodes as possible until cycles block more
89        progress. After a call to this function, the graph cannot be modified and
90        therefore no more nodes can be added using "add".
91        """
92        if self._ready_nodes is not None:
93            raise ValueError("cannot prepare() more than once")
94
95        self._ready_nodes = [
96            i.node for i in self._node2info.values() if i.npredecessors == 0
97        ]
98        # ready_nodes is set before we look for cycles on purpose:
99        # if the user wants to catch the CycleError, that's fine,
100        # they can continue using the instance to grab as many
101        # nodes as possible before cycles block more progress
102        cycle = self._find_cycle()
103        if cycle:
104            raise CycleError(f"nodes are in a cycle", cycle)
105
106    def get_ready(self):
107        """Return a tuple of all the nodes that are ready.
108
109        Initially it returns all nodes with no predecessors; once those are marked
110        as processed by calling "done", further calls will return all new nodes that
111        have all their predecessors already processed. Once no more progress can be made,
112        empty tuples are returned.
113
114        Raises ValueError if called without calling "prepare" previously.
115        """
116        if self._ready_nodes is None:
117            raise ValueError("prepare() must be called first")
118
119        # Get the nodes that are ready and mark them
120        result = tuple(self._ready_nodes)
121        n2i = self._node2info
122        for node in result:
123            n2i[node].npredecessors = _NODE_OUT
124
125        # Clean the list of nodes that are ready and update
126        # the counter of nodes that we have returned.
127        self._ready_nodes.clear()
128        self._npassedout += len(result)
129
130        return result
131
132    def is_active(self):
133        """Return ``True`` if more progress can be made and ``False`` otherwise.
134
135        Progress can be made if cycles do not block the resolution and either there
136        are still nodes ready that haven't yet been returned by "get_ready" or the
137        number of nodes marked "done" is less than the number that have been returned
138        by "get_ready".
139
140        Raises ValueError if called without calling "prepare" previously.
141        """
142        if self._ready_nodes is None:
143            raise ValueError("prepare() must be called first")
144        return self._nfinished < self._npassedout or bool(self._ready_nodes)
145
146    def __bool__(self):
147        return self.is_active()
148
149    def done(self, *nodes):
150        """Marks a set of nodes returned by "get_ready" as processed.
151
152        This method unblocks any successor of each node in *nodes* for being returned
153        in the future by a call to "get_ready".
154
155        Raises :exec:`ValueError` if any node in *nodes* has already been marked as
156        processed by a previous call to this method, if a node was not added to the
157        graph by using "add" or if called without calling "prepare" previously or if
158        node has not yet been returned by "get_ready".
159        """
160
161        if self._ready_nodes is None:
162            raise ValueError("prepare() must be called first")
163
164        n2i = self._node2info
165
166        for node in nodes:
167
168            # Check if we know about this node (it was added previously using add()
169            if (nodeinfo := n2i.get(node)) is None:
170                raise ValueError(f"node {node!r} was not added using add()")
171
172            # If the node has not being returned (marked as ready) previously, inform the user.
173            stat = nodeinfo.npredecessors
174            if stat != _NODE_OUT:
175                if stat >= 0:
176                    raise ValueError(
177                        f"node {node!r} was not passed out (still not ready)"
178                    )
179                elif stat == _NODE_DONE:
180                    raise ValueError(f"node {node!r} was already marked done")
181                else:
182                    assert False, f"node {node!r}: unknown status {stat}"
183
184            # Mark the node as processed
185            nodeinfo.npredecessors = _NODE_DONE
186
187            # Go to all the successors and reduce the number of predecessors, collecting all the ones
188            # that are ready to be returned in the next get_ready() call.
189            for successor in nodeinfo.successors:
190                successor_info = n2i[successor]
191                successor_info.npredecessors -= 1
192                if successor_info.npredecessors == 0:
193                    self._ready_nodes.append(successor)
194            self._nfinished += 1
195
196    def _find_cycle(self):
197        n2i = self._node2info
198        stack = []
199        itstack = []
200        seen = set()
201        node2stacki = {}
202
203        for node in n2i:
204            if node in seen:
205                continue
206
207            while True:
208                if node in seen:
209                    # If we have seen already the node and is in the
210                    # current stack we have found a cycle.
211                    if node in node2stacki:
212                        return stack[node2stacki[node] :] + [node]
213                    # else go on to get next successor
214                else:
215                    seen.add(node)
216                    itstack.append(iter(n2i[node].successors).__next__)
217                    node2stacki[node] = len(stack)
218                    stack.append(node)
219
220                # Backtrack to the topmost stack entry with
221                # at least another successor.
222                while stack:
223                    try:
224                        node = itstack[-1]()
225                        break
226                    except StopIteration:
227                        del node2stacki[stack.pop()]
228                        itstack.pop()
229                else:
230                    break
231        return None
232
233    def static_order(self):
234        """Returns an iterable of nodes in a topological order.
235
236        The particular order that is returned may depend on the specific
237        order in which the items were inserted in the graph.
238
239        Using this method does not require to call "prepare" or "done". If any
240        cycle is detected, :exc:`CycleError` will be raised.
241        """
242        self.prepare()
243        while self.is_active():
244            node_group = self.get_ready()
245            yield from node_group
246            self.done(*node_group)
247