• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4#
5# Copyright (c) 2025 Northeastern University
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19from typing import List, Iterable, Tuple, Union, Optional, Callable, Set
20
21import networkx as nx
22from networkx import DiGraph
23
24from ohos.sbom.data.ninja_json import NinjaJson
25from ohos.sbom.data.target import Target
26
27
28class DependGraphAnalyzer:
29    """
30     Dependency graph service based on networkx.DiGraph
31    """
32
33    def __init__(self, src: Union[NinjaJson, List[Target]]) -> None:
34        if isinstance(src, NinjaJson):
35            targets = list(src.all_targets())
36        elif isinstance(src, list):
37            targets = src
38        else:
39            raise TypeError("src must be NinjaJson or List[Target]")
40
41        self._graph = self._build_graph(targets)
42
43    @property
44    def graph(self) -> DiGraph:
45        return self._graph
46
47    @staticmethod
48    def _build_graph(targets: List[Target]) -> DiGraph:
49        g = nx.DiGraph()
50        for t in targets:
51            g.add_node(t.target_name, data=t)
52        target_names = {t.target_name for t in targets}
53        for t in targets:
54            for dep in t.deps:
55                if dep in target_names:
56                    g.add_edge(t.target_name, dep)
57        return g
58
59    def nodes(self) -> List[str]:
60        return list(self._graph.nodes)
61
62    def edges(self) -> List[Tuple[str, str]]:
63        return list(self._graph.edges)
64
65    def get_target(self, name: str) -> Target:
66        return self._graph.nodes[name]["data"]
67
68    def predecessors(self, name: str) -> List[str]:
69        return list(self._graph.predecessors(name))
70
71    def successors(self, name: str) -> List[str]:
72        return list(self._graph.successors(name))
73
74    def ancestors(self, name: str) -> List[str]:
75        return list(nx.ancestors(self._graph, name))
76
77    def descendants(self, name: str) -> List[str]:
78        return list(nx.descendants(self._graph, name))
79
80    def shortest_path(self, source: str, target: str) -> List[str]:
81        return nx.shortest_path(self._graph, source, target)
82
83    def sub_graph(self, nodes: Iterable[str]):
84        return self._graph.subgraph(nodes).copy()
85
86    def add_virtual_root(self, root_name: str, children: List[str]):
87        virtual_target = type("VirtualTarget", (), {
88            "target_name": root_name,
89            "type": "virtual_root",
90            "outputs": [],
91            "source_outputs": {}
92        })()
93        self._graph.add_node(root_name, data=virtual_target)
94
95        for child in children:
96            if child not in self._graph:
97                raise ValueError(f"virtual root '{child}' not exist in graph")
98            self._graph.add_edge(root_name, child)
99
100    def remove_virtual_root(self, root_name: str):
101        if root_name in self._graph:
102            self._graph.remove_node(root_name)
103
104    def depend_subgraph(
105            self,
106            src: Union[str, Target],
107            *,
108            max_depth: int,
109    ) -> DiGraph:
110
111        if isinstance(src, Target):
112            src = src.target_name
113        if max_depth is None:
114            max_depth = len(self._graph)
115        return nx.ego_graph(self.graph, src, radius=max_depth, center=True, undirected=False)
116
117    def dfs_downstream(
118            self,
119            start: Union[str, Target],
120            max_depth: Optional[int] = None,
121            pre_visit: Optional[Callable[[str, int, Optional[str]], bool]] = None,
122            post_visit: Optional[Callable[[str, int, Optional[str]], None]] = None
123    ) -> List[str]:
124        """
125        Perform depth-first traversal from the start point along downstream dependencies (successors)
126
127        Parameters:
128            start: traversal start point (target name or Target object)
129            max_depth: maximum traversal depth (None means no limit)
130            pre_visit: callback function before visiting a node
131                Parameters: (current node name, current depth, parent node name)
132                Return: bool - whether to continue traversing the node's children (False skips children)
133            post_visit: callback function after visiting a node
134                Parameters: (current node name, current depth, parent node name)
135
136        Returns:
137            List of nodes in traversal order
138        """
139        return self._dfs(
140            start=start,
141            neighbor_func=lambda n: self.successors(n),
142            max_depth=max_depth,
143            pre_visit=pre_visit,
144            post_visit=post_visit
145        )
146
147    def dfs_upstream(
148            self,
149            start: Union[str, Target],
150            max_depth: Optional[int] = None,
151            pre_visit: Optional[Callable[[str, int, Optional[str]], bool]] = None,
152            post_visit: Optional[Callable[[str, int, Optional[str]], None]] = None
153    ) -> List[str]:
154        return self._dfs(
155            start=start,
156            neighbor_func=lambda n: self.predecessors(n),
157            max_depth=max_depth,
158            pre_visit=pre_visit,
159            post_visit=post_visit
160        )
161
162    def _dfs(
163            self,
164            start: Union[str, Target],
165            neighbor_func: Callable[[str], List[str]],
166            max_depth: Optional[int],
167            pre_visit: Optional[Callable[[str, int, Optional[str]], bool]],
168            post_visit: Optional[Callable[[str, int, Optional[str]], None]]
169    ) -> List[str]:
170        if isinstance(start, Target):
171            start_name = start.target_name
172        else:
173            start_name = start
174        if start_name not in self.nodes():
175            raise ValueError(f"node {start_name} not exist in graph")
176
177        visited = set()
178        traversal_order = []
179        stack = [(start_name, 0, None, False)]
180
181        while stack:
182            node, depth, parent, is_processed = stack.pop()
183
184            if max_depth is not None and depth > max_depth:
185                continue
186
187            if not is_processed:
188                continue_traverse = self._process_node_pre(node=node, depth=depth, parent=parent, visited=visited,
189                                                           traversal_order=traversal_order, stack=stack,
190                                                           pre_visit=pre_visit)
191                if continue_traverse:
192                    self._push_neighbors(node=node, depth=depth, parent=parent, visited=visited,
193                                         neighbor_func=neighbor_func, stack=stack)
194            else:
195                self._process_node_post(node=node, depth=depth, parent=parent, post_visit=post_visit)
196
197        return traversal_order
198
199    def _process_node_pre(
200            self,
201            node: str,
202            depth: int,
203            parent: Optional[str],
204            visited: Set[str],
205            traversal_order: List[str],
206            stack: List[Tuple[str, int, Optional[str], bool]],
207            pre_visit: Optional[Callable[[str, int, Optional[str]], bool]]
208    ) -> bool:
209        """Handle pre-visit logic and return whether to continue traversing children."""
210        if node in visited:
211            return False
212        visited.add(node)
213        traversal_order.append(node)
214
215        continue_traverse = True
216        if pre_visit is not None:
217            try:
218                continue_traverse = pre_visit(node, depth, parent)
219            except Exception as e:
220                raise RuntimeError(f"pre_visit execute failed: {e}") from e
221
222        # Push node back for post-processing
223        stack.append((node, depth, parent, True))
224
225        return continue_traverse
226
227    def _push_neighbors(
228            self,
229            node: str,
230            depth: int,
231            parent: Optional[str],
232            visited: Set[str],
233            neighbor_func: Callable[[str], List[str]],
234            stack: List[Tuple[str, int, Optional[str], bool]]
235    ):
236        """Push unvisited neighbors onto the stack in reverse order."""
237        neighbors = neighbor_func(node)
238        for neighbor in reversed(neighbors):
239            if neighbor not in visited:
240                stack.append((neighbor, depth + 1, parent, False))
241
242    def _process_node_post(
243            self,
244            node: str,
245            depth: int,
246            parent: Optional[str],
247            post_visit: Optional[Callable[[str, int, Optional[str]], None]]
248    ):
249        """Handle post-visit logic."""
250        if post_visit is not None:
251            try:
252                post_visit(node, depth, parent)
253            except Exception as e:
254                raise RuntimeError(f"post_visit execute failed: {e}") from e
255