1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and methods for processing debugger-decorated graphs.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.core.framework import graph_pb2 21from tensorflow.python.framework import op_def_registry 22from tensorflow.python.platform import tf_logging as logging 23 24 25def parse_node_or_tensor_name(name): 26 """Get the node name from a string that can be node or tensor name. 27 28 Args: 29 name: An input node name (e.g., "node_a") or tensor name (e.g., 30 "node_a:0"), as a str. 31 32 Returns: 33 1) The node name, as a str. If the input name is a tensor name, i.e., 34 consists of a colon, the final colon and the following output slot 35 will be stripped. 36 2) If the input name is a tensor name, the output slot, as an int. If 37 the input name is not a tensor name, None. 38 """ 39 40 if ":" in name and not name.endswith(":"): 41 node_name = name[:name.rfind(":")] 42 output_slot = int(name[name.rfind(":") + 1:]) 43 44 return node_name, output_slot 45 else: 46 return name, None 47 48 49def get_node_name(element_name): 50 node_name, _ = parse_node_or_tensor_name(element_name) 51 return node_name 52 53 54def get_output_slot(element_name): 55 """Get the output slot number from the name of a graph element. 56 57 If element_name is a node name without output slot at the end, 0 will be 58 assumed. 59 60 Args: 61 element_name: (`str`) name of the graph element in question. 62 63 Returns: 64 (`int`) output slot number. 65 """ 66 _, output_slot = parse_node_or_tensor_name(element_name) 67 return output_slot if output_slot is not None else 0 68 69 70def is_copy_node(node_name): 71 """Determine whether a node name is that of a debug Copy node. 72 73 Such nodes are inserted by TensorFlow core upon request in 74 RunOptions.debug_options.debug_tensor_watch_opts. 75 76 Args: 77 node_name: Name of the node. 78 79 Returns: 80 A bool indicating whether the input argument is the name of a debug Copy 81 node. 82 """ 83 return node_name.startswith("__copy_") 84 85 86def is_debug_node(node_name): 87 """Determine whether a node name is that of a debug node. 88 89 Such nodes are inserted by TensorFlow core upon request in 90 RunOptions.debug_options.debug_tensor_watch_opts. 91 92 Args: 93 node_name: Name of the node. 94 95 Returns: 96 A bool indicating whether the input argument is the name of a debug node. 97 """ 98 return node_name.startswith("__dbg_") 99 100 101def parse_debug_node_name(node_name): 102 """Parse the name of a debug node. 103 104 Args: 105 node_name: Name of the debug node. 106 107 Returns: 108 1. Name of the watched node, as a str. 109 2. Output slot index of the watched tensor, as an int. 110 3. Index of the debug node, as an int. 111 4. Name of the debug op, as a str, e.g, "DebugIdentity". 112 113 Raises: 114 ValueError: If the input node name is not a valid debug node name. 115 """ 116 prefix = "__dbg_" 117 118 name = node_name 119 if not name.startswith(prefix): 120 raise ValueError("Invalid prefix in debug node name: '%s'" % node_name) 121 122 name = name[len(prefix):] 123 124 if name.count("_") < 2: 125 raise ValueError("Invalid debug node name: '%s'" % node_name) 126 127 debug_op = name[name.rindex("_") + 1:] 128 name = name[:name.rindex("_")] 129 130 debug_op_index = int(name[name.rindex("_") + 1:]) 131 name = name[:name.rindex("_")] 132 133 if name.count(":") != 1: 134 raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name) 135 136 watched_node_name = name[:name.index(":")] 137 watched_output_slot = int(name[name.index(":") + 1:]) 138 139 return watched_node_name, watched_output_slot, debug_op_index, debug_op 140 141 142class GraphTracingReachedDestination(Exception): 143 pass 144 145 146class DFSGraphTracer(object): 147 """Graph input tracer using depth-first search.""" 148 149 def __init__(self, 150 input_lists, 151 skip_node_names=None, 152 destination_node_name=None): 153 """Constructor of _DFSGraphTracer. 154 155 Args: 156 input_lists: A list of dicts. Each dict is an adjacency (input) map from 157 the recipient node name as the key and the list of input node names 158 as the value. 159 skip_node_names: Optional: a list of node names to skip tracing. 160 destination_node_name: Optional: destination node name. If not `None`, it 161 should be the name of a destination not as a str and the graph tracing 162 will raise GraphTracingReachedDestination as soon as the node has been 163 reached. 164 165 Raises: 166 GraphTracingReachedDestination: if stop_at_node_name is not None and 167 the specified node is reached. 168 """ 169 170 self._input_lists = input_lists 171 self._skip_node_names = skip_node_names 172 173 self._inputs = [] 174 self._visited_nodes = [] 175 self._depth_count = 0 176 self._depth_list = [] 177 178 self._destination_node_name = destination_node_name 179 180 def trace(self, graph_element_name): 181 """Trace inputs. 182 183 Args: 184 graph_element_name: Name of the node or an output tensor of the node, as a 185 str. 186 187 Raises: 188 GraphTracingReachedDestination: if destination_node_name of this tracer 189 object is not None and the specified node is reached. 190 """ 191 self._depth_count += 1 192 193 node_name = get_node_name(graph_element_name) 194 if node_name == self._destination_node_name: 195 raise GraphTracingReachedDestination() 196 197 if node_name in self._skip_node_names: 198 return 199 if node_name in self._visited_nodes: 200 return 201 202 self._visited_nodes.append(node_name) 203 204 for input_list in self._input_lists: 205 if node_name not in input_list: 206 continue 207 for inp in input_list[node_name]: 208 if get_node_name(inp) in self._visited_nodes: 209 continue 210 self._inputs.append(inp) 211 self._depth_list.append(self._depth_count) 212 self.trace(inp) 213 214 self._depth_count -= 1 215 216 def inputs(self): 217 return self._inputs 218 219 def depth_list(self): 220 return self._depth_list 221 222 223def _infer_device_name(graph_def): 224 """Infer device name from a partition GraphDef.""" 225 device_name = None 226 for node in graph_def.node: 227 if node.device: 228 device_name = node.device 229 break 230 if device_name is None: 231 logging.warn( 232 "Failed to infer device name from partition GraphDef: none of the " 233 "nodes of the GraphDef has a non-empty device name.") 234 return device_name 235 236 237class DebugGraph(object): 238 """Represents a debugger-decorated graph.""" 239 240 def __init__(self, debug_graph_def, device_name=None): 241 self._debug_graph_def = debug_graph_def 242 self._non_debug_graph_def = None 243 244 self._node_attributes = {} 245 self._node_inputs = {} 246 self._node_reversed_ref_inputs = {} 247 self._node_ctrl_inputs = {} 248 self._node_recipients = {} 249 self._node_ctrl_recipients = {} 250 self._node_devices = {} 251 self._node_op_types = {} 252 self._copy_send_nodes = [] 253 self._ref_args = {} 254 255 self._device_name = device_name 256 if not self._device_name: 257 self._device_name = _infer_device_name(debug_graph_def) 258 259 for node in debug_graph_def.node: 260 self._process_debug_graph_node(node) 261 262 self._prune_non_control_edges_of_debug_ops() 263 self._prune_control_edges_of_debug_ops() 264 self._prune_nodes_from_input_and_recipient_maps(self._get_copy_nodes()) 265 266 self._populate_recipient_maps() 267 268 def _process_debug_graph_node(self, node): 269 """Process a node from the debug GraphDef. 270 271 Args: 272 node: (NodeDef) A partition-graph node to be processed. 273 274 Raises: 275 ValueError: If duplicate node names are encountered. 276 """ 277 if is_debug_node(node.name): 278 # This is a debug node. Parse the node name and retrieve the 279 # information about debug watches on tensors. But do not include 280 # the node in the graph. 281 return 282 283 if node.name in self._node_inputs: 284 raise ValueError("Duplicate node name on device %s: '%s'" % 285 (self._device_name, node.name)) 286 287 self._node_attributes[node.name] = node.attr 288 289 self._node_inputs[node.name] = [] 290 self._node_ctrl_inputs[node.name] = [] 291 self._node_recipients[node.name] = [] 292 self._node_ctrl_recipients[node.name] = [] 293 294 if node.name not in self._node_devices: 295 self._node_devices[node.name] = set() 296 self._node_devices[node.name].add( 297 node.device if node.device else self._device_name) 298 self._node_op_types[node.name] = node.op 299 self._ref_args[node.name] = self._get_ref_args(node) 300 301 for inp in node.input: 302 if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"): 303 self._copy_send_nodes.append(node.name) 304 305 if inp.startswith("^"): 306 cinp = inp[1:] 307 self._node_ctrl_inputs[node.name].append(cinp) 308 else: 309 self._node_inputs[node.name].append(inp) 310 311 def _get_ref_args(self, node): 312 """Determine whether an input of an op is ref-type. 313 314 Args: 315 node: A `NodeDef`. 316 317 Returns: 318 A list of the arg names (as strs) that are ref-type. 319 """ 320 op_def = op_def_registry.get(node.op) 321 if op_def is None: 322 return [] 323 324 ref_args = [] 325 for i, output_arg in enumerate(op_def.output_arg): 326 if output_arg.is_ref: 327 arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i)) 328 ref_args.append(arg_name) 329 return ref_args 330 331 def _get_copy_nodes(self): 332 """Find all Copy nodes in the loaded graph.""" 333 copy_nodes = [] 334 for node in self._node_inputs: 335 if is_copy_node(node): 336 copy_nodes.append(node) 337 return copy_nodes 338 339 def _prune_non_control_edges_of_debug_ops(self): 340 """Prune (non-control) edges related to debug ops. 341 342 Prune the Copy ops and associated _Send ops inserted by the debugger out 343 from the non-control inputs and output recipients map. Replace the inputs 344 and recipients with original ones. 345 """ 346 for node in self._node_inputs: 347 inputs = self._node_inputs[node] 348 349 for i, inp in enumerate(inputs): 350 if is_copy_node(inp): 351 # Find the input to the Copy node, which should be the original 352 # input to the node. 353 orig_inp = self._node_inputs[inp][0] 354 inputs[i] = orig_inp 355 356 def _prune_control_edges_of_debug_ops(self): 357 """Prune control edges related to the debug ops.""" 358 for node in self._node_ctrl_inputs: 359 ctrl_inputs = self._node_ctrl_inputs[node] 360 debug_op_inputs = [] 361 for ctrl_inp in ctrl_inputs: 362 if is_debug_node(ctrl_inp): 363 debug_op_inputs.append(ctrl_inp) 364 for debug_op_inp in debug_op_inputs: 365 ctrl_inputs.remove(debug_op_inp) 366 367 def _populate_recipient_maps(self): 368 """Populate the map from node name to recipient(s) of its output(s). 369 370 This method also populates the input map based on reversed ref edges. 371 """ 372 for node in self._node_inputs: 373 inputs = self._node_inputs[node] 374 for inp in inputs: 375 inp = get_node_name(inp) 376 if inp not in self._node_recipients: 377 self._node_recipients[inp] = [] 378 self._node_recipients[inp].append(node) 379 380 if inp in self._ref_args: 381 if inp not in self._node_reversed_ref_inputs: 382 self._node_reversed_ref_inputs[inp] = [] 383 self._node_reversed_ref_inputs[inp].append(node) 384 385 for node in self._node_ctrl_inputs: 386 ctrl_inputs = self._node_ctrl_inputs[node] 387 for ctrl_inp in ctrl_inputs: 388 if ctrl_inp in self._copy_send_nodes: 389 continue 390 391 if ctrl_inp not in self._node_ctrl_recipients: 392 self._node_ctrl_recipients[ctrl_inp] = [] 393 self._node_ctrl_recipients[ctrl_inp].append(node) 394 395 def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune): 396 """Prune nodes out of input and recipient maps. 397 398 Args: 399 nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned. 400 """ 401 for node in nodes_to_prune: 402 del self._node_inputs[node] 403 del self._node_ctrl_inputs[node] 404 del self._node_recipients[node] 405 del self._node_ctrl_recipients[node] 406 407 def _reconstruct_non_debug_graph_def(self): 408 """Reconstruct non-debug GraphDef. 409 410 Non-debug GraphDef means the original GraphDef without the Copy* and Debug 411 nodes inserted by the debugger. 412 """ 413 if self._non_debug_graph_def: 414 return 415 416 self._non_debug_graph_def = graph_pb2.GraphDef() 417 for node in self._debug_graph_def.node: 418 if is_copy_node(node.name) or is_debug_node(node.name): 419 continue 420 421 new_node = self._non_debug_graph_def.node.add() 422 new_node.CopyFrom(node) 423 424 # Redo the list of inputs, because in _debug_graph_def, the list can 425 # consist of Copy* and Debug* nodes inserted by the debugger. Those will 426 # be replaced with the original inputs here. 427 del new_node.input[:] 428 for inp in self._node_inputs[node.name]: 429 new_node.input.append(inp) 430 for ctrl_inp in self._node_ctrl_inputs[node.name]: 431 new_node.input.append("^" + ctrl_inp) 432 433 @property 434 def device_name(self): 435 return self._device_name 436 437 @property 438 def debug_graph_def(self): 439 """The debugger-decorated GraphDef.""" 440 return self._debug_graph_def 441 442 @property 443 def non_debug_graph_def(self): 444 """The GraphDef without the Copy* and Debug* nodes added by the debugger.""" 445 self._reconstruct_non_debug_graph_def() 446 return self._non_debug_graph_def 447 448 @property 449 def node_devices(self): 450 return self._node_devices 451 452 @property 453 def node_op_types(self): 454 return self._node_op_types 455 456 @property 457 def node_attributes(self): 458 return self._node_attributes 459 460 @property 461 def node_inputs(self): 462 return self._node_inputs 463 464 @property 465 def node_ctrl_inputs(self): 466 return self._node_ctrl_inputs 467 468 @property 469 def node_reversed_ref_inputs(self): 470 return self._node_reversed_ref_inputs 471 472 @property 473 def node_recipients(self): 474 return self._node_recipients 475 476 @property 477 def node_ctrl_recipients(self): 478 return self._node_ctrl_recipients 479 480 481def reconstruct_non_debug_graph_def(debug_graph_def): 482 """Reconstruct original (non-debugger-decorated) partition GraphDef. 483 484 This method strips the input `tf.compat.v1.GraphDef` of the Copy* and 485 Debug*-type nodes inserted by the debugger. 486 487 The reconstructed partition graph is identical to the original (i.e., 488 non-debugger-decorated) partition graph except in the following respects: 489 1) The exact names of the runtime-inserted internal nodes may differ. 490 These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops. 491 2) As a consequence of 1, the nodes that receive input directly from such 492 send- and recv-type ops will have different input names. 493 3) The parallel_iteration attribute of while-loop Enter ops are set to 1. 494 495 Args: 496 debug_graph_def: The debugger-decorated `tf.compat.v1.GraphDef`, with the 497 debugger-inserted Copy* and Debug* nodes. 498 499 Returns: 500 The reconstructed `tf.compat.v1.GraphDef` stripped of the debugger-inserted 501 nodes. 502 """ 503 return DebugGraph(debug_graph_def).non_debug_graph_def 504