1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and methods for processing debugger-decorated graphs.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from six.moves import xrange # pylint: disable=redefined-builtin 21 22from tensorflow.core.framework import graph_pb2 23from tensorflow.python.framework import op_def_registry 24from tensorflow.python.platform import tf_logging as logging 25 26 27def parse_node_or_tensor_name(name): 28 """Get the node name from a string that can be node or tensor name. 29 30 Args: 31 name: An input node name (e.g., "node_a") or tensor name (e.g., 32 "node_a:0"), as a str. 33 34 Returns: 35 1) The node name, as a str. If the input name is a tensor name, i.e., 36 consists of a colon, the final colon and the following output slot 37 will be stripped. 38 2) If the input name is a tensor name, the output slot, as an int. If 39 the input name is not a tensor name, None. 40 """ 41 42 if ":" in name and not name.endswith(":"): 43 node_name = name[:name.rfind(":")] 44 output_slot = int(name[name.rfind(":") + 1:]) 45 46 return node_name, output_slot 47 else: 48 return name, None 49 50 51def get_node_name(element_name): 52 node_name, _ = parse_node_or_tensor_name(element_name) 53 return node_name 54 55 56def get_output_slot(element_name): 57 """Get the output slot number from the name of a graph element. 58 59 If element_name is a node name without output slot at the end, 0 will be 60 assumed. 61 62 Args: 63 element_name: (`str`) name of the graph element in question. 64 65 Returns: 66 (`int`) output slot number. 67 """ 68 _, output_slot = parse_node_or_tensor_name(element_name) 69 return output_slot if output_slot is not None else 0 70 71 72def is_copy_node(node_name): 73 """Determine whether a node name is that of a debug Copy node. 74 75 Such nodes are inserted by TensorFlow core upon request in 76 RunOptions.debug_options.debug_tensor_watch_opts. 77 78 Args: 79 node_name: Name of the node. 80 81 Returns: 82 A bool indicating whether the input argument is the name of a debug Copy 83 node. 84 """ 85 return node_name.startswith("__copy_") 86 87 88def is_debug_node(node_name): 89 """Determine whether a node name is that of a debug node. 90 91 Such nodes are inserted by TensorFlow core upon request in 92 RunOptions.debug_options.debug_tensor_watch_opts. 93 94 Args: 95 node_name: Name of the node. 96 97 Returns: 98 A bool indicating whether the input argument is the name of a debug node. 99 """ 100 return node_name.startswith("__dbg_") 101 102 103def parse_debug_node_name(node_name): 104 """Parse the name of a debug node. 105 106 Args: 107 node_name: Name of the debug node. 108 109 Returns: 110 1. Name of the watched node, as a str. 111 2. Output slot index of the watched tensor, as an int. 112 3. Index of the debug node, as an int. 113 4. Name of the debug op, as a str, e.g, "DebugIdentity". 114 115 Raises: 116 ValueError: If the input node name is not a valid debug node name. 117 """ 118 prefix = "__dbg_" 119 120 name = node_name 121 if not name.startswith(prefix): 122 raise ValueError("Invalid prefix in debug node name: '%s'" % node_name) 123 124 name = name[len(prefix):] 125 126 if name.count("_") < 2: 127 raise ValueError("Invalid debug node name: '%s'" % node_name) 128 129 debug_op = name[name.rindex("_") + 1:] 130 name = name[:name.rindex("_")] 131 132 debug_op_index = int(name[name.rindex("_") + 1:]) 133 name = name[:name.rindex("_")] 134 135 if name.count(":") != 1: 136 raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name) 137 138 watched_node_name = name[:name.index(":")] 139 watched_output_slot = int(name[name.index(":") + 1:]) 140 141 return watched_node_name, watched_output_slot, debug_op_index, debug_op 142 143 144class GraphTracingReachedDestination(Exception): 145 pass 146 147 148class DFSGraphTracer(object): 149 """Graph input tracer using depth-first search.""" 150 151 def __init__(self, 152 input_lists, 153 skip_node_names=None, 154 destination_node_name=None): 155 """Constructor of _DFSGraphTracer. 156 157 Args: 158 input_lists: A list of dicts. Each dict is an adjacency (input) map from 159 the recipient node name as the key and the list of input node names 160 as the value. 161 skip_node_names: Optional: a list of node names to skip tracing. 162 destination_node_name: Optional: destination node name. If not `None`, it 163 should be the name of a destination not as a str and the graph tracing 164 will raise GraphTracingReachedDestination as soon as the node has been 165 reached. 166 167 Raises: 168 GraphTracingReachedDestination: if stop_at_node_name is not None and 169 the specified node is reached. 170 """ 171 172 self._input_lists = input_lists 173 self._skip_node_names = skip_node_names 174 175 self._inputs = [] 176 self._visited_nodes = [] 177 self._depth_count = 0 178 self._depth_list = [] 179 180 self._destination_node_name = destination_node_name 181 182 def trace(self, graph_element_name): 183 """Trace inputs. 184 185 Args: 186 graph_element_name: Name of the node or an output tensor of the node, as a 187 str. 188 189 Raises: 190 GraphTracingReachedDestination: if destination_node_name of this tracer 191 object is not None and the specified node is reached. 192 """ 193 self._depth_count += 1 194 195 node_name = get_node_name(graph_element_name) 196 if node_name == self._destination_node_name: 197 raise GraphTracingReachedDestination() 198 199 if node_name in self._skip_node_names: 200 return 201 if node_name in self._visited_nodes: 202 return 203 204 self._visited_nodes.append(node_name) 205 206 for input_list in self._input_lists: 207 if node_name not in input_list: 208 continue 209 for inp in input_list[node_name]: 210 if get_node_name(inp) in self._visited_nodes: 211 continue 212 self._inputs.append(inp) 213 self._depth_list.append(self._depth_count) 214 self.trace(inp) 215 216 self._depth_count -= 1 217 218 def inputs(self): 219 return self._inputs 220 221 def depth_list(self): 222 return self._depth_list 223 224 225def _infer_device_name(graph_def): 226 """Infer device name from a partition GraphDef.""" 227 device_name = None 228 for node in graph_def.node: 229 if node.device: 230 device_name = node.device 231 break 232 if device_name is None: 233 logging.warn( 234 "Failed to infer device name from partition GraphDef: none of the " 235 "nodes of the GraphDef has a non-empty device name.") 236 return device_name 237 238 239class DebugGraph(object): 240 """Represents a debugger-decorated graph.""" 241 242 def __init__(self, debug_graph_def, device_name=None): 243 self._debug_graph_def = debug_graph_def 244 self._non_debug_graph_def = None 245 246 self._node_attributes = {} 247 self._node_inputs = {} 248 self._node_reversed_ref_inputs = {} 249 self._node_ctrl_inputs = {} 250 self._node_recipients = {} 251 self._node_ctrl_recipients = {} 252 self._node_devices = {} 253 self._node_op_types = {} 254 self._copy_send_nodes = [] 255 self._ref_args = {} 256 257 self._device_name = device_name 258 if not self._device_name: 259 self._device_name = _infer_device_name(debug_graph_def) 260 261 for node in debug_graph_def.node: 262 self._process_debug_graph_node(node) 263 264 self._prune_non_control_edges_of_debug_ops() 265 self._prune_control_edges_of_debug_ops() 266 self._prune_nodes_from_input_and_recipient_maps(self._get_copy_nodes()) 267 268 self._populate_recipient_maps() 269 270 def _process_debug_graph_node(self, node): 271 """Process a node from the debug GraphDef. 272 273 Args: 274 node: (NodeDef) A partition-graph node to be processed. 275 276 Raises: 277 ValueError: If duplicate node names are encountered. 278 """ 279 if is_debug_node(node.name): 280 # This is a debug node. Parse the node name and retrieve the 281 # information about debug watches on tensors. But do not include 282 # the node in the graph. 283 return 284 285 if node.name in self._node_inputs: 286 raise ValueError("Duplicate node name on device %s: '%s'" % 287 (self._device_name, node.name)) 288 289 self._node_attributes[node.name] = node.attr 290 291 self._node_inputs[node.name] = [] 292 self._node_ctrl_inputs[node.name] = [] 293 self._node_recipients[node.name] = [] 294 self._node_ctrl_recipients[node.name] = [] 295 296 if node.name not in self._node_devices: 297 self._node_devices[node.name] = set() 298 self._node_devices[node.name].add( 299 node.device if node.device else self._device_name) 300 self._node_op_types[node.name] = node.op 301 self._ref_args[node.name] = self._get_ref_args(node) 302 303 for inp in node.input: 304 if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"): 305 self._copy_send_nodes.append(node.name) 306 307 if inp.startswith("^"): 308 cinp = inp[1:] 309 self._node_ctrl_inputs[node.name].append(cinp) 310 else: 311 self._node_inputs[node.name].append(inp) 312 313 def _get_ref_args(self, node): 314 """Determine whether an input of an op is ref-type. 315 316 Args: 317 node: A `NodeDef`. 318 319 Returns: 320 A list of the arg names (as strs) that are ref-type. 321 """ 322 op_def = op_def_registry.get_registered_ops().get(node.op) 323 ref_args = [] 324 if op_def: 325 for i, output_arg in enumerate(op_def.output_arg): 326 if output_arg.is_ref: 327 arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i)) 328 ref_args.append(arg_name) 329 return ref_args 330 331 def _get_copy_nodes(self): 332 """Find all Copy nodes in the loaded graph.""" 333 copy_nodes = [] 334 for node in self._node_inputs: 335 if is_copy_node(node): 336 copy_nodes.append(node) 337 return copy_nodes 338 339 def _prune_non_control_edges_of_debug_ops(self): 340 """Prune (non-control) edges related to debug ops. 341 342 Prune the Copy ops and associated _Send ops inserted by the debugger out 343 from the non-control inputs and output recipients map. Replace the inputs 344 and recipients with original ones. 345 """ 346 for node in self._node_inputs: 347 inputs = self._node_inputs[node] 348 349 for i in xrange(len(inputs)): 350 inp = inputs[i] 351 if is_copy_node(inp): 352 # Find the input to the Copy node, which should be the original 353 # input to the node. 354 orig_inp = self._node_inputs[inp][0] 355 inputs[i] = orig_inp 356 357 def _prune_control_edges_of_debug_ops(self): 358 """Prune control edges related to the debug ops.""" 359 for node in self._node_ctrl_inputs: 360 ctrl_inputs = self._node_ctrl_inputs[node] 361 debug_op_inputs = [] 362 for ctrl_inp in ctrl_inputs: 363 if is_debug_node(ctrl_inp): 364 debug_op_inputs.append(ctrl_inp) 365 for debug_op_inp in debug_op_inputs: 366 ctrl_inputs.remove(debug_op_inp) 367 368 def _populate_recipient_maps(self): 369 """Populate the map from node name to recipient(s) of its output(s). 370 371 This method also populates the input map based on reversed ref edges. 372 """ 373 for node in self._node_inputs: 374 inputs = self._node_inputs[node] 375 for inp in inputs: 376 inp = get_node_name(inp) 377 if inp not in self._node_recipients: 378 self._node_recipients[inp] = [] 379 self._node_recipients[inp].append(node) 380 381 if inp in self._ref_args: 382 if inp not in self._node_reversed_ref_inputs: 383 self._node_reversed_ref_inputs[inp] = [] 384 self._node_reversed_ref_inputs[inp].append(node) 385 386 for node in self._node_ctrl_inputs: 387 ctrl_inputs = self._node_ctrl_inputs[node] 388 for ctrl_inp in ctrl_inputs: 389 if ctrl_inp in self._copy_send_nodes: 390 continue 391 392 if ctrl_inp not in self._node_ctrl_recipients: 393 self._node_ctrl_recipients[ctrl_inp] = [] 394 self._node_ctrl_recipients[ctrl_inp].append(node) 395 396 def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune): 397 """Prune nodes out of input and recipient maps. 398 399 Args: 400 nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned. 401 """ 402 for node in nodes_to_prune: 403 del self._node_inputs[node] 404 del self._node_ctrl_inputs[node] 405 del self._node_recipients[node] 406 del self._node_ctrl_recipients[node] 407 408 def _reconstruct_non_debug_graph_def(self): 409 """Reconstruct non-debug GraphDef. 410 411 Non-debug GraphDef means the original GraphDef without the Copy* and Debug 412 nodes inserted by the debugger. 413 """ 414 if self._non_debug_graph_def: 415 return 416 417 self._non_debug_graph_def = graph_pb2.GraphDef() 418 for node in self._debug_graph_def.node: 419 if is_copy_node(node.name) or is_debug_node(node.name): 420 continue 421 422 new_node = self._non_debug_graph_def.node.add() 423 new_node.CopyFrom(node) 424 425 # Redo the list of inputs, because in _debug_graph_def, the list can 426 # consist of Copy* and Debug* nodes inserted by the debugger. Those will 427 # be replaced with the original inputs here. 428 del new_node.input[:] 429 for inp in self._node_inputs[node.name]: 430 new_node.input.append(inp) 431 for ctrl_inp in self._node_ctrl_inputs[node.name]: 432 new_node.input.append("^" + ctrl_inp) 433 434 @property 435 def device_name(self): 436 return self._device_name 437 438 @property 439 def debug_graph_def(self): 440 """The debugger-decorated GraphDef.""" 441 return self._debug_graph_def 442 443 @property 444 def non_debug_graph_def(self): 445 """The GraphDef without the Copy* and Debug* nodes added by the debugger.""" 446 self._reconstruct_non_debug_graph_def() 447 return self._non_debug_graph_def 448 449 @property 450 def node_devices(self): 451 return self._node_devices 452 453 @property 454 def node_op_types(self): 455 return self._node_op_types 456 457 @property 458 def node_attributes(self): 459 return self._node_attributes 460 461 @property 462 def node_inputs(self): 463 return self._node_inputs 464 465 @property 466 def node_ctrl_inputs(self): 467 return self._node_ctrl_inputs 468 469 @property 470 def node_reversed_ref_inputs(self): 471 return self._node_reversed_ref_inputs 472 473 @property 474 def node_recipients(self): 475 return self._node_recipients 476 477 @property 478 def node_ctrl_recipients(self): 479 return self._node_ctrl_recipients 480 481 482def reconstruct_non_debug_graph_def(debug_graph_def): 483 """Reconstruct original (non-debugger-decorated) partition GraphDef. 484 485 This method strips the input `tf.GraphDef` of the Copy* and Debug*-type nodes 486 inserted by the debugger. 487 488 The reconstructed partition graph is identical to the original (i.e., 489 non-debugger-decorated) partition graph except in the following respects: 490 1) The exact names of the runtime-inserted internal nodes may differ. 491 These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops. 492 2) As a consequence of 1, the nodes that receive input directly from such 493 send- and recv-type ops will have different input names. 494 3) The parallel_iteration attribute of while-loop Enter ops are set to 1. 495 496 Args: 497 debug_graph_def: The debugger-decorated `tf.GraphDef`, with the 498 debugger-inserted Copy* and Debug* nodes. 499 500 Returns: 501 The reconstructed `tf.GraphDef` stripped of the debugger-inserted nodes. 502 """ 503 return DebugGraph(debug_graph_def).non_debug_graph_def 504