1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions to handle debug-dump data of TensorFlow Debugger.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import glob 23import json 24import os 25import platform 26import re 27 28import numpy as np 29import six 30 31from tensorflow.core.framework import graph_pb2 32from tensorflow.core.framework import types_pb2 33from tensorflow.core.util import event_pb2 34from tensorflow.python.debug.lib import debug_graphs 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.platform import gfile 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util import compat 39 40 41# TODO(cais): Tie these string constants in with C++? 42METADATA_FILE_PREFIX = "_tfdbg_" 43CORE_METADATA_TAG = "core_metadata_" 44GRAPH_FILE_TAG = "graph_" 45DEVICE_TAG = "device_" 46HASH_TAG = "hash" 47 48FETCHES_INFO_FILE_TAG = "fetches_info_" 49FEED_KEYS_INFO_FILE_TAG = "feed_keys_info_" 50 51 52def _glob(glob_pattern): 53 if platform.system() == "Windows": 54 return glob.glob(glob_pattern) 55 else: 56 return gfile.Glob(glob_pattern) 57 58 59class InconvertibleTensorProto(object): 60 """Represents a TensorProto that cannot be converted to np.ndarray.""" 61 62 def __init__(self, tensor_proto, initialized=True): 63 """Constructor. 64 65 Args: 66 tensor_proto: the `TensorProto` object that cannot be represented as a 67 `np.ndarray` object. 68 initialized: (`bool`) whether the Tensor is initialized. 69 """ 70 self._tensor_proto = tensor_proto 71 self._initialized = initialized 72 73 def __str__(self): 74 output = "" if self._initialized else "Uninitialized tensor:\n" 75 output += str(self._tensor_proto) 76 return output 77 78 @property 79 def initialized(self): 80 return self._initialized 81 82 83def load_tensor_from_event_file(event_file_path): 84 """Load a tensor from an event file. 85 86 Assumes that the event file contains a `Event` protobuf and the `Event` 87 protobuf contains a `Tensor` value. 88 89 Args: 90 event_file_path: (`str`) path to the event file. 91 92 Returns: 93 The tensor value loaded from the event file, as a `numpy.ndarray`. For 94 uninitialized Tensors, returns `None`. For Tensors of data types that 95 cannot be converted to `numpy.ndarray` (e.g., `tf.resource`), return 96 `None`. 97 """ 98 99 event = event_pb2.Event() 100 with gfile.Open(event_file_path, "rb") as f: 101 event.ParseFromString(f.read()) 102 return load_tensor_from_event(event) 103 104 105def load_tensor_from_event(event): 106 """Load a tensor from an Event proto. 107 108 Args: 109 event: The Event proto, assumed to hold a tensor value in its 110 summary.value[0] field. 111 112 Returns: 113 The tensor value loaded from the event file, as a `numpy.ndarray`, if 114 representation of the tensor value by a `numpy.ndarray` is possible. 115 For uninitialized Tensors, returns `None`. For Tensors of data types that 116 cannot be represented as `numpy.ndarray` (e.g., `tf.resource`), return 117 the `TensorProto` protobuf object without converting it to a 118 `numpy.ndarray`. 119 """ 120 121 tensor_proto = event.summary.value[0].tensor 122 shape = tensor_util.TensorShapeProtoToList(tensor_proto.tensor_shape) 123 num_elements = 1 124 for shape_dim in shape: 125 num_elements *= shape_dim 126 127 if tensor_proto.tensor_content or tensor_proto.string_val or not num_elements: 128 # Initialized tensor or empty tensor. 129 if tensor_proto.dtype == types_pb2.DT_RESOURCE: 130 tensor_value = InconvertibleTensorProto(tensor_proto) 131 else: 132 try: 133 tensor_value = tensor_util.MakeNdarray(tensor_proto) 134 except KeyError: 135 tensor_value = InconvertibleTensorProto(tensor_proto) 136 else: 137 # Uninitialized tensor or tensor of unconvertible data type. 138 tensor_value = InconvertibleTensorProto(tensor_proto, False) 139 140 return tensor_value 141 142 143def _load_graph_def_from_event_file(event_file_path): 144 event = event_pb2.Event() 145 with gfile.Open(event_file_path, "rb") as f: 146 event.ParseFromString(f.read()) 147 148 return graph_pb2.GraphDef.FromString(event.graph_def) 149 150 151def _load_log_message_from_event_file(event_file_path): 152 event = event_pb2.Event() 153 with gfile.Open(event_file_path, "rb") as f: 154 event.ParseFromString(f.read()) 155 156 return event.log_message.message 157 158 159def _is_graph_file(file_name): 160 return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG) 161 162 163def _is_run_fetches_info_file(file_name): 164 return file_name == METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG 165 166 167def _is_run_feed_keys_info_file(file_name): 168 return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG 169 170 171def _get_tensor_name(node_name, output_slot): 172 """Get tensor name given node name and output slot index. 173 174 Args: 175 node_name: Name of the node that outputs the tensor, as a string. 176 output_slot: Output slot index of the tensor, as an integer. 177 178 Returns: 179 Name of the tensor, as a string. 180 """ 181 182 return "%s:%d" % (node_name, output_slot) 183 184 185def _get_tensor_watch_key(node_name, output_slot, debug_op): 186 """Get the string representation of a debug watch on a tensor. 187 188 Args: 189 node_name: Name of the node by which the watched tensor is produced, as a 190 string. 191 output_slot: Output slot index of the tensor, as an integer. 192 debug_op: Name of the debug op that is used to watch the tensor, as a 193 string. 194 195 Returns: 196 A string representing the debug watch on the tensor (i.e., the "watch 197 key"). 198 """ 199 return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op) 200 201 202def has_inf_or_nan(datum, tensor): 203 """A predicate for whether a tensor consists of any bad numerical values. 204 205 This predicate is common enough to merit definition in this module. 206 Bad numerical values include `nan`s and `inf`s. 207 The signature of this function follows the requirement of the method 208 `DebugDumpDir.find()`. 209 210 Args: 211 datum: (`DebugTensorDatum`) Datum metadata. 212 tensor: (`numpy.ndarray` or None) Value of the tensor. None represents 213 an uninitialized tensor. 214 215 Returns: 216 (`bool`) True if and only if tensor consists of any nan or inf values. 217 """ 218 219 _ = datum # Datum metadata is unused in this predicate. 220 221 if isinstance(tensor, InconvertibleTensorProto): 222 # Uninitialized tensor doesn't have bad numerical values. 223 # Also return False for data types that cannot be represented as numpy 224 # arrays. 225 return False 226 elif (np.issubdtype(tensor.dtype, np.floating) or 227 np.issubdtype(tensor.dtype, np.complex) or 228 np.issubdtype(tensor.dtype, np.integer)): 229 return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor)) 230 else: 231 return False 232 233 234_CoreMetadata = collections.namedtuple("CoreMetadata", [ 235 "global_step", "session_run_index", "executor_step_index", "input_names", 236 "output_names", "target_nodes" 237]) 238 239 240def extract_core_metadata_from_event_proto(event): 241 json_metadata = json.loads(event.log_message.message) 242 return _CoreMetadata(json_metadata["global_step"], 243 json_metadata["session_run_index"], 244 json_metadata["executor_step_index"], 245 json_metadata["input_names"], 246 json_metadata["output_names"], 247 json_metadata["target_nodes"]) 248 249 250def device_name_to_device_path(device_name): 251 """Convert device name to device path.""" 252 device_name_items = compat.as_text(device_name).split("/") 253 device_name_items = [item.replace(":", "_") for item in device_name_items] 254 return METADATA_FILE_PREFIX + DEVICE_TAG + ",".join(device_name_items) 255 256 257def device_path_to_device_name(device_dir): 258 """Parse device name from device path. 259 260 Args: 261 device_dir: (str) a directory name for the device. 262 263 Returns: 264 (str) parsed device name. 265 """ 266 path_items = os.path.basename(device_dir)[ 267 len(METADATA_FILE_PREFIX) + len(DEVICE_TAG):].split(",") 268 return "/".join([ 269 path_item.replace("device_", "device:").replace("_", ":", 1) 270 for path_item in path_items]) 271 272 273class DebugTensorDatum(object): 274 """A single tensor dumped by TensorFlow Debugger (tfdbg). 275 276 Contains metadata about the dumped tensor, including `timestamp`, 277 `node_name`, `output_slot`, `debug_op`, and path to the dump file 278 (`file_path`). 279 280 This type does not hold the generally space-expensive tensor value (numpy 281 array). Instead, it points to the file from which the tensor value can be 282 loaded (with the `get_tensor` method) if needed. 283 """ 284 285 def __init__(self, dump_root, debug_dump_rel_path): 286 """`DebugTensorDatum` constructor. 287 288 Args: 289 dump_root: (`str`) Debug dump root directory. This path should not include 290 the path component that represents the device name (see also below). 291 debug_dump_rel_path: (`str`) Path to a debug dump file, relative to the 292 `dump_root`. The first item of this relative path is assumed to be 293 a path representing the name of the device that the Tensor belongs to. 294 See `device_path_to_device_name` for more details on the device path. 295 For example, suppose the debug dump root 296 directory is `/tmp/tfdbg_1` and the dump file is at 297 `/tmp/tfdbg_1/<device_path>/>ns_1/node_a_0_DebugIdentity_123456789`, 298 then the value of the debug_dump_rel_path should be 299 `<device_path>/ns_1/node_a_0_DebugIdenity_1234456789`. 300 301 Raises: 302 ValueError: If the base file name of the dump file does not conform to 303 the dump file naming pattern: 304 `node_name`_`output_slot`_`debug_op`_`timestamp` 305 """ 306 307 path_components = os.path.normpath(debug_dump_rel_path).split(os.sep) 308 self._device_name = device_path_to_device_name(path_components[0]) 309 base = path_components[-1] 310 if base.count("_") < 3: 311 raise ValueError( 312 "Dump file path does not conform to the naming pattern: %s" % base) 313 314 self._extended_timestamp = base.split("_")[-1] 315 # It may include an index suffix at the end if file path collision happened 316 # due to identical timestamps. 317 if "-" in self._extended_timestamp: 318 self._timestamp = int( 319 self._extended_timestamp[:self._extended_timestamp.find("-")]) 320 else: 321 self._timestamp = int(self._extended_timestamp) 322 323 self._debug_op = base.split("_")[-2] 324 self._output_slot = int(base.split("_")[-3]) 325 326 node_base_name = "_".join(base.split("_")[:-3]) 327 self._node_name = "/".join(path_components[1:-1] + [node_base_name]) 328 329 self._file_path = os.path.join(dump_root, debug_dump_rel_path) 330 self._dump_size_bytes = (gfile.Stat(self._file_path).length if 331 gfile.Exists(self._file_path) else None) 332 333 def __str__(self): 334 return "{DebugTensorDatum (%s) %s:%d @ %s @ %d}" % (self.device_name, 335 self.node_name, 336 self.output_slot, 337 self.debug_op, 338 self.timestamp) 339 340 def __repr__(self): 341 return self.__str__() 342 343 def get_tensor(self): 344 """Get tensor from the dump (`Event`) file. 345 346 Returns: 347 The tensor loaded from the dump (`Event`) file. 348 """ 349 350 return load_tensor_from_event_file(self.file_path) 351 352 # TODO(cais): Add time unit suffix to timestamp and t0 (us). 353 @property 354 def timestamp(self): 355 """Timestamp of when this tensor value was dumped. 356 357 Returns: 358 (`int`) The timestamp in microseconds. 359 """ 360 361 return self._timestamp 362 363 @property 364 def extended_timestamp(self): 365 """Extended timestamp, possibly with an index suffix. 366 367 The index suffix, e.g., "-1", is for disambiguating multiple dumps of the 368 same tensor with the same timestamp, which can occur if the dumping events 369 are spaced by shorter than the temporal resolution of the timestamps. 370 371 Returns: 372 (`str`) The extended timestamp. 373 """ 374 375 return self._extended_timestamp 376 377 @property 378 def debug_op(self): 379 """Name of the debug op. 380 381 Returns: 382 (`str`) debug op name (e.g., `DebugIdentity`). 383 """ 384 385 return self._debug_op 386 387 @property 388 def device_name(self): 389 """Name of the device that the tensor belongs to. 390 391 Returns: 392 (`str`) device name. 393 """ 394 395 return self._device_name 396 397 @property 398 def node_name(self): 399 """Name of the node from which the tensor value was dumped. 400 401 Returns: 402 (`str`) name of the node watched by the debug op. 403 """ 404 405 return self._node_name 406 407 @property 408 def output_slot(self): 409 """Output slot index from which the tensor value was dumped. 410 411 Returns: 412 (`int`) output slot index watched by the debug op. 413 """ 414 415 return self._output_slot 416 417 @property 418 def tensor_name(self): 419 """Name of the tensor watched by the debug op. 420 421 Returns: 422 (`str`) `Tensor` name, in the form of `node_name`:`output_slot` 423 """ 424 425 return _get_tensor_name(self.node_name, self.output_slot) 426 427 @property 428 def watch_key(self): 429 """Watch key identities a debug watch on a tensor. 430 431 Returns: 432 (`str`) A watch key, in the form of `tensor_name`:`debug_op`. 433 """ 434 435 return _get_tensor_watch_key(self.node_name, self.output_slot, 436 self.debug_op) 437 438 @property 439 def file_path(self): 440 """Path to the file which stores the value of the dumped tensor.""" 441 442 return self._file_path 443 444 @property 445 def dump_size_bytes(self): 446 """Size of the dump file. 447 448 Unit: byte. 449 450 Returns: 451 If the dump file exists, size of the dump file, in bytes. 452 If the dump file does not exist, None. 453 """ 454 455 return self._dump_size_bytes 456 457 458class WatchKeyDoesNotExistInDebugDumpDirError(ValueError): 459 pass 460 461 462class DebugDumpDir(object): 463 """Data set from a debug-dump directory on filesystem. 464 465 An instance of `DebugDumpDir` contains all `DebugTensorDatum` instances 466 in a tfdbg dump root directory. 467 """ 468 469 def __init__(self, dump_root, partition_graphs=None, validate=True): 470 """`DebugDumpDir` constructor. 471 472 Args: 473 dump_root: (`str`) path to the dump root directory. 474 partition_graphs: A repeated field of GraphDefs representing the 475 partition graphs executed by the TensorFlow runtime. 476 validate: (`bool`) whether the dump files are to be validated against the 477 partition graphs. 478 479 Raises: 480 IOError: If dump_root does not exist as a directory. 481 ValueError: If more than one core metadata file is found under the dump 482 root directory. 483 """ 484 485 if not gfile.IsDirectory(dump_root): 486 raise IOError("Dump root directory %s does not exist" % dump_root) 487 488 self._core_metadata = [] 489 490 # Find the list of devices. 491 self._dump_root = dump_root 492 493 self._load_core_metadata() 494 self._load_fetches_info() 495 self._load_feeds_info() 496 self._load_all_device_dumps(partition_graphs, validate) 497 498 self._python_graph = None 499 500 def _load_all_device_dumps(self, partition_graphs, validate): 501 """Load the dump data for all devices.""" 502 device_dirs = _glob(os.path.join( 503 self._dump_root, METADATA_FILE_PREFIX + DEVICE_TAG + "*")) 504 505 self._device_names = [] 506 self._t0s = {} 507 self._dump_tensor_data = {} 508 self._dump_graph_file_paths = {} 509 self._debug_watches = {} 510 self._watch_key_to_devices = {} 511 self._watch_key_to_datum = {} 512 self._watch_key_to_rel_time = {} 513 self._watch_key_to_dump_size_bytes = {} 514 for device_dir in device_dirs: 515 device_name = device_path_to_device_name(device_dir) 516 self._device_names.append(device_name) 517 self._load_device_dumps(device_name, device_dir) 518 self._load_partition_graphs(partition_graphs, validate) 519 self._calculate_t0() 520 521 for device_name in self._device_names: 522 self._create_tensor_watch_maps(device_name) 523 524 def _load_device_dumps(self, device_name, device_root): 525 """Load `DebugTensorDatum` instances from the dump root of a given device. 526 527 Populates a map {device_name: a list of `DebugTensorDatum`}, where the list 528 is sorted by ascending timestamp. 529 530 This sorting order reflects the order in which the TensorFlow executor 531 processed the nodes of the graph. It is (one of many possible) topological 532 sort of the nodes. This is useful for displaying tensors in the debugger 533 frontend as well as for the use case in which the user wants to find a 534 "culprit tensor", i.e., the first tensor in the graph that exhibits certain 535 problematic properties, i.e., all zero values, or bad numerical values such 536 as nan and inf. 537 538 In addition, creates a map from node name to debug watches. In this Map, 539 the key is the watched node name; the value is a dictionary. 540 Of this dictionary, the key is the watched_output_slot. 541 542 This method attempts to load the debug watches from the tensor dump files 543 first, before loading the full set of debug watches from the partition 544 graphs as done later. This is necessary because sometimes the partition 545 graphs may not be available, e.g., when the run errors out. 546 547 Args: 548 device_name: (`str`) name of the device. 549 device_root: (`str`) dump root directory of the given device. 550 551 Raises: 552 ValueError: If GraphDef for the device is not available. 553 """ 554 555 self._dump_tensor_data[device_name] = [] 556 self._debug_watches[device_name] = collections.defaultdict( 557 lambda: collections.defaultdict(set)) 558 559 for root, _, files in gfile.Walk(device_root): 560 for f in files: 561 if _is_graph_file(f): 562 self._dump_graph_file_paths[device_name] = os.path.join(root, f) 563 else: 564 datum = self._dump_file_name_to_datum(root, f) 565 self._dump_tensor_data[device_name].append(datum) 566 self._debug_watches[device_name][datum.node_name][ 567 datum.output_slot].add(datum.debug_op) 568 569 self._dump_tensor_data[device_name] = sorted( 570 self._dump_tensor_data[device_name], 571 key=lambda x: x.extended_timestamp) 572 573 if self._dump_tensor_data[device_name]: 574 self._t0s[device_name] = self._dump_tensor_data[device_name][0].timestamp 575 else: 576 self._t0s[device_name] = None 577 578 def _calculate_t0(self): 579 """Calculate the first timestamp across all devices.""" 580 t0s = [t0 for t0 in six.itervalues(self._t0s) if t0 is not None] 581 self._t0 = min(t0s) if t0s else None 582 583 def _load_core_metadata(self): 584 core_metadata_files = _glob(os.path.join( 585 self._dump_root, METADATA_FILE_PREFIX + CORE_METADATA_TAG + "*")) 586 for core_metadata_file in core_metadata_files: 587 with gfile.Open(core_metadata_file, "rb") as f: 588 event = event_pb2.Event() 589 event.ParseFromString(f.read()) 590 self._core_metadata.append( 591 extract_core_metadata_from_event_proto(event)) 592 593 def _load_fetches_info(self): 594 fetches_info_files = _glob(os.path.join( 595 self._dump_root, METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG + "*")) 596 self._run_fetches_info = [] 597 for fetches_info_file in fetches_info_files: 598 self._run_fetches_info.append( 599 _load_log_message_from_event_file(fetches_info_file)) 600 601 def _load_feeds_info(self): 602 feeds_info_files = _glob(os.path.join( 603 self._dump_root, METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG + "*")) 604 self._run_feed_keys_info = [] 605 for feeds_info_file in feeds_info_files: 606 self._run_feed_keys_info.append( 607 _load_log_message_from_event_file(feeds_info_file)) 608 609 def _dump_file_name_to_datum(self, dir_name, file_name): 610 """Obtain a DebugTensorDatum from the directory and file name. 611 612 Args: 613 dir_name: (`str`) Name of the directory in which the dump file resides. 614 file_name: (`str`) Base name of the dump file. 615 616 Returns: 617 (`DebugTensorDatum`) The `DebugTensorDatum` loaded from the dump file. 618 """ 619 620 # Calculate the relative path of the dump file with respect to the root. 621 debug_dump_rel_path = os.path.join( 622 os.path.relpath(dir_name, self._dump_root), file_name) 623 return DebugTensorDatum(self._dump_root, debug_dump_rel_path) 624 625 def _create_tensor_watch_maps(self, device_name): 626 """Create maps from tensor watch keys to datum and to timestamps. 627 628 Create a map from watch key (tensor name + debug op) to `DebugTensorDatum` 629 item. Also make a map from watch key to relative timestamp. 630 "relative" means (absolute timestamp - t0). 631 632 Args: 633 device_name: (str) name of the device. 634 """ 635 636 self._watch_key_to_datum[device_name] = {} 637 self._watch_key_to_rel_time[device_name] = {} 638 self._watch_key_to_dump_size_bytes[device_name] = {} 639 for datum in self._dump_tensor_data[device_name]: 640 if datum.watch_key not in self._watch_key_to_devices: 641 self._watch_key_to_devices[datum.watch_key] = {device_name} 642 else: 643 self._watch_key_to_devices[datum.watch_key].add(device_name) 644 645 if datum.watch_key not in self._watch_key_to_datum[device_name]: 646 self._watch_key_to_datum[device_name][datum.watch_key] = [datum] 647 self._watch_key_to_rel_time[device_name][datum.watch_key] = [ 648 datum.timestamp - self._t0] 649 self._watch_key_to_dump_size_bytes[device_name][datum.watch_key] = [ 650 datum.dump_size_bytes] 651 else: 652 self._watch_key_to_datum[device_name][datum.watch_key].append(datum) 653 self._watch_key_to_rel_time[device_name][datum.watch_key].append( 654 datum.timestamp - self._t0) 655 self._watch_key_to_dump_size_bytes[device_name][datum.watch_key].append( 656 datum.dump_size_bytes) 657 658 def set_python_graph(self, python_graph): 659 """Provide Python `Graph` object to the wrapper. 660 661 Unlike the partition graphs, which are protobuf `GraphDef` objects, `Graph` 662 is a Python object and carries additional information such as the traceback 663 of the construction of the nodes in the graph. 664 665 Args: 666 python_graph: (ops.Graph) The Python Graph object. 667 """ 668 669 self._python_graph = python_graph 670 self._node_traceback = {} 671 if self._python_graph: 672 for op in self._python_graph.get_operations(): 673 self._node_traceback[op.name] = op.traceback 674 675 @property 676 def python_graph(self): 677 """Get the Python graph. 678 679 Returns: 680 If the Python graph has been set, returns a `tf.Graph` object. Otherwise, 681 returns None. 682 """ 683 684 return self._python_graph 685 686 @property 687 def core_metadata(self): 688 """Metadata about the `Session.run()` call from the core runtime. 689 690 Of the three counters available in the return value, `global_step` is 691 supplied by the caller of the debugged `Session.run()`, while 692 `session_run_index` and `executor_step_index` are determined by the state 693 of the core runtime, automatically. For the same fetch list, feed keys and 694 debug tensor watch options, the same executor will be used and 695 `executor_step_index` should increase by one at a time. However, runs with 696 different fetch lists, feed keys and debug_tensor watch options that all 697 share the same `Session` object can lead to gaps in `session_run_index`. 698 699 Returns: 700 If core metadata are loaded, a `namedtuple` with the fields: 701 `global_step`: A global step count supplied by the caller of 702 `Session.run()`. It is optional to the caller. If the caller did not 703 supply this parameter, its value will be -1. 704 `session_run_index`: A sorted index for Run() calls to the underlying 705 TensorFlow `Session` object. 706 `executor_step_index`: A counter for invocations of a given runtime 707 executor. The same executor is re-used for the same fetched tensors, 708 target nodes, input feed keys and debug tensor watch options. 709 `input_names`: Names of the input (feed) Tensors. 710 `output_names`: Names of the output (fetched) Tensors. 711 `target_nodes`: Names of the target nodes. 712 If the core metadata have not been loaded, `None`. 713 If more than one core metadata files exist, return a list of the 714 `nametuple` described above. 715 """ 716 717 output = self._core_metadata 718 return output[0] if len(output) == 1 else output 719 720 @property 721 def dumped_tensor_data(self): 722 """Retrieve dumped tensor data.""" 723 if len(self.devices()) == 1: 724 return self._dump_tensor_data[self.devices()[0]] 725 else: 726 all_devices_data = six.itervalues(self._dump_tensor_data) 727 data = [] 728 for device_data in all_devices_data: 729 data.extend(device_data) 730 return sorted(data, key=lambda x: x.extended_timestamp) 731 732 @property 733 def t0(self): 734 """Absolute timestamp of the first dumped tensor across all devices. 735 736 Returns: 737 (`int`) absolute timestamp of the first dumped tensor, in microseconds. 738 """ 739 return self._t0 740 741 @property 742 def size(self): 743 """Total number of dumped tensors in the dump root directory. 744 745 Returns: 746 (`int`) The total number of dumped tensors in the dump root directory. 747 """ 748 return sum(len(self._dump_tensor_data[device_name]) 749 for device_name in self._dump_tensor_data) 750 751 def _load_partition_graphs(self, client_partition_graphs, validate): 752 """Load and process partition graphs. 753 754 Load the graphs; parse the input and control input structure; obtain the 755 device and op type of each node; remove the Copy and debug ops inserted 756 by the debugger. The gathered information can be used to validate the 757 tensor dumps. 758 759 Args: 760 client_partition_graphs: A repeated field of GraphDefs representing the 761 partition graphs executed by the TensorFlow runtime, from the Python 762 client. These partition graphs are used only if partition graphs 763 cannot be loaded from the dump directory on the file system. 764 validate: (`bool`) Whether the dump files are to be validated against the 765 partition graphs. 766 767 Raises: 768 ValueError: If the partition GraphDef of one or more devices fail to be 769 loaded. 770 """ 771 self._debug_graphs = {} 772 self._node_devices = {} 773 774 partition_graphs_and_device_names = [] 775 for device_name in self._device_names: 776 partition_graph = None 777 if device_name in self._dump_graph_file_paths: 778 partition_graph = _load_graph_def_from_event_file( 779 self._dump_graph_file_paths[device_name]) 780 else: 781 logging.warn( 782 "Failed to load partition graphs for device %s from disk. " 783 "As a fallback, the client graphs will be used. This " 784 "may cause mismatches in device names." % device_name) 785 partition_graph = self._find_partition_graph(client_partition_graphs, 786 device_name) 787 788 if partition_graph: 789 partition_graphs_and_device_names.append((partition_graph, 790 device_name)) 791 792 for partition_graph, maybe_device_name in partition_graphs_and_device_names: 793 debug_graph = debug_graphs.DebugGraph(partition_graph, 794 device_name=maybe_device_name) 795 self._debug_graphs[debug_graph.device_name] = debug_graph 796 self._collect_node_devices(debug_graph) 797 798 if validate and debug_graph.device_name in self._dump_tensor_data: 799 self._validate_dump_with_graphs(debug_graph.device_name) 800 801 def _find_partition_graph(self, partition_graphs, device_name): 802 if partition_graphs is None: 803 return None 804 else: 805 for graph_def in partition_graphs: 806 for node_def in graph_def.node: 807 if node_def.device == device_name: 808 return graph_def 809 return None 810 811 def _collect_node_devices(self, debug_graph): 812 for node_name in debug_graph.node_devices: 813 if node_name in self._node_devices: 814 self._node_devices[node_name] = self._node_devices[node_name].union( 815 debug_graph.node_devices[node_name]) 816 else: 817 self._node_devices[node_name] = debug_graph.node_devices[node_name] 818 819 def _validate_dump_with_graphs(self, device_name): 820 """Validate the dumped tensor data against the partition graphs. 821 822 Only the watched nodes are validated by this method, because tfdbg allows 823 clients to watch only a subset of the nodes. 824 825 Args: 826 device_name: (`str`) device name. 827 828 Raises: 829 LookupError: If the partition graphs have not been loaded yet. 830 ValueError: If dumps contain node names not found in partition graph. 831 Or if the temporal order of the dump's timestamps violate the 832 input relations on the partition graphs. 833 """ 834 if not self._debug_graphs: 835 raise LookupError( 836 "No partition graphs loaded for device %s" % device_name) 837 debug_graph = self._debug_graphs[device_name] 838 839 # Verify that the node names in the dump data are all present in the 840 # partition graphs. 841 for datum in self._dump_tensor_data[device_name]: 842 if datum.node_name not in debug_graph.node_inputs: 843 raise ValueError("Node name '%s' is not found in partition graphs of " 844 "device %s." % (datum.node_name, device_name)) 845 846 pending_inputs = {} 847 for node in debug_graph.node_inputs: 848 pending_inputs[node] = [] 849 inputs = debug_graph.node_inputs[node] 850 for inp in inputs: 851 inp_node = debug_graphs.get_node_name(inp) 852 inp_output_slot = debug_graphs.get_output_slot(inp) 853 # Inputs from Enter and NextIteration nodes are not validated because 854 # DebugNodeInserter::InsertNodes() in the debugger core skips creating 855 # control edges from debug ops watching these types of nodes. 856 if (inp_node in self._debug_watches[device_name] and 857 inp_output_slot in self._debug_watches[device_name][inp_node] and 858 debug_graph.node_op_types.get(inp) not in ( 859 "Enter", "NextIteration") and 860 (inp_node, inp_output_slot) not in pending_inputs[node]): 861 pending_inputs[node].append((inp_node, inp_output_slot)) 862 863 for i, datum in enumerate(self._dump_tensor_data[device_name]): 864 node = datum.node_name 865 slot = datum.output_slot 866 # In some cases (e.g., system clocks with insufficient precision), 867 # the upstream and downstream tensors may have identical timestamps, the 868 # following check examines this possibility and avoids raising an error if 869 # that is the case. 870 if not self._satisfied_at_timestamp( 871 device_name, pending_inputs[node], datum.timestamp, start_i=i + 1): 872 raise ValueError("Causality violated in timing relations of debug " 873 "dumps: %s (%d): " 874 "these input(s) are not satisfied: %s" % 875 (node, datum.timestamp, repr(pending_inputs[node]))) 876 877 recipients = debug_graph.node_recipients[node] 878 for recipient in recipients: 879 recipient_pending_inputs = pending_inputs[recipient] 880 if (node, slot) in recipient_pending_inputs: 881 if self.node_op_type(recipient) == "Merge": 882 # If this is a Merge op, we automatically clear the list because 883 # a Merge node only requires one of its two inputs. 884 del recipient_pending_inputs[:] 885 else: 886 del recipient_pending_inputs[ 887 recipient_pending_inputs.index((node, slot))] 888 889 def _satisfied_at_timestamp(self, device_name, pending, timestamp, start_i=0): 890 """Determine whether pending inputs are satisfied at given timestamp. 891 892 Note: This method mutates the input argument "pending". 893 894 Args: 895 device_name: (str) device name. 896 pending: A list of 2-tuple (node_name, output_slot): the dependencies to 897 check. 898 timestamp: (int) the timestamp in question. 899 start_i: (int) the index in self._dump_tensor_data to start searching for 900 the timestamp. 901 902 Returns: 903 (bool) Whether all the dependencies in pending are satisfied at the 904 timestamp. If pending is empty to begin with, return True. 905 """ 906 if not pending: 907 return True 908 909 for datum in self._dump_tensor_data[device_name][start_i:]: 910 if datum.timestamp > timestamp: 911 break 912 if (datum.timestamp == timestamp and 913 (datum.node_name, datum.output_slot) in pending): 914 pending.remove((datum.node_name, datum.output_slot)) 915 if not pending: 916 return True 917 918 return not pending 919 920 def loaded_partition_graphs(self): 921 """Test whether partition graphs have been loaded.""" 922 return bool(self._debug_graphs) 923 924 def partition_graphs(self): 925 """Get the partition graphs. 926 927 Returns: 928 Partition graphs as a list of GraphDef. 929 930 Raises: 931 LookupError: If no partition graphs have been loaded. 932 """ 933 if not self._debug_graphs: 934 raise LookupError("No partition graphs have been loaded.") 935 return [self._debug_graphs[key].debug_graph_def 936 for key in self._debug_graphs] 937 938 def reconstructed_non_debug_partition_graphs(self): 939 """Reconstruct partition graphs with the debugger-inserted ops stripped. 940 941 The reconstructed partition graphs are identical to the original (i.e., 942 non-debugger-decorated) partition graphs except in the following respects: 943 1) The exact names of the runtime-inserted internal nodes may differ. 944 These include _Send, _Recv, _HostSend, _HostRecv, _Retval ops. 945 2) As a consequence of 1, the nodes that receive input directly from such 946 send- and recv-type ops will have different input names. 947 3) The parallel_iteration attribute of while-loop Enter ops are set to 1. 948 949 Returns: 950 A dict mapping device names (`str`s) to reconstructed `tf.GraphDef`s. 951 """ 952 non_debug_graphs = dict() 953 for key in self._debug_graphs: 954 non_debug_graphs[key] = self._debug_graphs[key].non_debug_graph_def 955 return non_debug_graphs 956 957 @property 958 def run_fetches_info(self): 959 """Get a str representation of the fetches used in the Session.run() call. 960 961 Returns: 962 If the information is available from one `Session.run` call, a `str` 963 obtained from `repr(fetches)`. 964 If the information is available from multiple `Session.run` calls, a 965 `list` of `str` from `repr(fetches)`. 966 If the information is not available, `None`. 967 """ 968 969 output = self._run_fetches_info 970 return output[0] if len(output) == 1 else output 971 972 @property 973 def run_feed_keys_info(self): 974 """Get a str representation of the feed_dict used in the Session.run() call. 975 976 Returns: 977 If the information is available from one `Session.run` call, a `str` 978 obtained from `repr(feed_dict)`. 979 If the information is available from multiple `Session.run` calls, a 980 `list` of `str` obtained from `repr(feed_dict)`. 981 If the information is not available, `None`. 982 """ 983 984 output = self._run_feed_keys_info 985 return output[0] if len(output) == 1 else output 986 987 def _infer_device_name(self, device_name, node_name): 988 """Infer the device name given node name. 989 990 If device_name is provided (i.e., not None), it'll be simply returned right 991 away. 992 993 Args: 994 device_name: (str or None) name of the device. If None, will try to infer 995 the device name by looking at the available nodes. 996 node_name: (str) name of the node. 997 998 Returns: 999 (str) Inferred name of the device, if available. 1000 1001 Raises: 1002 ValueError: If the node name does not exist on any of the available 1003 devices or if there are multiple devices that contain the node with 1004 the given name. 1005 """ 1006 if device_name is None: 1007 if node_name in self._node_devices: 1008 if len(self._node_devices[node_name]) == 1: 1009 return list(self._node_devices[node_name])[0] 1010 else: 1011 raise ValueError( 1012 "There are multiple (%d) devices with nodes named '%s' but " 1013 "device_name is not specified." % 1014 (len(self._node_devices[node_name]), node_name)) 1015 else: 1016 raise ValueError("None of the %d device(s) has a node named '%s'." % 1017 (len(self._device_names), node_name)) 1018 else: 1019 return device_name 1020 1021 def nodes(self, device_name=None): 1022 """Get a list of all nodes from the partition graphs. 1023 1024 Args: 1025 device_name: (`str`) name of device. If None, all nodes from all available 1026 devices will be included. 1027 1028 Returns: 1029 All nodes' names, as a list of str. 1030 1031 Raises: 1032 LookupError: If no partition graphs have been loaded. 1033 ValueError: If specified node name does not exist. 1034 """ 1035 if not self._debug_graphs: 1036 raise LookupError("No partition graphs have been loaded.") 1037 if device_name is None: 1038 nodes = [] 1039 for device_name in self._debug_graphs: 1040 nodes.extend(self._debug_graphs[device_name].node_inputs.keys()) 1041 return nodes 1042 else: 1043 if device_name not in self._debug_graphs: 1044 raise ValueError("Invalid device name: %s" % device_name) 1045 return self._debug_graphs[device_name].node_inputs.keys() 1046 1047 def node_attributes(self, node_name, device_name=None): 1048 """Get the attributes of a node. 1049 1050 Args: 1051 node_name: Name of the node in question. 1052 device_name: (`str`) name of the device. If there is only one device or if 1053 node_name exists on only one device, this argument is optional. 1054 1055 Returns: 1056 Attributes of the node. 1057 1058 Raises: 1059 LookupError: If no partition graphs have been loaded. 1060 """ 1061 if not self._debug_graphs: 1062 raise LookupError("No partition graphs have been loaded.") 1063 1064 device_name = self._infer_device_name(device_name, node_name) 1065 return self._debug_graphs[device_name].node_attributes[node_name] 1066 1067 def node_inputs(self, node_name, is_control=False, device_name=None): 1068 """Get the inputs of given node according to partition graphs. 1069 1070 Args: 1071 node_name: Name of the node. 1072 is_control: (`bool`) Whether control inputs, rather than non-control 1073 inputs, are to be returned. 1074 device_name: (`str`) name of the device. If there is only one device or if 1075 node_name exists on only one device, this argument is optional. 1076 1077 Returns: 1078 (`list` of `str`) inputs to the node, as a list of node names. 1079 1080 Raises: 1081 LookupError: If node inputs and control inputs have not been loaded 1082 from partition graphs yet. 1083 """ 1084 if not self._debug_graphs: 1085 raise LookupError( 1086 "Node inputs are not loaded from partition graphs yet.") 1087 1088 device_name = self._infer_device_name(device_name, node_name) 1089 if is_control: 1090 return self._debug_graphs[device_name].node_ctrl_inputs[node_name] 1091 else: 1092 return self._debug_graphs[device_name].node_inputs[node_name] 1093 1094 def transitive_inputs(self, 1095 node_name, 1096 include_control=True, 1097 include_reversed_ref=False, 1098 device_name=None,): 1099 """Get the transitive inputs of given node according to partition graphs. 1100 1101 Args: 1102 node_name: Name of the node. 1103 include_control: Include control inputs (True by default). 1104 include_reversed_ref: Whether a ref input, say from A to B, is to be also 1105 considered as an input from B to A. The rationale is that ref inputs 1106 generally let the recipient (e.g., B in this case) mutate the value of 1107 the source (e.g., A in this case). So the reverse direction of the ref 1108 edge reflects the direction of information flow. 1109 device_name: (`str`) name of the device. If there is only one device or if 1110 node_name exists on only one device, this argument is optional. 1111 1112 Returns: 1113 (`list` of `str`) all transitive inputs to the node, as a list of node 1114 names. 1115 1116 Raises: 1117 LookupError: If node inputs and control inputs have not been loaded 1118 from partition graphs yet. 1119 """ 1120 if not self._debug_graphs: 1121 raise LookupError( 1122 "Node inputs are not loaded from partition graphs yet.") 1123 1124 device_name = self._infer_device_name(device_name, node_name) 1125 1126 input_lists = [self._debug_graphs[device_name].node_inputs] 1127 if include_control: 1128 input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs) 1129 if include_reversed_ref: 1130 input_lists.append( 1131 self._debug_graphs[device_name].node_reversed_ref_inputs) 1132 tracer = debug_graphs.DFSGraphTracer( 1133 input_lists, 1134 skip_node_names=self._get_merge_node_names(device_name)) 1135 tracer.trace(node_name) 1136 return tracer.inputs() 1137 1138 def _get_merge_node_names(self, device_name): 1139 """Lazily get a list of Merge nodes on a given device.""" 1140 if device_name not in self._device_names: 1141 raise ValueError("Invalid device name: %s" % device_name) 1142 1143 if not hasattr(self, "_merge_node_names"): 1144 self._merge_node_names = {} 1145 if device_name not in self._merge_node_names: 1146 debug_graph = self._debug_graphs[device_name] 1147 self._merge_node_names[device_name] = [ 1148 node for node in debug_graph.node_op_types 1149 if debug_graph.node_op_types[node] == "Merge"] 1150 return self._merge_node_names[device_name] 1151 1152 def find_some_path(self, 1153 src_node_name, 1154 dst_node_name, 1155 include_control=True, 1156 include_reversed_ref=False, 1157 device_name=None): 1158 """Find a path between a source node and a destination node. 1159 1160 Limitation: the source and destination are required to be on the same 1161 device, i.e., this method does not yet take into account Send/Recv nodes 1162 across devices. 1163 1164 TODO(cais): Make this method work across device edges by tracing Send/Recv 1165 nodes. 1166 1167 Args: 1168 src_node_name: (`str`) name of the source node or name of an output tensor 1169 of the node. 1170 dst_node_name: (`str`) name of the destination node or name of an output 1171 tensor of the node. 1172 include_control: (`bool`) whrther control edges are considered in the 1173 graph tracing. 1174 include_reversed_ref: Whether a ref input, say from A to B, is to be also 1175 considered as an input from B to A. The rationale is that ref inputs 1176 generally let the recipient (e.g., B in this case) mutate the value of 1177 the source (e.g., A in this case). So the reverse direction of the ref 1178 edge reflects the direction of information flow. 1179 device_name: (`str`) name of the device. If there is only one device or if 1180 node_name exists on only one device, this argument is optional. 1181 1182 Returns: 1183 A path from the src_node_name to dst_node_name, as a `list` of `str`, if 1184 it exists. The list includes src_node_name as the first item and 1185 dst_node_name as the last. 1186 If such a path does not exist, `None`. 1187 1188 Raises: 1189 ValueError: If the source and destination nodes are not on the same 1190 device. 1191 """ 1192 src_device_name = self._infer_device_name(device_name, src_node_name) 1193 dst_device_name = self._infer_device_name(device_name, dst_node_name) 1194 1195 if src_device_name != dst_device_name: 1196 raise ValueError( 1197 "Source (%s) and destination (%s) are not on the same device: " 1198 "%s vs. %s" % (src_node_name, dst_node_name, src_device_name, 1199 dst_device_name)) 1200 1201 input_lists = [self._debug_graphs[dst_device_name].node_inputs] 1202 debug_graph = self._debug_graphs[dst_device_name] 1203 if include_control: 1204 input_lists.append(debug_graph.node_ctrl_inputs) 1205 if include_reversed_ref: 1206 input_lists.append(debug_graph.node_reversed_ref_inputs) 1207 tracer = debug_graphs.DFSGraphTracer( 1208 input_lists, 1209 skip_node_names=self._get_merge_node_names(dst_device_name), 1210 destination_node_name=src_node_name) 1211 # Here the value of destination_node_name is src_node_name, because we 1212 # are tracing the graph from output to its inputs (i.e., going backwards 1213 # on the graph). 1214 1215 try: 1216 tracer.trace(dst_node_name) 1217 except debug_graphs.GraphTracingReachedDestination: 1218 # Prune nodes not on the path. 1219 inputs = [dst_node_name] + tracer.inputs() 1220 depth_list = [0] + tracer.depth_list() 1221 1222 path = [] 1223 curr_depth = depth_list[-1] 1224 for inp, depth in zip(reversed(inputs), reversed(depth_list)): 1225 if depth == curr_depth: 1226 path.append(inp) 1227 curr_depth -= 1 1228 return path 1229 1230 def node_recipients(self, node_name, is_control=False, device_name=None): 1231 """Get recipient of the given node's output according to partition graphs. 1232 1233 Args: 1234 node_name: (`str`) name of the node. 1235 is_control: (`bool`) whether control outputs, rather than non-control 1236 outputs, are to be returned. 1237 device_name: (`str`) name of the device. If there is only one device or if 1238 node_name exists on only one device, this argument is optional. 1239 1240 Returns: 1241 (`list` of `str`) all inputs to the node, as a list of node names. 1242 1243 Raises: 1244 LookupError: If node inputs and control inputs have not been loaded 1245 from partition graphs yet. 1246 """ 1247 1248 if not self._debug_graphs: 1249 raise LookupError( 1250 "Node recipients are not loaded from partition graphs yet.") 1251 1252 device_name = self._infer_device_name(device_name, node_name) 1253 debug_graph = self._debug_graphs[device_name] 1254 if is_control: 1255 return debug_graph.node_ctrl_recipients[node_name] 1256 else: 1257 return debug_graph.node_recipients[node_name] 1258 1259 def devices(self): 1260 """Get the list of device names. 1261 1262 Returns: 1263 (`list` of `str`) names of the devices. 1264 """ 1265 return self._device_names 1266 1267 def node_exists(self, node_name, device_name=None): 1268 """Test if a node exists in the partition graphs. 1269 1270 Args: 1271 node_name: (`str`) name of the node to be checked. 1272 device_name: optional device name. If None, will search for the node 1273 on all available devices. Otherwise, search for the node only on 1274 the given device. 1275 1276 Returns: 1277 A boolean indicating whether the node exists. 1278 1279 Raises: 1280 LookupError: If no partition graphs have been loaded yet. 1281 ValueError: If device_name is specified but cannot be found. 1282 """ 1283 if not self._debug_graphs: 1284 raise LookupError( 1285 "Nodes have not been loaded from partition graphs yet.") 1286 1287 if (device_name is not None) and device_name not in self._debug_graphs: 1288 raise ValueError( 1289 "The specified device_name '%s' cannot be found." % device_name) 1290 1291 for _, debug_graph in self._debug_graphs.items(): 1292 if node_name in debug_graph.node_inputs: 1293 return True 1294 return False 1295 1296 def node_device(self, node_name): 1297 """Get the names of the devices that has nodes of the specified name. 1298 1299 Args: 1300 node_name: (`str`) name of the node. 1301 1302 Returns: 1303 (`str` or `list` of `str`) name of the device(s) on which the node of the 1304 given name is found. Returns a `str` if there is only one such device, 1305 otherwise return a `list` of `str`. 1306 1307 Raises: 1308 LookupError: If node inputs and control inputs have not been loaded 1309 from partition graphs yet. 1310 ValueError: If the node does not exist in partition graphs. 1311 """ 1312 if not self._debug_graphs: 1313 raise LookupError( 1314 "Node devices are not loaded from partition graphs yet.") 1315 1316 if node_name not in self._node_devices: 1317 raise ValueError("Node '%s' does not exist in partition graphs." % 1318 node_name) 1319 1320 output = list(self._node_devices[node_name]) 1321 return output[0] if len(output) == 1 else output 1322 1323 def node_op_type(self, node_name, device_name=None): 1324 """Get the op type of given node. 1325 1326 Args: 1327 node_name: (`str`) name of the node. 1328 device_name: (`str`) name of the device. If there is only one device or if 1329 node_name exists on only one device, this argument is optional. 1330 1331 Returns: 1332 (`str`) op type of the node. 1333 1334 Raises: 1335 LookupError: If node op types have not been loaded 1336 from partition graphs yet. 1337 """ 1338 if not self._debug_graphs: 1339 raise LookupError( 1340 "Node op types are not loaded from partition graphs yet.") 1341 1342 device_name = self._infer_device_name(device_name, node_name) 1343 return self._debug_graphs[device_name].node_op_types[node_name] 1344 1345 def debug_watch_keys(self, node_name, device_name=None): 1346 """Get all tensor watch keys of given node according to partition graphs. 1347 1348 Args: 1349 node_name: (`str`) name of the node. 1350 device_name: (`str`) name of the device. If there is only one device or if 1351 node_name exists on only one device, this argument is optional. 1352 1353 Returns: 1354 (`list` of `str`) all debug tensor watch keys. Returns an empty list if 1355 the node name does not correspond to any debug watch keys. 1356 1357 Raises: 1358 `LookupError`: If debug watch information has not been loaded from 1359 partition graphs yet. 1360 """ 1361 1362 try: 1363 device_name = self._infer_device_name(device_name, node_name) 1364 except ValueError: 1365 return [] 1366 1367 if node_name not in self._debug_watches[device_name]: 1368 return [] 1369 1370 watch_keys = [] 1371 for watched_slot in self._debug_watches[device_name][node_name]: 1372 debug_ops = self._debug_watches[device_name][node_name][watched_slot] 1373 for debug_op in debug_ops: 1374 watch_keys.append( 1375 _get_tensor_watch_key(node_name, watched_slot, debug_op)) 1376 1377 return watch_keys 1378 1379 def watch_key_to_data(self, debug_watch_key, device_name=None): 1380 """Get all `DebugTensorDatum` instances corresponding to a debug watch key. 1381 1382 Args: 1383 debug_watch_key: (`str`) debug watch key. 1384 device_name: (`str`) name of the device. If there is only one device or if 1385 the specified debug_watch_key exists on only one device, this argument 1386 is optional. 1387 1388 Returns: 1389 A list of `DebugTensorDatum` instances that correspond to the debug watch 1390 key. If the watch key does not exist, returns an empty list. 1391 1392 Raises: 1393 ValueError: If there are multiple devices that have the debug_watch_key, 1394 but device_name is not specified. 1395 """ 1396 if device_name is None: 1397 matching_device_names = [ 1398 name for name in self._watch_key_to_datum 1399 if debug_watch_key in self._watch_key_to_datum[name]] 1400 if not matching_device_names: 1401 return [] 1402 elif len(matching_device_names) == 1: 1403 device_name = matching_device_names[0] 1404 else: 1405 raise ValueError( 1406 "The debug watch key '%s' exists on multiple (%d) devices, but " 1407 "device name is not specified." % 1408 (debug_watch_key, len(matching_device_names))) 1409 elif device_name not in self._debug_key_to_datum: 1410 raise ValueError( 1411 "There is no device named '%s' consisting of debug watch keys." % 1412 device_name) 1413 1414 return self._watch_key_to_datum[device_name].get(debug_watch_key, []) 1415 1416 def find(self, 1417 predicate, 1418 first_n=0, 1419 device_name=None, 1420 exclude_node_names=None): 1421 """Find dumped tensor data by a certain predicate. 1422 1423 Args: 1424 predicate: A callable that takes two input arguments: 1425 1426 ```python 1427 def predicate(debug_tensor_datum, tensor): 1428 # returns a bool 1429 ``` 1430 1431 where `debug_tensor_datum` is an instance of `DebugTensorDatum`, which 1432 carries the metadata, such as the `Tensor`'s node name, output slot 1433 timestamp, debug op name, etc.; and `tensor` is the dumped tensor value 1434 as a `numpy.ndarray`. 1435 first_n: (`int`) return only the first n `DebugTensotDatum` instances (in 1436 time order) for which the predicate returns True. To return all the 1437 `DebugTensotDatum` instances, let first_n be <= 0. 1438 device_name: optional device name. 1439 exclude_node_names: Optional regular expression to exclude nodes with 1440 names matching the regular expression. 1441 1442 Returns: 1443 A list of all `DebugTensorDatum` objects in this `DebugDumpDir` object 1444 for which predicate returns True, sorted in ascending order of the 1445 timestamp. 1446 """ 1447 if exclude_node_names: 1448 exclude_node_names = re.compile(exclude_node_names) 1449 1450 matched_data = [] 1451 for device in (self._dump_tensor_data if device_name is None 1452 else (self._dump_tensor_data[device_name],)): 1453 for datum in self._dump_tensor_data[device]: 1454 if exclude_node_names and exclude_node_names.match(datum.node_name): 1455 continue 1456 1457 if predicate(datum, datum.get_tensor()): 1458 matched_data.append(datum) 1459 1460 if first_n > 0 and len(matched_data) >= first_n: 1461 return matched_data 1462 1463 return matched_data 1464 1465 def get_tensor_file_paths(self, 1466 node_name, 1467 output_slot, 1468 debug_op, 1469 device_name=None): 1470 """Get the file paths from a debug-dumped tensor. 1471 1472 Args: 1473 node_name: (`str`) name of the node that the tensor is produced by. 1474 output_slot: (`int`) output slot index of tensor. 1475 debug_op: (`str`) name of the debug op. 1476 device_name: (`str`) name of the device. If there is only one device or if 1477 the specified debug_watch_key exists on only one device, this argument 1478 is optional. 1479 1480 Returns: 1481 List of file path(s) loaded. This is a list because each debugged tensor 1482 may be dumped multiple times. 1483 1484 Raises: 1485 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in 1486 the debug-dump data. 1487 """ 1488 1489 device_name = self._infer_device_name(device_name, node_name) 1490 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1491 if watch_key not in self._watch_key_to_datum[device_name]: 1492 raise WatchKeyDoesNotExistInDebugDumpDirError( 1493 "Watch key \"%s\" does not exist in the debug dump of device %s" % 1494 (watch_key, device_name)) 1495 1496 return [datum.file_path for datum in 1497 self._watch_key_to_datum[device_name][watch_key]] 1498 1499 def get_tensors(self, node_name, output_slot, debug_op, device_name=None): 1500 """Get the tensor value from for a debug-dumped tensor. 1501 1502 The tensor may be dumped multiple times in the dump root directory, so a 1503 list of tensors (`numpy.ndarray`) is returned. 1504 1505 Args: 1506 node_name: (`str`) name of the node that the tensor is produced by. 1507 output_slot: (`int`) output slot index of tensor. 1508 debug_op: (`str`) name of the debug op. 1509 device_name: (`str`) name of the device. If there is only one device or if 1510 the specified debug_watch_key exists on only one device, this argument 1511 is optional. 1512 1513 Returns: 1514 List of tensors (`numpy.ndarray`) loaded from the debug-dump file(s). 1515 1516 Raises: 1517 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor does not exist in 1518 the debug-dump data. 1519 """ 1520 1521 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1522 try: 1523 device_name = self._infer_device_name(device_name, node_name) 1524 return [datum.get_tensor() for datum in 1525 self._watch_key_to_datum[device_name][watch_key]] 1526 except (ValueError, KeyError): 1527 raise WatchKeyDoesNotExistInDebugDumpDirError( 1528 "Watch key \"%s\" does not exist in the debug dump of device %s" % 1529 (watch_key, device_name)) 1530 1531 def get_rel_timestamps(self, 1532 node_name, 1533 output_slot, 1534 debug_op, 1535 device_name=None): 1536 """Get the relative timestamp from for a debug-dumped tensor. 1537 1538 Relative timestamp means (absolute timestamp - `t0`), where `t0` is the 1539 absolute timestamp of the first dumped tensor in the dump root. The tensor 1540 may be dumped multiple times in the dump root directory, so a list of 1541 relative timestamps (`numpy.ndarray`) is returned. 1542 1543 Args: 1544 node_name: (`str`) name of the node that the tensor is produced by. 1545 output_slot: (`int`) output slot index of tensor. 1546 debug_op: (`str`) name of the debug op. 1547 device_name: (`str`) name of the device. If there is only one device or if 1548 the specified debug_watch_key exists on only one device, this argument 1549 is optional. 1550 1551 Returns: 1552 (`list` of `int`) list of relative timestamps. 1553 1554 Raises: 1555 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not 1556 exist in the debug dump data. 1557 """ 1558 1559 device_name = self._infer_device_name(device_name, node_name) 1560 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1561 if watch_key not in self._watch_key_to_datum[device_name]: 1562 raise WatchKeyDoesNotExistInDebugDumpDirError( 1563 "Watch key \"%s\" does not exist in the debug dump" % watch_key) 1564 1565 # TODO(cais): Figure out whether this should be relative to the global t0. 1566 return self._watch_key_to_rel_time[device_name][watch_key] 1567 1568 def get_dump_sizes_bytes(self, 1569 node_name, 1570 output_slot, 1571 debug_op, 1572 device_name=None): 1573 """Get the sizes of the dump files for a debug-dumped tensor. 1574 1575 Unit of the file size: byte. 1576 1577 Args: 1578 node_name: (`str`) name of the node that the tensor is produced by. 1579 output_slot: (`int`) output slot index of tensor. 1580 debug_op: (`str`) name of the debug op. 1581 device_name: (`str`) name of the device. If there is only one device or if 1582 the specified debug_watch_key exists on only one device, this argument 1583 is optional. 1584 1585 Returns: 1586 (`list` of `int`): list of dump file sizes in bytes. 1587 1588 Raises: 1589 WatchKeyDoesNotExistInDebugDumpDirError: If the tensor watch key does not 1590 exist in the debug dump data. 1591 """ 1592 1593 device_name = self._infer_device_name(device_name, node_name) 1594 watch_key = _get_tensor_watch_key(node_name, output_slot, debug_op) 1595 if watch_key not in self._watch_key_to_datum[device_name]: 1596 raise WatchKeyDoesNotExistInDebugDumpDirError( 1597 "Watch key \"%s\" does not exist in the debug dump of device %s" % 1598 (watch_key, device_name)) 1599 1600 return self._watch_key_to_dump_size_bytes[device_name][watch_key] 1601 1602 def node_traceback(self, element_name): 1603 """Try to retrieve the Python traceback of node's construction. 1604 1605 Args: 1606 element_name: (`str`) Name of a graph element (node or tensor). 1607 1608 Returns: 1609 (list) The traceback list object as returned by the `extract_trace` 1610 method of Python's traceback module. 1611 1612 Raises: 1613 LookupError: If Python graph is not available for traceback lookup. 1614 KeyError: If the node cannot be found in the Python graph loaded. 1615 """ 1616 1617 if self._python_graph is None: 1618 raise LookupError("Python graph is not available for traceback lookup") 1619 1620 node_name = debug_graphs.get_node_name(element_name) 1621 if node_name not in self._node_traceback: 1622 raise KeyError("Cannot find node \"%s\" in Python graph" % node_name) 1623 1624 return self._node_traceback[node_name] 1625