1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ======================================================================== 15"""A utility to trace tensor values on TPU.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import operator 22 23import os 24import os.path 25import sys 26 27import numpy as np 28import six 29 30from tensorflow.core.framework import summary_pb2 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import graph_io 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.lib.io import file_io 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_ops 39from tensorflow.python.ops import control_flow_util 40from tensorflow.python.ops import gen_math_ops 41from tensorflow.python.ops import init_ops 42from tensorflow.python.ops import linalg_ops 43from tensorflow.python.ops import logging_ops 44from tensorflow.python.ops import math_ops 45from tensorflow.python.ops import nn_impl 46from tensorflow.python.ops import state_ops 47from tensorflow.python.ops import summary_ops_v2 as summary 48from tensorflow.python.ops import variable_scope 49from tensorflow.python.platform import analytics 50from tensorflow.python.platform import gfile 51from tensorflow.python.platform import tf_logging as logging 52from tensorflow.python.summary import summary_iterator 53from tensorflow.python.tpu import tensor_tracer_flags 54from tensorflow.python.tpu import tensor_tracer_report 55from tensorflow.python.tpu import tpu 56from tensorflow.python.tpu.ops import tpu_ops 57from tensorflow.python.training import training_util 58 59_DEVICE_TYPE_TPU = 'tpu' 60_DEVICE_TYPE_CPU = 'cpu' 61_TRACE_MODE_PART_TENSOR_SIZE = 3 62 63_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' 64_REASON_UNSAFE_OP = 'not-traced-unsafe-op' 65_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op' 66_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op' 67_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' 68_REASON_SKIP_SCALAR = 'not-traced-scalar' 69_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' 70_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' 71_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' 72_REASON_SCALAR_GET_TRACED = 'traced-scalar' 73_REASON_TENSOR_GET_TRACED = 'traced-tensor' 74_REASON_USER_INCLUDED = 'traced-user-included' 75_REASON_USER_EXCLUDED = 'not-traced-user-excluded' 76_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' 77_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' 78_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op' 79 80_OUTPUT_STREAM_ESCAPE = 'file://' 81_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' 82_TRACE_FILE_NAME = 'trace.all' 83_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.' 84_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0 85_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage' 86_TT_SNAPSHOT = 'tensor_tracer_snapshot' 87_REPLICA_ID_TAG = '#replica-id: ' 88 89_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM 90_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX 91_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN 92_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN 93_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR 94_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE 95 96_TT_SUMMARY_TAG = 'tensor_tracer_summary' 97_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer' 98_TT_HOSTCALL_KEY = 'tensor_tracer_host_call' 99_TT_EVENT_FILE_SUFFIX = '.tensor_tracer' 100 101_TT_SUMMARY_MAX_QUEUE = 100 102 103 104def op_priority(op_type): 105 """Returns the priority of the op. 106 107 If the priority of the op is k, it will be traced if trace_level>=k. 108 Args: 109 op_type: String name of the operation type. 110 Returns: 111 Integer value corresponding the priority of the op. 112 """ 113 if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range', 114 'VariableShape', 'Fill', 'OneHot', 'ShapeN'): 115 # Lowest priority ops, e.g., constant ops accross different steps, 116 # They will be traced only if trace_level>=7 117 return 7 118 119 if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient', 120 'PreventGradient', 'Squeeze'): 121 # Operations without numerical effects. 122 # They will be only if trace_level>=6 123 return 6 124 if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile', 125 'CollectivePermute', 'SplitV'): 126 # Operations that merge or slice an input, will be traced if trace_level>=5 127 return 5 128 if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'): 129 # Operations less likely to provide useful information, 130 # will be traced if trace_level>=4 131 return 4 132 if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'): 133 # Add operations that are less likely create any issues, will be traced 134 # if trace_level>=3 (default=3) 135 return 3 136 if op_type in ('Neg', 'Sub'): 137 # Sub operations that are less likely create any issues, will be traced 138 # trace_level>=2 139 return 2 140 if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select', 141 'Maximum', 'Mean', 'Variance'): 142 # Multiplication and some other operations, will be traced if trace_level>=1 143 return 1 144 return 0 145 146 147def read_tensor_tracer_event_file(event_file): 148 """Reads the event file written by tensor tracer. 149 150 Args: 151 event_file: Path to the event file that contains only tensor tracer events. 152 Returns: 153 An event dictionary in the form of 154 {step_number: {tensor_name: tensor_content}} 155 Raises: 156 ValueError: If an unexpected trace is found. 157 """ 158 event_dict = {} 159 for trace_event in summary_iterator.summary_iterator(event_file): 160 # First event is an event with file_version: "brain.Event:2" 161 if not trace_event.HasField('summary'): 162 continue 163 step = trace_event.step 164 if step not in event_dict: 165 event_dict[step] = {} 166 167 if len(trace_event.summary.value) != 1: 168 raise ValueError('Single step contains %d summary values,' 169 ' expected 1.' % len(trace_event.summary.value)) 170 tensor_value = trace_event.summary.value[0] 171 tensor_name = tensor_value.tag 172 173 real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim] 174 tensor_content = np.frombuffer( 175 tensor_value.tensor.tensor_content, 176 dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype() 177 ).reshape(real_shape) 178 event_dict[step][tensor_name] = tensor_content 179 return event_dict 180 181 182def tensor_tracepoint(tensor, checkpoint_name): 183 """Adds a checkpoint with the given checkpoint name for the given tensor. 184 185 The tensor will be added to the list of tensors that will be traced by the 186 tensor tracer. 187 188 Args: 189 tensor: the tensor object for which the tracing is requested. 190 checkpoint_name: a string name for the checkpoint. This name has to be a 191 unique name if used within model comparison. The tensors that have the same 192 checkpoint identifier is compared in model comparison. 193 Returns: 194 The provided tensor. 195 """ 196 197 tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) 198 tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, 199 (tensor, checkpoint_name)) 200 return tensor 201 202 203def keras_layer_tracepoint(layer, checkpoint_name): 204 """An interface for adding the tensor outputs of a keras layer. 205 206 Encapsulates tensor_tracepoint. 207 208 Args: 209 layer: A keras layer. 210 checkpoint_name: a string name for the checkpoint. This name has to be a 211 unique name if used within model comparison. The tensors that have the same 212 checkpoint identifier is compared in model comparison. 213 214 Returns: 215 The provided layer. 216 """ 217 try: 218 outputs = layer.output 219 if tensor_util.is_tensor(outputs): 220 tensor_tracepoint(outputs, '%s' % (checkpoint_name)) 221 else: 222 idx = 0 223 for output_tensor in outputs: 224 if tensor_util.is_tensor(outputs): 225 tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) 226 idx += 1 227 except AttributeError: 228 pass 229 except RuntimeError: 230 pass 231 return layer 232 233 234def _trace_files_need_precreated(output_dir): 235 """Return True if trace files must be pre-created by users.""" 236 237 if not output_dir.startswith('/'): 238 return False 239 if len(output_dir) < 5: 240 return False 241 if output_dir[2] != 'n': 242 return False 243 if output_dir[3] != 's': 244 return False 245 if output_dir[1] != 'c': 246 return False 247 if output_dir[4] != '/': 248 return False 249 return True 250 251 252class TensorTracer(object): 253 """A software construct for tracing tensor values in a TF graph on TPU. 254 255 This utility is disabled by default. It can be enabled by setting 256 the TENSOR_TRACER_FLAGS env variable as: 257 export TENSOR_TRACER_FLAGS="--enable=1" 258 If it is enabled, it will trace the output tensor values of 259 selected Ops in the graph. It has two outputs: (1) the traces and (2) 260 a report. The traces are dumped to a specified local file on the TPU 261 host. The report is printed to the log.info of the TPU job. 262 By passing options via the env variable, users can change: 263 (1) the trace mode (e.g., detecting NaN/Inf, printing partial or 264 full tensor values) 265 (2) which Ops to be traced (via op.name or op.type) 266 (3) output trace file path. 267 """ 268 # The set of graphs that are rewritten by tensor tracer. 269 _traced_graphs = set() 270 271 @staticmethod 272 def is_enabled(): 273 """Returns True if TensorTracer is enabled.""" 274 return tensor_tracer_flags.TTParameters().is_enabled() 275 276 @staticmethod 277 def check_device_type(device_type): 278 """Checks if the given device type is valid.""" 279 280 if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU): 281 raise ValueError('Invalid device_type "%s"'%device_type) 282 283 @staticmethod 284 def check_trace_mode(device_type, trace_mode): 285 """Checks if the given trace mode work on the given device type. 286 287 Args: 288 device_type: Device type, TPU, GPU, CPU. 289 trace_mode: Tensor tracer trace mode. 290 Raises: 291 ValueError: If the given trace mode is not supported for the device. 292 """ 293 if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY: 294 if device_type != _DEVICE_TYPE_TPU: 295 raise ValueError('Device_type "%s" is not yet supported for ' 296 'trace mode "%s"' % (device_type, trace_mode)) 297 298 @staticmethod 299 def loop_cond_op(op): 300 return op.type in ('LoopCond', 'RefLoopCond') 301 302 @staticmethod 303 def while_loop_op(op): 304 """Returns true if op is one of the special ops of in a while loop. 305 306 Args: 307 op: A tf.Operation. 308 309 Returns: 310 True if the given op is one of [Switch, Merge, Enter, Exit, 311 NextIteration, LoopCond], which are all building blocks for TF while 312 loops. 313 """ 314 return (control_flow_util.IsLoopSwitch(op) or 315 control_flow_util.IsLoopMerge(op) or 316 control_flow_util.IsLoopEnter(op) or 317 control_flow_util.IsLoopExit(op) or 318 TensorTracer.loop_cond_op(op) or 319 op.type in ('RefNextIteration', 'NextIteration')) 320 321 @staticmethod 322 def control_flow_op(op): 323 """Returns true if op is one of the special ops of in a while loop. 324 325 Args: 326 op: A tf.Operation. 327 328 Returns: 329 True if the given op is one of [Switch, Merge, Enter, Exit, 330 NextIteration, LoopCond], which are all building blocks for TF while 331 loops. 332 """ 333 return (control_flow_util.IsSwitch(op) or 334 control_flow_util.IsMerge(op)) 335 336 @staticmethod 337 def unsafe_op(op): 338 """Returns True if this op is not safe to be traced.""" 339 340 if control_flow_util.IsInCond(op): 341 return True 342 # Reasons for not including following op types: 343 # Assign: cause incorrect result with CPU tracing. 344 if op.type == 'Assign': 345 return True 346 return False 347 348 @staticmethod 349 def device_mismatch(device_type, op): 350 if device_type == _DEVICE_TYPE_TPU: 351 # pylint: disable=protected-access 352 return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr 353 # pylint: enable=protected-access 354 return False 355 356 @staticmethod 357 def unsafe_scalar_trace(op): 358 """Return true if scalar output tensor from Op is not safe to be traced.""" 359 360 # Tracing the following causes cycle in the graph on TPU. 361 if op.type in ('LoopCond', 'Enter', 'Merge', 'Const', 362 'Switch', 'Less', 'ReadVariableOp'): 363 return True 364 # Tracing the following will cause casting-issue 365 # with the norm tracing mode or other compilation issues on CPU. 366 if op.type in ('VarHandleOp', 'IteratorToStringHandle', 367 'IteratorGetNext', 'OneShotIterator', 368 'IteratorV2', 'MakeIterator', 369 'BatchDatasetV2', 'MapDataset', 370 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', 371 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'): 372 return True 373 return False 374 375 def _is_interesting_op(self, op): 376 """Returns True if the given op is not an interesting one to be traced.""" 377 # If flag is set to include less interesting ops, then include everything. 378 if self._parameters.include_less_interesting_ops: 379 return True 380 return op_priority(op.type) <= self._parameters.trace_level 381 382 @staticmethod 383 def reason(op_idx, details): 384 """Returns reason why the Op at op_idx is traced or not.""" 385 386 return '%d %s'%(op_idx, details) 387 388 def __init__(self): 389 """Initializes a TensorTracer. 390 391 Sets the various member fields from the flags (if given) or the defaults. 392 """ 393 self._replica_id = None 394 self._tt_config = tensor_tracer_report.TensorTracerConfig() 395 self._parameters = tensor_tracer_flags.TTParameters() 396 self._included_op_full_names = set() 397 self._host_call_fn = {} 398 self._cache_variables = {} 399 self._traced_op_names = set() 400 401 def _get_all_cache_variables(self): 402 return self._cache_variables 403 404 def _create_or_get_tensor_values_cache(self, cache_name, graph=None, 405 shape=None, dtype=dtypes.float32): 406 """Creates a variable as the cache to store intermediate tensor values. 407 408 Args: 409 cache_name: Name to be given to the cache (an instance of tf.variable). 410 graph: Tensorflow graph. 411 shape: A list of dimensions. 412 dtype: Data type of created cache. 413 Returns: 414 A ref to newly created or existing cache with the given dimensions. 415 Raises: 416 ValueError: If missing a parameter to create the cache. 417 """ 418 def _escape_namescopes(variable_name): 419 # TODO(deveci): This might cause name collisions as in "foo/bar/mytensor" 420 # and "foo_bar/mytensor". 421 return variable_name.replace('/', '_').replace(':', '_') 422 423 if cache_name not in self._cache_variables: 424 if graph is None: 425 raise ValueError('Graph must be provided at cache creation.') 426 if shape is None: 427 raise ValueError('shape must be provided at cache creation.') 428 graph = graph or ops.get_default_graph() 429 if dtype.is_integer: 430 init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE) 431 else: 432 init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE 433 434 # Create in proper graph and base name_scope. 435 with graph.as_default() as g, g.name_scope(None): 436 self._cache_variables[cache_name] = variable_scope.get_variable( 437 _TT_SNAPSHOT + '_' + _escape_namescopes(cache_name), 438 shape=shape, dtype=dtype, 439 initializer=init_ops.constant_initializer(init_val), 440 trainable=False, 441 use_resource=True, 442 collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES]) 443 return self._cache_variables[cache_name] 444 445 def _add_replica_id_to_graph(self): 446 """Adds nodes for computing the replica ID to the graph.""" 447 448 if self._tt_config.num_replicas: 449 with ops.control_dependencies(None): 450 # Uses None as dependency to run outside of TPU graph rewrites. 451 self._replica_id = tpu_ops.tpu_replicated_input( 452 list(range(self._tt_config.num_replicas)), 453 name='tt_replica_id') 454 else: 455 self._replica_id = 'unknown' 456 457 def _inside_op_range(self, idx): 458 """Return True if the given index is inside the selected range.""" 459 460 if idx < self._parameters.op_range[0]: 461 return False 462 return (self._parameters.op_range[1] < 0 or 463 idx <= self._parameters.op_range[1]) 464 465 def _is_user_included_op(self, op): 466 """Checks whether the op is included in the tensor tracer flags. 467 468 Args: 469 op: tf Operation 470 Returns: 471 True, if the op is included. 472 An op is included if: 473 - Its op name is given in included_opnames 474 - Its op type is given in included_optypes 475 - The op is at most _trace_ops_before_included hops before an included op 476 - The op is at most _trace_ops_after_included hops after an included op 477 """ 478 479 def _is_op_or_any_neighbor_included(op, check_before=0, check_after=0): 480 """Helper function to check if op is included or not.""" 481 if op.name in self._included_op_full_names: 482 return True 483 for opname_re in self._parameters.included_opname_re_list: 484 if opname_re.match(op.name): 485 self._included_op_full_names.add(op.name) 486 return True 487 488 for optype_re in self._parameters.included_optype_re_list: 489 if optype_re.match(op.type): 490 self._included_op_full_names.add(op.name) 491 return True 492 493 if check_after > 0: 494 for out_tensor in op.outputs: 495 for consumer in out_tensor.consumers(): 496 if _is_op_or_any_neighbor_included(consumer, check_after - 1, 0): 497 self._included_op_full_names.add(op.name) 498 return True 499 if check_before > 0: 500 for input_tensor in op.inputs: 501 if _is_op_or_any_neighbor_included(input_tensor.op, 502 0, 503 check_before - 1): 504 self._included_op_full_names.add(op.name) 505 return True 506 return False 507 # check_after and check_before are swapped below, as below operation 508 # checks the distance from an arbitrary op to included ops. 509 return _is_op_or_any_neighbor_included( 510 op, self._parameters.trace_ops_after_included, 511 self._parameters.trace_ops_before_included) 512 513 def _is_user_excluded_op(self, op): 514 for opname_re in self._parameters.excluded_opname_re_list: 515 if opname_re.match(op.name): 516 return True 517 for optype_re in self._parameters.excluded_optype_re_list: 518 if optype_re.match(op.type): 519 return True 520 return False 521 522 def _signature_types(self): 523 """Returns a dictionary holding the order of signatures in the cache for the selected trace mode.""" 524 if self._parameters.trace_mode in set([ 525 tensor_tracer_flags.TRACE_MODE_NAN_INF, 526 tensor_tracer_flags.TRACE_MODE_NORM, 527 tensor_tracer_flags.TRACE_MODE_MAX_ABS]): 528 return {self._parameters.trace_mode: 0} 529 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 530 return self._parameters.summary_signatures 531 return {} 532 533 def _num_signature_dimensions(self): 534 return len(self._signature_types()) 535 536 def _use_tensor_values_cache(self): 537 """Returns True if immediate tensors should be first saved to a cache.""" 538 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 539 # For summary tace mode only compact format is supported. 540 return True 541 542 if self._parameters.trace_mode not in set([ 543 tensor_tracer_flags.TRACE_MODE_NAN_INF, 544 tensor_tracer_flags.TRACE_MODE_NORM, 545 tensor_tracer_flags.TRACE_MODE_MAX_ABS, 546 tensor_tracer_flags.TRACE_MODE_SUMMARY 547 ]): 548 return False 549 if (self._parameters.trace_dir and 550 _trace_files_need_precreated(self._parameters.trace_dir)): 551 return True 552 return self._parameters.use_compact_trace 553 554 def _use_tensor_buffer(self): 555 """Returns true if the whole tensor needs to be cached/buffered in memory.""" 556 return (self._parameters.trace_mode == 557 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) 558 559 def _save_tensor_value_to_cache_op(self, cache_idx, updates): 560 """Returns an op that will save the given updates to an entry in the cache. 561 562 Args: 563 cache_idx: The cache index of the tensor within the cache. 564 updates: A dictionary of the signature updates. 565 Returns: 566 Cache update operation. 567 """ 568 # state_ops.scatter_update allows updates only along the first dimension. 569 # Make a compact array by concantating different signatures, and update 570 # them all together. 571 sorted_update = [] 572 if self._num_signature_dimensions() > 1: 573 signature_indices = self._signature_types() 574 for _, val in sorted(updates.items(), 575 key=lambda item: signature_indices[item[0]]): 576 sorted_update.append(val) 577 updates = array_ops.stack(sorted_update, axis=0) 578 updates = array_ops.reshape(updates, [1, 579 self._num_signature_dimensions()]) 580 else: 581 (_, val), = updates.items() 582 updates = array_ops.reshape(val, [1, self._num_signature_dimensions()]) 583 indices = constant_op.constant([cache_idx]) 584 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG) 585 return state_ops.scatter_update(cache, indices, updates).op 586 587 def _snapshot_tensor(self, tensor): 588 """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable. 589 590 Args: 591 tensor: tensor whose values will be stored in a new tf.Variable. 592 Returns: 593 An assignment operation. 594 """ 595 596 snapshot_variable = self._create_or_get_tensor_values_cache( 597 tensor.name, tensor.op.graph, 598 tensor.shape.as_list(), tensor.dtype) 599 return state_ops.assign(snapshot_variable, tensor).op 600 601 def _preprocess_traced_tensor(self, tensor): 602 """Computes NAN/Norm/Max on TPUs before sending to CPU. 603 604 Args: 605 tensor: The tensor to be traced. 606 Returns: 607 A tensor that should be input to the trace_function. 608 Raises: 609 RuntimeError: If the trace mode is invalid. 610 """ 611 612 def _detect_nan_inf(tensor): 613 """Trace function for detecting any NaN/Inf in the tensor.""" 614 615 if tensor.dtype.is_floating: 616 mask = math_ops.reduce_any( 617 gen_math_ops.logical_or( 618 gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) 619 output_tensor = control_flow_ops.cond( 620 mask, 621 lambda: constant_op.constant([1.0]), 622 lambda: constant_op.constant([0.0])) 623 else: 624 output_tensor = constant_op.constant([0.0]) 625 return output_tensor 626 627 def _compute_signature(tensor, tf_op, cast_to_f32=True): 628 if cast_to_f32: 629 tensor = math_ops.cast(tensor, dtypes.float32) 630 output_tensor = tf_op(tensor) 631 # Return type should be scalar. Set it if it does not have the 632 # information. 633 if not output_tensor.get_shape().is_fully_defined(): 634 output_tensor = array_ops.reshape(output_tensor, []) 635 return output_tensor 636 637 def _show_size(tensor): 638 # In order to check the size of a tensor. 639 # Not all sizes are known at the compile time, also, different replicas 640 # sometimes get different sizes of tensors. 641 # Collect it here to be used in merging replica data. 642 tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False) 643 # Cast to float32, so that it can be placed into same cache with other 644 # signatures. 645 return math_ops.cast(tsize, dtypes.float32) 646 647 def _show_max(tensor, cast_to_f32=True): 648 # returns -inf for empty tensor 649 return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32) 650 651 def _show_min(tensor, cast_to_f32=True): 652 # returns inf for empty tensor 653 return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32) 654 655 def _show_norm(tensor, cast_to_f32=True): 656 # returns 0 for empty tensor 657 return _compute_signature(tensor, linalg_ops.norm, cast_to_f32) 658 659 def _show_mean_and_variance(tensor, cast_to_f32=True): 660 """Returns the mean and variance of the given tensor.""" 661 if cast_to_f32: 662 tensor = math_ops.cast(tensor, dtypes.float32) 663 # returns nan for empty tensor 664 mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0]) 665 # The shape has to be 1. Set it if it does not have the information. 666 if not mean.get_shape().is_fully_defined(): 667 mean = array_ops.reshape(mean, []) 668 if not var.get_shape().is_fully_defined(): 669 var = array_ops.reshape(var, []) 670 return mean, var 671 672 def _show_max_abs(tensor): 673 tensor = math_ops.cast(tensor, dtypes.float32) 674 output_tensor = math_ops.reduce_max(math_ops.abs(tensor)) 675 zero = constant_op.constant(0, dtypes.float32) 676 output_tensor = gen_math_ops.maximum(zero, output_tensor) 677 # The shape has to be 1. Set it if it does not have the information. 678 output_tensor = array_ops.reshape(output_tensor, [1]) 679 return output_tensor 680 681 def _detect_inf_nan_producer(tensor): 682 """Checks if the tensor is the first NaN/Inf tensor in the computation path.""" 683 if tensor.op.inputs: 684 inp_check = [ 685 _detect_nan_inf(inp_tensor) for inp_tensor in tensor.op.inputs 686 ] 687 is_any_input_inf_nan = math_ops.add_n(inp_check) 688 else: 689 is_any_input_inf_nan = constant_op.constant(0, dtypes.bool) 690 is_current_tensor_inf_nan = _detect_nan_inf(tensor) 691 # An op is NaN/INF producer only when all inputs are nan/inf free ( 692 # is_any_input_inf_nan = 0), and its output has nan/inf ( 693 # is_current_tensor_inf_nan=1). Below will be 1 if op nan/inf is producer. 694 is_nan_producer = is_current_tensor_inf_nan - is_any_input_inf_nan 695 is_nan_producer = math_ops.reduce_any(is_nan_producer > 0) 696 return is_nan_producer 697 698 if (self._parameters.trace_mode == 699 tensor_tracer_flags.TRACE_MODE_FULL_IF_NAN): 700 return {self._parameters.trace_mode: _detect_inf_nan_producer(tensor)} 701 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF: 702 return {self._parameters.trace_mode: _detect_nan_inf(tensor)} 703 if (self._parameters.trace_mode == 704 tensor_tracer_flags.TRACE_MODE_PART_TENSOR): 705 return {self._parameters.trace_mode: tensor} 706 if (self._parameters.trace_mode in ( 707 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, 708 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)): 709 return {self._parameters.trace_mode: tensor} 710 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM: 711 return {self._parameters.trace_mode: array_ops.reshape( 712 _show_norm(tensor), [1])} 713 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS: 714 return {self._parameters.trace_mode: _show_max_abs(tensor)} 715 716 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 717 tensor = math_ops.cast(tensor, dtypes.float32) 718 result_dict = {} 719 # Call mean and variance computation here to avoid adding the same nodes 720 # twice. 721 if (_TT_SUMMARY_MEAN in self._signature_types() or 722 _TT_SUMMARY_VAR in self._signature_types()): 723 mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False) 724 725 for signature_name, _ in sorted(self._signature_types().items(), 726 key=lambda x: x[1]): 727 if signature_name == _TT_SUMMARY_NORM: 728 signature_result_tensor = _show_norm(tensor, cast_to_f32=False) 729 elif signature_name == _TT_SUMMARY_MAX: 730 signature_result_tensor = _show_max(tensor, cast_to_f32=False) 731 elif signature_name == _TT_SUMMARY_MIN: 732 signature_result_tensor = _show_min(tensor, cast_to_f32=False) 733 elif signature_name == _TT_SUMMARY_SIZE: 734 signature_result_tensor = _show_size(tensor) 735 elif signature_name == _TT_SUMMARY_MEAN: 736 signature_result_tensor = mean 737 elif signature_name == _TT_SUMMARY_VAR: 738 signature_result_tensor = variance 739 else: 740 raise ValueError('Unknown signature type :%s.' % signature_name) 741 742 result_dict[signature_name] = signature_result_tensor 743 return result_dict 744 745 raise RuntimeError( 746 'Tensor trace fun for %s is not yet implemented' 747 % self._parameters.trace_mode) 748 749 def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order): 750 """Makes the tensor tracing function called by outside compilation. 751 752 Args: 753 tensor_name: name of the tensor being traced. 754 tensor_trace_order: TensorTraceOrder object holding tensorname to id map. 755 Returns: 756 A function to be passed as the first argument to outside compilation. 757 758 Raises: 759 RuntimeError: If the trace mode is invalid. 760 """ 761 762 def _print_tensor(tensor_name, num_elements, tensor, output_tensor): 763 """Prints a tensor value to a file. 764 765 Args: 766 tensor_name: name of the tensor being traced. 767 num_elements: number of elements to print (-1 means print all). 768 tensor: the tensor needs to be returned. 769 output_tensor: the tensor needs to be printed. 770 771 Returns: 772 The same tensor passed via the "tensor" argument. 773 774 Raises: 775 ValueError: If tensor_name is not already in 776 self._tensorname_idx_map. 777 """ 778 779 if self._parameters.is_brief_mode(): 780 if tensor_name not in tensor_trace_order.tensorname_idx_map: 781 raise ValueError( 782 'Tensor name %s is not in the tensorname_idx_map'%tensor_name) 783 msg = '%d'%self._tensorname_idx_map[tensor_name] 784 else: 785 msg = '"%s"'%tensor_name 786 787 if self._parameters.trace_dir: 788 output_path = os.path.join(self._parameters.trace_dir, _TRACE_FILE_NAME) 789 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 790 else: 791 output_stream = sys.stderr 792 return logging_ops.print_v2(msg, array_ops.shape(output_tensor), 793 '@', self._replica_id, 794 '\n', output_tensor, '\n', 795 summarize=num_elements, 796 output_stream=output_stream) 797 798 def _show_part_tensor(tensor): 799 """Trace function for printing part of the tensor.""" 800 801 return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE, 802 tensor, tensor) 803 804 def _show_full_tensor(tensor): 805 """Trace function for printing the entire tensor.""" 806 807 return _print_tensor(tensor_name, -1, tensor, tensor) 808 809 def _show_full_tensors(tensor): 810 """Prints the full tensor values for the tensors that are _trace_stack_size hops away from a given tensor.""" 811 812 def _get_distance_k_tensors(k_before=0): 813 """Returns the tensors that are at most k_before hops away from the tensor.""" 814 if k_before < 0: 815 return [] 816 visited_tensors = {tensor: 0} 817 visitor_queue = [tensor] 818 head = 0 819 while head < len(visitor_queue): 820 current_tensor = visitor_queue[head] 821 head += 1 822 distance = visited_tensors[current_tensor] 823 if distance == k_before: 824 break 825 for input_tensor in current_tensor.op.inputs: 826 if input_tensor in visited_tensors: 827 continue 828 visitor_queue.append(input_tensor) 829 visited_tensors[input_tensor] = distance + 1 830 return visitor_queue 831 832 tensors_to_print = _get_distance_k_tensors( 833 self._parameters.trace_stack_size) 834 print_ops = [_print_tensor(t.name, -1, t, t) for t in tensors_to_print] 835 with ops.control_dependencies(print_ops): 836 return constant_op.constant(True) 837 838 if (self._parameters.trace_mode == 839 tensor_tracer_flags.TRACE_MODE_FULL_IF_NAN): 840 return _show_full_tensors 841 if (self._parameters.trace_mode == 842 tensor_tracer_flags.TRACE_MODE_PART_TENSOR): 843 return _show_part_tensor 844 # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF, 845 # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are 846 # performed within TPUs and only their results are transferred to CPU. 847 # Simply, print the full tensor for these trace modes. 848 if self._parameters.trace_mode in ( 849 tensor_tracer_flags.TRACE_MODE_NAN_INF, 850 tensor_tracer_flags.TRACE_MODE_NORM, 851 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR, 852 tensor_tracer_flags.TRACE_MODE_MAX_ABS, 853 tensor_tracer_flags.TRACE_MODE_SUMMARY 854 ): 855 return _show_full_tensor 856 857 raise RuntimeError('Tensor trace fun for %s is not yet implemented' 858 %self._parameters.trace_mode) 859 860 def _skip_op(self, op_id, op, ops_in_exec_path, report_handler): 861 """Returns True if we should not trace Op. 862 863 Args: 864 op_id: Topological index of the op. 865 op: tf.Operation 866 ops_in_exec_path: Set of operations that are in the execution path. 867 report_handler: An instance of tensor_tracer_report.TTReportHandle. 868 Returns: 869 True if the op should not be traced, false otherwise. 870 """ 871 if TensorTracer.while_loop_op(op): 872 report_handler.instrument_op( 873 op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP)) 874 return True 875 if TensorTracer.control_flow_op(op): 876 report_handler.instrument_op( 877 op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP)) 878 return True 879 if TensorTracer.unsafe_op(op): 880 report_handler.instrument_op( 881 op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP)) 882 return True 883 if TensorTracer.device_mismatch(self._tt_config.device_type, op): 884 report_handler.instrument_op( 885 op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH)) 886 return True 887 if op not in ops_in_exec_path: 888 report_handler.instrument_op( 889 op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED)) 890 return True 891 892 if self._is_user_included_op(op): 893 report_handler.instrument_op( 894 op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) 895 return False 896 897 if not self._inside_op_range(op_id): 898 report_handler.instrument_op( 899 op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE)) 900 return True 901 if not self._is_interesting_op(op): 902 report_handler.instrument_op( 903 op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP)) 904 return True 905 if self._is_user_excluded_op(op): 906 report_handler.instrument_op( 907 op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) 908 return True 909 return False 910 911 def _skip_tensor(self, op_id, out_tensor, report_handler): 912 """Returns True if we should not trace out_tensor. 913 914 Args: 915 op_id: Topological index of the op producing tensor. 916 out_tensor: tf.Tensor 917 report_handler: An instance of tensor_tracer_report.TTReportHandle. 918 Returns: 919 True if the tensor should not be traced, false otherwise. 920 """ 921 922 # Skips a tensor if the tensor has a non-numeric type. 923 # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) 924 # because it also excludes tensors with dtypes, bool, and 925 # float32_ref, which we actually want to trace. 926 non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, 927 dtypes.string]) 928 if out_tensor.dtype in non_numeric_tensor_types: 929 930 report_handler.instrument_tensor( 931 out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR)) 932 return True 933 # Skip a tensor if it feeds a special while loop op. 934 if [consumer for consumer in out_tensor.consumers() if 935 TensorTracer.while_loop_op(consumer)]: 936 report_handler.instrument_tensor( 937 out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP)) 938 return True 939 if self._is_user_included_op(out_tensor.op): 940 report_handler.instrument_tensor( 941 out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED)) 942 return False 943 if self._is_user_excluded_op(out_tensor.op): 944 report_handler.instrument_tensor( 945 out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED)) 946 return True 947 if not out_tensor.get_shape().is_fully_defined(): 948 # If trace mode is nan-inf, norm or max, then the tensor will be reduced 949 # to a scalar before the outside compilation call. 950 if self._parameters.trace_mode in ( 951 tensor_tracer_flags.TRACE_MODE_NAN_INF, 952 tensor_tracer_flags.TRACE_MODE_NORM, 953 tensor_tracer_flags.TRACE_MODE_MAX_ABS, 954 tensor_tracer_flags.TRACE_MODE_SUMMARY 955 ): 956 report_handler.instrument_tensor( 957 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) 958 return False 959 else: 960 report_handler.instrument_tensor( 961 out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE)) 962 return True 963 rank = len(out_tensor.shape) 964 if rank < 1: 965 # scalar 966 if self._parameters.trace_scalar_ops: 967 if TensorTracer.unsafe_scalar_trace(out_tensor.op): 968 report_handler.instrument_tensor( 969 out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR)) 970 return True 971 else: 972 report_handler.instrument_tensor( 973 out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED)) 974 return False 975 else: 976 report_handler.instrument_tensor( 977 out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR)) 978 return True 979 else: 980 # tensor 981 report_handler.instrument_tensor( 982 out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED)) 983 return False 984 985 def _filter_execution_path_operations(self, operations, fetches): 986 """Returns the set of ops in the execution path to compute given fetches.""" 987 988 # If no fetch provided, then return all operations. 989 if fetches is None: 990 return set(operations) 991 # Convert to list, if a single element is provided. 992 if not isinstance(fetches, (list, tuple)): 993 fetches = [fetches] 994 # If a tensor is given as fetch, convert it to op. 995 op_fetches = [] 996 for fetch in fetches: 997 if isinstance(fetch, ops.Operation): 998 op_fetches.append(fetch) 999 elif isinstance(fetch, ops.Tensor): 1000 op_fetches.append(fetch.op) 1001 else: 1002 raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' 1003 %fetch) 1004 1005 execution_path_operations = set(op_fetches) 1006 traverse_stack = list(op_fetches) 1007 while True: 1008 if not traverse_stack: 1009 break 1010 head_op = traverse_stack.pop() 1011 input_ops = [tensor_input.op for tensor_input in head_op.inputs] 1012 input_ops.extend(head_op.control_inputs) 1013 1014 for input_op in input_ops: 1015 if input_op not in execution_path_operations: 1016 # Filter out loop condition operations, tracing them causes a cycle. 1017 # Trace only the loop-body. 1018 if TensorTracer.loop_cond_op(input_op): 1019 continue 1020 execution_path_operations.add(input_op) 1021 traverse_stack.append(input_op) 1022 return execution_path_operations 1023 1024 def _determine_and_instrument_traced_tensors(self, graph_order, 1025 ops_in_exec_path, 1026 tensor_trace_points, 1027 report_handler): 1028 """Determines the tensors to trace and instruments the trace details. 1029 1030 Args: 1031 graph_order: graph_order tuple containing graph (tf.graph), operations 1032 (list of operations), op_to_idx (op id mapping), (tensors) list of 1033 tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether 1034 there is a cycle in the graph), topological_order_or_cycle (list of ops 1035 in topological order or list of ops creating a cycle). 1036 ops_in_exec_path: Set of ops in the execution path. 1037 tensor_trace_points: Collection of programatic tensor trace points. 1038 report_handler: An instance of tensor_tracer_report.TTReportHandle. 1039 Returns: 1040 List of tensors to be traced. 1041 """ 1042 1043 traced_tensors = [] 1044 checkpoint_operations = set([tensor.op 1045 for (tensor, _) in tensor_trace_points]) 1046 for op_id, op in enumerate(graph_order.operations): 1047 if checkpoint_operations and op not in checkpoint_operations: 1048 continue 1049 if self._skip_op(op_id, op, ops_in_exec_path, report_handler): 1050 continue 1051 for i in range(len(op.outputs)): 1052 out_tensor = op.outputs[i] 1053 if not self._skip_tensor(op_id, out_tensor, report_handler): 1054 traced_tensors.append(out_tensor) 1055 return traced_tensors 1056 1057 def _check_trace_files(self): 1058 """Checks if any requirements for trace files are satisfied.""" 1059 1060 if not self._parameters.trace_dir: 1061 # traces will be written to stderr. No need to check trace files. 1062 return 1063 if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY: 1064 # Output files are handled by tf.summary operations, no need to precreate 1065 # them. 1066 return 1067 if _trace_files_need_precreated(self._parameters.trace_dir): 1068 for replica_id in range(0, self._tt_config.num_replicas): 1069 trace_file_path = os.path.join( 1070 self._parameters.trace_dir, 1071 _COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id 1072 if not gfile.Exists(trace_file_path): 1073 raise RuntimeError( 1074 '%s must be pre-created with the ' 1075 'appropriate properties.'%trace_file_path) 1076 else: 1077 if not gfile.Exists(self._parameters.trace_dir): 1078 file_io.recursive_create_dir(self._parameters.trace_dir) 1079 if not gfile.Exists(self._parameters.trace_dir): 1080 raise RuntimeError('Failed to create %s'%self._parameters.trace_dir) 1081 1082 def _determine_trace_and_create_report(self, graph, ops_in_exec_path): 1083 """Work needs to be done prior to TPU or CPU tracing. 1084 1085 Args: 1086 graph: tf.graph 1087 ops_in_exec_path: Set of operations in the execution path. 1088 Returns: 1089 An instance of tensor_tracer_report.TensorTraceOrder, containing list of 1090 tensors to be traced with their topological order information. 1091 """ 1092 1093 self._check_trace_files() 1094 1095 graph_order = tensor_tracer_report.sort_tensors_and_ops(graph) 1096 tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION) 1097 1098 report_handler = tensor_tracer_report.TTReportHandle() 1099 traced_tensors = self._determine_and_instrument_traced_tensors( 1100 graph_order, ops_in_exec_path, tensor_trace_points, report_handler) 1101 1102 tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order, 1103 traced_tensors) 1104 num_signatures = self._num_signature_dimensions() 1105 if num_signatures: 1106 self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG, 1107 graph, 1108 [len(traced_tensors), 1109 num_signatures]) 1110 if self._parameters.trace_mode in ( 1111 tensor_tracer_flags.TRACE_MODE_SUMMARY, 1112 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY): 1113 report_proto = report_handler.create_report_proto(self._tt_config, 1114 self._parameters, 1115 tensor_trace_order, 1116 tensor_trace_points, 1117 self._signature_types()) 1118 report_handler.write_report_proto(report_proto, self._parameters) 1119 else: 1120 report_handler.create_report(self._tt_config, self._parameters, 1121 tensor_trace_order, tensor_trace_points) 1122 return tensor_trace_order 1123 1124 def _create_host_call(self): 1125 return self._parameters.trace_mode in ( 1126 tensor_tracer_flags.TRACE_MODE_SUMMARY, 1127 tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY) 1128 1129 def _generate_flush_cache_op(self, num_replicas, on_tpu): 1130 """Generates an Op that will flush the cache to file. 1131 1132 Args: 1133 num_replicas: total number of replicas. 1134 on_tpu: if the graph is executed on TPU. 1135 1136 Returns: 1137 The Op to flush the cache to file. 1138 """ 1139 1140 def _flush_fun(cache, replica_id): 1141 """Flushes the cache to a file corresponding to replica_id.""" 1142 1143 def _f(file_index): 1144 """Generates a func that flushes the cache to a file.""" 1145 def _print_cache(): 1146 """Flushes the cache to a file.""" 1147 replica_str = ('%d' % file_index) 1148 if self._parameters.trace_dir: 1149 output_path = (os.path.join(self._parameters.trace_dir, 1150 _COMPACT_TRACE_FILE_PREFIX) 1151 + replica_str) 1152 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 1153 else: 1154 output_stream = sys.stderr 1155 1156 new_step_line = _REPLICA_ID_TAG + replica_str 1157 print_ops = [] 1158 for i in range(self._num_signature_dimensions()): 1159 print_ops.append(logging_ops.print_v2( 1160 new_step_line, '\n', 1161 cache[:, i], '\n', 1162 summarize=-1, 1163 output_stream=output_stream)) 1164 with ops.control_dependencies(print_ops): 1165 return constant_op.constant(0).op 1166 return _print_cache 1167 1168 def _eq(file_index): 1169 return math_ops.equal(replica_id, file_index) 1170 1171 flush_op_cases = {} 1172 for i in range(num_replicas): 1173 flush_op_cases[_eq(i)] = _f(i) 1174 # Each replica needs to determine where to write their output. 1175 # To do this, we check if replica_id is 0, then 1, ..., and then 1176 # num_replicas - 1 statically; and return the corresponding static file 1177 # name. We cannot simply set the file name in python, as replica_id is 1178 # only known during tf runtime, and we cannot create dynamic filenames. 1179 return control_flow_ops.case(flush_op_cases, exclusive=True) 1180 1181 cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG) 1182 if on_tpu: 1183 flush_op = tpu.outside_compilation(_flush_fun, 1184 cache.value(), self._replica_id) 1185 else: 1186 flush_op = _flush_fun(cache.value(), self._replica_id) 1187 1188 with ops.control_dependencies([flush_op]): 1189 reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, 1190 dtype=cache.dtype, 1191 shape=cache.shape) 1192 assign_op = state_ops.assign(cache, reset_value).op 1193 with ops.control_dependencies([assign_op]): 1194 return constant_op.constant(0).op 1195 1196 def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu): 1197 """Flushes the intermediate tensor values in the graph to the cache. 1198 1199 Args: 1200 tensor_fetches: list of tensor results returned by the model_fn. 1201 op_fetches: list of ops that are returned by the model_fn, e.g., train_op. 1202 on_tpu: if the graph is executed on TPU. 1203 1204 Returns: 1205 An identical copy of tensor_fetches. 1206 """ 1207 # Add a dependency to op and tensor fetches to make sure that all tracing 1208 # ops are executed before flushing trace results. 1209 with ops.control_dependencies(op_fetches + 1210 [tensor.op for tensor in tensor_fetches]): 1211 flush_cache_op = self._generate_flush_cache_op( 1212 self._tt_config.num_replicas, on_tpu) 1213 return control_flow_ops.tuple(tensor_fetches, 1214 control_inputs=[flush_cache_op]) 1215 1216 def _process_tensor_fetches(self, tensor_fetches): 1217 """Check that tensor_fetches is not empty and have valid tensors.""" 1218 # If none or empty list. 1219 if tensor_fetches is None: 1220 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 1221 'None.') 1222 if not isinstance(tensor_fetches, (list, tuple)): 1223 tensor_fetches = [tensor_fetches] 1224 elif not tensor_fetches: 1225 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 1226 'empty list.') 1227 fetches = [] 1228 for fetch in tensor_fetches: 1229 if isinstance(fetch, ops.Tensor): 1230 fetches.append(fetch) 1231 else: 1232 raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch) 1233 return fetches 1234 1235 def _process_op_fetches(self, op_fetches): 1236 """Check that op_fetches have valid ops.""" 1237 if op_fetches is None: 1238 return [] 1239 1240 if not isinstance(op_fetches, (list, tuple)): 1241 op_fetches = [op_fetches] 1242 1243 fetches = [] 1244 for fetch in op_fetches: 1245 if isinstance(fetch, ops.Operation): 1246 fetches.append(fetch) 1247 elif isinstance(fetch, ops.Tensor): 1248 fetches.append(fetch.op) 1249 else: 1250 logging.warning('Ignoring the given op_fetch:%s, which is not an op.' % 1251 fetch) 1252 return fetches 1253 1254 def _convert_fetches_to_input_format(self, input_fetches, current_fetches): 1255 """Changes current_fetches' format, so that it matches input_fetches.""" 1256 if isinstance(input_fetches, ops.Tensor): 1257 if len(current_fetches) != 1: 1258 raise RuntimeError('Tensor tracer input/output fetches do not match.') 1259 return current_fetches[0] 1260 else: 1261 if len(current_fetches) != len(current_fetches): 1262 raise RuntimeError('Tensor tracer input/output fetches do not match.') 1263 elif isinstance(input_fetches, tuple): 1264 return tuple(current_fetches) 1265 else: 1266 return current_fetches 1267 1268 def _get_op_control_flow_context(self, op): 1269 """Returns the control flow of the given op. 1270 1271 Args: 1272 op: tf.Operation for which the control flow context is requested. 1273 Returns: 1274 op_control_flow_context: which the is control flow context of the given 1275 op. If the operation type is LoopExit, returns the outer control flow 1276 context. 1277 """ 1278 # pylint: disable=protected-access 1279 op_control_flow_context = op._control_flow_context 1280 # pylint: enable=protected-access 1281 if control_flow_util.IsLoopExit(op): 1282 op_control_flow_context = op_control_flow_context.outer_context 1283 return op_control_flow_context 1284 1285 def _prepare_host_call_fn(self, processed_t_fetches, op_fetches): 1286 """Creates a host call function that will write the cache as tb summary. 1287 1288 Args: 1289 processed_t_fetches: List of tensor provided to session.run. 1290 op_fetches: List of operations provided to session.run. 1291 Raises: 1292 ValueError if trace_dir is not set. 1293 """ 1294 if self._parameters.trace_dir is None: 1295 raise ValueError('Provide a trace_dir for tensor tracer in summary mode. ' 1296 '--trace_dir=/model/dir') 1297 1298 def _write_cache(step, **kwargs): 1299 """Writes the given caches as tensor summary. 1300 1301 Args: 1302 step: Step tensor with dimension [num_cores]. 1303 **kwargs: The dictionary of tensors that needs to be written as 1304 summaries. Key and value pairs within kwargs correspond to the tag 1305 name, and tensor content that will be written using summary.write. 1306 The trace_modes that use this function are: 1307 - summary: In summary mode, kwargs includes a single (tag, content) 1308 pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache 1309 variable. The dimension of the signature_cache is: 1310 num_cores x num_traced_tensors x num_signatures. 1311 - full_tensor_summary: kwargs will include all traced tensors. Tag 1312 and content correspond to the name of the tensor, and its actual 1313 content. 1314 Returns: 1315 A tf.Operation that needs to be executed for the host call dependencies. 1316 Raises: 1317 RuntimeError: if there is no aggregate function defined for a signature. 1318 """ 1319 1320 # TODO(deveci): Parametrize max_queue, so that flushing op can be called 1321 # less frequently. 1322 # Setting max_queue to 100 appears to be safe even when the number of 1323 # iterations are much lower, as the destructor of the writer flushes it. 1324 summary_write_ops = [] 1325 with summary.create_file_writer_v2( 1326 self._parameters.trace_dir, 1327 filename_suffix=_TT_EVENT_FILE_SUFFIX, 1328 max_queue=_TT_SUMMARY_MAX_QUEUE).as_default(): 1329 summary_metadata = summary_pb2.SummaryMetadata( 1330 plugin_data=summary_pb2.SummaryMetadata.PluginData( 1331 plugin_name=_TT_TENSORBOARD_PLUGIN_NAME)) 1332 for key, value in kwargs.items(): 1333 # Check whether we need to compute aggregated statistics that merge 1334 # all cores statistics. 1335 if not self._parameters.collect_summary_per_core: 1336 # Merge only statistics tensor, if it is any other tensor we simply, 1337 # concatenate them. 1338 if key == _TT_SUMMARY_TAG: 1339 agg_fn_map = self._parameters.get_signature_to_agg_fn_map() 1340 signature_idx_map = self._signature_types() 1341 aggregation_result = [] 1342 for signature, idx in sorted(signature_idx_map.items(), 1343 key=operator.itemgetter(1)): 1344 if signature not in agg_fn_map: 1345 raise RuntimeError('No aggregation function is defined for ' 1346 'signature %s.' % signature) 1347 1348 # The dimensions of the statistics tensor is 1349 # num_cores x num_traced_tensors x num_signatures 1350 # value[:,:,idx] will return the portion of the tensor relasted 1351 # to signature. 1352 signature_tensor = value[:, :, idx] 1353 # Merge it along the first (core) axis. 1354 agg_fn = agg_fn_map[signature] 1355 agg_tensor = agg_fn(signature_tensor, axis=0) 1356 aggregation_result.append(agg_tensor) 1357 # Merge results corresponding to different signatures 1358 1359 merged_signatures = array_ops.stack(aggregation_result) 1360 # merged_signatures has dimensions 1361 # num_signatures x num_traced_tensors, transpose it so that it 1362 # will match with the original structure 1363 # num_traced_tensors x num_signatures. 1364 transposed_signatures = array_ops.transpose(merged_signatures) 1365 # Expand 1 more dimension so that it will match with the expected 1366 # structure num_cores x num_traced_tensors x num_signatures. 1367 value = array_ops.expand_dims(transposed_signatures, axis=0) 1368 1369 with ops.control_dependencies( 1370 summary.summary_writer_initializer_op()): 1371 summary_write_ops.append(summary.write( 1372 _TT_SUMMARY_TAG + '/' + key, value, metadata=summary_metadata, 1373 step=step[0])) 1374 return control_flow_ops.group(summary_write_ops) 1375 1376 step = array_ops.reshape(training_util.get_or_create_global_step(), [1]) 1377 self._host_call_fn = {} 1378 1379 host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches] 1380 1381 caches_to_write = {} 1382 with ops.control_dependencies(host_call_deps): 1383 all_caches = self._get_all_cache_variables() 1384 for cache_name, cache_variable in all_caches.items(): 1385 # Increase the cache rank by 1, so that when host call concatenates 1386 # tensors from different replicas, we can identify them with [core_id]. 1387 new_cache_shape = [1] 1388 new_cache_shape.extend(cache_variable.shape.as_list()) 1389 cache = array_ops.reshape(cache_variable.value(), new_cache_shape) 1390 caches_to_write[cache_name] = cache 1391 # Add step to parameter dictionary. 1392 caches_to_write['step'] = step 1393 # Other options without adding step to parameter dictionary are 1394 # * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it 1395 # considers caches_to_write as a single parameter, rather than a keyword 1396 # parameters. 1397 # * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with 1398 # a syntax error. 1399 self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write) 1400 1401 def host_call_deps_and_fn(self): 1402 return self._host_call_fn 1403 1404 def get_traced_op_names(self): 1405 """Returns the set of traced op names.""" 1406 return self._traced_op_names 1407 1408 def _trace_execution(self, graph, 1409 tensor_fetches, 1410 op_fetches=None, 1411 on_tpu=True): 1412 """Commong tracing function for both CPU and TPUs. 1413 1414 The caller function should set device_type, num_replicas, 1415 num_replicas_per_host, num_hosts and replica_id before calling 1416 _trace_execution. 1417 1418 1419 Args: 1420 graph: the graph of Ops executed on the TPU. 1421 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 1422 returned by model_fn given to session.run. Function must be provided 1423 with as least one tensor to fetch. 1424 op_fetches: A list of op fetches returned by model_fn given to 1425 session.run. op_fetches and tensor_fetches are used to determine the 1426 nodes that will be executed. Can be None. 1427 on_tpu: True if executing on TPU. 1428 1429 Returns: 1430 tensor_fetches: an exact copy of tensor_fetches that has additional 1431 dependencies. 1432 Raises: 1433 RuntimeError: If tensor_fetches is None or empty. 1434 """ 1435 def _cast_unsupported_dtypes(tensor): 1436 """Casts tensor to a supported type.""" 1437 1438 if tensor.dtype.__eq__(dtypes.int64): 1439 # outside-compilation doesn't support int64 input yet. 1440 return math_ops.cast(tensor, dtypes.int32) 1441 if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( 1442 dtypes.float16): 1443 # Since host can't handle bf16, convert tensor to f32. 1444 return math_ops.cast(tensor, dtypes.float32) 1445 return tensor 1446 1447 trace_mode = self._parameters.trace_mode 1448 device_type = self._tt_config.device_type 1449 1450 analytics.track_usage('tensor_tracer', [trace_mode, device_type]) 1451 TensorTracer.check_device_type(device_type) 1452 TensorTracer.check_trace_mode(device_type, trace_mode) 1453 # Check in_tensor_fetches, and op_fetches and convert them to lists. 1454 processed_t_fetches = self._process_tensor_fetches(tensor_fetches) 1455 op_fetches = self._process_op_fetches(op_fetches) 1456 all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches] 1457 1458 # Filter out the operations that won't be executed. 1459 # if fetches=None, then ops_in_exec_path = set(operations) 1460 exec_op_set = self._filter_execution_path_operations(graph.get_operations(), 1461 all_fetches) 1462 # Write report file, and determine the traced tensors. 1463 tensor_trace_order = self._determine_trace_and_create_report( 1464 graph, exec_op_set) 1465 1466 tensor_fetch_set = set(processed_t_fetches) 1467 tracing_ops = [] 1468 1469 # pylint: disable=protected-access 1470 current_control_flow_context = graph._get_control_flow_context() 1471 # pylint: enable=protected-access 1472 1473 sorted_exec_op_list = list(exec_op_set) 1474 sorted_exec_op_list.sort(key=lambda op: op.name) 1475 # Trace ops only if they are in the execution path. 1476 for op in sorted_exec_op_list: 1477 for i in range(len(op.outputs)): 1478 out_tensor = op.outputs[i] 1479 tensor_name = out_tensor.name 1480 if tensor_name not in tensor_trace_order.tensorname_to_cache_idx: 1481 continue 1482 self._traced_op_names.add(op.name) 1483 # Create the list of consumers before calling _preprocess_traced_tensor. 1484 # Otherwise, adding control input below, will introduce a cycle in the 1485 # graph. 1486 consumers = out_tensor.consumers() 1487 # Not all consumers may be in the exec path. Filter out the consumers 1488 # to keep the graph simpler. 1489 consumers = [cop for cop in consumers if cop in exec_op_set] 1490 1491 # If there is no consumer of the tensor, there is no need to trace it; 1492 # unless the tensor itself is one of the fetches. 1493 is_a_fetched_tensor = out_tensor in tensor_fetch_set 1494 if (not consumers) and (not is_a_fetched_tensor): 1495 continue 1496 1497 op_control_flow_context = self._get_op_control_flow_context(op) 1498 if op_control_flow_context: 1499 # pylint: disable=protected-access 1500 graph._set_control_flow_context(op_control_flow_context) 1501 # pylint: enable=protected-access 1502 1503 processed_tensors = self._preprocess_traced_tensor(out_tensor) 1504 1505 if on_tpu: 1506 for signature in processed_tensors.keys(): 1507 processed_tensors[signature] = _cast_unsupported_dtypes( 1508 processed_tensors[signature]) 1509 1510 if self._use_tensor_values_cache(): 1511 # Use a small cache to store the characteristics of the tensor. 1512 cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name] 1513 trace_op = self._save_tensor_value_to_cache_op(cache_idx, 1514 processed_tensors) 1515 elif self._use_tensor_buffer(): 1516 if len(processed_tensors) != 1: 1517 raise RuntimeError('Multiple stats are only allowed in compact ' 1518 'mode.') 1519 processed_out_tensor = processed_tensors.values()[0] 1520 # Store the whole tensor in a buffer. 1521 trace_op = self._snapshot_tensor(processed_out_tensor) 1522 else: 1523 1524 def tpu_wrap_trace_fn(tensor, out_tensor_name): 1525 """Wraps the trace_fn with outside compilation if on TPUs.""" 1526 tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name, 1527 tensor_trace_order) 1528 if on_tpu: 1529 return tpu.outside_compilation(tensor_trace_fn, tensor) 1530 else: 1531 return tensor_trace_fn(tensor) 1532 1533 def conditional_trace_fn(predicate_tensor, out_tensor, trace_fn, 1534 out_tensor_name): 1535 """Creates a cond op that traces the out_tensor if predicate is satisfied.""" 1536 return control_flow_ops.cond( 1537 predicate_tensor, lambda: trace_fn(out_tensor, out_tensor_name), 1538 lambda: constant_op.constant(False)).op 1539 1540 if len(processed_tensors) != 1: 1541 raise RuntimeError('Multiple stats are only allowed in compact ' 1542 'mode.') 1543 # Collecting multiple statistics are only supported in the summary 1544 # mode that uses compact format(self._use_tensor_values_cache = true). 1545 # Non-compact mode currently allows single stat per tensor. 1546 processed_out_tensor = six.next(six.itervalues(processed_tensors)) 1547 1548 if self._parameters.is_conditional_trace: 1549 trace_op = conditional_trace_fn(processed_out_tensor, out_tensor, 1550 tpu_wrap_trace_fn, tensor_name) 1551 elif self._parameters.included_cores: 1552 should_print = constant_op.constant(False) 1553 for core in self._parameters.included_cores: 1554 should_print = gen_math_ops.logical_or( 1555 should_print, gen_math_ops.equal(self._replica_id, core)) 1556 trace_op = conditional_trace_fn(should_print, processed_out_tensor, 1557 tpu_wrap_trace_fn, tensor_name) 1558 1559 else: 1560 trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name) 1561 1562 if op_control_flow_context: 1563 # pylint: disable=protected-access 1564 graph._set_control_flow_context(current_control_flow_context) 1565 # pylint: enable=protected-access 1566 1567 if is_a_fetched_tensor: 1568 tracing_ops.append(trace_op) 1569 continue 1570 # Add it to all consumers, as some consumers may not be executed if they 1571 # are in a control flow. 1572 for consumer_op in consumers: 1573 # pylint: disable=protected-access 1574 consumer_op._add_control_input(trace_op) 1575 # pylint: enable=protected-access 1576 1577 # pylint: disable=protected-access 1578 graph._set_control_flow_context(current_control_flow_context) 1579 # pylint: enable=protected-access 1580 if tracing_ops: 1581 # If we are tracing a fetched tensor, their dependency is stored in 1582 # tracing_ops. 1583 processed_t_fetches = control_flow_ops.tuple(processed_t_fetches, 1584 control_inputs=tracing_ops) 1585 if self._use_tensor_values_cache() or self._use_tensor_buffer(): 1586 if self._create_host_call(): 1587 self._prepare_host_call_fn(processed_t_fetches, op_fetches) 1588 if not on_tpu: 1589 write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY] 1590 cache_write_op = write_cache(**caches_to_write) 1591 processed_t_fetches = control_flow_ops.tuple( 1592 processed_t_fetches, control_inputs=[cache_write_op]) 1593 del self._host_call_fn[_TT_HOSTCALL_KEY] 1594 else: 1595 processed_t_fetches = self._flush_tensor_values_cache( 1596 processed_t_fetches, op_fetches, on_tpu=on_tpu) 1597 1598 # processed_t_fetches is a list at this point. Convert it to the same 1599 # format as given in tensor_fetches. 1600 return self._convert_fetches_to_input_format(tensor_fetches, 1601 processed_t_fetches) 1602 1603 def trace_tpu(self, graph, 1604 tensor_fetches, 1605 op_fetches=None, 1606 num_replicas=None, 1607 num_replicas_per_host=None, 1608 num_hosts=None): 1609 """Traces the tensors generated by TPU Ops in a TF graph. 1610 1611 Args: 1612 graph: the graph of Ops executed on the TPU. 1613 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 1614 returned by model_fn given to session.run. Function must be provided 1615 with as least one tensor to fetch. 1616 op_fetches: A list of op fetches returned by model_fn given to 1617 session.run. op_fetches and tensor_fetches are used to determine the 1618 nodes that will be executed. Can be None. 1619 num_replicas: number of replicas used on the TPU. 1620 num_replicas_per_host: number of replicas per TPU host. 1621 num_hosts: total number of TPU hosts. 1622 1623 Returns: 1624 tensor_fetches: an exact copy of tensor_fetches that has additional 1625 dependencies. 1626 Raises: 1627 RuntimeError: If num_replicas_per_host > 8. 1628 RuntimeError: If tensor_fetches is None or empty. 1629 """ 1630 if graph in TensorTracer._traced_graphs: 1631 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 1632 'multiple calls.') 1633 return tensor_fetches 1634 else: 1635 TensorTracer._traced_graphs.add(graph) 1636 1637 self._tt_config.device_type = _DEVICE_TYPE_TPU 1638 self._tt_config.num_replicas = num_replicas 1639 self._tt_config.num_replicas_per_host = num_replicas_per_host 1640 self._tt_config.num_hosts = num_hosts 1641 if self._tt_config.num_replicas is not None: 1642 if self._tt_config.num_replicas_per_host is None: 1643 self._tt_config.num_replicas_per_host = 8 1644 if self._tt_config.num_hosts is None: 1645 self._tt_config.num_hosts = ( 1646 num_replicas // self._tt_config.num_replicas_per_host + 1647 (num_replicas % self._tt_config.num_replicas_per_host > 0)) 1648 1649 if self._parameters.graph_dump_path: 1650 graph_io.write_graph(graph, self._parameters.graph_dump_path, 1651 'graph_before_tt.pbtxt') 1652 with graph.as_default(): 1653 self._add_replica_id_to_graph() 1654 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 1655 on_tpu=True) 1656 if self._parameters.graph_dump_path: 1657 graph_io.write_graph(graph, self._parameters.graph_dump_path, 1658 'graph_after_tt.pbtxt') 1659 return tensor_fetches 1660 1661 def trace_cpu(self, graph, tensor_fetches, op_fetches=None): 1662 """Traces the tensors generated by CPU Ops in a TF graph. 1663 1664 Args: 1665 graph: the graph of Ops executed on the CPU. 1666 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 1667 returned by model_fn given to session.run. Function must be provided 1668 with as least one tensor to fetch. 1669 op_fetches: A list of op fetches returned by model_fn given to 1670 session.run. op_fetches and tensor_fetches are used to determine the 1671 nodes that will be executed. Can be None. 1672 1673 Returns: 1674 tensor_fetches: an exact copy of tensor_fetches that has additional 1675 dependencies. 1676 Raises: 1677 RuntimeError: If tensor_fetches is None or empty. 1678 """ 1679 1680 if graph in TensorTracer._traced_graphs: 1681 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 1682 'multiple calls.') 1683 return tensor_fetches 1684 else: 1685 TensorTracer._traced_graphs.add(graph) 1686 1687 self._tt_config.device_type = _DEVICE_TYPE_CPU 1688 self._tt_config.num_replicas = 1 1689 self._tt_config.num_replicas_per_host = 1 1690 self._tt_config.num_hosts = 1 1691 self._replica_id = 0 1692 if self._parameters.graph_dump_path: 1693 graph_io.write_graph(graph, self._parameters.graph_dump_path, 1694 'graph_before_tt.pbtxt') 1695 with graph.as_default(): 1696 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 1697 on_tpu=False) 1698 if self._parameters.graph_dump_path: 1699 graph_io.write_graph(graph, self._parameters.graph_dump_path, 1700 'graph_after_tt.pbtxt') 1701 return tensor_fetches 1702