• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""A utility to trace tensor values on TPU."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import operator
22
23import os
24import os.path
25import sys
26
27import numpy as np
28import six
29
30from tensorflow.core.framework import summary_pb2
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import graph_io
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.lib.io import file_io
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import control_flow_util
40from tensorflow.python.ops import gen_math_ops
41from tensorflow.python.ops import init_ops
42from tensorflow.python.ops import linalg_ops
43from tensorflow.python.ops import logging_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import nn_impl
46from tensorflow.python.ops import state_ops
47from tensorflow.python.ops import summary_ops_v2 as summary
48from tensorflow.python.ops import variable_scope
49from tensorflow.python.platform import analytics
50from tensorflow.python.platform import gfile
51from tensorflow.python.platform import tf_logging as logging
52from tensorflow.python.summary import summary_iterator
53from tensorflow.python.tpu import tensor_tracer_flags
54from tensorflow.python.tpu import tensor_tracer_report
55from tensorflow.python.tpu import tpu
56from tensorflow.python.tpu.ops import tpu_ops
57from tensorflow.python.training import training_util
58
59_DEVICE_TYPE_TPU = 'tpu'
60_DEVICE_TYPE_CPU = 'cpu'
61_TRACE_MODE_PART_TENSOR_SIZE = 3
62
63_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
64_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
65_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
66_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op'
67_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar'
68_REASON_SKIP_SCALAR = 'not-traced-scalar'
69_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op'
70_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch'
71_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
72_REASON_SCALAR_GET_TRACED = 'traced-scalar'
73_REASON_TENSOR_GET_TRACED = 'traced-tensor'
74_REASON_USER_INCLUDED = 'traced-user-included'
75_REASON_USER_EXCLUDED = 'not-traced-user-excluded'
76_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
77_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
78_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
79
80_OUTPUT_STREAM_ESCAPE = 'file://'
81_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
82_TRACE_FILE_NAME = 'trace.all'
83_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
84_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
85_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage'
86_TT_SNAPSHOT = 'tensor_tracer_snapshot'
87_REPLICA_ID_TAG = '#replica-id: '
88
89_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM
90_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX
91_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN
92_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN
93_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR
94_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE
95
96_TT_SUMMARY_TAG = 'tensor_tracer_summary'
97_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer'
98_TT_HOSTCALL_KEY = 'tensor_tracer_host_call'
99_TT_EVENT_FILE_SUFFIX = '.tensor_tracer'
100
101_TT_SUMMARY_MAX_QUEUE = 100
102
103
104def op_priority(op_type):
105  """Returns the priority of the op.
106
107  If the priority of the op is k, it will be traced if trace_level>=k.
108  Args:
109    op_type: String name of the operation type.
110  Returns:
111    Integer value corresponding the priority of the op.
112  """
113  if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range',
114                 'VariableShape', 'Fill', 'OneHot', 'ShapeN'):
115    # Lowest priority ops, e.g., constant ops accross different steps,
116    # They will be traced only if trace_level>=7
117    return 7
118
119  if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient',
120                 'PreventGradient', 'Squeeze'):
121    # Operations without numerical effects.
122    # They will be only if trace_level>=6
123    return 6
124  if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile',
125                 'CollectivePermute', 'SplitV'):
126    # Operations that merge or slice an input, will be traced if trace_level>=5
127    return 5
128  if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'):
129    # Operations less likely to provide useful information,
130    # will be traced if trace_level>=4
131    return 4
132  if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'):
133    # Add operations that are less likely create any issues, will be traced
134    # if trace_level>=3 (default=3)
135    return 3
136  if op_type in ('Neg', 'Sub'):
137    # Sub operations that are less likely create any issues, will be traced
138    # trace_level>=2
139    return 2
140  if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select',
141                 'Maximum', 'Mean', 'Variance'):
142    # Multiplication and some other operations, will be traced if trace_level>=1
143    return 1
144  return 0
145
146
147def read_tensor_tracer_event_file(event_file):
148  """Reads the event file written by tensor tracer.
149
150  Args:
151    event_file: Path to the event file that contains only tensor tracer events.
152  Returns:
153    An event dictionary in the form of
154    {step_number: {tensor_name: tensor_content}}
155  Raises:
156    ValueError: If an unexpected trace is found.
157  """
158  event_dict = {}
159  for trace_event in summary_iterator.summary_iterator(event_file):
160    # First event is an event with file_version: "brain.Event:2"
161    if not trace_event.HasField('summary'):
162      continue
163    step = trace_event.step
164    if step not in event_dict:
165      event_dict[step] = {}
166
167    if len(trace_event.summary.value) != 1:
168      raise ValueError('Single step contains %d summary values,'
169                       ' expected 1.' % len(trace_event.summary.value))
170    tensor_value = trace_event.summary.value[0]
171    tensor_name = tensor_value.tag
172
173    real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim]
174    tensor_content = np.frombuffer(
175        tensor_value.tensor.tensor_content,
176        dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype()
177        ).reshape(real_shape)
178    event_dict[step][tensor_name] = tensor_content
179  return event_dict
180
181
182def tensor_tracepoint(tensor, checkpoint_name):
183  """Adds a checkpoint with the given checkpoint name for the given tensor.
184
185  The tensor will be added to the list of tensors that will be traced by the
186  tensor tracer.
187
188  Args:
189     tensor: the tensor object for which the tracing is requested.
190     checkpoint_name: a string name for the checkpoint. This name has to be a
191     unique name if used within model comparison. The tensors that have the same
192     checkpoint identifier is compared in model comparison.
193  Returns:
194    The provided tensor.
195  """
196
197  tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION)
198  tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION,
199                                 (tensor, checkpoint_name))
200  return tensor
201
202
203def keras_layer_tracepoint(layer, checkpoint_name):
204  """An interface for adding the tensor outputs of a keras layer.
205
206  Encapsulates tensor_tracepoint.
207
208  Args:
209     layer: A keras layer.
210     checkpoint_name: a string name for the checkpoint. This name has to be a
211     unique name if used within model comparison. The tensors that have the same
212     checkpoint identifier is compared in model comparison.
213
214  Returns:
215    The provided layer.
216  """
217  try:
218    outputs = layer.output
219    if tensor_util.is_tensor(outputs):
220      tensor_tracepoint(outputs, '%s' % (checkpoint_name))
221    else:
222      idx = 0
223      for output_tensor in outputs:
224        if tensor_util.is_tensor(outputs):
225          tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx))
226        idx += 1
227  except AttributeError:
228    pass
229  except RuntimeError:
230    pass
231  return layer
232
233
234def _trace_files_need_precreated(output_dir):
235  """Return True if trace files must be pre-created by users."""
236
237  if not output_dir.startswith('/'):
238    return False
239  if len(output_dir) < 5:
240    return False
241  if output_dir[2] != 'n':
242    return False
243  if output_dir[3] != 's':
244    return False
245  if output_dir[1] != 'c':
246    return False
247  if output_dir[4] != '/':
248    return False
249  return True
250
251
252class TensorTracer(object):
253  """A software construct for tracing tensor values in a TF graph on TPU.
254
255  This utility is disabled by default. It can be enabled by setting
256  the TENSOR_TRACER_FLAGS env variable as:
257    export TENSOR_TRACER_FLAGS="--enable=1"
258  If it is enabled, it will trace the output tensor values of
259  selected Ops in the graph. It has two outputs: (1) the traces and (2)
260  a report. The traces are dumped to a specified local file on the TPU
261  host. The report is printed to the log.info of the TPU job.
262  By passing options via the env variable, users can change:
263     (1) the trace mode (e.g., detecting NaN/Inf, printing partial or
264         full tensor values)
265     (2) which Ops to be traced (via op.name or op.type)
266     (3) output trace file path.
267  """
268  # The set of graphs that are rewritten by tensor tracer.
269  _traced_graphs = set()
270
271  @staticmethod
272  def is_enabled():
273    """Returns True if TensorTracer is enabled."""
274    return tensor_tracer_flags.TTParameters().is_enabled()
275
276  @staticmethod
277  def check_device_type(device_type):
278    """Checks if the given device type is valid."""
279
280    if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU):
281      raise ValueError('Invalid device_type "%s"'%device_type)
282
283  @staticmethod
284  def check_trace_mode(device_type, trace_mode):
285    """Checks if the given trace mode work on the given device type.
286
287    Args:
288      device_type: Device type, TPU, GPU, CPU.
289      trace_mode: Tensor tracer trace mode.
290    Raises:
291      ValueError: If the given trace mode is not supported for the device.
292    """
293    if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY:
294      if device_type != _DEVICE_TYPE_TPU:
295        raise ValueError('Device_type "%s" is not yet supported for '
296                         'trace mode "%s"' % (device_type, trace_mode))
297
298  @staticmethod
299  def loop_cond_op(op):
300    return op.type in ('LoopCond', 'RefLoopCond')
301
302  @staticmethod
303  def while_loop_op(op):
304    """Returns true if op is one of the special ops of in a while loop.
305
306    Args:
307       op: A tf.Operation.
308
309    Returns:
310       True if the given op is one of [Switch, Merge, Enter, Exit,
311       NextIteration, LoopCond], which are all building blocks for TF while
312       loops.
313    """
314    return  (control_flow_util.IsLoopSwitch(op) or
315             control_flow_util.IsLoopMerge(op) or
316             control_flow_util.IsLoopEnter(op) or
317             control_flow_util.IsLoopExit(op) or
318             TensorTracer.loop_cond_op(op) or
319             op.type in ('RefNextIteration', 'NextIteration'))
320
321  @staticmethod
322  def control_flow_op(op):
323    """Returns true if op is one of the special ops of in a while loop.
324
325    Args:
326       op: A tf.Operation.
327
328    Returns:
329       True if the given op is one of [Switch, Merge, Enter, Exit,
330       NextIteration, LoopCond], which are all building blocks for TF while
331       loops.
332    """
333    return  (control_flow_util.IsSwitch(op) or
334             control_flow_util.IsMerge(op))
335
336  @staticmethod
337  def unsafe_op(op):
338    """Returns True if this op is not safe to be traced."""
339
340    if control_flow_util.IsInCond(op):
341      return True
342    # Reasons for not including following op types:
343    #    Assign: cause incorrect result with CPU tracing.
344    if op.type == 'Assign':
345      return True
346    return False
347
348  @staticmethod
349  def device_mismatch(device_type, op):
350    if device_type == _DEVICE_TYPE_TPU:
351      # pylint: disable=protected-access
352      return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr
353      # pylint: enable=protected-access
354    return False
355
356  @staticmethod
357  def unsafe_scalar_trace(op):
358    """Return true if scalar output tensor from Op is not safe to be traced."""
359
360    # Tracing the following causes cycle in the graph on TPU.
361    if op.type in ('LoopCond', 'Enter', 'Merge', 'Const',
362                   'Switch', 'Less', 'ReadVariableOp'):
363      return True
364    # Tracing the following will cause casting-issue
365    # with the norm tracing mode or other compilation issues on CPU.
366    if op.type in ('VarHandleOp', 'IteratorToStringHandle',
367                   'IteratorGetNext', 'OneShotIterator',
368                   'IteratorV2', 'MakeIterator',
369                   'BatchDatasetV2', 'MapDataset',
370                   'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
371                   'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'):
372      return True
373    return False
374
375  def _is_interesting_op(self, op):
376    """Returns True if the given op is not an interesting one to be traced."""
377    # If flag is set to include less interesting ops, then include everything.
378    if self._parameters.include_less_interesting_ops:
379      return True
380    return op_priority(op.type) <= self._parameters.trace_level
381
382  @staticmethod
383  def reason(op_idx, details):
384    """Returns reason why the Op at op_idx is traced or not."""
385
386    return '%d %s'%(op_idx, details)
387
388  def __init__(self):
389    """Initializes a TensorTracer.
390
391    Sets the various member fields from the flags (if given) or the defaults.
392    """
393    self._replica_id = None
394    self._tt_config = tensor_tracer_report.TensorTracerConfig()
395    self._parameters = tensor_tracer_flags.TTParameters()
396    self._included_op_full_names = set()
397    self._host_call_fn = {}
398    self._cache_variables = {}
399    self._traced_op_names = set()
400
401  def _get_all_cache_variables(self):
402    return self._cache_variables
403
404  def _create_or_get_tensor_values_cache(self, cache_name, graph=None,
405                                         shape=None, dtype=dtypes.float32):
406    """Creates a variable as the cache to store intermediate tensor values.
407
408    Args:
409      cache_name: Name to be given to the cache (an instance of tf.variable).
410      graph: Tensorflow graph.
411      shape: A list of dimensions.
412      dtype: Data type of created cache.
413    Returns:
414      A ref to newly created or existing cache with the given dimensions.
415    Raises:
416      ValueError: If missing a parameter to create the cache.
417    """
418    def _escape_namescopes(variable_name):
419      # TODO(deveci): This might cause name collisions as in "foo/bar/mytensor"
420      # and "foo_bar/mytensor".
421      return variable_name.replace('/', '_').replace(':', '_')
422
423    if cache_name not in self._cache_variables:
424      if graph is None:
425        raise ValueError('Graph must be provided at cache creation.')
426      if shape is None:
427        raise ValueError('shape must be provided at cache creation.')
428      graph = graph or ops.get_default_graph()
429      if dtype.is_integer:
430        init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE)
431      else:
432        init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE
433
434      # Create in proper graph and base name_scope.
435      with graph.as_default() as g, g.name_scope(None):
436        self._cache_variables[cache_name] = variable_scope.get_variable(
437            _TT_SNAPSHOT + '_' + _escape_namescopes(cache_name),
438            shape=shape, dtype=dtype,
439            initializer=init_ops.constant_initializer(init_val),
440            trainable=False,
441            use_resource=True,
442            collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
443    return self._cache_variables[cache_name]
444
445  def _add_replica_id_to_graph(self):
446    """Adds nodes for computing the replica ID to the graph."""
447
448    if self._tt_config.num_replicas:
449      with ops.control_dependencies(None):
450        # Uses None as dependency to run outside of TPU graph rewrites.
451        self._replica_id = tpu_ops.tpu_replicated_input(
452            list(range(self._tt_config.num_replicas)),
453            name='tt_replica_id')
454    else:
455      self._replica_id = 'unknown'
456
457  def _inside_op_range(self, idx):
458    """Return True if the given index is inside the selected range."""
459
460    if idx < self._parameters.op_range[0]:
461      return False
462    return (self._parameters.op_range[1] < 0 or
463            idx <= self._parameters.op_range[1])
464
465  def _is_user_included_op(self, op):
466    """Checks whether the op is included in the tensor tracer flags.
467
468    Args:
469      op: tf Operation
470    Returns:
471      True, if the op is included.
472      An op is included if:
473      - Its op name is given in included_opnames
474      - Its op type is given in included_optypes
475      - The op is at most _trace_ops_before_included hops before an included op
476      - The op is at most _trace_ops_after_included hops after an included op
477    """
478
479    def _is_op_or_any_neighbor_included(op, check_before=0, check_after=0):
480      """Helper function to check if op is included or not."""
481      if op.name in self._included_op_full_names:
482        return True
483      for opname_re in self._parameters.included_opname_re_list:
484        if opname_re.match(op.name):
485          self._included_op_full_names.add(op.name)
486          return True
487
488      for optype_re in self._parameters.included_optype_re_list:
489        if optype_re.match(op.type):
490          self._included_op_full_names.add(op.name)
491          return True
492
493      if check_after > 0:
494        for out_tensor in op.outputs:
495          for consumer in out_tensor.consumers():
496            if _is_op_or_any_neighbor_included(consumer, check_after - 1, 0):
497              self._included_op_full_names.add(op.name)
498              return True
499      if check_before > 0:
500        for input_tensor in op.inputs:
501          if _is_op_or_any_neighbor_included(input_tensor.op,
502                                             0,
503                                             check_before - 1):
504            self._included_op_full_names.add(op.name)
505            return True
506      return False
507    # check_after and check_before are swapped below, as below operation
508    # checks the distance from an arbitrary op to included ops.
509    return _is_op_or_any_neighbor_included(
510        op, self._parameters.trace_ops_after_included,
511        self._parameters.trace_ops_before_included)
512
513  def _is_user_excluded_op(self, op):
514    for opname_re in self._parameters.excluded_opname_re_list:
515      if opname_re.match(op.name):
516        return True
517    for optype_re in self._parameters.excluded_optype_re_list:
518      if optype_re.match(op.type):
519        return True
520    return False
521
522  def _signature_types(self):
523    """Returns a dictionary holding the order of signatures in the cache for the selected trace mode."""
524    if self._parameters.trace_mode in set([
525        tensor_tracer_flags.TRACE_MODE_NAN_INF,
526        tensor_tracer_flags.TRACE_MODE_NORM,
527        tensor_tracer_flags.TRACE_MODE_MAX_ABS]):
528      return {self._parameters.trace_mode: 0}
529    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
530      return self._parameters.summary_signatures
531    return {}
532
533  def _num_signature_dimensions(self):
534    return len(self._signature_types())
535
536  def _use_tensor_values_cache(self):
537    """Returns True if immediate tensors should be first saved to a cache."""
538    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
539      # For summary tace mode only compact format is supported.
540      return True
541
542    if self._parameters.trace_mode not in set([
543        tensor_tracer_flags.TRACE_MODE_NAN_INF,
544        tensor_tracer_flags.TRACE_MODE_NORM,
545        tensor_tracer_flags.TRACE_MODE_MAX_ABS,
546        tensor_tracer_flags.TRACE_MODE_SUMMARY
547    ]):
548      return False
549    if (self._parameters.trace_dir and
550        _trace_files_need_precreated(self._parameters.trace_dir)):
551      return True
552    return self._parameters.use_compact_trace
553
554  def _use_tensor_buffer(self):
555    """Returns true if the whole tensor needs to be cached/buffered in memory."""
556    return (self._parameters.trace_mode ==
557            tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
558
559  def _save_tensor_value_to_cache_op(self, cache_idx, updates):
560    """Returns an op that will save the given updates to an entry in the cache.
561
562    Args:
563      cache_idx: The cache index of the tensor within the cache.
564      updates: A dictionary of the signature updates.
565    Returns:
566      Cache update operation.
567    """
568    # state_ops.scatter_update allows updates only along the first dimension.
569    # Make a compact array by concantating different signatures, and update
570    # them all together.
571    sorted_update = []
572    if self._num_signature_dimensions() > 1:
573      signature_indices = self._signature_types()
574      for _, val in sorted(updates.items(),
575                           key=lambda item: signature_indices[item[0]]):
576        sorted_update.append(val)
577      updates = array_ops.stack(sorted_update, axis=0)
578      updates = array_ops.reshape(updates, [1,
579                                            self._num_signature_dimensions()])
580    else:
581      (_, val), = updates.items()
582      updates = array_ops.reshape(val, [1, self._num_signature_dimensions()])
583    indices = constant_op.constant([cache_idx])
584    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG)
585    return state_ops.scatter_update(cache, indices, updates).op
586
587  def _snapshot_tensor(self, tensor):
588    """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable.
589
590    Args:
591      tensor: tensor whose values will be stored in a new tf.Variable.
592    Returns:
593      An assignment operation.
594    """
595
596    snapshot_variable = self._create_or_get_tensor_values_cache(
597        tensor.name, tensor.op.graph,
598        tensor.shape.as_list(), tensor.dtype)
599    return state_ops.assign(snapshot_variable, tensor).op
600
601  def _preprocess_traced_tensor(self, tensor):
602    """Computes NAN/Norm/Max on TPUs before sending to CPU.
603
604    Args:
605      tensor: The tensor to be traced.
606    Returns:
607      A tensor that should be input to the trace_function.
608    Raises:
609      RuntimeError: If the trace mode is invalid.
610    """
611
612    def _detect_nan_inf(tensor):
613      """Trace function for detecting any NaN/Inf in the tensor."""
614
615      if tensor.dtype.is_floating:
616        mask = math_ops.reduce_any(
617            gen_math_ops.logical_or(
618                gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor)))
619        output_tensor = control_flow_ops.cond(
620            mask,
621            lambda: constant_op.constant([1.0]),
622            lambda: constant_op.constant([0.0]))
623      else:
624        output_tensor = constant_op.constant([0.0])
625      return output_tensor
626
627    def _compute_signature(tensor, tf_op, cast_to_f32=True):
628      if cast_to_f32:
629        tensor = math_ops.cast(tensor, dtypes.float32)
630      output_tensor = tf_op(tensor)
631      # Return type should be scalar. Set it if it does not have the
632      # information.
633      if not output_tensor.get_shape().is_fully_defined():
634        output_tensor = array_ops.reshape(output_tensor, [])
635      return output_tensor
636
637    def _show_size(tensor):
638      # In order to check the size of a tensor.
639      # Not all sizes are known at the compile time, also, different replicas
640      # sometimes get different sizes of tensors.
641      # Collect it here to be used in merging replica data.
642      tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False)
643      # Cast to float32, so that it can be placed into same cache with other
644      # signatures.
645      return math_ops.cast(tsize, dtypes.float32)
646
647    def _show_max(tensor, cast_to_f32=True):
648      # returns -inf for empty tensor
649      return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32)
650
651    def _show_min(tensor, cast_to_f32=True):
652      # returns inf for empty tensor
653      return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32)
654
655    def _show_norm(tensor, cast_to_f32=True):
656      # returns 0 for empty tensor
657      return _compute_signature(tensor, linalg_ops.norm, cast_to_f32)
658
659    def _show_mean_and_variance(tensor, cast_to_f32=True):
660      """Returns the mean and variance of the given tensor."""
661      if cast_to_f32:
662        tensor = math_ops.cast(tensor, dtypes.float32)
663      # returns nan for empty tensor
664      mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0])
665      # The shape has to be 1. Set it if it does not have the information.
666      if not mean.get_shape().is_fully_defined():
667        mean = array_ops.reshape(mean, [])
668      if not var.get_shape().is_fully_defined():
669        var = array_ops.reshape(var, [])
670      return mean, var
671
672    def _show_max_abs(tensor):
673      tensor = math_ops.cast(tensor, dtypes.float32)
674      output_tensor = math_ops.reduce_max(math_ops.abs(tensor))
675      zero = constant_op.constant(0, dtypes.float32)
676      output_tensor = gen_math_ops.maximum(zero, output_tensor)
677      # The shape has to be 1. Set it if it does not have the information.
678      output_tensor = array_ops.reshape(output_tensor, [1])
679      return output_tensor
680
681    def _detect_inf_nan_producer(tensor):
682      """Checks if the tensor is the first NaN/Inf tensor in the computation path."""
683      if tensor.op.inputs:
684        inp_check = [
685            _detect_nan_inf(inp_tensor) for inp_tensor in tensor.op.inputs
686        ]
687        is_any_input_inf_nan = math_ops.add_n(inp_check)
688      else:
689        is_any_input_inf_nan = constant_op.constant(0, dtypes.bool)
690      is_current_tensor_inf_nan = _detect_nan_inf(tensor)
691      # An op is NaN/INF producer only when all inputs are nan/inf free (
692      # is_any_input_inf_nan = 0), and its output has nan/inf (
693      # is_current_tensor_inf_nan=1). Below will be 1 if op nan/inf is producer.
694      is_nan_producer = is_current_tensor_inf_nan - is_any_input_inf_nan
695      is_nan_producer = math_ops.reduce_any(is_nan_producer > 0)
696      return is_nan_producer
697
698    if (self._parameters.trace_mode ==
699        tensor_tracer_flags.TRACE_MODE_FULL_IF_NAN):
700      return {self._parameters.trace_mode: _detect_inf_nan_producer(tensor)}
701    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
702      return {self._parameters.trace_mode: _detect_nan_inf(tensor)}
703    if (self._parameters.trace_mode ==
704        tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
705      return {self._parameters.trace_mode: tensor}
706    if (self._parameters.trace_mode in (
707        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
708        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)):
709      return {self._parameters.trace_mode: tensor}
710    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM:
711      return {self._parameters.trace_mode: array_ops.reshape(
712          _show_norm(tensor), [1])}
713    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS:
714      return {self._parameters.trace_mode: _show_max_abs(tensor)}
715
716    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
717      tensor = math_ops.cast(tensor, dtypes.float32)
718      result_dict = {}
719      # Call mean and variance computation here to avoid adding the same nodes
720      # twice.
721      if (_TT_SUMMARY_MEAN in self._signature_types() or
722          _TT_SUMMARY_VAR in self._signature_types()):
723        mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False)
724
725      for signature_name, _ in sorted(self._signature_types().items(),
726                                      key=lambda x: x[1]):
727        if signature_name == _TT_SUMMARY_NORM:
728          signature_result_tensor = _show_norm(tensor, cast_to_f32=False)
729        elif signature_name == _TT_SUMMARY_MAX:
730          signature_result_tensor = _show_max(tensor, cast_to_f32=False)
731        elif signature_name == _TT_SUMMARY_MIN:
732          signature_result_tensor = _show_min(tensor, cast_to_f32=False)
733        elif signature_name == _TT_SUMMARY_SIZE:
734          signature_result_tensor = _show_size(tensor)
735        elif signature_name == _TT_SUMMARY_MEAN:
736          signature_result_tensor = mean
737        elif signature_name == _TT_SUMMARY_VAR:
738          signature_result_tensor = variance
739        else:
740          raise ValueError('Unknown signature type :%s.' % signature_name)
741
742        result_dict[signature_name] = signature_result_tensor
743      return result_dict
744
745    raise RuntimeError(
746        'Tensor trace fun for %s is not yet implemented'
747        % self._parameters.trace_mode)
748
749  def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order):
750    """Makes the tensor tracing function called by outside compilation.
751
752    Args:
753      tensor_name: name of the tensor being traced.
754      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
755    Returns:
756      A function to be passed as the first argument to outside compilation.
757
758    Raises:
759      RuntimeError: If the trace mode is invalid.
760    """
761
762    def _print_tensor(tensor_name, num_elements, tensor, output_tensor):
763      """Prints a tensor value to a file.
764
765      Args:
766        tensor_name: name of the tensor being traced.
767        num_elements: number of elements to print (-1 means print all).
768        tensor: the tensor needs to be returned.
769        output_tensor: the tensor needs to be printed.
770
771      Returns:
772        The same tensor passed via the "tensor" argument.
773
774      Raises:
775        ValueError: If tensor_name is not already in
776                    self._tensorname_idx_map.
777      """
778
779      if self._parameters.is_brief_mode():
780        if tensor_name not in tensor_trace_order.tensorname_idx_map:
781          raise ValueError(
782              'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
783        msg = '%d'%self._tensorname_idx_map[tensor_name]
784      else:
785        msg = '"%s"'%tensor_name
786
787      if self._parameters.trace_dir:
788        output_path = os.path.join(self._parameters.trace_dir, _TRACE_FILE_NAME)
789        output_stream = _OUTPUT_STREAM_ESCAPE + output_path
790      else:
791        output_stream = sys.stderr
792      return logging_ops.print_v2(msg, array_ops.shape(output_tensor),
793                                  '@', self._replica_id,
794                                  '\n', output_tensor, '\n',
795                                  summarize=num_elements,
796                                  output_stream=output_stream)
797
798    def _show_part_tensor(tensor):
799      """Trace function for printing part of the tensor."""
800
801      return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE,
802                           tensor, tensor)
803
804    def _show_full_tensor(tensor):
805      """Trace function for printing the entire tensor."""
806
807      return _print_tensor(tensor_name, -1, tensor, tensor)
808
809    def _show_full_tensors(tensor):
810      """Prints the full tensor values for the tensors that are _trace_stack_size hops away from a given tensor."""
811
812      def _get_distance_k_tensors(k_before=0):
813        """Returns the tensors that are at most k_before hops away from the tensor."""
814        if k_before < 0:
815          return []
816        visited_tensors = {tensor: 0}
817        visitor_queue = [tensor]
818        head = 0
819        while head < len(visitor_queue):
820          current_tensor = visitor_queue[head]
821          head += 1
822          distance = visited_tensors[current_tensor]
823          if distance == k_before:
824            break
825          for input_tensor in current_tensor.op.inputs:
826            if input_tensor in visited_tensors:
827              continue
828            visitor_queue.append(input_tensor)
829            visited_tensors[input_tensor] = distance + 1
830        return visitor_queue
831
832      tensors_to_print = _get_distance_k_tensors(
833          self._parameters.trace_stack_size)
834      print_ops = [_print_tensor(t.name, -1, t, t) for t in tensors_to_print]
835      with ops.control_dependencies(print_ops):
836        return constant_op.constant(True)
837
838    if (self._parameters.trace_mode ==
839        tensor_tracer_flags.TRACE_MODE_FULL_IF_NAN):
840      return _show_full_tensors
841    if (self._parameters.trace_mode ==
842        tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
843      return _show_part_tensor
844    # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF,
845    # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are
846    # performed within TPUs and only their results are transferred to CPU.
847    # Simply, print the full tensor for these trace modes.
848    if self._parameters.trace_mode in (
849        tensor_tracer_flags.TRACE_MODE_NAN_INF,
850        tensor_tracer_flags.TRACE_MODE_NORM,
851        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
852        tensor_tracer_flags.TRACE_MODE_MAX_ABS,
853        tensor_tracer_flags.TRACE_MODE_SUMMARY
854        ):
855      return _show_full_tensor
856
857    raise RuntimeError('Tensor trace fun for %s is not yet implemented'
858                       %self._parameters.trace_mode)
859
860  def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
861    """Returns True if we should not trace Op.
862
863    Args:
864      op_id: Topological index of the op.
865      op: tf.Operation
866      ops_in_exec_path: Set of operations that are in the execution path.
867      report_handler: An instance of tensor_tracer_report.TTReportHandle.
868    Returns:
869      True if the op should not be traced, false otherwise.
870    """
871    if TensorTracer.while_loop_op(op):
872      report_handler.instrument_op(
873          op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
874      return True
875    if TensorTracer.control_flow_op(op):
876      report_handler.instrument_op(
877          op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP))
878      return True
879    if TensorTracer.unsafe_op(op):
880      report_handler.instrument_op(
881          op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP))
882      return True
883    if TensorTracer.device_mismatch(self._tt_config.device_type, op):
884      report_handler.instrument_op(
885          op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH))
886      return True
887    if op not in ops_in_exec_path:
888      report_handler.instrument_op(
889          op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED))
890      return True
891
892    if self._is_user_included_op(op):
893      report_handler.instrument_op(
894          op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
895      return False
896
897    if not self._inside_op_range(op_id):
898      report_handler.instrument_op(
899          op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE))
900      return True
901    if not self._is_interesting_op(op):
902      report_handler.instrument_op(
903          op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP))
904      return True
905    if self._is_user_excluded_op(op):
906      report_handler.instrument_op(
907          op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
908      return True
909    return False
910
911  def _skip_tensor(self, op_id, out_tensor, report_handler):
912    """Returns True if we should not trace out_tensor.
913
914    Args:
915      op_id: Topological index of the op producing tensor.
916      out_tensor: tf.Tensor
917      report_handler: An instance of tensor_tracer_report.TTReportHandle.
918    Returns:
919      True if the tensor should not be traced, false otherwise.
920    """
921
922    # Skips a tensor if the tensor has a non-numeric type.
923    #   Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
924    #         because it also excludes tensors with dtypes, bool, and
925    #         float32_ref, which we actually want to trace.
926    non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
927                                    dtypes.string])
928    if out_tensor.dtype in non_numeric_tensor_types:
929
930      report_handler.instrument_tensor(
931          out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR))
932      return True
933    # Skip a tensor if it feeds a special while loop op.
934    if [consumer for consumer in out_tensor.consumers() if
935        TensorTracer.while_loop_op(consumer)]:
936      report_handler.instrument_tensor(
937          out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP))
938      return True
939    if self._is_user_included_op(out_tensor.op):
940      report_handler.instrument_tensor(
941          out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
942      return False
943    if self._is_user_excluded_op(out_tensor.op):
944      report_handler.instrument_tensor(
945          out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
946      return True
947    if not out_tensor.get_shape().is_fully_defined():
948      # If trace mode is nan-inf, norm or max, then the tensor will be reduced
949      # to a scalar before the outside compilation call.
950      if self._parameters.trace_mode in (
951          tensor_tracer_flags.TRACE_MODE_NAN_INF,
952          tensor_tracer_flags.TRACE_MODE_NORM,
953          tensor_tracer_flags.TRACE_MODE_MAX_ABS,
954          tensor_tracer_flags.TRACE_MODE_SUMMARY
955          ):
956        report_handler.instrument_tensor(
957            out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
958        return False
959      else:
960        report_handler.instrument_tensor(
961            out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE))
962        return True
963    rank = len(out_tensor.shape)
964    if rank < 1:
965      # scalar
966      if self._parameters.trace_scalar_ops:
967        if TensorTracer.unsafe_scalar_trace(out_tensor.op):
968          report_handler.instrument_tensor(
969              out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR))
970          return True
971        else:
972          report_handler.instrument_tensor(
973              out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED))
974          return False
975      else:
976        report_handler.instrument_tensor(
977            out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR))
978        return True
979    else:
980      # tensor
981      report_handler.instrument_tensor(
982          out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
983      return False
984
985  def _filter_execution_path_operations(self, operations, fetches):
986    """Returns the set of ops in the execution path to compute given fetches."""
987
988    # If no fetch provided, then return all operations.
989    if fetches is None:
990      return set(operations)
991    # Convert to list, if a single element is provided.
992    if not isinstance(fetches, (list, tuple)):
993      fetches = [fetches]
994    # If a tensor is given as fetch, convert it to op.
995    op_fetches = []
996    for fetch in fetches:
997      if isinstance(fetch, ops.Operation):
998        op_fetches.append(fetch)
999      elif isinstance(fetch, ops.Tensor):
1000        op_fetches.append(fetch.op)
1001      else:
1002        raise RuntimeError('Given fetch:%s is neither a tensor nor an op.'
1003                           %fetch)
1004
1005    execution_path_operations = set(op_fetches)
1006    traverse_stack = list(op_fetches)
1007    while True:
1008      if not traverse_stack:
1009        break
1010      head_op = traverse_stack.pop()
1011      input_ops = [tensor_input.op for tensor_input in head_op.inputs]
1012      input_ops.extend(head_op.control_inputs)
1013
1014      for input_op in input_ops:
1015        if input_op not in execution_path_operations:
1016          # Filter out loop condition operations, tracing them causes a cycle.
1017          # Trace only the loop-body.
1018          if TensorTracer.loop_cond_op(input_op):
1019            continue
1020          execution_path_operations.add(input_op)
1021          traverse_stack.append(input_op)
1022    return execution_path_operations
1023
1024  def _determine_and_instrument_traced_tensors(self, graph_order,
1025                                               ops_in_exec_path,
1026                                               tensor_trace_points,
1027                                               report_handler):
1028    """Determines the tensors to trace and instruments the trace details.
1029
1030    Args:
1031      graph_order: graph_order tuple containing graph (tf.graph), operations
1032        (list of operations), op_to_idx (op id mapping), (tensors) list of
1033        tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether
1034        there is a cycle in the graph), topological_order_or_cycle (list of ops
1035        in topological order or list of ops creating a cycle).
1036      ops_in_exec_path: Set of ops in the execution path.
1037      tensor_trace_points: Collection of programatic tensor trace points.
1038      report_handler: An instance of tensor_tracer_report.TTReportHandle.
1039    Returns:
1040      List of tensors to be traced.
1041    """
1042
1043    traced_tensors = []
1044    checkpoint_operations = set([tensor.op
1045                                 for (tensor, _) in tensor_trace_points])
1046    for op_id, op in enumerate(graph_order.operations):
1047      if checkpoint_operations and op not in checkpoint_operations:
1048        continue
1049      if self._skip_op(op_id, op, ops_in_exec_path, report_handler):
1050        continue
1051      for i in range(len(op.outputs)):
1052        out_tensor = op.outputs[i]
1053        if not self._skip_tensor(op_id, out_tensor, report_handler):
1054          traced_tensors.append(out_tensor)
1055    return traced_tensors
1056
1057  def _check_trace_files(self):
1058    """Checks if any requirements for trace files are satisfied."""
1059
1060    if not self._parameters.trace_dir:
1061      # traces will be written to stderr. No need to check trace files.
1062      return
1063    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
1064      # Output files are handled by tf.summary operations, no need to precreate
1065      # them.
1066      return
1067    if _trace_files_need_precreated(self._parameters.trace_dir):
1068      for replica_id in range(0, self._tt_config.num_replicas):
1069        trace_file_path = os.path.join(
1070            self._parameters.trace_dir,
1071            _COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id
1072        if not gfile.Exists(trace_file_path):
1073          raise RuntimeError(
1074              '%s must be pre-created with the '
1075              'appropriate properties.'%trace_file_path)
1076    else:
1077      if not gfile.Exists(self._parameters.trace_dir):
1078        file_io.recursive_create_dir(self._parameters.trace_dir)
1079        if not gfile.Exists(self._parameters.trace_dir):
1080          raise RuntimeError('Failed to create %s'%self._parameters.trace_dir)
1081
1082  def _determine_trace_and_create_report(self, graph, ops_in_exec_path):
1083    """Work needs to be done prior to TPU or CPU tracing.
1084
1085    Args:
1086      graph: tf.graph
1087      ops_in_exec_path: Set of operations in the execution path.
1088    Returns:
1089      An instance of tensor_tracer_report.TensorTraceOrder, containing list of
1090      tensors to be traced with their topological order information.
1091    """
1092
1093    self._check_trace_files()
1094
1095    graph_order = tensor_tracer_report.sort_tensors_and_ops(graph)
1096    tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION)
1097
1098    report_handler = tensor_tracer_report.TTReportHandle()
1099    traced_tensors = self._determine_and_instrument_traced_tensors(
1100        graph_order, ops_in_exec_path, tensor_trace_points, report_handler)
1101
1102    tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
1103                                                               traced_tensors)
1104    num_signatures = self._num_signature_dimensions()
1105    if num_signatures:
1106      self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG,
1107                                              graph,
1108                                              [len(traced_tensors),
1109                                               num_signatures])
1110    if self._parameters.trace_mode in (
1111        tensor_tracer_flags.TRACE_MODE_SUMMARY,
1112        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY):
1113      report_proto = report_handler.create_report_proto(self._tt_config,
1114                                                        self._parameters,
1115                                                        tensor_trace_order,
1116                                                        tensor_trace_points,
1117                                                        self._signature_types())
1118      report_handler.write_report_proto(report_proto, self._parameters)
1119    else:
1120      report_handler.create_report(self._tt_config, self._parameters,
1121                                   tensor_trace_order, tensor_trace_points)
1122    return tensor_trace_order
1123
1124  def _create_host_call(self):
1125    return self._parameters.trace_mode in (
1126        tensor_tracer_flags.TRACE_MODE_SUMMARY,
1127        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
1128
1129  def _generate_flush_cache_op(self, num_replicas, on_tpu):
1130    """Generates an Op that will flush the cache to file.
1131
1132    Args:
1133      num_replicas: total number of replicas.
1134      on_tpu: if the graph is executed on TPU.
1135
1136    Returns:
1137      The Op to flush the cache to file.
1138    """
1139
1140    def _flush_fun(cache, replica_id):
1141      """Flushes the cache to a file corresponding to replica_id."""
1142
1143      def _f(file_index):
1144        """Generates a func that flushes the cache to a file."""
1145        def _print_cache():
1146          """Flushes the cache to a file."""
1147          replica_str = ('%d' % file_index)
1148          if self._parameters.trace_dir:
1149            output_path = (os.path.join(self._parameters.trace_dir,
1150                                        _COMPACT_TRACE_FILE_PREFIX)
1151                           + replica_str)
1152            output_stream = _OUTPUT_STREAM_ESCAPE + output_path
1153          else:
1154            output_stream = sys.stderr
1155
1156          new_step_line = _REPLICA_ID_TAG + replica_str
1157          print_ops = []
1158          for i in range(self._num_signature_dimensions()):
1159            print_ops.append(logging_ops.print_v2(
1160                new_step_line, '\n',
1161                cache[:, i], '\n',
1162                summarize=-1,
1163                output_stream=output_stream))
1164          with ops.control_dependencies(print_ops):
1165            return constant_op.constant(0).op
1166        return _print_cache
1167
1168      def _eq(file_index):
1169        return math_ops.equal(replica_id, file_index)
1170
1171      flush_op_cases = {}
1172      for i in range(num_replicas):
1173        flush_op_cases[_eq(i)] = _f(i)
1174      # Each replica needs to determine where to write their output.
1175      # To do this, we check if replica_id is 0, then 1, ..., and then
1176      # num_replicas - 1 statically; and return the corresponding static file
1177      # name. We cannot simply set the file name in python, as replica_id is
1178      # only known during tf runtime, and we cannot create dynamic filenames.
1179      return control_flow_ops.case(flush_op_cases, exclusive=True)
1180
1181    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG)
1182    if on_tpu:
1183      flush_op = tpu.outside_compilation(_flush_fun,
1184                                         cache.value(), self._replica_id)
1185    else:
1186      flush_op = _flush_fun(cache.value(), self._replica_id)
1187
1188    with ops.control_dependencies([flush_op]):
1189      reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
1190                                         dtype=cache.dtype,
1191                                         shape=cache.shape)
1192      assign_op = state_ops.assign(cache, reset_value).op
1193      with ops.control_dependencies([assign_op]):
1194        return constant_op.constant(0).op
1195
1196  def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu):
1197    """Flushes the intermediate tensor values in the graph to the cache.
1198
1199    Args:
1200      tensor_fetches: list of tensor results returned by the model_fn.
1201      op_fetches: list of ops that are returned by the model_fn, e.g., train_op.
1202      on_tpu: if the graph is executed on TPU.
1203
1204    Returns:
1205      An identical copy of tensor_fetches.
1206    """
1207    # Add a dependency to op and tensor fetches to make sure that all tracing
1208    # ops are executed before flushing trace results.
1209    with ops.control_dependencies(op_fetches +
1210                                  [tensor.op for tensor in tensor_fetches]):
1211      flush_cache_op = self._generate_flush_cache_op(
1212          self._tt_config.num_replicas, on_tpu)
1213      return control_flow_ops.tuple(tensor_fetches,
1214                                    control_inputs=[flush_cache_op])
1215
1216  def _process_tensor_fetches(self, tensor_fetches):
1217    """Check that tensor_fetches is not empty and have valid tensors."""
1218    # If none or empty list.
1219    if tensor_fetches is None:
1220      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1221                         'None.')
1222    if not isinstance(tensor_fetches, (list, tuple)):
1223      tensor_fetches = [tensor_fetches]
1224    elif not tensor_fetches:
1225      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1226                         'empty list.')
1227    fetches = []
1228    for fetch in tensor_fetches:
1229      if isinstance(fetch, ops.Tensor):
1230        fetches.append(fetch)
1231      else:
1232        raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch)
1233    return fetches
1234
1235  def _process_op_fetches(self, op_fetches):
1236    """Check that op_fetches have valid ops."""
1237    if op_fetches is None:
1238      return []
1239
1240    if not isinstance(op_fetches, (list, tuple)):
1241      op_fetches = [op_fetches]
1242
1243    fetches = []
1244    for fetch in op_fetches:
1245      if isinstance(fetch, ops.Operation):
1246        fetches.append(fetch)
1247      elif isinstance(fetch, ops.Tensor):
1248        fetches.append(fetch.op)
1249      else:
1250        logging.warning('Ignoring the given op_fetch:%s, which is not an op.' %
1251                        fetch)
1252    return fetches
1253
1254  def _convert_fetches_to_input_format(self, input_fetches, current_fetches):
1255    """Changes current_fetches' format, so that it matches input_fetches."""
1256    if isinstance(input_fetches, ops.Tensor):
1257      if len(current_fetches) != 1:
1258        raise RuntimeError('Tensor tracer input/output fetches do not match.')
1259      return current_fetches[0]
1260    else:
1261      if len(current_fetches) != len(current_fetches):
1262        raise RuntimeError('Tensor tracer input/output fetches do not match.')
1263      elif isinstance(input_fetches, tuple):
1264        return tuple(current_fetches)
1265      else:
1266        return current_fetches
1267
1268  def _get_op_control_flow_context(self, op):
1269    """Returns the control flow of the given op.
1270
1271    Args:
1272      op: tf.Operation for which the control flow context is requested.
1273    Returns:
1274      op_control_flow_context: which the is control flow context of the given
1275      op. If the operation type is LoopExit, returns the outer control flow
1276      context.
1277    """
1278    # pylint: disable=protected-access
1279    op_control_flow_context = op._control_flow_context
1280    # pylint: enable=protected-access
1281    if control_flow_util.IsLoopExit(op):
1282      op_control_flow_context = op_control_flow_context.outer_context
1283    return op_control_flow_context
1284
1285  def _prepare_host_call_fn(self, processed_t_fetches, op_fetches):
1286    """Creates a host call function that will write the cache as tb summary.
1287
1288    Args:
1289      processed_t_fetches: List of tensor provided to session.run.
1290      op_fetches: List of operations provided to session.run.
1291    Raises:
1292      ValueError if trace_dir is not set.
1293    """
1294    if self._parameters.trace_dir is None:
1295      raise ValueError('Provide a trace_dir for tensor tracer in summary mode. '
1296                       '--trace_dir=/model/dir')
1297
1298    def _write_cache(step, **kwargs):
1299      """Writes the given caches as tensor summary.
1300
1301      Args:
1302        step: Step tensor with dimension [num_cores].
1303        **kwargs: The dictionary of tensors that needs to be written as
1304          summaries. Key and value pairs within kwargs correspond to the tag
1305          name, and tensor content that will be written using summary.write.
1306          The trace_modes that use this function are:
1307            - summary: In summary mode, kwargs includes a single (tag, content)
1308            pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache
1309            variable. The dimension of the signature_cache is:
1310              num_cores x num_traced_tensors x num_signatures.
1311            - full_tensor_summary: kwargs will include all traced tensors. Tag
1312            and content correspond to the name of the tensor, and its actual
1313            content.
1314      Returns:
1315        A tf.Operation that needs to be executed for the host call dependencies.
1316      Raises:
1317        RuntimeError: if there is no aggregate function defined for a signature.
1318      """
1319
1320      # TODO(deveci): Parametrize max_queue, so that flushing op can be called
1321      # less frequently.
1322      # Setting max_queue to 100 appears to be safe even when the number of
1323      # iterations are much lower, as the destructor of the writer flushes it.
1324      summary_write_ops = []
1325      with summary.create_file_writer_v2(
1326          self._parameters.trace_dir,
1327          filename_suffix=_TT_EVENT_FILE_SUFFIX,
1328          max_queue=_TT_SUMMARY_MAX_QUEUE).as_default():
1329        summary_metadata = summary_pb2.SummaryMetadata(
1330            plugin_data=summary_pb2.SummaryMetadata.PluginData(
1331                plugin_name=_TT_TENSORBOARD_PLUGIN_NAME))
1332        for key, value in kwargs.items():
1333          # Check whether we need to compute aggregated statistics that merge
1334          # all cores statistics.
1335          if not self._parameters.collect_summary_per_core:
1336            # Merge only statistics tensor, if it is any other tensor we simply,
1337            # concatenate them.
1338            if key == _TT_SUMMARY_TAG:
1339              agg_fn_map = self._parameters.get_signature_to_agg_fn_map()
1340              signature_idx_map = self._signature_types()
1341              aggregation_result = []
1342              for signature, idx in sorted(signature_idx_map.items(),
1343                                           key=operator.itemgetter(1)):
1344                if signature not in agg_fn_map:
1345                  raise RuntimeError('No aggregation function is defined for '
1346                                     'signature %s.' % signature)
1347
1348                # The dimensions of the statistics tensor is
1349                # num_cores x num_traced_tensors x num_signatures
1350                # value[:,:,idx] will return the portion of the tensor relasted
1351                # to signature.
1352                signature_tensor = value[:, :, idx]
1353                # Merge it along the first (core) axis.
1354                agg_fn = agg_fn_map[signature]
1355                agg_tensor = agg_fn(signature_tensor, axis=0)
1356                aggregation_result.append(agg_tensor)
1357              # Merge results corresponding to different signatures
1358
1359              merged_signatures = array_ops.stack(aggregation_result)
1360              # merged_signatures has dimensions
1361              # num_signatures x num_traced_tensors, transpose it so that it
1362              # will match with the original structure
1363              # num_traced_tensors x num_signatures.
1364              transposed_signatures = array_ops.transpose(merged_signatures)
1365              # Expand 1 more dimension so that it will match with the expected
1366              # structure num_cores x num_traced_tensors x num_signatures.
1367              value = array_ops.expand_dims(transposed_signatures, axis=0)
1368
1369          with ops.control_dependencies(
1370              summary.summary_writer_initializer_op()):
1371            summary_write_ops.append(summary.write(
1372                _TT_SUMMARY_TAG + '/' + key, value, metadata=summary_metadata,
1373                step=step[0]))
1374      return control_flow_ops.group(summary_write_ops)
1375
1376    step = array_ops.reshape(training_util.get_or_create_global_step(), [1])
1377    self._host_call_fn = {}
1378
1379    host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches]
1380
1381    caches_to_write = {}
1382    with ops.control_dependencies(host_call_deps):
1383      all_caches = self._get_all_cache_variables()
1384      for cache_name, cache_variable in all_caches.items():
1385        # Increase the cache rank by 1, so that when host call concatenates
1386        # tensors from different replicas, we can identify them with [core_id].
1387        new_cache_shape = [1]
1388        new_cache_shape.extend(cache_variable.shape.as_list())
1389        cache = array_ops.reshape(cache_variable.value(), new_cache_shape)
1390        caches_to_write[cache_name] = cache
1391    # Add step to parameter dictionary.
1392    caches_to_write['step'] = step
1393    # Other options without adding step to parameter dictionary are
1394    #  * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it
1395    #    considers caches_to_write as a single parameter, rather than a keyword
1396    #    parameters.
1397    #  * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with
1398    #    a syntax error.
1399    self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write)
1400
1401  def host_call_deps_and_fn(self):
1402    return self._host_call_fn
1403
1404  def get_traced_op_names(self):
1405    """Returns the set of traced op names."""
1406    return self._traced_op_names
1407
1408  def _trace_execution(self, graph,
1409                       tensor_fetches,
1410                       op_fetches=None,
1411                       on_tpu=True):
1412    """Commong tracing function for both CPU and TPUs.
1413
1414    The caller function should set device_type, num_replicas,
1415    num_replicas_per_host, num_hosts and replica_id before calling
1416    _trace_execution.
1417
1418
1419    Args:
1420      graph: the graph of Ops executed on the TPU.
1421      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1422        returned by model_fn given to session.run. Function must be provided
1423        with as least one tensor to fetch.
1424      op_fetches: A list of op fetches returned by model_fn given to
1425        session.run. op_fetches and tensor_fetches are used to determine the
1426        nodes that will be executed. Can be None.
1427      on_tpu: True if executing on TPU.
1428
1429    Returns:
1430      tensor_fetches: an exact copy of tensor_fetches that has additional
1431                      dependencies.
1432    Raises:
1433      RuntimeError: If tensor_fetches is None or empty.
1434    """
1435    def _cast_unsupported_dtypes(tensor):
1436      """Casts tensor to a supported type."""
1437
1438      if tensor.dtype.__eq__(dtypes.int64):
1439        # outside-compilation doesn't support int64 input yet.
1440        return math_ops.cast(tensor, dtypes.int32)
1441      if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
1442          dtypes.float16):
1443        # Since host can't handle bf16, convert tensor to f32.
1444        return math_ops.cast(tensor, dtypes.float32)
1445      return tensor
1446
1447    trace_mode = self._parameters.trace_mode
1448    device_type = self._tt_config.device_type
1449
1450    analytics.track_usage('tensor_tracer', [trace_mode, device_type])
1451    TensorTracer.check_device_type(device_type)
1452    TensorTracer.check_trace_mode(device_type, trace_mode)
1453    # Check in_tensor_fetches, and op_fetches and convert them to lists.
1454    processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
1455    op_fetches = self._process_op_fetches(op_fetches)
1456    all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]
1457
1458    # Filter out the operations that won't be executed.
1459    # if fetches=None, then ops_in_exec_path = set(operations)
1460    exec_op_set = self._filter_execution_path_operations(graph.get_operations(),
1461                                                         all_fetches)
1462    # Write report file, and determine the traced tensors.
1463    tensor_trace_order = self._determine_trace_and_create_report(
1464        graph, exec_op_set)
1465
1466    tensor_fetch_set = set(processed_t_fetches)
1467    tracing_ops = []
1468
1469    # pylint: disable=protected-access
1470    current_control_flow_context = graph._get_control_flow_context()
1471    # pylint: enable=protected-access
1472
1473    sorted_exec_op_list = list(exec_op_set)
1474    sorted_exec_op_list.sort(key=lambda op: op.name)
1475    # Trace ops only if they are in the execution path.
1476    for op in sorted_exec_op_list:
1477      for i in range(len(op.outputs)):
1478        out_tensor = op.outputs[i]
1479        tensor_name = out_tensor.name
1480        if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
1481          continue
1482        self._traced_op_names.add(op.name)
1483        # Create the list of consumers before calling _preprocess_traced_tensor.
1484        # Otherwise, adding control input below, will introduce a cycle in the
1485        # graph.
1486        consumers = out_tensor.consumers()
1487        # Not all consumers may be in the exec path. Filter out the consumers
1488        # to keep the graph simpler.
1489        consumers = [cop for cop in consumers if cop in exec_op_set]
1490
1491        # If there is no consumer of the tensor, there is no need to trace it;
1492        # unless the tensor itself is one of the fetches.
1493        is_a_fetched_tensor = out_tensor in tensor_fetch_set
1494        if (not consumers) and (not is_a_fetched_tensor):
1495          continue
1496
1497        op_control_flow_context = self._get_op_control_flow_context(op)
1498        if op_control_flow_context:
1499          # pylint: disable=protected-access
1500          graph._set_control_flow_context(op_control_flow_context)
1501          # pylint: enable=protected-access
1502
1503        processed_tensors = self._preprocess_traced_tensor(out_tensor)
1504
1505        if on_tpu:
1506          for signature in processed_tensors.keys():
1507            processed_tensors[signature] = _cast_unsupported_dtypes(
1508                processed_tensors[signature])
1509
1510        if self._use_tensor_values_cache():
1511          # Use a small cache to store the characteristics of the tensor.
1512          cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
1513          trace_op = self._save_tensor_value_to_cache_op(cache_idx,
1514                                                         processed_tensors)
1515        elif self._use_tensor_buffer():
1516          if len(processed_tensors) != 1:
1517            raise RuntimeError('Multiple stats are only allowed in compact '
1518                               'mode.')
1519          processed_out_tensor = processed_tensors.values()[0]
1520          # Store the whole tensor in a buffer.
1521          trace_op = self._snapshot_tensor(processed_out_tensor)
1522        else:
1523
1524          def tpu_wrap_trace_fn(tensor, out_tensor_name):
1525            """Wraps the trace_fn with outside compilation if on TPUs."""
1526            tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name,
1527                                                          tensor_trace_order)
1528            if on_tpu:
1529              return tpu.outside_compilation(tensor_trace_fn, tensor)
1530            else:
1531              return tensor_trace_fn(tensor)
1532
1533          def conditional_trace_fn(predicate_tensor, out_tensor, trace_fn,
1534                                   out_tensor_name):
1535            """Creates a cond op that traces the out_tensor if predicate is satisfied."""
1536            return control_flow_ops.cond(
1537                predicate_tensor, lambda: trace_fn(out_tensor, out_tensor_name),
1538                lambda: constant_op.constant(False)).op
1539
1540          if len(processed_tensors) != 1:
1541            raise RuntimeError('Multiple stats are only allowed in compact '
1542                               'mode.')
1543          # Collecting multiple statistics are only supported in the summary
1544          # mode that uses compact format(self._use_tensor_values_cache = true).
1545          # Non-compact mode currently allows single stat per tensor.
1546          processed_out_tensor = six.next(six.itervalues(processed_tensors))
1547
1548          if self._parameters.is_conditional_trace:
1549            trace_op = conditional_trace_fn(processed_out_tensor, out_tensor,
1550                                            tpu_wrap_trace_fn, tensor_name)
1551          elif self._parameters.included_cores:
1552            should_print = constant_op.constant(False)
1553            for core in self._parameters.included_cores:
1554              should_print = gen_math_ops.logical_or(
1555                  should_print, gen_math_ops.equal(self._replica_id, core))
1556            trace_op = conditional_trace_fn(should_print, processed_out_tensor,
1557                                            tpu_wrap_trace_fn, tensor_name)
1558
1559          else:
1560            trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name)
1561
1562        if op_control_flow_context:
1563          # pylint: disable=protected-access
1564          graph._set_control_flow_context(current_control_flow_context)
1565          # pylint: enable=protected-access
1566
1567        if is_a_fetched_tensor:
1568          tracing_ops.append(trace_op)
1569          continue
1570        # Add it to all consumers, as some consumers may not be executed if they
1571        # are in a control flow.
1572        for consumer_op in consumers:
1573          # pylint: disable=protected-access
1574          consumer_op._add_control_input(trace_op)
1575          # pylint: enable=protected-access
1576
1577    # pylint: disable=protected-access
1578    graph._set_control_flow_context(current_control_flow_context)
1579    # pylint: enable=protected-access
1580    if tracing_ops:
1581      # If we are tracing a fetched tensor, their dependency is stored in
1582      # tracing_ops.
1583      processed_t_fetches = control_flow_ops.tuple(processed_t_fetches,
1584                                                   control_inputs=tracing_ops)
1585    if self._use_tensor_values_cache() or self._use_tensor_buffer():
1586      if self._create_host_call():
1587        self._prepare_host_call_fn(processed_t_fetches, op_fetches)
1588        if not on_tpu:
1589          write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
1590          cache_write_op = write_cache(**caches_to_write)
1591          processed_t_fetches = control_flow_ops.tuple(
1592              processed_t_fetches, control_inputs=[cache_write_op])
1593          del self._host_call_fn[_TT_HOSTCALL_KEY]
1594      else:
1595        processed_t_fetches = self._flush_tensor_values_cache(
1596            processed_t_fetches, op_fetches, on_tpu=on_tpu)
1597
1598    # processed_t_fetches is a list at this point. Convert it to the same
1599    # format as given in tensor_fetches.
1600    return self._convert_fetches_to_input_format(tensor_fetches,
1601                                                 processed_t_fetches)
1602
1603  def trace_tpu(self, graph,
1604                tensor_fetches,
1605                op_fetches=None,
1606                num_replicas=None,
1607                num_replicas_per_host=None,
1608                num_hosts=None):
1609    """Traces the tensors generated by TPU Ops in a TF graph.
1610
1611    Args:
1612      graph: the graph of Ops executed on the TPU.
1613      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1614        returned by model_fn given to session.run. Function must be provided
1615        with as least one tensor to fetch.
1616      op_fetches: A list of op fetches returned by model_fn given to
1617        session.run. op_fetches and tensor_fetches are used to determine the
1618        nodes that will be executed. Can be None.
1619      num_replicas: number of replicas used on the TPU.
1620      num_replicas_per_host: number of replicas per TPU host.
1621      num_hosts: total number of TPU hosts.
1622
1623    Returns:
1624      tensor_fetches: an exact copy of tensor_fetches that has additional
1625                      dependencies.
1626    Raises:
1627      RuntimeError: If num_replicas_per_host > 8.
1628      RuntimeError: If tensor_fetches is None or empty.
1629    """
1630    if graph in TensorTracer._traced_graphs:
1631      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
1632                      'multiple calls.')
1633      return tensor_fetches
1634    else:
1635      TensorTracer._traced_graphs.add(graph)
1636
1637    self._tt_config.device_type = _DEVICE_TYPE_TPU
1638    self._tt_config.num_replicas = num_replicas
1639    self._tt_config.num_replicas_per_host = num_replicas_per_host
1640    self._tt_config.num_hosts = num_hosts
1641    if self._tt_config.num_replicas is not None:
1642      if self._tt_config.num_replicas_per_host is None:
1643        self._tt_config.num_replicas_per_host = 8
1644      if self._tt_config.num_hosts is None:
1645        self._tt_config.num_hosts = (
1646            num_replicas // self._tt_config.num_replicas_per_host +
1647            (num_replicas % self._tt_config.num_replicas_per_host > 0))
1648
1649    if self._parameters.graph_dump_path:
1650      graph_io.write_graph(graph, self._parameters.graph_dump_path,
1651                           'graph_before_tt.pbtxt')
1652    with graph.as_default():
1653      self._add_replica_id_to_graph()
1654      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
1655                                             on_tpu=True)
1656    if self._parameters.graph_dump_path:
1657      graph_io.write_graph(graph, self._parameters.graph_dump_path,
1658                           'graph_after_tt.pbtxt')
1659    return tensor_fetches
1660
1661  def trace_cpu(self, graph, tensor_fetches, op_fetches=None):
1662    """Traces the tensors generated by CPU Ops in a TF graph.
1663
1664    Args:
1665      graph: the graph of Ops executed on the CPU.
1666      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1667        returned by model_fn given to session.run. Function must be provided
1668        with as least one tensor to fetch.
1669      op_fetches: A list of op fetches returned by model_fn given to
1670        session.run. op_fetches and tensor_fetches are used to determine the
1671        nodes that will be executed. Can be None.
1672
1673    Returns:
1674      tensor_fetches: an exact copy of tensor_fetches that has additional
1675                      dependencies.
1676    Raises:
1677      RuntimeError: If tensor_fetches is None or empty.
1678    """
1679
1680    if graph in TensorTracer._traced_graphs:
1681      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
1682                      'multiple calls.')
1683      return tensor_fetches
1684    else:
1685      TensorTracer._traced_graphs.add(graph)
1686
1687    self._tt_config.device_type = _DEVICE_TYPE_CPU
1688    self._tt_config.num_replicas = 1
1689    self._tt_config.num_replicas_per_host = 1
1690    self._tt_config.num_hosts = 1
1691    self._replica_id = 0
1692    if self._parameters.graph_dump_path:
1693      graph_io.write_graph(graph, self._parameters.graph_dump_path,
1694                           'graph_before_tt.pbtxt')
1695    with graph.as_default():
1696      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
1697                                             on_tpu=False)
1698    if self._parameters.graph_dump_path:
1699      graph_io.write_graph(graph, self._parameters.graph_dump_path,
1700                           'graph_after_tt.pbtxt')
1701    return tensor_fetches
1702