1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ======================================================================== 15"""A utility to trace tensor values on TPU.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import operator 22 23import os 24import os.path 25import sys 26 27import numpy as np 28import six 29 30from tensorflow.core.framework import summary_pb2 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import func_graph 34from tensorflow.python.framework import function 35from tensorflow.python.framework import graph_io 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tensor_util 38from tensorflow.python.lib.io import file_io 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import control_flow_ops 41from tensorflow.python.ops import control_flow_util 42from tensorflow.python.ops import gen_math_ops 43from tensorflow.python.ops import init_ops 44from tensorflow.python.ops import linalg_ops 45from tensorflow.python.ops import logging_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import nn_impl 48from tensorflow.python.ops import state_ops 49from tensorflow.python.ops import string_ops 50from tensorflow.python.ops import summary_ops_v2 as summary 51from tensorflow.python.ops import variable_scope 52from tensorflow.python.platform import analytics 53from tensorflow.python.platform import gfile 54from tensorflow.python.platform import remote_utils 55from tensorflow.python.platform import tf_logging as logging 56from tensorflow.python.summary import summary_iterator 57from tensorflow.python.tpu import tensor_tracer_flags 58from tensorflow.python.tpu import tensor_tracer_report 59from tensorflow.python.tpu import tpu 60from tensorflow.python.tpu.ops import tpu_ops 61from tensorflow.python.training import training_util 62 63_DEVICE_TYPE_TPU = 'tpu' 64_DEVICE_TYPE_CPU = 'cpu' 65_TRACE_MODE_PART_TENSOR_SIZE = 3 66 67_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' 68_REASON_UNSAFE_OP = 'not-traced-unsafe-op' 69_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op' 70_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op' 71_REASON_IN_CONTROL_FLOW = 'not-traced-in-control-flow' 72_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' 73_REASON_SKIP_SCALAR = 'not-traced-scalar' 74_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' 75_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' 76_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' 77_REASON_SCALAR_GET_TRACED = 'traced-scalar' 78_REASON_TENSOR_GET_TRACED = 'traced-tensor' 79_REASON_USER_INCLUDED = 'traced-user-included' 80_REASON_USER_EXCLUDED = 'not-traced-user-excluded' 81_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' 82_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' 83_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op' 84 85_OUTPUT_STREAM_ESCAPE = 'file://' 86_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' 87TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers' 88_TRACE_FILE_NAME = 'trace.all' 89_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.' 90_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0 91_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage' 92_TT_SNAPSHOT = 'tensor_tracer_snapshot' 93_REPLICA_ID_TAG = '#replica-id: ' 94_SKIP_REPORT_FILE = 'None' # Do not write report proto if --report_file=None 95 96_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM 97_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX 98_TT_SUMMARY_MAX_ABS = tensor_tracer_flags.TT_SUMMARY_MAX_ABS 99_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN 100_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN 101_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR 102_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE 103 104_TT_SUMMARY_TAG = 'tensor_tracer_summary' 105_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer' 106_TT_HOSTCALL_KEY = 'tensor_tracer_host_call' 107_TT_EVENT_FILE_SUFFIX = '.tensor_tracer' 108 109_TT_SUMMARY_MAX_QUEUE = 10 110 111 112def set_parameters(tensor_tracer_params=None): 113 """Enables tensor tracer and sets its parameters. 114 115 Example usage: 116 tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir', 117 'trace_mode': 'norm', 118 'report_file': '/usr/tmp/trace_dir/report.all'} 119 tensor_tracer.set_parameters(tensor_tracer_parameters) 120 121 This sets up the parameters for tensor tracer. A call to tensor tracer as 122 below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be 123 skipped as this call is hooked into tpu.rewrite. 124 tt = tensor_tracer.TensorTracer() 125 loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss) 126 127 Args: 128 tensor_tracer_params: Tensor tracer parameter dictionary. Below gives 129 examples of these parameters: See tensor_tracer_report.py for all 130 parameters. 131 - enable: If set, tensor tracer will be enabled. Calling 132 enable_tensor_tracer automatically adds this parameters. 133 - trace_mode: The trace_mode to be used by tensor tracer. These include: 134 - summary: Collects multiple statistics for traced tensors, and writes 135 them a summary file that can be visualized using tensorboard. This 136 mode currently only works for TPUEstimator. It can be also be used 137 for other models, but outfeed must be handled by the user. 138 - norm: Collects norm of each traced tensor and writes them into a 139 text file pointed by 'trace_dir' flag. (Default mode). 140 - nan-inf: Checks the existince of NaNs and Infs in the tensor, and 141 writes a boolean value to a text file pointed by 'trace_dir' flag. 142 Note that 'norm' mode can also capture this information with more 143 numerical info. 144 - max-abs: Collects the absolute max for each traced tensors and 145 writes it into a text file pointed by 'trace_dir' flag. 146 - full-tensor: Writes the full tensor content of the traced tensors 147 into a text file pointed by 'trace_dir' flag. 148 - part-tensor: Writes a part of the tensor content of the traced 149 tensors into a text file pointed by 'trace_dir' flag. 150 - full_tensor_summary: Writes the full tensors as binary event files. 151 The outputs can be read using: trace = 152 tensor_tracer.read_tensor_tracer_event_file(event_file_path) 153 154 - report_file: Path to the metadata file that is written during graph 155 construction. If not set, metadata will be printed to stdout during 156 graph construction. 157 - trace_dir: Path where the execution traces will be written during the 158 graph execution. If not set, trace will be printed to stderr. 159 - trace_level: Tensor tracer aims to trace everything it can. This 160 introduces some overhead on graph execution and graph compilation 161 times. Using trace_level parameter, it is possible to trace operation 162 based on their priorities. For example, - trace_level=7 is the highest 163 trace_level, in which every op is traced. - trace_level=6 will skip 164 constant operations such as tf.constant. - trace_level=5 will skip 165 less important ops such as tf.identities. - The default trace_level=3, 166 that will skip concat ops, or random number generators. - To reduce 167 the graph compile time overhead, trace_level can be set to 0, that 168 will skip additions, and substractions, and multiplications as well. 169 - excluded_opnames: If set, any matching op name will not be traced. 170 excluded_opnames can be set as a regular expression. E.g, 171 excluded_opnames=.* will exclude everything. 172 - excluded_optypes: If set, any matching op type will not be traced. 173 excluded_optypes can be set as a regular expression. E.g, 174 excluded_optypes=.* will exclude everything. excluded_optypes=MatMul 175 will exclude all MatMul ops from tracing. 176 - included_opnames: If set, any matching op name will be forced to be 177 traced. included_opnames can be set as a regular expression. E.g, 178 '--included_opnames=some_op --excluded_opname=*.' will only trace 179 some_op. 180 - included_optypes: If set, any matching op type will be forced to be 181 traced. included_optypes can be set as a regular expression. E.g, 182 '--included_optypes=some_op_type --excluded_optypes=*.' will trace 183 only the ops with type 'some_op_type' 184 - flush_summaries: If summary mode is used, flush_summaries=1 will 185 flush summaries using outside compilation. Note that, if used with 186 low level APIs, flush_summaries=1 is necessary to obtain results. 187 Advanced Flags: 188 - trace_scalar: Scalar values are not traced by default. If this flag is 189 set, scalar values will also be traced. 190 - op_range: In the form of '%d:%d' that limits the tracing to the ops 191 within this limit. --op_range='5:10' will trace only the ops that have 192 topological order between 5-10. 193 - submode: 'brief' or 'detailed'. If the trace mode is not compact, 194 brief mode will print only the id of each traced tensor to save some 195 space. 'detailed' mode prints the full tensor name. 196 - use_fingerprint_subdirectory: The trace directory will be chosen as 197 using the fingerprint of the trace metadata under the provided 198 trace_dir. 199 """ 200 flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE 201 if tensor_tracer_params: 202 for key, value in tensor_tracer_params.items(): 203 flags += ' --%s=%s' % (key, value) 204 os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = flags 205 206 207def op_priority(op_type): 208 """Returns the priority of the op. 209 210 If the priority of the op is k, it will be traced if trace_level>=k. 211 Args: 212 op_type: String name of the operation type. 213 Returns: 214 Integer value corresponding the priority of the op. 215 """ 216 if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range', 217 'VariableShape', 'Fill', 'OneHot', 'ShapeN'): 218 # Lowest priority ops, e.g., constant ops across different steps, 219 # They will be traced only if trace_level>=7 220 return 7 221 222 if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient', 223 'PreventGradient', 'Squeeze'): 224 # Operations without numerical effects. 225 # They will be only if trace_level>=6 226 return 6 227 if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile', 228 'CollectivePermute', 'SplitV'): 229 # Operations that merge or slice an input, will be traced if trace_level>=5 230 return 5 231 if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'): 232 # Operations less likely to provide useful information, 233 # will be traced if trace_level>=4 234 return 4 235 if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'): 236 # Add operations that are less likely create any issues, will be traced 237 # if trace_level>=3 (default=3) 238 return 3 239 if op_type in ('Neg', 'Sub'): 240 # Sub operations that are less likely create any issues, will be traced 241 # trace_level>=2 242 return 2 243 if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select', 244 'Maximum', 'Mean', 'Variance'): 245 # Multiplication and some other operations, will be traced if trace_level>=1 246 return 1 247 return 0 248 249 250def read_tensor_tracer_event_file(event_file): 251 """Reads the event file written by tensor tracer. 252 253 This can be used to read the full tensors written into binary event files by 254 by TensorTracer with trace_mode=full_tensor_summary. 255 256 Example usage: 257 result_dict = tensor_tracer.read_tensor_tracer_event_file(event_file_path) 258 for step, tensor_dict in result_dict.items(): 259 for tensor_name, full_tensor_content in tensor_dict.items(): 260 logging.info(tensor_name, full_tensor_content) 261 262 Args: 263 event_file: Path to the event file that contains only tensor tracer events. 264 Returns: 265 An event dictionary in the form of 266 {step_number: {tensor_name: tensor_content}} 267 Raises: 268 ValueError: If an unexpected trace is found. 269 """ 270 event_dict = {} 271 for trace_event in summary_iterator.summary_iterator(event_file): 272 # First event is an event with file_version: "brain.Event:2" 273 if not trace_event.HasField('summary'): 274 continue 275 step = trace_event.step 276 if step not in event_dict: 277 event_dict[step] = {} 278 279 if len(trace_event.summary.value) != 1: 280 raise ValueError('Single step contains %d summary values,' 281 ' expected 1.' % len(trace_event.summary.value)) 282 tensor_value = trace_event.summary.value[0] 283 tensor_name = tensor_value.tag 284 285 real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim] 286 tensor_content = np.frombuffer( 287 tensor_value.tensor.tensor_content, 288 dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype() 289 ).reshape(real_shape) 290 event_dict[step][tensor_name] = tensor_content 291 return event_dict 292 293 294def trace_tensor(tensor, tracepoint_name=None): 295 """Programmatic interface to trace a tensor with Tensor Tracer. 296 297 Tensor Tracer, by default, traces all tensors in the execution. This function 298 can be used to limit traced tensors. If this function is called for a subset 299 of the tensors, only those will be traced. 300 301 For example, Tensor Traacer will only trace c below. 302 c = tf.MatMul(a, b) 303 tensor_tracer.trace_tensor(c) 304 d = tf.add(c, 1) 305 Args: 306 tensor: the tensor object for which the tracing is requested. 307 tracepoint_name: an optional tensor tracepoint name string. A tracepoint 308 name is an Tensor Tracer internal name for the tensor. It is useful when 309 comparing equivalent traces from different models that have different 310 tensor namings. Equivalent tensors (with different names) can be mapped 311 to each other by assigning a common tracepoint_name. 312 313 Returns: 314 The provided tensor. 315 """ 316 if tracepoint_name is None: 317 tracepoint_name = tensor.name 318 tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) 319 tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, 320 (tensor, tracepoint_name)) 321 return tensor 322 323 324def keras_layer_tracepoint(layer, checkpoint_name): 325 """An interface for adding the tensor outputs of a keras layer. 326 327 Encapsulates trace_tensor. 328 329 Args: 330 layer: A keras layer. 331 checkpoint_name: a string name for the checkpoint. This name has to be a 332 unique name if used within model comparison. The tensors that have the same 333 checkpoint identifier is compared in model comparison. 334 335 Returns: 336 The provided layer. 337 """ 338 try: 339 outputs = layer.output 340 if tensor_util.is_tf_type(outputs): 341 trace_tensor(outputs, '%s' % (checkpoint_name)) 342 else: 343 idx = 0 344 for output_tensor in outputs: 345 if tensor_util.is_tf_type(outputs): 346 trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx)) 347 idx += 1 348 except AttributeError: 349 pass 350 except RuntimeError: 351 pass 352 return layer 353 354 355class TensorTracer(object): 356 """A software construct for tracing tensor values in a TF graph. 357 358 This utility is disabled by default. It is hooked into tpu.rewrite, so it can 359 easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as 360 below without a code change. 361 export TENSOR_TRACER_FLAGS="--enable=1" 362 363 Below is the use example to enable it on CPUs or GPUs, or for more advance use 364 cases on TPUs. 365 366 a = x + 1 367 b = a * 2 368 rs = tf.reduce_sum(b) 369 tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir', 370 'report_file: 'path/to/report/file'}) 371 tt = tensor_tracer.TensorTracer() 372 if on_tpu: 373 rs = tt.trace_tpu(tf.get_default_graph(), 374 tensor_fetches=rs) 375 else: 376 rs = tt.trace_cpu(tf.get_default_graph(), 377 tensor_fetches=rs) 378 session.run(rs) 379 380 If it is enabled, it will trace the output tensor values of 381 selected Ops in the graph. It has two outputs: (1) the traces and (2) 382 a report. The traces are dumped to a specified directory during the graph 383 execution, while the report is dumped during the graph construction. 384 By passing options via the env variable, users can change: 385 (1) the trace mode (e.g., detecting NaN/Inf, printing partial or 386 full tensor values) 387 (2) which Ops to be traced (via op.name or op.type) 388 (3) output trace file path. 389 390 """ 391 # The set of graphs that are rewritten by tensor tracer. 392 _traced_graphs = set() 393 394 @staticmethod 395 def is_enabled(): 396 """Returns True if TensorTracer is enabled.""" 397 return tensor_tracer_flags.TTParameters().is_enabled() 398 399 @staticmethod 400 def check_device_type(device_type): 401 """Checks if the given device type is valid.""" 402 403 if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU): 404 raise ValueError('Invalid device_type "%s"'%device_type) 405 406 @staticmethod 407 def check_trace_mode(device_type, trace_mode): 408 """Checks if the given trace mode work on the given device type. 409 410 Args: 411 device_type: Device type, TPU, GPU, CPU. 412 trace_mode: Tensor tracer trace mode. 413 Raises: 414 ValueError: If the given trace mode is not supported for the device. 415 """ 416 if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY: 417 if device_type != _DEVICE_TYPE_TPU: 418 raise ValueError('Device_type "%s" is not yet supported for ' 419 'trace mode "%s"' % (device_type, trace_mode)) 420 421 @staticmethod 422 def loop_cond_op(op): 423 return op.type in ('LoopCond', 'RefLoopCond') 424 425 @staticmethod 426 def while_loop_op(op): 427 """Returns true if op is one of the special ops of in a while loop. 428 429 Args: 430 op: A tf.Operation. 431 432 Returns: 433 True if the given op is one of [Switch, Merge, Enter, Exit, 434 NextIteration, LoopCond], which are all building blocks for TF while 435 loops. 436 """ 437 return (control_flow_util.IsLoopSwitch(op) or 438 control_flow_util.IsLoopMerge(op) or 439 control_flow_util.IsLoopEnter(op) or 440 control_flow_util.IsLoopExit(op) or 441 TensorTracer.loop_cond_op(op) or 442 op.type in ('RefNextIteration', 'NextIteration')) 443 444 @staticmethod 445 def control_flow_op(op): 446 """Returns true if op is one of the special ops of in a while loop. 447 448 Args: 449 op: A tf.Operation. 450 451 Returns: 452 True if the given op is one of [Switch, Merge, Enter, Exit, 453 NextIteration, LoopCond], which are all building blocks for TF while 454 loops. 455 """ 456 return (control_flow_util.IsSwitch(op) or 457 control_flow_util.IsMerge(op)) 458 459 @staticmethod 460 def unsafe_op(op): 461 """Returns True if this op is not safe to be traced.""" 462 463 # Reasons for not including following op types: 464 # Assign: cause incorrect result with CPU tracing. 465 if op.type == 'Assign': 466 return True 467 return False 468 469 @staticmethod 470 def device_mismatch(device_type, op): 471 if device_type == _DEVICE_TYPE_TPU: 472 # pylint: disable=protected-access 473 return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr 474 # pylint: enable=protected-access 475 return False 476 477 @staticmethod 478 def unsafe_scalar_trace(op): 479 """Return true if scalar output tensor from Op is not safe to be traced.""" 480 481 # Tracing the following causes cycle in the graph on TPU. 482 if op.type in ('LoopCond', 'Enter', 'Merge', 'Const', 483 'Switch', 'Less', 'ReadVariableOp'): 484 return True 485 # Tracing the following will cause casting-issue 486 # with the norm tracing mode or other compilation issues on CPU. 487 if op.type in ('VarHandleOp', 'IteratorToStringHandle', 488 'IteratorGetNext', 'OneShotIterator', 489 'IteratorV2', 'MakeIterator', 490 'BatchDatasetV2', 'MapDataset', 491 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', 492 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'): 493 return True 494 return False 495 496 def _is_interesting_op(self, op): 497 """Returns True if the given op is not an interesting one to be traced.""" 498 return op_priority(op.type) <= self._parameters.trace_level 499 500 @staticmethod 501 def reason(op_idx, details): 502 """Returns reason why the Op at op_idx is traced or not.""" 503 504 return '%d %s'%(op_idx, details) 505 506 def __init__(self): 507 """Initializes a TensorTracer. 508 509 Sets the various member fields from the flags (if given) or the defaults. 510 """ 511 self._replica_id = None 512 self._tt_config = tensor_tracer_report.TensorTracerConfig() 513 self._parameters = None 514 self._host_call_fn = {} 515 self._cache_variables = {} 516 self._traced_op_names = set() 517 self._report_proto = None 518 self._temp_cache_var = [] 519 self._report_proto_path = '' 520 self._outmost_context = None 521 522 def report_proto(self): 523 """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes. 524 525 Returns: 526 A tensor_tracer.proto object. 527 Raises: 528 ValueError if called before tracing happens, or when trace mode is not 529 summary or full_tensor_summary. 530 """ 531 if self._report_proto: 532 return self._report_proto 533 else: 534 raise ValueError('Call to report_proto must be done after tracing.' 535 'Report proto only exists for ' 536 'trace_mode=[summary|full_tensor_summary]') 537 538 def report_proto_path(self): 539 """Getter for path where tensor_tracer.proto object should be written. 540 541 Returns: 542 A string path. 543 """ 544 return self._report_proto_path 545 546 def _get_all_cache_variables(self): 547 return self._cache_variables 548 549 def _create_or_get_tensor_values_cache(self, cache_name, graph=None, 550 shape=None, dtype=dtypes.float32): 551 """Creates a variable as the cache to store intermediate tensor values. 552 553 Args: 554 cache_name: Name to be given to the cache (an instance of tf.variable). 555 graph: Tensorflow graph. 556 shape: A list of dimensions. 557 dtype: Data type of created cache. 558 Returns: 559 A ref to newly created or existing cache with the given dimensions. 560 Raises: 561 ValueError: If missing a parameter to create the cache. 562 """ 563 def _escape_namescopes(variable_name): 564 # TODO(deveci): This might cause name collisions as in "foo/bar/mytensor" 565 # and "foo_bar/mytensor". 566 return variable_name.replace('/', '_').replace(':', '_') 567 568 if cache_name not in self._cache_variables: 569 if graph is None: 570 raise ValueError('Graph must be provided at cache creation.') 571 if shape is None: 572 raise ValueError('shape must be provided at cache creation.') 573 graph = graph or ops.get_default_graph() 574 if dtype.is_integer: 575 init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE) 576 else: 577 init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE 578 579 # Create in proper graph and base name_scope. 580 with graph.as_default() as g, g.name_scope(None): 581 self._cache_variables[cache_name] = variable_scope.get_variable( 582 _TT_SNAPSHOT + '_' + _escape_namescopes(cache_name), 583 shape=shape, dtype=dtype, 584 initializer=init_ops.constant_initializer(init_val), 585 trainable=False, 586 use_resource=True, 587 collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES]) 588 return self._cache_variables[cache_name] 589 590 def _add_replica_id_to_graph(self): 591 """Adds nodes for computing the replica ID to the graph.""" 592 593 if self._tt_config.num_replicas: 594 with ops.control_dependencies(None): 595 # Uses None as dependency to run outside of TPU graph rewrites. 596 self._replica_id = tpu_ops.tpu_replicated_input( 597 list(range(self._tt_config.num_replicas)), 598 name='tt_replica_id') 599 else: 600 self._replica_id = 'unknown' 601 602 def _inside_op_range(self, idx): 603 """Return True if the given index is inside the selected range.""" 604 605 if idx < self._parameters.op_range[0]: 606 return False 607 return (self._parameters.op_range[1] < 0 or 608 idx <= self._parameters.op_range[1]) 609 610 def _is_user_included_op(self, op): 611 """Checks whether the op is included in the tensor tracer flags. 612 613 Args: 614 op: tf Operation 615 Returns: 616 True, if the op is included. 617 An op is included if: 618 - Its op name is given in included_opnames 619 - Its op type is given in included_optypes 620 - The op is at most _trace_ops_before_included hops before an included op 621 - The op is at most _trace_ops_after_included hops after an included op 622 """ 623 for opname_re in self._parameters.included_opname_re_list: 624 if opname_re.match(op.name): 625 return True 626 627 for optype_re in self._parameters.included_optype_re_list: 628 if optype_re.match(op.type): 629 return True 630 return False 631 632 def _is_user_excluded_op(self, op): 633 for opname_re in self._parameters.excluded_opname_re_list: 634 if opname_re.match(op.name): 635 return True 636 for optype_re in self._parameters.excluded_optype_re_list: 637 if optype_re.match(op.type): 638 return True 639 return False 640 641 def _signature_types(self): 642 """Returns a dictionary holding the order of signatures in the cache for the selected trace mode.""" 643 if self._parameters.trace_mode in set([ 644 tensor_tracer_flags.TRACE_MODE_NAN_INF, 645 tensor_tracer_flags.TRACE_MODE_NORM, 646 tensor_tracer_flags.TRACE_MODE_MAX_ABS]): 647 return {self._parameters.trace_mode: 0} 648 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 649 return self._parameters.summary_signatures 650 return {} 651 652 def _num_signature_dimensions(self): 653 return len(self._signature_types()) 654 655 def _use_temp_cache(self): 656 """Returns true if the intermediate values should be stacked instead of being stored in a tf.Variable. 657 658 Returns: 659 A boolean, denoting whether to use a temporary cache or not. 660 """ 661 # If full tensors need to be stored tf.variables, then do not use temp 662 # variables to store them. 663 if self._use_tensor_buffer(): 664 return False 665 if self._use_tensor_values_cache(): 666 return self._parameters.use_temp_cache_var 667 else: 668 # Temporary caches only replaces tf.Variables caches. If no cache is used 669 # return False. 670 return False 671 672 def _use_tensor_values_cache(self): 673 """Returns True if immediate tensors should be first saved to a cache.""" 674 return self._parameters.use_compact_trace 675 676 def _use_tensor_buffer(self): 677 """Returns true if the whole tensor needs to be cached/buffered in memory.""" 678 return (self._parameters.trace_mode == 679 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) 680 681 def _merge_tensor_signatures(self, signatures): 682 """Returns a tensor that merges the given signatures. 683 684 Args: 685 signatures: A dictionary of the signature updates from signature name to 686 a tensor of dimension [1]. 687 Returns: 688 A tensor that concats the signature values in a predefined order. 689 """ 690 sorted_update = [] 691 if self._num_signature_dimensions() > 1: 692 signature_indices = self._signature_types() 693 for _, val in sorted(signatures.items(), 694 key=lambda item: signature_indices[item[0]]): 695 sorted_update.append(val) 696 updates = array_ops.stack( 697 sorted_update, axis=0, name='merge_single_op_signatures') 698 elif self._num_signature_dimensions() == 1: 699 # Avoid stack operation if there is only a single signature. 700 (_, val), = signatures.items() 701 updates = val 702 else: 703 raise ValueError('Cannot merge 0 signatures.') 704 return updates 705 706 def _save_tensor_value_to_tmp_cache(self, cache_idx, updates): 707 """Returns an op that will save the given updates to an entry in the cache. 708 709 Args: 710 cache_idx: The cache index of the tensor within the cache. 711 updates: A dictionary of the signature updates from signature name to 712 a tensor of dimension [1]. 713 """ 714 updates = self._merge_tensor_signatures(updates) 715 updates = array_ops.reshape(updates, 716 [self._num_signature_dimensions()]) 717 self._temp_cache_var[cache_idx] = updates 718 719 def _save_tensor_value_to_cache_op(self, cache_idx, updates): 720 """Returns an op that will save the given updates to an entry in the cache. 721 722 Args: 723 cache_idx: The cache index of the tensor within the cache. 724 updates: A dictionary of the signature updates. 725 Returns: 726 Cache update operation. 727 """ 728 # state_ops.scatter_update allows updates only along the first dimension. 729 # Make a compact array by concatenating different signatures, and update 730 # them all together. 731 updates = self._merge_tensor_signatures(updates) 732 updates = array_ops.reshape(updates, 733 [1, self._num_signature_dimensions()]) 734 indices = constant_op.constant([cache_idx]) 735 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG) 736 return state_ops.scatter_update(cache, indices, updates).op 737 738 def _snapshot_tensor(self, tensor): 739 """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable. 740 741 Args: 742 tensor: tensor whose values will be stored in a new tf.Variable. 743 Returns: 744 An assignment operation. 745 """ 746 747 snapshot_variable = self._create_or_get_tensor_values_cache( 748 tensor.name, tensor.op.graph, 749 tensor.shape.as_list(), tensor.dtype) 750 return state_ops.assign(snapshot_variable, tensor).op 751 752 def _preprocess_traced_tensor(self, tensor): 753 """Computes NAN/Norm/Max on TPUs before sending to CPU. 754 755 Args: 756 tensor: The tensor to be traced. 757 Returns: 758 A tensor that should be input to the trace_function. 759 Raises: 760 RuntimeError: If the trace mode is invalid. 761 """ 762 763 def _detect_nan_inf(tensor): 764 """Trace function for detecting any NaN/Inf in the tensor.""" 765 766 if tensor.dtype.is_floating: 767 mask = math_ops.reduce_any( 768 gen_math_ops.logical_or( 769 gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) 770 output_tensor = control_flow_ops.cond( 771 mask, 772 lambda: constant_op.constant([1.0]), 773 lambda: constant_op.constant([0.0])) 774 else: 775 output_tensor = constant_op.constant([0.0]) 776 return output_tensor 777 778 def _compute_signature(tensor, tf_op, cast_to_f32=True): 779 if cast_to_f32: 780 tensor = math_ops.cast(tensor, dtypes.float32) 781 output_tensor = tf_op(tensor) 782 # Return type should be scalar. Set it if it does not have the 783 # information. 784 if not output_tensor.get_shape().is_fully_defined(): 785 output_tensor = array_ops.reshape(output_tensor, []) 786 return output_tensor 787 788 def _show_size(tensor): 789 # In order to check the size of a tensor. 790 # Not all sizes are known at the compile time, also, different replicas 791 # sometimes get different sizes of tensors. 792 # Collect it here to be used in merging replica data. 793 tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False) 794 # Cast to float32, so that it can be placed into same cache with other 795 # signatures. 796 return math_ops.cast(tsize, dtypes.float32) 797 798 def _show_max(tensor, cast_to_f32=True): 799 # returns -inf for empty tensor 800 return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32) 801 802 def _show_min(tensor, cast_to_f32=True): 803 # returns inf for empty tensor 804 return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32) 805 806 def _show_norm(tensor, cast_to_f32=True): 807 # returns 0 for empty tensor 808 return _compute_signature(tensor, linalg_ops.norm, cast_to_f32) 809 810 def _show_mean_and_variance(tensor, cast_to_f32=True): 811 """Returns the mean and variance of the given tensor.""" 812 if cast_to_f32: 813 tensor = math_ops.cast(tensor, dtypes.float32) 814 # returns nan for empty tensor 815 mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0]) 816 # The shape has to be 1. Set it if it does not have the information. 817 if not mean.get_shape().is_fully_defined(): 818 mean = array_ops.reshape(mean, []) 819 if not var.get_shape().is_fully_defined(): 820 var = array_ops.reshape(var, []) 821 return mean, var 822 823 def _show_max_abs(tensor, cast_to_f32=True): 824 return _compute_signature( 825 tensor, lambda t: math_ops.reduce_max(math_ops.abs(t)), cast_to_f32) 826 827 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: 828 return {self._parameters.trace_mode: _detect_nan_inf(tensor)} 829 if (self._parameters.trace_mode == 830 tensor_tracer_flags.TRACE_MODE_PART_TENSOR): 831 return {self._parameters.trace_mode: tensor} 832 if (self._parameters.trace_mode in ( 833 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, 834 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)): 835 return {self._parameters.trace_mode: tensor} 836 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM: 837 return {self._parameters.trace_mode: array_ops.reshape( 838 _show_norm(tensor), [1])} 839 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS: 840 return {self._parameters.trace_mode: _show_max_abs(tensor)} 841 842 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 843 tensor = math_ops.cast(tensor, dtypes.float32) 844 result_dict = {} 845 # Call mean and variance computation here to avoid adding the same nodes 846 # twice. 847 if (_TT_SUMMARY_MEAN in self._signature_types() or 848 _TT_SUMMARY_VAR in self._signature_types()): 849 mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False) 850 851 for signature_name, _ in sorted(self._signature_types().items(), 852 key=lambda x: x[1]): 853 if signature_name == _TT_SUMMARY_NORM: 854 signature_result_tensor = _show_norm(tensor, cast_to_f32=False) 855 elif signature_name == _TT_SUMMARY_MAX: 856 signature_result_tensor = _show_max(tensor, cast_to_f32=False) 857 elif signature_name == _TT_SUMMARY_MAX_ABS: 858 signature_result_tensor = _show_max_abs(tensor, cast_to_f32=False) 859 elif signature_name == _TT_SUMMARY_MIN: 860 signature_result_tensor = _show_min(tensor, cast_to_f32=False) 861 elif signature_name == _TT_SUMMARY_SIZE: 862 signature_result_tensor = _show_size(tensor) 863 elif signature_name == _TT_SUMMARY_MEAN: 864 signature_result_tensor = mean 865 elif signature_name == _TT_SUMMARY_VAR: 866 signature_result_tensor = variance 867 else: 868 raise ValueError('Unknown signature type :%s.' % signature_name) 869 870 result_dict[signature_name] = signature_result_tensor 871 return result_dict 872 873 raise RuntimeError( 874 'Tensor trace fun for %s is not yet implemented' 875 % self._parameters.trace_mode) 876 877 def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order): 878 """Makes the tensor tracing function called by outside compilation. 879 880 Args: 881 tensor_name: name of the tensor being traced. 882 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 883 Returns: 884 A function to be passed as the first argument to outside compilation. 885 886 Raises: 887 RuntimeError: If the trace mode is invalid. 888 """ 889 890 def _print_tensor(tensor_name, num_elements, tensor, output_tensor): 891 """Prints a tensor value to a file. 892 893 Args: 894 tensor_name: name of the tensor being traced. 895 num_elements: number of elements to print (-1 means print all). 896 tensor: the tensor needs to be returned. 897 output_tensor: the tensor needs to be printed. 898 899 Returns: 900 The same tensor passed via the "tensor" argument. 901 902 Raises: 903 ValueError: If tensor_name is not already in 904 tensor_trace_order.tensorname_to_cache_idx. 905 """ 906 907 if self._parameters.is_brief_mode(): 908 if tensor_name not in tensor_trace_order.tensorname_to_cache_idx: 909 raise ValueError( 910 'Tensor name %s is not in the tensorname_to_cache_idx' % 911 tensor_name) 912 msg = '%d' % tensor_trace_order.tensorname_to_cache_idx[tensor_name] 913 else: 914 msg = '"%s"' % tensor_name 915 916 if self._parameters.trace_dir: 917 output_path = os.path.join( 918 self._parameters.trace_dir, 919 _TRACE_FILE_NAME + self._get_outfile_suffix()) 920 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 921 else: 922 output_stream = sys.stderr 923 return logging_ops.print_v2(msg, array_ops.shape(output_tensor), 924 '@', self._replica_id, 925 '\n', output_tensor, '\n', 926 summarize=num_elements, 927 output_stream=output_stream) 928 929 def _show_part_tensor(tensor): 930 """Trace function for printing part of the tensor.""" 931 932 return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE, 933 tensor, tensor) 934 935 def _show_full_tensor(tensor): 936 """Trace function for printing the entire tensor.""" 937 938 return _print_tensor(tensor_name, -1, tensor, tensor) 939 940 if (self._parameters.trace_mode == 941 tensor_tracer_flags.TRACE_MODE_PART_TENSOR): 942 return _show_part_tensor 943 # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF, 944 # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are 945 # performed within TPUs and only their results are transferred to CPU. 946 # Simply, print the full tensor for these trace modes. 947 if self._parameters.trace_mode in ( 948 tensor_tracer_flags.TRACE_MODE_NAN_INF, 949 tensor_tracer_flags.TRACE_MODE_NORM, 950 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, 951 tensor_tracer_flags.TRACE_MODE_MAX_ABS, 952 tensor_tracer_flags.TRACE_MODE_SUMMARY 953 ): 954 return _show_full_tensor 955 956 raise RuntimeError('Tensor trace fun for %s is not yet implemented' 957 %self._parameters.trace_mode) 958 959 def _is_in_control_flow(self, op): 960 """Returns true if the given op is inside a tf.cond or in tf.while_loop. 961 962 Args: 963 op: A tensorflow op that should be checked whether in control flow or not. 964 Returns: 965 A boolean value whether the op is in control flow or not. 966 """ 967 return control_flow_util.IsInCond(op) 968 969 def _is_in_outmost_while_loop(self, op): 970 """Returns true if the op is at the same level with the training loop. 971 972 Returns false if the op is in an inner while loop or if it is outside of the 973 training loop. 974 Args: 975 op: tf.Operation 976 977 Returns: 978 A boolean. 979 """ 980 ctxt = self._get_op_control_flow_context(op) 981 outer_while_context = control_flow_util.GetContainingWhileContext(ctxt) 982 return outer_while_context == control_flow_util.GetContainingWhileContext( 983 self._outmost_context) 984 985 def _should_trace_in_control_flow(self): 986 """Returns false incase it is not safe to trace ops in tf.cond or tf.while_loop.""" 987 # As different from the other trace modes, TRACE_MODE_OPTIONAL_SUMMARY 988 # forces the execution of the traced tensors. We should not trace the ops 989 # that may not be executed due to control flow. 990 if self._use_temp_cache(): 991 return False 992 elif self._tt_config.device_type == _DEVICE_TYPE_TPU: 993 # On TPUs do not trace in control flow unless we use caches to store 994 # intermediate values as calling outside compilation within an inner loop 995 # causes errors. 996 return self._use_tensor_values_cache() or self._use_tensor_buffer() 997 return True 998 999 def _skip_op(self, op_id, op, ops_in_exec_path, report_handler): 1000 """Returns True if we should not trace Op. 1001 1002 Args: 1003 op_id: Topological index of the op. 1004 op: tf.Operation 1005 ops_in_exec_path: Set of operations that are in the execution path. 1006 report_handler: An instance of tensor_tracer_report.TTReportHandle. 1007 Returns: 1008 True if the op should not be traced, false otherwise. 1009 """ 1010 if TensorTracer.while_loop_op(op): 1011 report_handler.instrument_op( 1012 op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP)) 1013 return True 1014 if TensorTracer.control_flow_op(op): 1015 report_handler.instrument_op( 1016 op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP)) 1017 return True 1018 if TensorTracer.unsafe_op(op): 1019 report_handler.instrument_op( 1020 op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP)) 1021 return True 1022 if TensorTracer.device_mismatch(self._tt_config.device_type, op): 1023 report_handler.instrument_op( 1024 op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH)) 1025 return True 1026 if op not in ops_in_exec_path: 1027 report_handler.instrument_op( 1028 op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED)) 1029 return True 1030 # TensorTracer will not trace the operations that are in an inner while loop 1031 # or tf.cond when a temporary cache is used. Temporary cache adds direct 1032 # data dependencies to traced operations, and needs a static number of 1033 # traced operations. For these cases, 1034 # - We do not know the number of slots required when there are inner while 1035 # loops. TensorTracer can only trace the result of a while loop. 1036 # - We do not know ahead of time which branch of the tf.cond 1037 # will be taken, so we avoid introducing data dependencies for the 1038 # operations inside a tf.cond. 1039 # - We also cannot have a data dependency to an operation in a different 1040 # while context. 1041 if self._is_in_control_flow(op) or not self._is_in_outmost_while_loop(op): 1042 if not self._should_trace_in_control_flow(): 1043 report_handler.instrument_op( 1044 op, TensorTracer.reason(op_id, _REASON_IN_CONTROL_FLOW)) 1045 return True 1046 if self._is_user_included_op(op): 1047 report_handler.instrument_op( 1048 op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) 1049 return False 1050 1051 if not self._inside_op_range(op_id): 1052 report_handler.instrument_op( 1053 op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE)) 1054 return True 1055 if not self._is_interesting_op(op): 1056 report_handler.instrument_op( 1057 op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP)) 1058 return True 1059 if self._is_user_excluded_op(op): 1060 report_handler.instrument_op( 1061 op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) 1062 return True 1063 return False 1064 1065 def _skip_tensor(self, op_id, out_tensor, report_handler): 1066 """Returns True if we should not trace out_tensor. 1067 1068 Args: 1069 op_id: Topological index of the op producing tensor. 1070 out_tensor: tf.Tensor 1071 report_handler: An instance of tensor_tracer_report.TTReportHandle. 1072 Returns: 1073 True if the tensor should not be traced, false otherwise. 1074 """ 1075 1076 # Skips a tensor if the tensor has a non-numeric type. 1077 # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) 1078 # because it also excludes tensors with dtypes, bool, and 1079 # float32_ref, which we actually want to trace. 1080 non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, 1081 dtypes.string]) 1082 if out_tensor.dtype in non_numeric_tensor_types: 1083 1084 report_handler.instrument_tensor( 1085 out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR)) 1086 return True 1087 # Skip a tensor if it feeds a special while loop op. 1088 if [consumer for consumer in out_tensor.consumers() if 1089 TensorTracer.while_loop_op(consumer)]: 1090 report_handler.instrument_tensor( 1091 out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP)) 1092 return True 1093 if self._is_user_included_op(out_tensor.op): 1094 report_handler.instrument_tensor( 1095 out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) 1096 return False 1097 if self._is_user_excluded_op(out_tensor.op): 1098 report_handler.instrument_tensor( 1099 out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) 1100 return True 1101 if not out_tensor.get_shape().is_fully_defined(): 1102 # If trace mode is nan-inf, norm or max, then the tensor will be reduced 1103 # to a scalar before the outside compilation call. 1104 if self._parameters.trace_mode in ( 1105 tensor_tracer_flags.TRACE_MODE_NAN_INF, 1106 tensor_tracer_flags.TRACE_MODE_NORM, 1107 tensor_tracer_flags.TRACE_MODE_MAX_ABS, 1108 tensor_tracer_flags.TRACE_MODE_SUMMARY 1109 ): 1110 report_handler.instrument_tensor( 1111 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) 1112 return False 1113 else: 1114 report_handler.instrument_tensor( 1115 out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE)) 1116 return True 1117 rank = len(out_tensor.shape) 1118 if rank < 1: 1119 # scalar 1120 if self._parameters.trace_scalar_ops: 1121 if TensorTracer.unsafe_scalar_trace(out_tensor.op): 1122 report_handler.instrument_tensor( 1123 out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR)) 1124 return True 1125 else: 1126 report_handler.instrument_tensor( 1127 out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED)) 1128 return False 1129 else: 1130 report_handler.instrument_tensor( 1131 out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR)) 1132 return True 1133 else: 1134 # tensor 1135 report_handler.instrument_tensor( 1136 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) 1137 return False 1138 1139 def _filter_execution_path_operations(self, operations, fetches): 1140 """Returns the set of ops in the execution path to compute given fetches.""" 1141 1142 # If no fetch provided, then return all operations. 1143 if fetches is None: 1144 return set(operations) 1145 # Convert to list, if a single element is provided. 1146 if not isinstance(fetches, (list, tuple)): 1147 fetches = [fetches] 1148 # If a tensor is given as fetch, convert it to op. 1149 op_fetches = [] 1150 for fetch in fetches: 1151 if isinstance(fetch, ops.Operation): 1152 op_fetches.append(fetch) 1153 elif isinstance(fetch, ops.Tensor): 1154 op_fetches.append(fetch.op) 1155 else: 1156 raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' 1157 %fetch) 1158 1159 execution_path_operations = set(op_fetches) 1160 traverse_stack = list(op_fetches) 1161 while True: 1162 if not traverse_stack: 1163 break 1164 head_op = traverse_stack.pop() 1165 input_ops = [tensor_input.op for tensor_input in head_op.inputs] 1166 input_ops.extend(head_op.control_inputs) 1167 1168 for input_op in input_ops: 1169 if input_op not in execution_path_operations: 1170 # Filter out loop condition operations, tracing them causes a cycle. 1171 # Trace only the loop-body. 1172 if TensorTracer.loop_cond_op(input_op): 1173 continue 1174 execution_path_operations.add(input_op) 1175 traverse_stack.append(input_op) 1176 return execution_path_operations 1177 1178 def _determine_and_instrument_traced_tensors(self, graph_order, 1179 ops_in_exec_path, 1180 tensor_trace_points, 1181 report_handler): 1182 """Determines the tensors to trace and instruments the trace details. 1183 1184 Args: 1185 graph_order: graph_order tuple containing graph (tf.graph), operations 1186 (list of operations), op_to_idx (op id mapping), (tensors) list of 1187 tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether 1188 there is a cycle in the graph), topological_order_or_cycle (list of ops 1189 in topological order or list of ops creating a cycle). 1190 ops_in_exec_path: Set of ops in the execution path. 1191 tensor_trace_points: Collection of programatic tensor trace points. 1192 report_handler: An instance of tensor_tracer_report.TTReportHandle. 1193 Returns: 1194 List of tensors to be traced. 1195 """ 1196 1197 traced_tensors = [] 1198 checkpoint_operations = set([tensor.op 1199 for (tensor, _) in tensor_trace_points]) 1200 for op_id, op in enumerate(graph_order.operations): 1201 if checkpoint_operations and op not in checkpoint_operations: 1202 continue 1203 if self._skip_op(op_id, op, ops_in_exec_path, report_handler): 1204 continue 1205 for i in range(len(op.outputs)): 1206 out_tensor = op.outputs[i] 1207 if not self._skip_tensor(op_id, out_tensor, report_handler): 1208 traced_tensors.append(out_tensor) 1209 return traced_tensors 1210 1211 def _check_trace_files(self): 1212 """Checks if any requirements for trace files are satisfied.""" 1213 1214 if not self._parameters.trace_dir: 1215 # traces will be written to stderr. No need to check trace files. 1216 return 1217 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 1218 # Output files are handled by tf.summary operations, no need to precreate 1219 # them. 1220 return 1221 if not gfile.Exists(self._parameters.trace_dir): 1222 file_io.recursive_create_dir(self._parameters.trace_dir) 1223 if not gfile.Exists(self._parameters.trace_dir): 1224 raise RuntimeError('Failed to create %s'%self._parameters.trace_dir) 1225 1226 def _create_temp_cache(self, num_traced_tensors, num_signatures): 1227 """Creates a temporary cache with the given dimensions. 1228 1229 Fills the self._temp_cache_var with num_traced_tensors tf.constant() ops 1230 that have shape of [num_signatures]. 1231 Args: 1232 num_traced_tensors: Int, denoting total number of traced tensors. 1233 num_signatures: Int, denoting the number of statistics collected per 1234 tensors. 1235 """ 1236 init_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, 1237 dtype=dtypes.float32, 1238 shape=[num_signatures]) 1239 self._temp_cache_var = [init_value for _ in range(num_traced_tensors)] 1240 1241 def _determine_trace_and_create_report(self, graph, ops_in_exec_path): 1242 """Work needs to be done prior to TPU or CPU tracing. 1243 1244 Args: 1245 graph: tf.graph 1246 ops_in_exec_path: Set of operations in the execution path. 1247 Returns: 1248 An instance of tensor_tracer_report.TensorTraceOrder, containing list of 1249 tensors to be traced with their topological order information. 1250 """ 1251 1252 self._check_trace_files() 1253 1254 graph_order = tensor_tracer_report.sort_tensors_and_ops(graph) 1255 tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION) 1256 1257 report_handler = tensor_tracer_report.TTReportHandle() 1258 traced_tensors = self._determine_and_instrument_traced_tensors( 1259 graph_order, ops_in_exec_path, tensor_trace_points, report_handler) 1260 logging.info('TensorTracer is tracing %d tensors.', len(traced_tensors)) 1261 1262 tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order, 1263 traced_tensors) 1264 num_signatures = self._num_signature_dimensions() 1265 # Create a cache variable if compact_tracing is used. 1266 if num_signatures and self._use_tensor_values_cache(): 1267 if self._use_temp_cache(): 1268 self._create_temp_cache(len(traced_tensors), num_signatures) 1269 else: 1270 self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, 1271 graph, 1272 [len(traced_tensors), 1273 num_signatures]) 1274 if self._parameters.trace_mode in ( 1275 tensor_tracer_flags.TRACE_MODE_SUMMARY, 1276 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY): 1277 self._report_proto = report_handler.create_report_proto( 1278 self._tt_config, self._parameters, tensor_trace_order, 1279 tensor_trace_points, self._signature_types()) 1280 if self._parameters.use_fingerprint_subdir: 1281 self._parameters.trace_dir = os.path.join( 1282 self._parameters.trace_dir, self._report_proto.fingerprint) 1283 logging.info('TensorTracer updating trace_dir to %s', 1284 self._parameters.trace_dir) 1285 self._report_proto_path = tensor_tracer_report.report_proto_path( 1286 self._parameters.trace_dir) 1287 if self._parameters.report_file_path != _SKIP_REPORT_FILE: 1288 report_handler.write_report_proto(self._report_proto, self._parameters) 1289 else: 1290 report_handler.create_report(self._tt_config, self._parameters, 1291 tensor_trace_order, tensor_trace_points) 1292 return tensor_trace_order 1293 1294 def _create_host_call(self): 1295 return self._parameters.trace_mode in ( 1296 tensor_tracer_flags.TRACE_MODE_SUMMARY, 1297 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) 1298 1299 def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream, 1300 tensor_trace_order): 1301 """Generates a print operation to print trace inspection. 1302 1303 Args: 1304 cache: Tensor storing the trace results for the step. 1305 replica_id: Tensor storing the replica id of the running core. 1306 step_num: Step number. 1307 output_stream: Where to print the outputs, e.g., file path, or sys.stderr. 1308 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 1309 1310 Returns: 1311 The Op to flush the cache to file. 1312 """ 1313 def _inspect_tensor(tensor): 1314 """Returns the text to be printed for inspection output.""" 1315 if (self._parameters.trace_mode == 1316 tensor_tracer_flags.TRACE_MODE_NAN_INF): 1317 return control_flow_ops.cond( 1318 math_ops.greater(tensor, 0.0), 1319 lambda: 'has NaNs/Infs!', 1320 lambda: 'has no NaNs or Infs.') 1321 else: 1322 return tensor 1323 1324 # Check if the cache includes any nan or inf 1325 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: 1326 # Cache has 1s or 0s if the mode is NaN_INF 1327 step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0) 1328 else: 1329 # Cache has the actual numerics for other modes. 1330 step_has_nan_or_inf = math_ops.reduce_any( 1331 gen_math_ops.logical_or( 1332 gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache))) 1333 1334 # Summarizing message for each step. 1335 step_error_message = control_flow_ops.cond( 1336 step_has_nan_or_inf, 1337 lambda: 'NaNs or Infs in the step!', 1338 lambda: 'No numerical issues have been found for the step.') 1339 1340 # No need to print core numbers if the cache is merged already. 1341 if self._parameters.collect_summary_per_core: 1342 stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->', 1343 step_error_message, 1344 'Printing tensors for mode:%s...' % self._parameters.trace_mode] 1345 else: 1346 stats = ['\n\n', 'step:', step_num, '-->', step_error_message, 1347 'Printing tensors for mode:%s...' % self._parameters.trace_mode] 1348 1349 for tensor_name, cache_idx in sorted( 1350 tensor_trace_order.tensorname_to_cache_idx.items(), 1351 key=lambda item: item[1]): 1352 if self._parameters.collect_summary_per_core: 1353 stats.extend([ 1354 '\n', 'core:', replica_id, ',', 'step:', step_num, ',', 1355 tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) 1356 else: 1357 stats.extend([ 1358 '\n', 'step:', step_num, ',', 1359 tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])]) 1360 return logging_ops.print_v2(*stats, summarize=-1, 1361 output_stream=output_stream) 1362 1363 def _get_outfile_suffix(self): 1364 if remote_utils.is_remote_path(self._parameters.trace_dir): 1365 return remote_utils.get_appendable_file_encoding() 1366 else: 1367 return '' 1368 1369 def _generate_flush_cache_op(self, num_replicas, on_tpu, tensor_trace_order): 1370 """Generates an Op that will flush the cache to file. 1371 1372 Args: 1373 num_replicas: total number of replicas. 1374 on_tpu: if the graph is executed on TPU. 1375 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 1376 1377 Returns: 1378 The Op to flush the cache to file. 1379 """ 1380 1381 def _flush_fun(cache, replica_id, step_num): 1382 """Flushes the cache to a file corresponding to replica_id.""" 1383 1384 def _f(file_index): 1385 """Generates a func that flushes the cache to a file.""" 1386 def _print_cache(): 1387 """Flushes the cache to a file.""" 1388 replica_str = ('%d' % file_index) 1389 if self._parameters.trace_dir: 1390 output_path = (os.path.join(self._parameters.trace_dir, 1391 _COMPACT_TRACE_FILE_PREFIX) 1392 + replica_str + self._get_outfile_suffix()) 1393 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 1394 else: 1395 output_stream = sys.stderr 1396 1397 new_step_line = _REPLICA_ID_TAG + replica_str 1398 print_ops = [] 1399 if self._parameters.inspect_trace: 1400 if self._num_signature_dimensions() > 1: 1401 raise ValueError('Inspecting multi signatures are not supported.') 1402 print_ops.append(self._inspect_summary_cache( 1403 cache=cache, replica_id=replica_id, step_num=step_num, 1404 output_stream=output_stream, 1405 tensor_trace_order=tensor_trace_order)) 1406 else: 1407 for i in range(self._num_signature_dimensions()): 1408 print_ops.append(logging_ops.print_v2( 1409 new_step_line, '\n', 1410 cache[:, i], '\n', 1411 summarize=-1, 1412 output_stream=output_stream)) 1413 with ops.control_dependencies(print_ops): 1414 return constant_op.constant(0).op 1415 return _print_cache 1416 1417 def _eq(file_index): 1418 return math_ops.equal(replica_id, file_index) 1419 1420 flush_op_cases = {} 1421 flush_op_cases[_eq(0)] = _f(0) 1422 for i in range(1, num_replicas): 1423 if on_tpu and not self._parameters.collect_summary_per_core: 1424 # If this is the case, the cache is already merged for all cores. 1425 # Only first core flushes the cache. 1426 flush_op_cases[_eq(i)] = control_flow_ops.no_op 1427 else: 1428 flush_op_cases[_eq(i)] = _f(i) 1429 # Each replica needs to determine where to write their output. 1430 # To do this, we check if replica_id is 0, then 1, ..., and then 1431 # num_replicas - 1 statically; and return the corresponding static file 1432 # name. We cannot simply set the file name in python, as replica_id is 1433 # only known during tf runtime, and we cannot create dynamic filenames. 1434 return control_flow_ops.case(flush_op_cases, exclusive=True) 1435 1436 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG) 1437 if self._use_temp_cache(): 1438 cache_val = cache 1439 else: 1440 cache_val = cache.value() 1441 1442 if on_tpu: 1443 # If we do not need to collect traces for all cores, merge and aggregate 1444 # per core trace. 1445 if not self._parameters.collect_summary_per_core: 1446 cache_val = self.merge_caches_on_tpu(cache_val) 1447 cache_val = self.aggregate_global_cache(cache_val)[0] 1448 1449 flush_op = tpu.outside_compilation( 1450 _flush_fun, cache_val, self._replica_id, 1451 array_ops.identity(training_util.get_or_create_global_step())) 1452 else: 1453 flush_op = _flush_fun(cache_val, self._replica_id, 1454 training_util.get_or_create_global_step()) 1455 if self._use_temp_cache(): 1456 with ops.control_dependencies([flush_op]): 1457 return constant_op.constant(0).op 1458 else: 1459 # Re-initialize the local cache variable. 1460 with ops.control_dependencies([flush_op]): 1461 reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, 1462 dtype=cache.dtype, 1463 shape=cache.shape) 1464 assign_op = state_ops.assign(cache, reset_value).op 1465 with ops.control_dependencies([assign_op]): 1466 return constant_op.constant(0).op 1467 1468 def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu, 1469 tensor_trace_order): 1470 """Flushes the intermediate tensor values in the graph to the cache. 1471 1472 Args: 1473 tensor_fetches: list of tensor results returned by the model_fn. 1474 op_fetches: list of ops that are returned by the model_fn, e.g., train_op. 1475 on_tpu: if the graph is executed on TPU. 1476 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 1477 1478 Returns: 1479 An identical copy of tensor_fetches. 1480 """ 1481 # Add a dependency to op and tensor fetches to make sure that all tracing 1482 # ops are executed before flushing trace results. 1483 with ops.control_dependencies(op_fetches + 1484 [tensor.op for tensor in tensor_fetches]): 1485 flush_cache_op = self._generate_flush_cache_op( 1486 self._tt_config.num_replicas, on_tpu, tensor_trace_order) 1487 return control_flow_ops.tuple(tensor_fetches, 1488 control_inputs=[flush_cache_op]) 1489 1490 def _process_tensor_fetches(self, tensor_fetches): 1491 """Check that tensor_fetches is not empty and have valid tensors.""" 1492 # If none or empty list. 1493 if tensor_fetches is None: 1494 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 1495 'None.') 1496 if not isinstance(tensor_fetches, (list, tuple)): 1497 tensor_fetches = [tensor_fetches] 1498 elif not tensor_fetches: 1499 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 1500 'empty list.') 1501 fetches = [] 1502 for fetch in tensor_fetches: 1503 if isinstance(fetch, ops.Tensor): 1504 fetches.append(fetch) 1505 else: 1506 raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch) 1507 return fetches 1508 1509 def _process_op_fetches(self, op_fetches): 1510 """Check that op_fetches have valid ops.""" 1511 if op_fetches is None: 1512 return [] 1513 1514 if not isinstance(op_fetches, (list, tuple)): 1515 op_fetches = [op_fetches] 1516 1517 fetches = [] 1518 for fetch in op_fetches: 1519 if isinstance(fetch, ops.Operation): 1520 fetches.append(fetch) 1521 elif isinstance(fetch, ops.Tensor): 1522 fetches.append(fetch.op) 1523 else: 1524 logging.warning('Ignoring the given op_fetch:%s, which is not an op.' % 1525 fetch) 1526 return fetches 1527 1528 def _convert_fetches_to_input_format(self, input_fetches, current_fetches): 1529 """Changes current_fetches' format, so that it matches input_fetches.""" 1530 if isinstance(input_fetches, ops.Tensor): 1531 if len(current_fetches) != 1: 1532 raise RuntimeError('Tensor tracer input/output fetches do not match.') 1533 return current_fetches[0] 1534 else: 1535 if len(current_fetches) != len(current_fetches): 1536 raise RuntimeError('Tensor tracer input/output fetches do not match.') 1537 elif isinstance(input_fetches, tuple): 1538 return tuple(current_fetches) 1539 else: 1540 return current_fetches 1541 1542 def _get_op_control_flow_context(self, op): 1543 """Returns the control flow of the given op. 1544 1545 Args: 1546 op: tf.Operation for which the control flow context is requested. 1547 Returns: 1548 op_control_flow_context: which the is control flow context of the given 1549 op. If the operation type is LoopExit, returns the outer control flow 1550 context. 1551 """ 1552 # pylint: disable=protected-access 1553 op_control_flow_context = op._control_flow_context 1554 # pylint: enable=protected-access 1555 if control_flow_util.IsLoopExit(op): 1556 op_control_flow_context = op_control_flow_context.outer_context 1557 return op_control_flow_context 1558 1559 def merge_caches_on_tpu(self, local_tpu_cache_tensor): 1560 """Merges the given caches on tpu. 1561 1562 Args: 1563 local_tpu_cache_tensor: A local tensor that needs to be merged 1564 by concanting data from other tpu cores. 1565 Returns: 1566 A merged tf.Tensor. 1567 Raises: 1568 RuntimeError: if there is no aggregate function defined for a signature. 1569 """ 1570 x = array_ops.broadcast_to( 1571 local_tpu_cache_tensor, 1572 shape=[self._tt_config.num_replicas] + 1573 local_tpu_cache_tensor.shape.as_list()) 1574 return tpu_ops.all_to_all( 1575 x, concat_dimension=0, split_dimension=0, 1576 split_count=self._tt_config.num_replicas) 1577 1578 def aggregate_global_cache(self, global_tt_summary_cache): 1579 """Merges the given caches on tpu. 1580 1581 Args: 1582 global_tt_summary_cache: The global tensor tracer summary cache tensor 1583 with shape (num_cores, num_traced_tensors, num_traced_signatures). First 1584 dimension corresponds to core_id, where global_tpu_cache_tensor[i] 1585 correspond to the local cache from core-i. 1586 Returns: 1587 An aggregated tf.Tensor. 1588 Raises: 1589 RuntimeError: if there is no aggregate function defined for a signature. 1590 """ 1591 1592 # Merge only statistics tensor, if it is any other tensor we simply, 1593 # concatenate them. 1594 agg_fn_map = self._parameters.get_signature_to_agg_fn_map() 1595 signature_idx_map = self._signature_types() 1596 aggregation_result = [] 1597 for signature, idx in sorted(signature_idx_map.items(), 1598 key=operator.itemgetter(1)): 1599 if signature not in agg_fn_map: 1600 raise RuntimeError('No aggregation function is defined for ' 1601 'signature %s.' % signature) 1602 # The dimensions of the statistics tensor is 1603 # num_cores x num_traced_tensors x num_signatures 1604 # value[:,:,idx] will return the portion of the tensor related 1605 # to signature. 1606 signature_tensor = global_tt_summary_cache[:, :, idx] 1607 # Merge it along the first (core) axis. 1608 agg_fn = agg_fn_map[signature] 1609 agg_tensor = agg_fn(signature_tensor, axis=0) 1610 aggregation_result.append(agg_tensor) 1611 # Merge results corresponding to different signatures 1612 1613 merged_signatures = array_ops.stack(aggregation_result) 1614 # merged_signatures has dimensions 1615 # num_signatures x num_traced_tensors, transpose it so that it 1616 # will match with the original structure 1617 # num_traced_tensors x num_signatures. 1618 transposed_signatures = array_ops.transpose(merged_signatures) 1619 # Expand 1 more dimension so that it will match with the expected 1620 # structure num_cores x num_traced_tensors x num_signatures. 1621 return array_ops.expand_dims(transposed_signatures, axis=0) 1622 1623 def _prepare_host_call_fn(self, processed_t_fetches, op_fetches): 1624 """Creates a host call function that will write the cache as tb summary. 1625 1626 Args: 1627 processed_t_fetches: List of tensor provided to session.run. 1628 op_fetches: List of operations provided to session.run. 1629 Raises: 1630 ValueError if trace_dir is not set. 1631 """ 1632 if self._parameters.trace_dir is None: 1633 raise ValueError('Provide a trace_dir for tensor tracer in summary mode. ' 1634 '--trace_dir=/model/dir') 1635 1636 def _write_cache(step, event_file_suffix=None, **kwargs): 1637 """Writes the given caches as tensor summary. 1638 1639 Args: 1640 step: Step tensor with dimension [num_cores]. 1641 event_file_suffix: Event filename suffix tensor. 1642 **kwargs: The dictionary of tensors that needs to be written as 1643 summaries. Key and value pairs within kwargs correspond to the tag 1644 name, and tensor content that will be written using summary.write. 1645 The trace_modes that use this function are: 1646 - summary: In summary mode, kwargs includes a single (tag, content) 1647 pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache 1648 variable. The dimension of the signature_cache is: 1649 num_cores x num_traced_tensors x num_signatures. 1650 - full_tensor_summary: kwargs will include all traced tensors. Tag 1651 and content correspond to the name of the tensor, and its actual 1652 content. 1653 Returns: 1654 A tf.Operation that needs to be executed for the host call dependencies. 1655 Raises: 1656 RuntimeError: if there is no aggregate function defined for a signature. 1657 """ 1658 file_suffix = _TT_EVENT_FILE_SUFFIX 1659 if event_file_suffix is not None: 1660 file_suffix = string_ops.string_join([file_suffix, event_file_suffix], 1661 separator='.') 1662 # TODO(deveci): Parametrize max_queue, so that flushing op can be called 1663 # less frequently. 1664 # Setting max_queue to 100 appears to be safe even when the number of 1665 # iterations are much lower, as the destructor of the writer flushes it. 1666 summary_write_ops = [] 1667 summary_writer = summary.create_file_writer_v2( 1668 self._parameters.trace_dir, 1669 filename_suffix=file_suffix, 1670 max_queue=_TT_SUMMARY_MAX_QUEUE) 1671 ops.get_default_graph().add_to_collection( 1672 TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer) 1673 with summary_writer.as_default(): 1674 summary_metadata = summary_pb2.SummaryMetadata( 1675 plugin_data=summary_pb2.SummaryMetadata.PluginData( 1676 plugin_name=_TT_TENSORBOARD_PLUGIN_NAME)) 1677 for key, value in kwargs.items(): 1678 # Check whether we need to compute aggregated statistics that merge 1679 # all cores statistics. 1680 if not self._parameters.collect_summary_per_core: 1681 # Merge only statistics tensor, if it is any other tensor we simply, 1682 # concatenate them. 1683 # Also, if there is only a single core (first dim. is 0), then skip 1684 # aggregation. 1685 if key == _TT_SUMMARY_TAG and value.shape.as_list()[0] != 1: 1686 value = self.aggregate_global_cache(value) 1687 1688 with ops.control_dependencies([summary_writer.init()]): 1689 summary_write_ops.append(summary.write( 1690 _TT_SUMMARY_TAG + '/' + key, value, metadata=summary_metadata, 1691 step=step[0])) 1692 return control_flow_ops.group(summary_write_ops) 1693 1694 step = array_ops.reshape(training_util.get_or_create_global_step(), [1]) 1695 self._host_call_fn = {} 1696 1697 host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches] 1698 1699 caches_to_write = {} 1700 with ops.control_dependencies(host_call_deps): 1701 all_caches = self._get_all_cache_variables() 1702 for cache_name, cache_variable in all_caches.items(): 1703 # Increase the cache rank by 1, so that when host call concatenates 1704 # tensors from different replicas, we can identify them with [core_id]. 1705 new_cache_shape = [1] 1706 new_cache_shape.extend(cache_variable.shape.as_list()) 1707 cache = array_ops.reshape(cache_variable, new_cache_shape) 1708 caches_to_write[cache_name] = cache 1709 # Add step to parameter dictionary. 1710 caches_to_write['step'] = step 1711 # Other options without adding step to parameter dictionary are 1712 # * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it 1713 # considers caches_to_write as a single parameter, rather than a keyword 1714 # parameters. 1715 # * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with 1716 # a syntax error. 1717 self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write) 1718 1719 def host_call_deps_and_fn(self): 1720 return self._host_call_fn 1721 1722 def get_traced_op_names(self): 1723 """Returns the set of traced op names.""" 1724 return self._traced_op_names 1725 1726 def _trace_execution(self, graph, 1727 tensor_fetches, 1728 op_fetches=None, 1729 on_tpu=True): 1730 """Commong tracing function for both CPU and TPUs. 1731 1732 The caller function should set device_type, num_replicas, 1733 num_replicas_per_host, num_hosts and replica_id before calling 1734 _trace_execution. 1735 1736 1737 Args: 1738 graph: the graph of Ops executed on the TPU. 1739 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 1740 returned by model_fn given to session.run. Function must be provided 1741 with as least one tensor to fetch. 1742 op_fetches: A list of op fetches returned by model_fn given to 1743 session.run. op_fetches and tensor_fetches are used to determine the 1744 nodes that will be executed. Can be None. 1745 on_tpu: True if executing on TPU. 1746 1747 Returns: 1748 tensor_fetches: an exact copy of tensor_fetches that has additional 1749 dependencies. 1750 Raises: 1751 RuntimeError: If tensor_fetches is None or empty. 1752 """ 1753 def _cast_unsupported_dtypes(tensor): 1754 """Casts tensor to a supported type.""" 1755 1756 if tensor.dtype.__eq__(dtypes.int64): 1757 # outside-compilation doesn't support int64 input yet. 1758 return math_ops.cast(tensor, dtypes.int32) 1759 if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( 1760 dtypes.float16): 1761 # Since host can't handle bf16, convert tensor to f32. 1762 return math_ops.cast(tensor, dtypes.float32) 1763 return tensor 1764 1765 trace_mode = self._parameters.trace_mode 1766 device_type = self._tt_config.device_type 1767 # pylint: disable=protected-access 1768 self._outmost_context = graph._get_control_flow_context() 1769 # pylint: enable=protected-access 1770 1771 analytics.track_usage('tensor_tracer', [trace_mode, device_type]) 1772 TensorTracer.check_device_type(device_type) 1773 TensorTracer.check_trace_mode(device_type, trace_mode) 1774 # Check in_tensor_fetches, and op_fetches and convert them to lists. 1775 processed_t_fetches = self._process_tensor_fetches(tensor_fetches) 1776 op_fetches = self._process_op_fetches(op_fetches) 1777 all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches] 1778 1779 # Filter out the operations that won't be executed. 1780 # if fetches=None, then ops_in_exec_path = set(operations) 1781 exec_op_set = self._filter_execution_path_operations(graph.get_operations(), 1782 all_fetches) 1783 # Write report file, and determine the traced tensors. 1784 tensor_trace_order = self._determine_trace_and_create_report( 1785 graph, exec_op_set) 1786 1787 tensor_fetch_set = set(processed_t_fetches) 1788 tracing_ops = [] 1789 1790 sorted_exec_op_list = list(exec_op_set) 1791 sorted_exec_op_list.sort(key=lambda op: op.name) 1792 # Trace ops only if they are in the execution path. 1793 for op in sorted_exec_op_list: 1794 for i in range(len(op.outputs)): 1795 out_tensor = op.outputs[i] 1796 tensor_name = out_tensor.name 1797 if tensor_name not in tensor_trace_order.tensorname_to_cache_idx: 1798 continue 1799 self._traced_op_names.add(op.name) 1800 # Create the list of consumers before calling _preprocess_traced_tensor. 1801 # Otherwise, adding control input below, will introduce a cycle in the 1802 # graph. 1803 consumers = out_tensor.consumers() 1804 # Not all consumers may be in the exec path. Filter out the consumers 1805 # to keep the graph simpler. 1806 consumers = [cop for cop in consumers if cop in exec_op_set] 1807 1808 # If there is no consumer of the tensor, there is no need to trace it; 1809 # unless the tensor itself is one of the fetches. 1810 is_a_fetched_tensor = out_tensor in tensor_fetch_set 1811 if (not consumers) and (not is_a_fetched_tensor): 1812 continue 1813 1814 op_control_flow_context = self._get_op_control_flow_context(op) 1815 if op_control_flow_context: 1816 # pylint: disable=protected-access 1817 graph._set_control_flow_context(op_control_flow_context) 1818 # pylint: enable=protected-access 1819 1820 processed_tensors = self._preprocess_traced_tensor(out_tensor) 1821 1822 if on_tpu: 1823 for signature in processed_tensors.keys(): 1824 processed_tensors[signature] = _cast_unsupported_dtypes( 1825 processed_tensors[signature]) 1826 1827 if self._use_tensor_values_cache(): 1828 # Use a small cache (either temp cache or tf local variable) to store 1829 # the characteristics of the tensor. 1830 if self._use_temp_cache(): 1831 cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name] 1832 self._save_tensor_value_to_tmp_cache(cache_idx, processed_tensors) 1833 trace_op = None 1834 else: 1835 cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name] 1836 trace_op = self._save_tensor_value_to_cache_op(cache_idx, 1837 processed_tensors) 1838 elif self._use_tensor_buffer(): 1839 if len(processed_tensors) != 1: 1840 raise RuntimeError('Multiple stats are only allowed in compact ' 1841 'mode.') 1842 processed_out_tensor = list(processed_tensors.values())[0] 1843 # Store the whole tensor in a buffer. 1844 trace_op = self._snapshot_tensor(processed_out_tensor) 1845 else: 1846 1847 def tpu_wrap_trace_fn(tensor, out_tensor_name): 1848 """Wraps the trace_fn with outside compilation if on TPUs.""" 1849 tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name, 1850 tensor_trace_order) 1851 if on_tpu: 1852 return tpu.outside_compilation(tensor_trace_fn, tensor) 1853 else: 1854 return tensor_trace_fn(tensor) 1855 1856 if len(processed_tensors) != 1: 1857 raise RuntimeError('Multiple stats are only allowed in compact ' 1858 'mode.') 1859 # Collecting multiple statistics are only supported in the summary 1860 # mode that uses compact format(self._use_tensor_values_cache = true). 1861 # Non-compact mode currently allows single stat per tensor. 1862 processed_out_tensor = six.next(six.itervalues(processed_tensors)) 1863 trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name) 1864 1865 if op_control_flow_context: 1866 # pylint: disable=protected-access 1867 graph._set_control_flow_context(self._outmost_context) 1868 # pylint: enable=protected-access 1869 if trace_op: 1870 if is_a_fetched_tensor: 1871 tracing_ops.append(trace_op) 1872 continue 1873 # Add it to all consumers, as some consumers may not be executed if 1874 # they are in a control flow. 1875 for consumer_op in consumers: 1876 # pylint: disable=protected-access 1877 consumer_op._add_control_input(trace_op) 1878 # pylint: enable=protected-access 1879 1880 # pylint: disable=protected-access 1881 graph._set_control_flow_context(self._outmost_context) 1882 # pylint: enable=protected-access 1883 if tracing_ops: 1884 # If we are tracing a fetched tensor, their dependency is stored in 1885 # tracing_ops. 1886 processed_t_fetches = control_flow_ops.tuple(processed_t_fetches, 1887 control_inputs=tracing_ops) 1888 if self._use_tensor_values_cache() or self._use_tensor_buffer(): 1889 if self._use_temp_cache(): 1890 # Create the temporary tf cache variable by concantanating all 1891 # statistics. 1892 self._cache_variables[_TT_SUMMARY_TAG] = array_ops.stack( 1893 self._temp_cache_var, axis=0, name='stack_all_op_signatures') 1894 if self._create_host_call(): 1895 self._prepare_host_call_fn(processed_t_fetches, op_fetches) 1896 if not on_tpu: 1897 write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY] 1898 cache_write_op = write_cache(**caches_to_write) 1899 processed_t_fetches = control_flow_ops.tuple( 1900 processed_t_fetches, control_inputs=[cache_write_op]) 1901 del self._host_call_fn[_TT_HOSTCALL_KEY] 1902 elif self._parameters.flush_summaries_with_outside_compile: 1903 write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY] 1904 if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write): 1905 step = caches_to_write['step'] 1906 tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG] 1907 tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0]) 1908 if not self._parameters.collect_summary_per_core: 1909 tt_core_summary = self.aggregate_global_cache(tt_core_summary) 1910 1911 def write_if_core_0(step, replica_id, tt_summary): 1912 1913 return control_flow_ops.cond( 1914 math_ops.equal(replica_id, 0), 1915 lambda: write_cache(step=step, event_file_suffix=None, # pylint: disable=g-long-lambda 1916 tensor_tracer_summary=tt_summary), 1917 control_flow_ops.no_op) 1918 1919 write_op = tpu.outside_compilation(write_if_core_0, step=step, 1920 replica_id=self._replica_id, 1921 tt_summary=tt_core_summary) 1922 processed_t_fetches = control_flow_ops.tuple( 1923 processed_t_fetches, control_inputs=[write_op]) 1924 del self._host_call_fn[_TT_HOSTCALL_KEY] 1925 else: 1926 raise ValueError('Outside compiled flush in only supported for ' 1927 'summary mode') 1928 else: 1929 processed_t_fetches = self._flush_tensor_values_cache( 1930 processed_t_fetches, op_fetches, on_tpu=on_tpu, 1931 tensor_trace_order=tensor_trace_order) 1932 1933 # processed_t_fetches is a list at this point. Convert it to the same 1934 # format as given in tensor_fetches. 1935 return self._convert_fetches_to_input_format(tensor_fetches, 1936 processed_t_fetches) 1937 1938 def trace_tpu(self, graph, 1939 tensor_fetches, 1940 op_fetches=None, 1941 num_replicas=None, 1942 num_replicas_per_host=None, 1943 num_hosts=None): 1944 """Traces the tensors generated by TPU Ops in a TF graph. 1945 1946 Args: 1947 graph: the graph of Ops executed on the TPU. 1948 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 1949 returned by model_fn given to session.run. Function must be provided 1950 with as least one tensor to fetch. 1951 op_fetches: A list of op fetches returned by model_fn given to 1952 session.run. op_fetches and tensor_fetches are used to determine the 1953 nodes that will be executed. Can be None. 1954 num_replicas: number of replicas used on the TPU. 1955 num_replicas_per_host: number of replicas per TPU host. 1956 num_hosts: total number of TPU hosts. 1957 1958 Returns: 1959 tensor_fetches: an exact copy of tensor_fetches that has additional 1960 dependencies. 1961 Raises: 1962 RuntimeError: If num_replicas_per_host > 8. 1963 RuntimeError: If tensor_fetches is None or empty. 1964 """ 1965 if isinstance(graph, func_graph.FuncGraph) or isinstance( 1966 graph, function._FuncGraph): # pylint: disable=protected-access 1967 logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. ' 1968 'Ignoring tracing.') 1969 return tensor_fetches 1970 1971 if graph in TensorTracer._traced_graphs: 1972 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 1973 'multiple calls.') 1974 return tensor_fetches 1975 else: 1976 TensorTracer._traced_graphs.add(graph) 1977 # Reset the parameters in case parameters are changed. 1978 self._parameters = tensor_tracer_flags.TTParameters() 1979 self._tt_config.device_type = _DEVICE_TYPE_TPU 1980 self._tt_config.num_replicas = num_replicas 1981 self._tt_config.num_replicas_per_host = num_replicas_per_host 1982 self._tt_config.num_hosts = num_hosts 1983 if self._tt_config.num_replicas is not None: 1984 if self._tt_config.num_replicas_per_host is None: 1985 self._tt_config.num_replicas_per_host = 8 1986 if self._tt_config.num_hosts is None: 1987 self._tt_config.num_hosts = ( 1988 num_replicas // self._tt_config.num_replicas_per_host + 1989 (num_replicas % self._tt_config.num_replicas_per_host > 0)) 1990 1991 if self._parameters.graph_dump_path: 1992 graph_io.write_graph(graph, self._parameters.graph_dump_path, 1993 'graph_before_tt.pbtxt') 1994 with graph.as_default(): 1995 self._add_replica_id_to_graph() 1996 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 1997 on_tpu=True) 1998 if self._parameters.graph_dump_path: 1999 graph_io.write_graph(graph, self._parameters.graph_dump_path, 2000 'graph_after_tt.pbtxt') 2001 return tensor_fetches 2002 2003 def trace_cpu(self, graph, tensor_fetches, op_fetches=None): 2004 """Traces the tensors generated by CPU Ops in a TF graph. 2005 2006 Args: 2007 graph: the graph of Ops executed on the CPU. 2008 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 2009 returned by model_fn given to session.run. Function must be provided 2010 with as least one tensor to fetch. 2011 op_fetches: A list of op fetches returned by model_fn given to 2012 session.run. op_fetches and tensor_fetches are used to determine the 2013 nodes that will be executed. Can be None. 2014 2015 Returns: 2016 tensor_fetches: an exact copy of tensor_fetches that has additional 2017 dependencies. 2018 Raises: 2019 RuntimeError: If tensor_fetches is None or empty. 2020 """ 2021 if isinstance(graph, func_graph.FuncGraph) or isinstance( 2022 graph, function._FuncGraph): # pylint: disable=protected-access 2023 logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. ' 2024 'Ignoring tracing.') 2025 return tensor_fetches 2026 2027 if graph in TensorTracer._traced_graphs: 2028 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 2029 'multiple calls.') 2030 return tensor_fetches 2031 else: 2032 TensorTracer._traced_graphs.add(graph) 2033 # Reset the parameters in case parameters are changed. 2034 self._parameters = tensor_tracer_flags.TTParameters() 2035 2036 self._tt_config.device_type = _DEVICE_TYPE_CPU 2037 self._tt_config.num_replicas = 1 2038 self._tt_config.num_replicas_per_host = 1 2039 self._tt_config.num_hosts = 1 2040 self._replica_id = 0 2041 if self._parameters.graph_dump_path: 2042 graph_io.write_graph(graph, self._parameters.graph_dump_path, 2043 'graph_before_tt.pbtxt') 2044 with graph.as_default(): 2045 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 2046 on_tpu=False) 2047 if self._parameters.graph_dump_path: 2048 graph_io.write_graph(graph, self._parameters.graph_dump_path, 2049 'graph_after_tt.pbtxt') 2050 return tensor_fetches 2051