1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions to handle debug-dump data of TensorFlow Debugger.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import glob 23import json 24import os 25import platform 26import re 27 28import numpy as np 29import six 30 31from tensorflow.core.framework import graph_pb2 32from tensorflow.core.framework import types_pb2 33from tensorflow.core.util import event_pb2 34from tensorflow.python.debug.lib import debug_graphs 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.platform import gfile 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util import compat 39 40 41# TODO(cais): Tie these string constants in with C++? 42METADATA_FILE_PREFIX = "_tfdbg_" 43CORE_METADATA_TAG = "core_metadata_" 44GRAPH_FILE_TAG = "graph_" 45DEVICE_TAG = "device_" 46HASH_TAG = "hash" 47 48FETCHES_INFO_FILE_TAG = "fetches_info_" 49FEED_KEYS_INFO_FILE_TAG = "feed_keys_info_" 50 51 52def _glob(glob_pattern): 53 if platform.system() == "Windows": 54 return glob.glob(glob_pattern) 55 else: 56 return gfile.Glob(glob_pattern) 57 58 59class InconvertibleTensorProto(object): 60 """Represents a TensorProto that cannot be converted to np.ndarray.""" 61 62 def __init__(self, tensor_proto, initialized=True): 63 """Constructor. 64 65 Args: 66 tensor_proto: the `TensorProto` object that cannot be represented as a 67 `np.ndarray` object. 68 initialized: (`bool`) whether the Tensor is initialized. 69 """ 70 self._tensor_proto = tensor_proto 71 self._initialized = initialized 72 73 def __str__(self): 74 output = "" if self._initialized else "Uninitialized tensor:\n" 75 output += str(self._tensor_proto) 76 return output 77 78 @property 79 def initialized(self): 80 return self._initialized 81 82 83def load_tensor_from_event_file(event_file_path): 84 """Load a tensor from an event file. 85 86 Assumes that the event file contains a `Event` protobuf and the `Event` 87 protobuf contains a `Tensor` value. 88 89 Args: 90 event_file_path: (`str`) path to the event file. 91 92 Returns: 93 The tensor value loaded from the event file, as a `numpy.ndarray`. For 94 uninitialized Tensors, returns `None`. For Tensors of data types that 95 cannot be converted to `numpy.ndarray` (e.g., `tf.resource`), return 96 `None`. 97 """ 98 99 event = event_pb2.Event() 100 with gfile.Open(event_file_path, "rb") as f: 101 event.ParseFromString(f.read()) 102 return load_tensor_from_event(event) 103 104 105def load_tensor_from_event(event): 106 """Load a tensor from an Event proto. 107 108 Args: 109 event: The Event proto, assumed to hold a tensor value in its 110 summary.value[0] field. 111 112 Returns: 113 The tensor value loaded from the event file, as a `numpy.ndarray`, if 114 representation of the tensor value by a `numpy.ndarray` is possible. 115 For uninitialized Tensors, returns `None`. For Tensors of data types that 116 cannot be represented as `numpy.ndarray` (e.g., `tf.resource`), return 117 the `TensorProto` protobuf object without converting it to a 118 `numpy.ndarray`. 119 """ 120 121 tensor_proto = event.summary.value[0].tensor 122 shape = tensor_util.TensorShapeProtoToList(tensor_proto.tensor_shape) 123 num_elements = 1 124 for shape_dim in shape: 125 num_elements *= shape_dim 126 127 if tensor_proto.tensor_content or tensor_proto.string_val or not num_elements: 128 # Initialized tensor or empty tensor. 129 if tensor_proto.dtype == types_pb2.DT_RESOURCE: 130 tensor_value = InconvertibleTensorProto(tensor_proto) 131 else: 132 try: 133 tensor_value = tensor_util.MakeNdarray(tensor_proto) 134 except KeyError: 135 tensor_value = InconvertibleTensorProto(tensor_proto) 136 else: 137 # Uninitialized tensor or tensor of unconvertible data type. 138 tensor_value = InconvertibleTensorProto(tensor_proto, False) 139 140 return tensor_value 141 142 143def _load_graph_def_from_event_file(event_file_path): 144 event = event_pb2.Event() 145 with gfile.Open(event_file_path, "rb") as f: 146 event.ParseFromString(f.read()) 147 148 return graph_pb2.GraphDef.FromString(event.graph_def) 149 150 151def _load_log_message_from_event_file(event_file_path): 152 event = event_pb2.Event() 153 with gfile.Open(event_file_path, "rb") as f: 154 event.ParseFromString(f.read()) 155 156 return event.log_message.message 157 158 159def _is_graph_file(file_name): 160 return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG) 161 162 163def _is_run_fetches_info_file(file_name): 164 return file_name == METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG 165 166 167def _is_run_feed_keys_info_file(file_name): 168 return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG 169 170 171def _get_tensor_name(node_name, output_slot): 172 """Get tensor name given node name and output slot index. 173 174 Args: 175 node_name: Name of the node that outputs the tensor, as a string. 176 output_slot: Output slot index of the tensor, as an integer. 177 178 Returns: 179 Name of the tensor, as a string. 180 """ 181 182 return "%s:%d" % (node_name, output_slot) 183 184 185def _get_tensor_watch_key(node_name, output_slot, debug_op): 186 """Get the string representation of a debug watch on a tensor. 187 188 Args: 189 node_name: Name of the node by which the watched tensor is produced, as a 190 string. 191 output_slot: Output slot index of the tensor, as an integer. 192 debug_op: Name of the debug op that is used to watch the tensor, as a 193 string. 194 195 Returns: 196 A string representing the debug watch on the tensor (i.e., the "watch 197 key"). 198 """ 199 return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op) 200 201 202def has_inf_or_nan(datum, tensor): 203 """A predicate for whether a tensor consists of any bad numerical values. 204 205 This predicate is common enough to merit definition in this module. 206 Bad numerical values include `nan`s and `inf`s. 207 The signature of this function follows the requirement of the method 208 `DebugDumpDir.find()`. 209 210 Args: 211 datum: (`DebugTensorDatum`) Datum metadata. 212 tensor: (`numpy.ndarray` or None) Value of the tensor. None represents 213 an uninitialized tensor. 214 215 Returns: 216 (`bool`) True if and only if tensor consists of any nan or inf values. 217 """ 218 219 _ = datum # Datum metadata is unused in this predicate. 220 221 if isinstance(tensor, InconvertibleTensorProto): 222 # Uninitialized tensor doesn't have bad numerical values. 223 # Also return False for data types that cannot be represented as numpy 224 # arrays. 225 return False 226 elif (np.issubdtype(tensor.dtype, np.floating) or 227 np.issubdtype(tensor.dtype, np.complex) or 228 np.issubdtype(tensor.dtype, np.integer)): 229 return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor)) 230 else: 231 return False 232 233 234_CoreMetadata = collections.namedtuple("CoreMetadata", [ 235 "global_step", "session_run_index", "executor_step_index", "input_names", 236 "output_names", "target_nodes" 237]) 238 239 240def extract_core_metadata_from_event_proto(event): 241 json_metadata = json.loads(event.log_message.message) 242 return _CoreMetadata(json_metadata["global_step"], 243 json_metadata["session_run_index"], 244 json_metadata["executor_step_index"], 245 json_metadata["input_names"], 246 json_metadata["output_names"], 247 json_metadata["target_nodes"]) 248 249 250def device_name_to_device_path(device_name): 251 """Convert device name to device path.""" 252 device_name_items = compat.as_text(device_name).split("/") 253 device_name_items = [item.replace(":", "_") for item in device_name_items] 254 return METADATA_FILE_PREFIX + DEVICE_TAG + ",".join(device_name_items) 255 256 257def device_path_to_device_name(device_dir): 258 """Parse device name from device path. 259 260 Args: 261 device_dir: (str) a directory name for the device. 262 263 Returns: 264 (str) parsed device name. 265 """ 266 path_items = os.path.basename(device_dir)[ 267 len(METADATA_FILE_PREFIX) + len(DEVICE_TAG):].split(",") 268 return "/".join([ 269 path_item.replace("device_", "device:").replace("_", ":", 1) 270 for path_item in path_items]) 271 272 273class DebugTensorDatum(object): 274 """A single tensor dumped by TensorFlow Debugger (tfdbg). 275 276 Contains metadata about the dumped tensor, including `timestamp`, 277 `node_name`, `output_slot`, `debug_op`, and path to the dump file 278 (`file_path`). 279 280 This type does not hold the generally space-expensive tensor value (numpy 281 array). Instead, it points to the file from which the tensor value can be 282 loaded (with the `get_tensor` method) if needed. 283 """ 284 285 def __init__(self, dump_root, debug_dump_rel_path): 286 """`DebugTensorDatum` constructor. 287 288 Args: 289 dump_root: (`str`) Debug dump root directory. This path should not include 290 the path component that represents the device name (see also below). 291 debug_dump_rel_path: (`str`) Path to a debug dump file, relative to the 292 `dump_root`. The first item of this relative path is assumed to be 293 a path representing the name of the device that the Tensor belongs to. 294 See `device_path_to_device_name` for more details on the device path. 295 For example, suppose the debug dump root 296 directory is `/tmp/tfdbg_1` and the dump file is at 297 `/tmp/tfdbg_1/<device_path>/>ns_1/node_a_0_DebugIdentity_123456789`, 298 then the value of the debug_dump_rel_path should be 299 `<device_path>/ns_1/node_a_0_DebugIdentity_1234456789`. 300 301 Raises: 302 ValueError: If the base file name of the dump file does not conform to 303 the dump file naming pattern: 304 `node_name`_`output_slot`_`debug_op`_`timestamp` 305 """ 306 307 path_components = os.path.normpath(debug_dump_rel_path).split(os.sep) 308 self._device_name = device_path_to_device_name(path_components[0]) 309 base = path_components[-1] 310 if base.count("_") < 3: 311 raise ValueError( 312 "Dump file path does not conform to the naming pattern: %s" % base) 313 314 self._extended_timestamp = base.split("_")[-1] 315 # It may include an index suffix at the end if file path collision happened 316 # due to identical timestamps. 317 if "-" in self._extended_timestamp: 318 self._timestamp = int( 319 self._extended_timestamp[:self._extended_timestamp.find("-")]) 320 else: 321 self._timestamp = int(self._extended_timestamp) 322 323 self._debug_op = base.split("_")[-2] 324 self._output_slot = int(base.split("_")[-3]) 325 326 node_base_name = "_".join(base.split("_")[:-3]) 327 self._node_name = "/".join(path_components[1:-1] + [node_base_name]) 328 329 self._file_path = os.path.join(dump_root, debug_dump_rel_path) 330 self._dump_size_bytes = (gfile.Stat(self._file_path).length if 331 gfile.Exists(self._file_path) else None) 332 333 def __str__(self): 334 return "{DebugTensorDatum (%s) %s:%d @ %s @ %d}" % (self.device_name, 335 self.node_name, 336 self.output_slot, 337 self.debug_op, 338 self.timestamp) 339 340 def __repr__(self): 341 return self.__str__() 342 343 def get_tensor(self): 344 """Get tensor from the dump (`Event`) file. 345 346 Returns: 347 The tensor loaded from the dump (`Event`) file. 348 """ 349 350 return load_tensor_from_event_file(self.file_path) 351 352 # TODO(cais): Add time unit suffix to timestamp and t0 (us). 353 @property 354 def timestamp(self): 355 """Timestamp of when this tensor value was dumped. 356 357 Returns: 358 (`int`) The timestamp in microseconds. 359 """ 360 361 return self._timestamp 362 363 @property 364 def extended_timestamp(self): 365 """Extended timestamp, possibly with an index suffix. 366 367 The index suffix, e.g., "-1", is for disambiguating multiple dumps of the 368 same tensor with the same timestamp, which can occur if the dumping events 369 are spaced by shorter than the temporal resolution of the timestamps. 370 371 Returns: 372 (`str`) The extended timestamp. 373 """ 374 375 return self._extended_timestamp 376 377 @property 378 def debug_op(self): 379 """Name of the debug op. 380 381 Returns: 382 (`str`) debug op name (e.g., `DebugIdentity`). 383 """ 384 385 return self._debug_op 386 387 @property 388 def device_name(self): 389 """Name of the device that the tensor belongs to. 390 391 Returns: 392 (`str`) device name. 393 """ 394 395 return self._device_name 396 397 @property 398 def node_name(self): 399 """Name of the node from which the tensor value was dumped. 400 401 Returns: 402 (`str`) name of the node watched by the debug op. 403 """ 404 405 return self._node_name 406 407 @property 408 def output_slot(self): 409 """Output slot index from which the tensor value was dumped. 410 411 Returns: 412 (`int`) output slot index watched by the debug op. 413 """ 414 415 return self._output_slot 416 417 @property 418 def tensor_name(self): 419 """Name of the tensor watched by the debug op. 420 421 Returns: 422 (`str`) `Tensor` name, in the form of `node_name`:`output_slot` 423 """ 424 425 return _get_tensor_name(self.node_name, self.output_slot) 426 427 @property 428 def watch_key(self): 429 """Watch key identities a debug watch on a tensor. 430 431 Returns: 432 (`str`) A watch key, in the form of `tensor_name`:`debug_op`. 433 """ 434 435 return _get_tensor_watch_key(self.node_name, self.output_slot, 436 self.debug_op) 437 438 @property 439 def file_path(self): 440 """Path to the file which stores the value of the dumped tensor.""" 441 442 return self._file_path 443 444 @property 445 def dump_size_bytes(self): 446 """Size of the dump file. 447 448 Unit: byte. 449 450 Returns: 451 If the dump file exists, size of the dump file, in bytes. 452 If the dump file does not exist, None. 453 """ 454 455 return self._dump_size_bytes 456 457 458class WatchKeyDoesNotExistInDebugDumpDirError(ValueError): 459 pass 460 461 462class DebugDumpDir(object): 463 """Data set from a debug-dump directory on filesystem. 464 465 An instance of `DebugDumpDir` contains all `DebugTensorDatum` instances 466 in a tfdbg dump root directory. 467 """ 468 469 def __init__(self, dump_root, partition_graphs=None, validate=True): 470 """`DebugDumpDir` constructor. 471 472 Args: 473 dump_root: (`str`) path to the dump root directory. 474 partition_graphs: A repeated field of GraphDefs representing the 475 partition graphs executed by the TensorFlow runtime. 476 validate: (`bool`) whether the dump files are to be validated against the 477 partition graphs. 478 479 Raises: 480 IOError: If dump_root does not exist as a directory. 481 ValueError: If more than one core metadata file is found under the dump 482 root directory. 483 """ 484 485 if not gfile.IsDirectory(dump_root): 486 raise IOError("Dump root directory %s does not exist" % dump_root) 487 488 self._core_metadata = [] 489 490 # Find the list of devices. 491 self._dump_root = dump_root 492 493 self._load_core_metadata() 494 self._load_fetches_info() 495 self._load_feeds_info() 496 self._load_all_device_dumps(partition_graphs, validate) 497 498 self._python_graph = None 499 500 def _load_all_device_dumps(self, partition_graphs, validate): 501 """Load the dump data for all devices.""" 502 device_dirs = _glob(os.path.join( 503 self._dump_root, METADATA_FILE_PREFIX + DEVICE_TAG + "*")) 504 505 self._device_names = [] 506 self._t0s = {} 507 self._dump_tensor_data = {} 508 self._dump_graph_file_paths = {} 509 self._debug_watches = {} 510 self._watch_key_to_devices = {} 511 self._watch_key_to_datum = {} 512 self._watch_key_to_rel_time = {} 513 self._watch_key_to_dump_size_bytes = {} 514 for device_dir in device_dirs: 515 device_name = device_path_to_device_name(device_dir) 516 self._device_names.append(device_name) 517 self._load_device_dumps(device_name, device_dir) 518 self._load_partition_graphs(partition_graphs, validate) 519 self._calculate_t0() 520 521 for device_name in self._device_names: 522 self._create_tensor_watch_maps(device_name) 523 524 def _load_device_dumps(self, device_name, device_root): 525 """Load `DebugTensorDatum` instances from the dump root of a given device. 526 527 Populates a map {device_name: a list of `DebugTensorDatum`}, where the list 528 is sorted by ascending timestamp. 529 530 This sorting order reflects the order in which the TensorFlow executor 531 processed the nodes of the graph. It is (one of many possible) topological 532 sort of the nodes. This is useful for displaying tensors in the debugger 533 frontend as well as for the use case in which the user wants to find a 534 "culprit tensor", i.e., the first tensor in the graph that exhibits certain 535 problematic properties, i.e., all zero values, or bad numerical values such 536 as nan and inf. 537 538 In addition, creates a map from node name to debug watches. In this Map, 539 the key is the watched node name; the value is a dictionary. 540 Of this dictionary, the key is the watched_output_slot. 541 542 This method attempts to load the debug watches from the tensor dump files 543 first, before loading the full set of debug watches from the partition 544 graphs as done later. This is necessary because sometimes the partition 545 graphs may not be available, e.g., when the run errors out. 546 547 Args: 548 device_name: (`str`) name of the device. 549 device_root: (`str`) dump root directory of the given device. 550 551 Raises: 552 ValueError: If GraphDef for the device is not available. 553 """ 554 555 self._dump_tensor_data[device_name] = [] 556 self._debug_watches[device_name] = collections.defaultdict( 557 lambda: collections.defaultdict(set)) 558 559 for root, _, files in gfile.Walk(device_root): 560 for f in files: 561 if _is_graph_file(f): 562 self._dump_graph_file_paths[device_name] = os.path.join(root, f) 563 else: 564 datum = self._dump_file_name_to_datum(root, f) 565 self._dump_tensor_data[device_name].append(datum) 566 self._debug_watches[device_name][datum.node_name][ 567 datum.output_slot].add(datum.debug_op) 568 569 self._dump_tensor_data[device_name] = sorted( 570 self._dump_tensor_data[device_name], 571 key=lambda x: x.extended_timestamp) 572 573 if self._dump_tensor_data[device_name]: 574 self._t0s[device_name] = self._dump_tensor_data[device_name][0].timestamp 575 else: 576 self._t0s[device_name] = None 577 578 def _calculate_t0(self): 579 """Calculate the first timestamp across all devices.""" 580 t0s = [t0 for t0 in six.itervalues(self._t0s) if t0 is not None] 581 self._t0 = min(t0s) if t0s else None 582 583 def _load_core_metadata(self): 584 core_metadata_files = _glob(os.path.join( 585 self._dump_root, METADATA_FILE_PREFIX + CORE_METADATA_TAG + "*")) 586 for core_metadata_file in core_metadata_files: 587 with gfile.Open(core_metadata_file, "rb") as f: 588 event = event_pb2.Event() 589 event.ParseFromString(f.read()) 590 self._core_metadata.append( 591 extract_core_metadata_from_event_proto(event)) 592 593 def _load_fetches_info(self): 594 fetches_info_files = _glob(os.path.join( 595 self._dump_root, METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG + "*")) 596 self._run_fetches_info = [] 597 for fetches_info_file in fetches_info_files: 598 self._run_fetches_info.append( 599 _load_log_message_from_event_file(fetches_info_file)) 600 601 def _load_feeds_info(self): 602 feeds_info_files = _glob(os.path.join( 603 self._dump_root, METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG + "*")) 604 self._run_feed_keys_info = [] 605 for feeds_info_file in feeds_info_files: 606 self._run_feed_keys_info.append( 607 _load_log_message_from_event_file(feeds_info_file)) 608 609 def _dump_file_name_to_datum(self, dir_name, file_name): 610 """Obtain a DebugTensorDatum from the directory and file name. 611 612 Args: 613 dir_name: (`str`) Name of the directory in which the dump file resides. 614 file_name: (`str`) Base name of the dump file. 615 616 Returns: 617 (`DebugTensorDatum`) The `DebugTensorDatum` loaded from the dump file. 618 """ 619 620 # Calculate the relative path of the dump file with respect to the root. 621 debug_dump_rel_path = os.path.join( 622 os.path.relpath(dir_name, self._dump_root), file_name) 623 return DebugTensorDatum(self._dump_root, debug_dump_rel_path) 624 625 def _create_tensor_watch_maps(self, device_name): 626 """Create maps from tensor watch keys to datum and to timestamps. 627 628 Create a map from watch key (tensor name + debug op) to `DebugTensorDatum` 629 item. Also make a map from watch key to relative timestamp. 630 "relative" means (absolute timestamp - t0). 631 632 Args: 633 device_name: (str) name of the device. 634 """ 635 636 self._watch_key_to_datum[device_name] = {} 637 self._watch_key_to_rel_time[device_name] = {} 638 self._watch_key_to_dump_size_bytes[device_name] = {} 639 for datum in self._dump_tensor_data[device_name]: 640 if datum.watch_key not in self._watch_key_to_devices: 641 self._watch_key_to_devices[datum.watch_key] = {device_name} 642 else: 643 self._watch_key_to_devices[datum.watch_key].add(device_name) 644 645 if datum.watch_key not in self._watch_key_to_datum[device_name]: 646 self._watch_key_to_datum[device_name][datum.watch_key] = [datum] 647 self._watch_key_to_rel_time[device_name][datum.watch_key] = [ 648 datum.timestamp - self._t0] 649 self._watch_key_to_dump_size_bytes[device_name][datum.watch_key] = [ 650 datum.dump_size_bytes] 651 else: 652 self._watch_key_to_datum[device_name][datum.watch_key].append(datum) 653 self._watch_key_to_rel_time[device_name][datum.watch_key].append( 654 datum.timestamp - self._t0) 655 self._watch_key_to_dump_size_bytes[device_name][datum.watch_key].append( 656 datum.dump_size_bytes) 657 658 def set_python_graph(self, python_graph): 659 """Provide Python `Graph` object to the wrapper. 660 661 Unlike the partition graphs, which are protobuf `GraphDef` objects, `Graph` 662 is a Python object and carries additional information such as the traceback 663 of the construction of the nodes in the graph. 664 665 Args: 666 python_graph: (ops.Graph) The Python Graph object. 667 """ 668 669 self._python_graph = python_graph 670 self._node_traceback = {} 671 if self._python_graph: 672 for op in self._python_graph.get_operations(): 673 self._node_traceback[op.name] = tuple(map(tuple, op.traceback)) 674 675 @property 676 def python_graph(self): 677 """Get the Python graph. 678 679 Returns: 680 If the Python graph has been set, returns a `tf.Graph` object. Otherwise, 681 returns None. 682 """ 683 684 return self._python_graph 685 686 @property 687 def core_metadata(self): 688 """Metadata about the `Session.run()` call from the core runtime. 689 690 Of the three counters available in the return value, `global_step` is 691 supplied by the caller of the debugged `Session.run()`, while 692 `session_run_index` and `executor_step_index` are determined by the state 693 of the core runtime, automatically. For the same fetch list, feed keys and 694 debug tensor watch options, the same executor will be used and 695 `executor_step_index` should increase by one at a time. However, runs with 696 different fetch lists, feed keys and debug_tensor watch options that all 697 share the same `Session` object can lead to gaps in `session_run_index`. 698 699 Returns: 700 If core metadata are loaded, a `namedtuple` with the fields: 701 `global_step`: A global step count supplied by the caller of 702 `Session.run()`. It is optional to the caller. If the caller did not 703 supply this parameter, its value will be -1. 704 `session_run_index`: A sorted index for Run() calls to the underlying 705 TensorFlow `Session` object. 706 `executor_step_index`: A counter for invocations of a given runtime 707 executor. The same executor is re-used for the same fetched tensors, 708 target nodes, input feed keys and debug tensor watch options. 709 `input_names`: Names of the input (feed) Tensors. 710 `output_names`: Names of the output (fetched) Tensors. 711 `target_nodes`: Names of the target nodes. 712 If the core metadata have not been loaded, `None`. 713 If more than one core metadata files exist, return a list of the 714 `nametuple` described above. 715 """ 716 717 output = self._core_metadata 718 return output[0] if len(output) == 1 else output 719 720 @property 721 def dumped_tensor_data(self): 722 """Retrieve dumped tensor data.""" 723 if len(self.devices()) == 1: 724 return self._dump_tensor_data[self.devices()[0]] 725 else: 726 all_devices_data = six.itervalues(self._dump_tensor_data) 727 data = [] 728 for device_data in all_devices_data: 729 data.extend(device_data) 730 return sorted(data, key=lambda x: x.extended_timestamp) 731 732 @property 733 def t0(self): 734 """Absolute timestamp of the first dumped tensor across all devices. 735 736 Returns: 737 (`int`) absolute timestamp of the first dumped tensor, in microseconds. 738 """ 739 return self._t0 740 741 @property 742 def size(self): 743 """Total number of dumped tensors in the dump root directory. 744 745 Returns: 746 (`int`) The total number of dumped tensors in the dump root directory. 747 """ 748 return sum(len(self._dump_tensor_data[device_name]) 749 for device_name in self._dump_tensor_data) 750 751 def _load_partition_graphs(self, client_partition_graphs, validate): 752 """Load and process partition graphs. 753 754 Load the graphs; parse the input and control input structure; obtain the 755 device and op type of each node; remove the Copy and debug ops inserted 756 by the debugger. The gathered information can be used to validate the 757 tensor dumps. 758 759 Args: 760 client_partition_graphs: A repeated field of GraphDefs representing the 761 partition graphs executed by the TensorFlow runtime, from the Python 762 client. These partition graphs are used only if partition graphs 763 cannot be loaded from the dump directory on the file system. 764 validate: (`bool`) Whether the dump files are to be validated against the 765 partition graphs. 766 767 Raises: 768 ValueError: If the partition GraphDef of one or more devices fail to be 769 loaded. 770 """ 771 self._debug_graphs = {} 772 self._node_devices = {} 773 774 partition_graphs_and_device_names = [] 775 for device_name in self._device_names: 776 partition_graph = None 777 if device_name in self._dump_graph_file_paths: 778 partition_graph = _load_graph_def_from_event_file( 779 self._dump_graph_file_paths[device_name]) 780 else: 781 logging.warn( 782 "Failed to load partition graphs for device %s from disk. " 783 "As a fallback, the client graphs will be used. This " 784 "may cause mismatches in device names." % device_name) 785 partition_graph = self._find_partition_graph(client_partition_graphs, 786 device_name) 787 788 if partition_graph: 789 partition_graphs_and_device_names.append((partition_graph, 790 device_name)) 791 792 for partition_graph, maybe_device_name in partition_graphs_and_device_names: 793 debug_graph = debug_graphs.DebugGraph(partition_graph, 794 device_name=maybe_device_name) 795 self._debug_graphs[debug_graph.device_name] = debug_graph 796 self._collect_node_devices(debug_graph) 797 798 if validate and debug_graph.device_name in self._dump_tensor_data: 799 self._validate_dump_with_graphs(debug_graph.device_name) 800 801 def _find_partition_graph(self, partition_graphs, device_name): 802 if partition_graphs is None: 803 return None 804 else: 805 for graph_def in partition_graphs: 806 for node_def in graph_def.node: 807 if node_def.device == device_name: 808 return graph_def 809 return None 810 811 def _collect_node_devices(self, debug_graph): 812 for node_name in debug_graph.node_devices: 813 if node_name in self._node_devices: 814 self._node_devices[node_name] = self._node_devices[node_name].union( 815 debug_graph.node_devices[node_name]) 816 else: 817 self._node_devices[node_name] = debug_graph.node_devices[node_name] 818 819 def _validate_dump_with_graphs(self, device_name): 820 """Validate the dumped tensor data against the partition graphs. 821 822 Only the watched nodes are validated by this method, because tfdbg allows 823 clients to watch only a subset of the nodes. 824 825 Args: 826 device_name: (`str`) device name. 827 828 Raises: 829 LookupError: If the partition graphs have not been loaded yet. 830 ValueError: If dumps contain node names not found in partition graph. 831 Or if the temporal order of the dump's timestamps violate the 832 input relations on the partition graphs. 833 """ 834 if not self._debug_graphs: 835 raise LookupError( 836 "No partition graphs loaded for device %s" % device_name) 837 debug_graph = self._debug_graphs[device_name] 838 839 # Verify that the node names in the dump data are all present in the 840 # partition graphs. 841 for datum in self._dump_tensor_data[device_name]: 842 if datum.node_name not in debug_graph.node_inputs: 843 raise ValueError("Node name '%s' is not found in partition graphs of " 844 "device %s." % (datum.node_name, device_name)) 845 846 pending_inputs = {} 847 for node in debug_graph.node_inputs: 848 pending_inputs[node] = [] 849 inputs = debug_graph.node_inputs[node] 850 for inp in inputs: 851 inp_node = debug_graphs.get_node_name(inp) 852 inp_output_slot = debug_graphs.get_output_slot(inp) 853 # Inputs from Enter and NextIteration nodes are not validated because 854 # DebugNodeInserter::InsertNodes() in the debugger core skips creating 855 # control edges from debug ops watching these types of nodes. 856 if (inp_node in self._debug_watches[device_name] and 857 inp_output_slot in self._debug_watches[device_name][inp_node] and 858 debug_graph.node_op_types.get(inp) not in ( 859 "Enter", "NextIteration") and 860 (inp_node, inp_output_slot) not in pending_inputs[node]): 861 pending_inputs[node].append((inp_node, inp_output_slot)) 862 863 for i, datum in enumerate(self._dump_tensor_data[device_name]): 864 node = datum.node_name 865 slot = datum.output_slot 866 # In some cases (e.g., system clocks with insufficient precision), 867 # the upstream and downstream tensors may have identical timestamps, the 868 # following check examines this possibility and avoids raising an error if 869 # that is the case. 870 if not self._satisfied_at_timestamp( 871 device_name, pending_inputs[node], datum.timestamp, start_i=i + 1): 872 raise ValueError("Causality violated in timing relations of debug " 873 "dumps: %s (%d): " 874 "these input(s) are not satisfied: %s" % 875 (node, datum.timestamp, repr(pending_inputs[node]))) 876 877 recipients = debug_graph.node_recipients[node] 878 for recipient in recipients: 879 recipient_pending_inputs = pending_inputs[recipient] 880 if (node, slot) in recipient_pending_inputs: 881 if self.node_op_type(recipient) == "Merge": 882 # If this is a Merge op, we automatically clear the list because 883 # a Merge node only requires one of its two inputs. 884 del recipient_pending_inputs[:] 885 else: 886 del recipient_pending_inputs[ 887 recipient_pending_inputs.index((node, slot))] 888 889 def _satisfied_at_timestamp(self, device_name, pending, timestamp, start_i=0): 890 """Determine whether pending inputs are satisfied at given timestamp. 891 892 Note: This method mutates the input argument "pending". 893 894 Args: 895 device_name: (str) device name. 896 pending: A list of 2-tuple (node_name, output_slot): the dependencies to 897 check. 898 timestamp: (int) the timestamp in question. 899 start_i: (int) the index in self._dump_tensor_data to start searching for 900 the timestamp. 901 902 Returns: 903 (bool) Whether all the dependencies in pending are satisfied at the 904 timestamp. If pending is empty to begin with, return True. 905 """ 906 if not pending: 907 return True 908 909 for datum in self._dump_tensor_data[device_name][start_i:]: 910 if datum.timestamp > timestamp: 911 break 912 if (datum.timestamp == timestamp and 913 (datum.node_name, datum.output_slot) in pending): 914 pending.remove((datum.node_name, datum.output_slot)) 915 if not pending: 916 return True 917 918 return not pending 919 920 def loaded_partition_graphs(self): 921 """Test whether partition graphs have been loaded.""" 922 return bool(self._debug_graphs) 923 924 def partition_graphs(self): 925 """Get the partition graphs. 926 927 Returns: 928 Partition graphs as a list of GraphDef. 929 930 Raises: 931 LookupError: If no partition graphs have been loaded. 932 """ 933 if not self._debug_graphs: 934 raise LookupError("No partition graphs have been loaded.") 935 return [self._debug_graphs[key].debug_graph_def 936 for key in self._debug_graphs] 937 938 def reconstructed_non_debug_partition_graphs(self): 939 """Reconstruct partition graphs with the debugger-inserted ops stripped. 940 941 The reconstructed partition graphs are identical to the original (i.e., 942 non-debugger-decorated) partition graphs except in the following respects: 943 1) The exact names of the runtime-inserted internal nodes may differ. 944 These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops. 945 2) As a consequence of 1, the nodes that receive input directly from such 946 send- and recv-type ops will have different input names. 947 3) The parallel_iteration attribute of while-loop Enter ops are set to 1. 948 949 Returns: 950 A dict mapping device names (`str`s) to reconstructed 951 `tf.compat.v1.GraphDef`s. 952 """ 953 non_debug_graphs = {} 954 for key in self._debug_graphs: 955 non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def 956 return non_debug_graphs 957 958 @property 959 def run_fetches_info(self): 960 """Get a str representation of the fetches used in the Session.run() call. 961 962 Returns: 963 If the information is available from one `Session.run` call, a `str` 964 obtained from `repr(fetches)`. 965 If the information is available from multiple `Session.run` calls, a 966 `list` of `str` from `repr(fetches)`. 967 If the information is not available, `None`. 968 """ 969 970 output = self._run_fetches_info 971 return output[0] if len(output) == 1 else output 972 973 @property 974 def run_feed_keys_info(self): 975 """Get a str representation of the feed_dict used in the Session.run() call. 976 977 Returns: 978 If the information is available from one `Session.run` call, a `str` 979 obtained from `repr(feed_dict)`. 980 If the information is available from multiple `Session.run` calls, a 981 `list` of `str` obtained from `repr(feed_dict)`. 982 If the information is not available, `None`. 983 """ 984 985 output = self._run_feed_keys_info 986 return output[0] if len(output) == 1 else output 987 988 def _infer_device_name(self, device_name, node_name): 989 """Infer the device name given node name. 990 991 If device_name is provided (i.e., not None), it'll be simply returned right 992 away. 993 994 Args: 995 device_name: (str or None) name of the device. If None, will try to infer 996 the device name by looking at the available nodes. 997 node_name: (str) name of the node. 998 999 Returns: 1000 (str) Inferred name of the device, if available. 1001 1002 Raises: 1003 ValueError: If the node name does not exist on any of the available 1004 devices or if there are multiple devices that contain the node with 1005 the given name. 1006 """ 1007 if device_name is None: 1008 if node_name in self._node_devices: 1009 if len(self._node_devices[node_name]) == 1: 1010 return list(self._node_devices[node_name])[0] 1011 else: 1012 raise ValueError( 1013 "There are multiple (%d) devices with nodes named '%s' but " 1014 "device_name is not specified." % 1015 (len(self._node_devices[node_name]), node_name)) 1016 else: 1017 raise ValueError("None of the %d device(s) has a node named '%s'." % 1018 (len(self._device_names), node_name)) 1019 else: 1020 return device_name 1021 1022 def nodes(self, device_name=None): 1023 """Get a list of all nodes from the partition graphs. 1024 1025 Args: 1026 device_name: (`str`) name of device. If None, all nodes from all available 1027 devices will be included. 1028 1029 Returns: 1030 All nodes' names, as a list of str. 1031 1032 Raises: 1033 LookupError: If no partition graphs have been loaded. 1034 ValueError: If specified node name does not exist. 1035 """ 1036 if not self._debug_graphs: 1037 raise LookupError("No partition graphs have been loaded.") 1038 if device_name is None: 1039 nodes = [] 1040 for device_name in self._debug_graphs: 1041 nodes.extend(self._debug_graphs[device_name].node_inputs.keys()) 1042 return nodes 1043 else: 1044 if device_name not in self._debug_graphs: 1045 raise ValueError("Invalid device name: %s" % device_name) 1046 return self._debug_graphs[device_name].node_inputs.keys() 1047 1048 def node_attributes(self, node_name, device_name=None): 1049 """Get the attributes of a node. 1050 1051 Args: 1052 node_name: Name of the node in question. 1053 device_name: (`str`) name of the device. If there is only one device or if 1054 node_name exists on only one device, this argument is optional. 1055 1056 Returns: 1057 Attributes of the node. 1058 1059 Raises: 1060 LookupError: If no partition graphs have been loaded. 1061 """ 1062 if not self._debug_graphs: 1063 raise LookupError("No partition graphs have been loaded.") 1064 1065 device_name = self._infer_device_name(device_name, node_name) 1066 return self._debug_graphs[device_name].node_attributes[node_name] 1067 1068 def node_inputs(self, node_name, is_control=False, device_name=None): 1069 """Get the inputs of given node according to partition graphs. 1070 1071 Args: 1072 node_name: Name of the node. 1073 is_control: (`bool`) Whether control inputs, rather than non-control 1074 inputs, are to be returned. 1075 device_name: (`str`) name of the device. If there is only one device or if 1076 node_name exists on only one device, this argument is optional. 1077 1078 Returns: 1079 (`list` of `str`) inputs to the node, as a list of node names. 1080 1081 Raises: 1082 LookupError: If node inputs and control inputs have not been loaded 1083 from partition graphs yet. 1084 """ 1085 if not self._debug_graphs: 1086 raise LookupError( 1087 "Node inputs are not loaded from partition graphs yet.") 1088 1089 device_name = self._infer_device_name(device_name, node_name) 1090 if is_control: 1091 return self._debug_graphs[device_name].node_ctrl_inputs[node_name] 1092 else: 1093 return self._debug_graphs[device_name].node_inputs[node_name] 1094 1095 def transitive_inputs(self, 1096 node_name, 1097 include_control=True, 1098 include_reversed_ref=False, 1099 device_name=None,): 1100 """Get the transitive inputs of given node according to partition graphs. 1101 1102 Args: 1103 node_name: Name of the node. 1104 include_control: Include control inputs (True by default). 1105 include_reversed_ref: Whether a ref input, say from A to B, is to be also 1106 considered as an input from B to A. The rationale is that ref inputs 1107 generally let the recipient (e.g., B in this case) mutate the value of 1108 the source (e.g., A in this case). So the reverse direction of the ref 1109 edge reflects the direction of information flow. 1110 device_name: (`str`) name of the device. If there is only one device or if 1111 node_name exists on only one device, this argument is optional. 1112 1113 Returns: 1114 (`list` of `str`) all transitive inputs to the node, as a list of node 1115 names. 1116 1117 Raises: 1118 LookupError: If node inputs and control inputs have not been loaded 1119 from partition graphs yet. 1120 """ 1121 if not self._debug_graphs: 1122 raise LookupError( 1123 "Node inputs are not loaded from partition graphs yet.") 1124 1125 device_name = self._infer_device_name(device_name, node_name) 1126 1127 input_lists = [self._debug_graphs[device_name].node_inputs] 1128 if include_control: 1129 input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs) 1130 if include_reversed_ref: 1131 input_lists.append( 1132 self._debug_graphs[device_name].node_reversed_ref_inputs) 1133 tracer = debug_graphs.DFSGraphTracer( 1134 input_lists, 1135 skip_node_names=self._get_merge_node_names(device_name)) 1136 tracer.trace(node_name) 1137 return tracer.inputs() 1138 1139 def _get_merge_node_names(self, device_name): 1140 """Lazily get a list of Merge nodes on a given device.""" 1141 if device_name not in self._device_names: 1142 raise ValueError("Invalid device name: %s" % device_name) 1143 1144 if not hasattr(self, "_merge_node_names"): 1145 self._merge_node_names = {} 1146 if device_name not in self._merge_node_names: 1147 debug_graph = self._debug_graphs[device_name] 1148 self._merge_node_names[device_name] = [ 1149 node for node in debug_graph.node_op_types 1150 if debug_graph.node_op_types[node] == "Merge"] 1151 return self._merge_node_names[device_name] 1152 1153 def find_some_path(self, 1154 src_node_name, 1155 dst_node_name, 1156 include_control=True, 1157 include_reversed_ref=False, 1158 device_name=None): 1159 """Find a path between a source node and a destination node. 1160 1161 Limitation: the source and destination are required to be on the same 1162 device, i.e., this method does not yet take into account Send/Recv nodes 1163 across devices. 1164 1165 TODO(cais): Make this method work across device edges by tracing Send/Recv 1166 nodes. 1167 1168 Args: 1169 src_node_name: (`str`) name of the source node or name of an output tensor 1170 of the node. 1171 dst_node_name: (`str`) name of the destination node or name of an output 1172 tensor of the node. 1173 include_control: (`bool`) whrther control edges are considered in the 1174 graph tracing. 1175 include_reversed_ref: Whether a ref input, say from A to B, is to be also 1176 considered as an input from B to A. The rationale is that ref inputs 1177 generally let the recipient (e.g., B in this case) mutate the value of 1178 the source (e.g., A in this case). So the reverse direction of the ref 1179 edge reflects the direction of information flow. 1180 device_name: (`str`) name of the device. If there is only one device or if 1181 node_name exists on only one device, this argument is optional. 1182 1183 Returns: 1184 A path from the src_node_name to dst_node_name, as a `list` of `str`, if 1185 it exists. The list includes src_node_name as the first item and 1186 dst_node_name as the last. 1187 If such a path does not exist, `None`. 1188 1189 Raises: 1190 ValueError: If the source and destination nodes are not on the same 1191 device. 1192 """ 1193 src_device_name = self._infer_device_name(device_name, src_node_name) 1194 dst_device_name = self._infer_device_name(device_name, dst_node_name) 1195 1196 if src_device_name != dst_device_name: 1197 raise ValueError( 1198 "Source (%s) and destination (%s) are not on the same device: " 1199 "%s vs. %s" % (src_node_name, dst_node_name, src_device_name, 1200 dst_device_name)) 1201 1202 input_lists = [self._debug_graphs[dst_device_name].node_inputs] 1203 debug_graph = self._debug_graphs[dst_device_name] 1204 if include_control: 1205 input_lists.append(debug_graph.node_ctrl_inputs) 1206 if include_reversed_ref: 1207 input_lists.append(debug_graph.node_reversed_ref_inputs) 1208 tracer = debug_graphs.DFSGraphTracer( 1209 input_lists, 1210 skip_node_names=self._get_merge_node_names(dst_device_name), 1211 destination_node_name=src_node_name) 1212 # Here the value of destination_node_name is src_node_name, because we 1213 # are tracing the graph from output to its inputs (i.e., going backwards 1214 # on the graph). 1215 1216 try: 1217 tracer.trace(dst_node_name) 1218 except debug_graphs.GraphTracingReachedDestination: 1219 # Prune nodes not on the path. 1220 inputs = [dst_node_name] + tracer.inputs() 1221 depth_list = [0] + tracer.depth_list() 1222 1223 path = [] 1224 curr_depth = depth_list[-1] 1225 for inp, depth in zip(reversed(inputs), reversed(depth_list)): 1226 if depth == curr_depth: 1227 path.append(inp) 1228 curr_depth -= 1 1229 return path 1230 1231 def node_recipients(self, node_name, is_control=False, device_name=None): 1232 """Get recipient of the given node's output according to partition graphs. 1233 1234 Args: 1235 node_name: (`str`) name of the node. 1236 is_control: (`bool`) whether control outputs, rather than non-control 1237 outputs, are to be returned. 1238 device_name: (`str`) name of the device. If there is only one device or if 1239 node_name exists on only one device, this argument is optional. 1240 1241 Returns: 1242 (`list` of `str`) all inputs to the node, as a list of node names. 1243 1244 Raises: 1245 LookupError: If node inputs and control inputs have not been loaded 1246 from partition graphs yet. 1247 """ 1248 1249 if not self._debug_graphs: 1250 raise LookupError( 1251 "Node recipients are not loaded from partition graphs yet.") 1252 1253 device_name = self._infer_device_name(device_name, node_name) 1254 debug_graph = self._debug_graphs[device_name] 1255 if is_control: 1256 return debug_graph.node_ctrl_recipients[node_name] 1257 else: 1258 return debug_graph.node_recipients[node_name] 1259 1260 def devices(self): 1261 """Get the list of device names. 1262 1263 Returns: 1264 (`list` of `str`) names of the devices. 1265 """ 1266 return self._device_names 1267 1268 def node_exists(self, node_name, device_name=None): 1269 """Test if a node exists in the partition graphs. 1270 1271 Args: 1272 node_name: (`str`) name of the node to be checked. 1273 device_name: optional device name. If None, will search for the node 1274 on all available devices. Otherwise, search for the node only on 1275 the given device. 1276 1277 Returns: 1278 A boolean indicating whether the node exists. 1279 1280 Raises: 1281 LookupError: If no partition graphs have been loaded yet. 1282 ValueError: If device_name is specified but cannot be found. 1283 """ 1284 if not self._debug_graphs: 1285 raise LookupError( 1286 "Nodes have not been loaded from partition graphs yet.") 1287 1288 if (device_name is not None) and device_name not in self._debug_graphs: 1289 raise ValueError( 1290 "The specified device_name '%s' cannot be found." % device_name) 1291 1292 for _, debug_graph in self._debug_graphs.items(): 1293 if node_name in debug_graph.node_inputs: 1294 return True 1295 return False 1296 1297 def node_device(self, node_name): 1298 """Get the names of the devices that has nodes of the specified name. 1299 1300 Args: 1301 node_name: (`str`) name of the node. 1302 1303 Returns: 1304 (`str` or `list` of `str`) name of the device(s) on which the node of the 1305 given name is found. Returns a `str` if there is only one such device, 1306 otherwise return a `list` of `str`. 1307 1308 Raises: 1309 LookupError: If node inputs and control inputs have not been loaded 1310 from partition graphs yet. 1311 ValueError: If the node does not exist in partition graphs. 1312 """ 1313 if not self._debug_graphs: 1314 raise LookupError( 1315 "Node devices are not loaded from partition graphs yet.") 1316 1317 if node_name not in self._node_devices: 1318 raise ValueError("Node '%s' does not exist in partition graphs." % 1319 node_name) 1320 1321 output = list(self._node_devices[node_name]) 1322 return output[0] if len(output) == 1 else output 1323 1324 def node_op_type(self, node_name, device_name=None): 1325 """Get the op type of given node. 1326 1327 Args: 1328 node_name: (`str`) name of the node. 1329 device_name: (`str`) name of the device. If there is only one device or if 1330 node_name exists on only one device, this argument is optional. 1331 1332 Returns: 1333 (`str`) op type of the node. 1334 1335 Raises: 1336 LookupError: If node op types have not been loaded 1337 from partition graphs yet. 1338 """ 1339 if not self._debug_graphs: 1340 raise LookupError( 1341 "Node op types are not loaded from partition graphs yet.") 1342 1343 device_name = self._infer_device_name(device_name, node_name) 1344 return self._debug_graphs[device_name].node_op_types[node_name] 1345 1346 def debug_watch_keys(self, node_name, device_name=None): 1347 """Get all tensor watch keys of given node according to partition graphs. 1348 1349 Args: 1350 node_name: (`str`) name of the node. 1351 device_name: (`str`) name of the device. If there is only one device or if 1352 node_name exists on only one device, this argument is optional. 1353 1354 Returns: 1355 (`list` of `str`) all debug tensor watch keys. Returns an empty list if 1356 the node name does not correspond to any debug watch keys. 1357 1358 Raises: 1359 `LookupError`: If debug watch information has not been loaded from 1360 partition graphs yet. 1361 """ 1362 1363 try: 1364 device_name = self._infer_device_name(device_name, node_name) 1365 except ValueError: 1366 return [] 1367 1368 if node_name not in self._debug_watches[device_name]: 1369 return [] 1370 1371 watch_keys = [] 1372 for watched_slot in self._debug_watches[device_name][node_name]: 1373 debug_ops = self._debug_watches[device_name][node_name][watched_slot] 1374 for debug_op in debug_ops: 1375 watch_keys.append( 1376 _get_tensor_watch_key(node_name, watched_slot, debug_op)) 1377 1378 return watch_keys 1379 1380 def watch_key_to_data(self, debug_watch_key, device_name=None): 1381 """Get all `DebugTensorDatum` instances corresponding to a debug watch key. 1382 1383 Args: 1384 debug_watch_key: (`str`) debug watch key. 1385 device_name: (`str`) name of the device. If there is only one device or if 1386 the specified debug_watch_key exists on only one device, this argument 1387 is optional. 1388 1389 Returns: 1390 A list of `DebugTensorDatum` instances that correspond to the debug watch 1391 key. If the watch key does not exist, returns an empty list. 1392 1393 Raises: 1394 ValueError: If there are multiple devices that have the debug_watch_key, 1395 but device_name is not specified. 1396 """ 1397 if device_name is None: 1398 matching_device_names = [ 1399 name for name in self._watch_key_to_datum 1400 if debug_watch_key in self._watch_key_to_datum[name]] 1401 if not matching_device_names: 1402 return [] 1403 elif len(matching_device_names) == 1: 1404 device_name = matching_device_names[0] 1405 else: 1406 raise ValueError( 1407 "The debug watch key '%s' exists on multiple (%d) devices, but " 1408 "device name is not specified." % 1409 (debug_watch_key, len(matching_device_names))) 1410 elif device_name not in self._debug_key_to_datum: 1411 raise ValueError( 1412 "There is no device named '%s' consisting of debug watch keys." % 1413 device_name) 1414 1415 return self._watch_key_to_datum[device_name].get(debug_watch_key, []) 1416 1417 def find(self, 1418 predicate, 1419 first_n=0, 1420 device_name=None, 1421 exclude_node_names=None): 1422 """Find dumped tensor data by a certain predicate. 1423 1424 Args: 1425 predicate: A callable that takes two input arguments: 1426 1427 ```python 1428 def predicate(debug_tensor_datum, tensor): 1429 # returns a bool 1430 ``` 1431 1432 where `debug_tensor_datum` is an instance of `DebugTensorDatum`, which 1433 carries the metadata, such as the `Tensor`'s node name, output slot 1434 timestamp, debug op name, etc.; and `tensor` is the dumped tensor value 1435 as a `numpy.ndarray`. 1436 first_n: (`int`) return only the first n `DebugTensotDatum` instances (in 1437 time order) for which the predicate returns True. To return all the 1438 `DebugTensotDatum` instances, let first_n be <= 0. 1439 device_name: optional device name. 1440 exclude_node_names: Optional regular expression to exclude nodes with 1441 names matching the regular expression. 1442 1443 Returns: 1444 A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object 1445 for which predicate returns True, sorted in ascending order of the 1446 timestamp. 1447 """ 1448 if exclude_node_names: 1449 exclude_node_names = re.compile(exclude_node_names) 1450 1451 matched_data = [] 1452 for device in (self._dump_tensor_data if device_name is None 1453 else (self._dump_tensor_data[device_name],)): 1454 for datum in self._dump_tensor_data[device]: 1455 if exclude_node_names and exclude_node_names.match(datum.node_name): 1456 continue 1457 1458 if predicate(datum, datum.get_tensor()): 1459 matched_data.append(datum) 1460 1461 if first_n > 0 and len(matched_data) >= first_n: 1462 return matched_data 1463 1464 return matched_data 1465 1466 def get_tensor_file_paths(self, 1467 node_name, 1468 output_slot, 1469 debug_op, 1470 device_name=None): 1471 """Get the file paths from a debug-dumped tensor. 1472 1473 Args: 1474 node_name: (`str`) name of the node that the tensor is produced by. 1475 output_slot: (`int`) output slot index of tensor. 1476 debug_op: (`str`) name of the debug op. 1477 device_name: (`str`) name of the device. If there is only one device or if 1478 the specified debug_watch_key exists on only one device, this argument 1479 is optional. 1480 1481 Returns: 1482 List of file path(s) loaded. This is a list because each debugged tensor 1483 may be dumped multiple times. 1484 1485 Raises: 1486 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in 1487 the debug-dump data. 1488 """ 1489 1490 device_name = self._infer_device_name(device_name, node_name) 1491 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1492 if watch_key not in self._watch_key_to_datum[device_name]: 1493 raise WatchKeyDoesNotExistInDebugDumpDirError( 1494 "Watch key \"%s\" does not exist in the debug dump of device %s" % 1495 (watch_key, device_name)) 1496 1497 return [datum.file_path for datum in 1498 self._watch_key_to_datum[device_name][watch_key]] 1499 1500 def get_tensors(self, node_name, output_slot, debug_op, device_name=None): 1501 """Get the tensor value from for a debug-dumped tensor. 1502 1503 The tensor may be dumped multiple times in the dump root directory, so a 1504 list of tensors (`numpy.ndarray`) is returned. 1505 1506 Args: 1507 node_name: (`str`) name of the node that the tensor is produced by. 1508 output_slot: (`int`) output slot index of tensor. 1509 debug_op: (`str`) name of the debug op. 1510 device_name: (`str`) name of the device. If there is only one device or if 1511 the specified debug_watch_key exists on only one device, this argument 1512 is optional. 1513 1514 Returns: 1515 List of tensors (`numpy.ndarray`) loaded from the debug-dump file(s). 1516 1517 Raises: 1518 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in 1519 the debug-dump data. 1520 """ 1521 1522 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1523 try: 1524 device_name = self._infer_device_name(device_name, node_name) 1525 return [datum.get_tensor() for datum in 1526 self._watch_key_to_datum[device_name][watch_key]] 1527 except (ValueError, KeyError): 1528 raise WatchKeyDoesNotExistInDebugDumpDirError( 1529 "Watch key \"%s\" does not exist in the debug dump of device %s" % 1530 (watch_key, device_name)) 1531 1532 def get_rel_timestamps(self, 1533 node_name, 1534 output_slot, 1535 debug_op, 1536 device_name=None): 1537 """Get the relative timestamp from for a debug-dumped tensor. 1538 1539 Relative timestamp means (absolute timestamp - `t0`), where `t0` is the 1540 absolute timestamp of the first dumped tensor in the dump root. The tensor 1541 may be dumped multiple times in the dump root directory, so a list of 1542 relative timestamps (`numpy.ndarray`) is returned. 1543 1544 Args: 1545 node_name: (`str`) name of the node that the tensor is produced by. 1546 output_slot: (`int`) output slot index of tensor. 1547 debug_op: (`str`) name of the debug op. 1548 device_name: (`str`) name of the device. If there is only one device or if 1549 the specified debug_watch_key exists on only one device, this argument 1550 is optional. 1551 1552 Returns: 1553 (`list` of `int`) list of relative timestamps. 1554 1555 Raises: 1556 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not 1557 exist in the debug dump data. 1558 """ 1559 1560 device_name = self._infer_device_name(device_name, node_name) 1561 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1562 if watch_key not in self._watch_key_to_datum[device_name]: 1563 raise WatchKeyDoesNotExistInDebugDumpDirError( 1564 "Watch key \"%s\" does not exist in the debug dump" % watch_key) 1565 1566 # TODO(cais): Figure out whether this should be relative to the global t0. 1567 return self._watch_key_to_rel_time[device_name][watch_key] 1568 1569 def get_dump_sizes_bytes(self, 1570 node_name, 1571 output_slot, 1572 debug_op, 1573 device_name=None): 1574 """Get the sizes of the dump files for a debug-dumped tensor. 1575 1576 Unit of the file size: byte. 1577 1578 Args: 1579 node_name: (`str`) name of the node that the tensor is produced by. 1580 output_slot: (`int`) output slot index of tensor. 1581 debug_op: (`str`) name of the debug op. 1582 device_name: (`str`) name of the device. If there is only one device or if 1583 the specified debug_watch_key exists on only one device, this argument 1584 is optional. 1585 1586 Returns: 1587 (`list` of `int`): list of dump file sizes in bytes. 1588 1589 Raises: 1590 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not 1591 exist in the debug dump data. 1592 """ 1593 1594 device_name = self._infer_device_name(device_name, node_name) 1595 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1596 if watch_key not in self._watch_key_to_datum[device_name]: 1597 raise WatchKeyDoesNotExistInDebugDumpDirError( 1598 "Watch key \"%s\" does not exist in the debug dump of device %s" % 1599 (watch_key, device_name)) 1600 1601 return self._watch_key_to_dump_size_bytes[device_name][watch_key] 1602 1603 def node_traceback(self, element_name): 1604 """Try to retrieve the Python traceback of node's construction. 1605 1606 Args: 1607 element_name: (`str`) Name of a graph element (node or tensor). 1608 1609 Returns: 1610 (list) The traceback list object as returned by the `extract_trace` 1611 method of Python's traceback module. 1612 1613 Raises: 1614 LookupError: If Python graph is not available for traceback lookup. 1615 KeyError: If the node cannot be found in the Python graph loaded. 1616 """ 1617 1618 if self._python_graph is None: 1619 raise LookupError("Python graph is not available for traceback lookup") 1620 1621 node_name = debug_graphs.get_node_name(element_name) 1622 if node_name not in self._node_traceback: 1623 raise KeyError("Cannot find node \"%s\" in Python graph" % node_name) 1624 1625 return self._node_traceback[node_name] 1626