• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""Utilities to handle tensor tracer parameters."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22import os
23import os.path
24import re
25
26from tensorflow.python.ops import linalg_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import tf_logging as logging
29
30TRACE_MODE_NAN_INF = 'nan-inf'
31TRACE_MODE_PART_TENSOR = 'part-tensor'
32TRACE_MODE_FULL_TENSOR = 'full-tensor'
33TRACE_MODE_FULL_IF_NAN = 'trace-back-if-nan'
34TRACE_MODE_NORM = 'norm'
35TRACE_MODE_MAX_ABS = 'max-abs'
36TRACE_MODE_SUMMARY = 'summary'
37# summary mode to collects a finite set of signatures for each traced tensor,
38# (such as norm, max, min, mean) and dumps it using tb summaries.
39TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary'
40# Full tensor mode dumps the whole tensor values for the traced tensors without
41# any processing on them; using tb summaries.
42_FLAG_NAME_TRACE_STACK_SIZE = 'trace_stack_size'
43_SUBMODE_BRIEF = 'brief'
44_SUBMODE_DETAILED = 'detailed'
45_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
46_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
47_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"')
48_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)')
49_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*')
50_FLAG_NAME_ENABLE = 'enable'
51_FLAG_NAME_TRACE_MODE = 'trace_mode'
52_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace'
53_FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar'
54_FLAG_NAME_TRACE_BEFORE_OPS = 'trace_before_included_ops'
55_FLAG_NAME_TRACE_AFTER_OPS = 'trace_after_included_ops'
56_FLAG_NAME_SUBMODE = 'submode'
57_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops'
58_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
59_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
60_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
61_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
62_FLAG_NAME_INCLUDED_CORES = 'included_cores'
63_FLAG_NAME_TRACE_LEVEL = 'trace_level'
64_FLAG_NAME_TRACE_DIR = 'trace_dir'
65_FLAG_NAME_REPORT_FILE = 'report_file'
66_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
67_FLAG_NAME_OP_RANGE = 'op_range'
68# Folder to dump the pre (before tensor tracer updates) and post graphs (after
69# tensor tracer updates).
70_FLAG_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs'
71_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
72_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
73_FLAG_SUMMARY_SIGNATURES = 'signatures'
74_FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core'
75
76_TT_DEFAULT_TRACE_LEVEL = 3
77_TT_PREFIX = 'tensor_tracer'
78
79_TT_NORM = 'norm'
80_TT_MAX = 'max'
81_TT_MIN = 'min'
82_TT_MEAN = 'mean'
83_TT_VAR = 'var'
84_TT_SIZE = 'size'
85
86TT_SUMMARY_NORM = '%s_%s' % (_TT_PREFIX, _TT_NORM)
87TT_SUMMARY_MAX = '%s_%s' % (_TT_PREFIX, _TT_MAX)
88TT_SUMMARY_MIN = '%s_%s' % (_TT_PREFIX, _TT_MIN)
89TT_SUMMARY_MEAN = '%s_%s' % (_TT_PREFIX, _TT_MEAN)
90TT_SUMMARY_VAR = '%s_%s' % (_TT_PREFIX, _TT_VAR)
91TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE)
92
93TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN,
94                         TT_SUMMARY_MEAN, TT_SUMMARY_VAR, TT_SUMMARY_SIZE)
95
96_TT_DEFAULT_TRACE_LEVEL = 3
97
98
99class TTParameters(object):
100  """A class that handles the parameters of Tensor Tracer."""
101
102  def __init__(self, env=None):
103    if env:
104      self._env = env
105    else:
106      self._env = os.environ
107    self._validate_flag_names()
108    self.trace_mode = self._get_trace_mode()
109    self.submode = self._get_submode()
110    self.trace_dir = self._get_trace_dir()
111    self.report_file_path = self._get_report_filepath()
112    self.op_range = self._get_op_range()
113    self.excluded_opname_re_list = self._flag_value_to_re_list(
114        _FLAG_NAME_EXCLUDED_OPNAMES)
115    self.excluded_optype_re_list = self._flag_value_to_re_list(
116        _FLAG_NAME_EXCLUDED_OPTYPES)
117
118    self.included_opname_re_list = self._flag_value_to_re_list(
119        _FLAG_NAME_INCLUDED_OPNAMES)
120    self.included_optype_re_list = self._flag_value_to_re_list(
121        _FLAG_NAME_INCLUDED_OPTYPES)
122
123    self.is_conditional_trace = self._is_conditional_trace_mode()
124    self.trace_scalar_ops = self.is_flag_on(_FLAG_NAME_TRACE_SCALAR_OPS)
125    self.use_compact_trace = self.is_flag_on(_FLAG_NAME_USE_COMPACT_TRACE)
126
127    # _trace_ops_before_included and _trace_ops_after_included denotes to depth
128    # of tracing relative to the ops given in --included_opnames or
129    # --included_optypes
130    # For example, in the below graph
131    #                op1 --> op2 --> op3 --> op4 --> op5
132    # If --included_opnames=op3 then only op3 will be traced.
133    # If also --trace_before_included_ops=2 (_trace_ops_before_included), then
134    # op1 and op2 will be traced as they are at most 2 hops apart from an
135    # included op. Similarly, if --trace_after_included_ops=2, then op4 and op5
136    # will also be traced.
137    self.trace_ops_before_included = self._get_flag_int_value(
138        _FLAG_NAME_TRACE_BEFORE_OPS, 0)
139    self.trace_ops_after_included = self._get_flag_int_value(
140        _FLAG_NAME_TRACE_AFTER_OPS, 0)
141    self.trace_stack_size = self._get_flag_int_value(
142        _FLAG_NAME_TRACE_STACK_SIZE, 1)
143    _, self.graph_dump_path = self.get_flag_value(
144        _FLAG_DUMP_BEFORE_AFTER_GRAPHS)
145    self.included_cores = self._flag_value_as_int_list(
146        _FLAG_NAME_INCLUDED_CORES)
147    self.include_less_interesting_ops = self.is_flag_on(
148        _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS)
149    self.trace_level = self._get_flag_int_value(
150        _FLAG_NAME_TRACE_LEVEL, _TT_DEFAULT_TRACE_LEVEL)
151    self.summary_signatures = self._get_summary_signatures()
152    self.collect_summary_per_core = self.is_flag_on(_FLAG_NAME_SUMMARY_PER_CORE)
153
154  def _is_conditional_trace_mode(self):
155    return self.trace_mode == TRACE_MODE_FULL_IF_NAN
156
157  def _get_report_filepath(self):
158    """Sets the path of the output report file."""
159
160    found, report_file_path = self.get_flag_value(
161        _FLAG_NAME_REPORT_FILE)
162    if found and report_file_path \
163       and self.use_test_undeclared_outputs_dir():
164      if os.path.isabs(report_file_path):
165        raise ValueError('If use_test_undeclared_outputs_dir is set,'
166                         'report_file_path cannot be an absolute path (%s)'
167                         %report_file_path)
168      outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
169      report_file_path = os.path.join(outputs_dir, report_file_path)
170    return report_file_path
171
172  def _get_op_range(self):
173    """Sets the index range of the Ops that we will consider tracing."""
174    found, op_range = self.get_flag_value(_FLAG_NAME_OP_RANGE)
175    if not found or not op_range:
176      op_range = (-1, -1)  # this means including all ops.
177      return op_range
178    match = _OP_RANGE_PAT.match(op_range)
179    if not match:
180      op_range = (-1, -1)  # this means including all ops.
181      return op_range
182    op_range = (int(match.group(1)), int(match.group(2)))
183    return op_range
184
185  def _get_trace_dir(self):
186    found, trace_dir = self.get_flag_value(_FLAG_NAME_TRACE_DIR)
187    if found and trace_dir \
188       and self.use_test_undeclared_outputs_dir():
189      raise ValueError('Cannot not use --%s and --%s at the same time'
190                       %(_FLAG_NAME_TRACE_DIR,
191                         _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR))
192    if self.use_test_undeclared_outputs_dir():
193      trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
194    return trace_dir
195
196  def _get_trace_mode(self):
197    """Checks if the given trace mode is valid."""
198
199    found, trace_mode = self.get_flag_value(_FLAG_NAME_TRACE_MODE)
200    if not found or not trace_mode:
201      trace_mode = TRACE_MODE_NORM
202    valid_trace_modes = [
203        TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR,
204        TRACE_MODE_NORM, TRACE_MODE_MAX_ABS, TRACE_MODE_FULL_IF_NAN,
205        TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY
206    ]
207    if trace_mode not in valid_trace_modes:
208      raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
209                       'Valid trace modes are: %s'%(trace_mode,
210                                                    valid_trace_modes))
211    return trace_mode
212
213  def is_brief_mode(self):
214    return self.submode == _SUBMODE_BRIEF
215
216  def _get_submode(self):
217    """Checks if the given submode is valid."""
218
219    found, submode = self.get_flag_value(_FLAG_NAME_SUBMODE)
220    if not found or not submode:
221      submode = _SUBMODE_DETAILED
222    if not submode:
223      return
224    valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF]
225    if submode not in valid_submodes:
226      raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.'
227                       'Valid submodes are: %s'%(submode,
228                                                 valid_submodes))
229    return submode
230
231  @staticmethod
232  def match_next_flag(flags, pos):
233    """Returns the match for the next TensorTracer flag.
234
235    Args:
236       flags: a string that contains the flags.
237       pos: where in flags to start the search.
238
239    Returns:
240       A pair where the first element is the regular-expression
241       match found and the second element indicates if the match
242       has a value.
243    """
244
245    match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos)
246    if match:
247      return match, True
248    match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos)
249    if match:
250      return match, True
251    match = _FLAG_NO_QUOTE_PAT.match(flags, pos)
252    if match:
253      return match, True
254    match = _FLAG_NO_EQUAL_PAT.match(flags, pos)
255    if match:
256      # The flag is found but is not given a value.
257      return match, False
258    # The flag is not found.
259    return None, False
260
261  def _validate_flag_names(self):
262    """Validates if the TensorTrace flags passed are valid."""
263    valid_flag_names = [
264        _FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, _FLAG_NAME_USE_COMPACT_TRACE,
265        _FLAG_NAME_TRACE_SCALAR_OPS, _FLAG_NAME_TRACE_BEFORE_OPS,
266        _FLAG_NAME_TRACE_AFTER_OPS, _FLAG_NAME_TRACE_STACK_SIZE,
267        _FLAG_NAME_SUBMODE, _FLAG_NAME_EXCLUDED_OPNAMES,
268        _FLAG_NAME_EXCLUDED_OPTYPES, _FLAG_NAME_INCLUDED_OPNAMES,
269        _FLAG_NAME_INCLUDED_OPTYPES, _FLAG_NAME_TRACE_DIR,
270        _FLAG_NAME_INCLUDED_CORES, _FLAG_NAME_REPORT_FILE,
271        _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
272        _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, _FLAG_NAME_OP_RANGE,
273        _FLAG_DUMP_BEFORE_AFTER_GRAPHS, _FLAG_NAME_TRACE_LEVEL,
274        _FLAG_SUMMARY_SIGNATURES, _FLAG_NAME_SUMMARY_PER_CORE
275    ]
276    tensor_tracer_flags = self._env.get(_FLAGS_ENV_VAR)
277    if not tensor_tracer_flags:
278      return
279    pos = 0
280    while True:
281      match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos)
282      if not match:
283        break
284      flag_name = match.group(1)
285      if flag_name not in valid_flag_names:
286        raise ValueError(
287            'The flag name "%s" passed via the environment variable "%s" '
288            'is invalid. Valid flag names are:'
289            '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names))
290      pos = match.end()
291
292  def _get_summary_signatures(self):
293    """Verifies and returns the summary signatures.
294
295    Returns:
296      A dictionary of the signature identifiers {signature: index} that will be
297      computed when trace_mode is summary.
298    """
299    signatures = self._flag_value_as_list(_FLAG_SUMMARY_SIGNATURES)
300
301    tt_signatures = []
302    for signature in signatures:
303      signature_with_prefix = '%s_%s' % (_TT_PREFIX, signature)
304      if signature in TT_SUMMARY_SIGNATURES:
305        tt_signatures.append(signature)
306      elif signature_with_prefix in TT_SUMMARY_SIGNATURES:
307        tt_signatures.append(signature_with_prefix)
308      else:
309        logging.warning('Unknown signature:%s. Supported signatures: %s' % (
310            signature, TT_SUMMARY_SIGNATURES))
311    if not tt_signatures:
312      # Default case collects norm and max only.
313      return {TT_SUMMARY_MAX: 0, TT_SUMMARY_NORM: 1}
314    else:
315      return {signature: idx for idx, signature in enumerate(tt_signatures)}
316
317  def get_signature_to_agg_fn_map(self):
318    """Returns a map that contains the aggragate function for each signature."""
319    return {TT_SUMMARY_NORM: linalg_ops.norm,
320            TT_SUMMARY_MAX: math_ops.reduce_max,
321            TT_SUMMARY_MIN: math_ops.reduce_min,
322            TT_SUMMARY_MEAN: math_ops.reduce_mean,
323            TT_SUMMARY_VAR: math_ops.reduce_max,  # Simply reduce max variance.
324            TT_SUMMARY_SIZE: math_ops.reduce_sum}
325
326  def _flag_value_as_list(self, wanted_flag_name):
327    """Returns the string list of a TensorTracer flag.
328
329    Args:
330      wanted_flag_name: the name of the flag we are looking for.
331
332    Returns:
333      The list value of the flag.
334    """
335    string_value_list = []
336    found, flag_value = self.get_flag_value(wanted_flag_name)
337
338    if found:
339      string_value_list = flag_value.split(',')
340    return string_value_list
341
342  def _flag_value_as_int_list(self, wanted_flag_name):
343    """Returns the integer list of a TensorTracer flag.
344
345    Args:
346      wanted_flag_name: the name of the flag we are looking for.
347
348    Returns:
349      the value of the flag.
350    Raises:
351      RuntimeError: If supposedly deadcode is reached.
352    """
353    int_list = []
354    found, flag_value = self.get_flag_value(wanted_flag_name)
355
356    if found:
357      try:
358        integer_values = flag_value.split(',')
359        int_list = [int(int_val) for int_val in integer_values]
360      except ValueError:
361        logging.warning('Cannot convert %s to int for flag %s', int_list,
362                        wanted_flag_name)
363    return int_list
364
365  def _get_flag_int_value(self, wanted_flag_name, default_value):
366    """Returns the int value of a TensorTracer flag.
367
368    Args:
369      wanted_flag_name: the name of the flag we are looking for.
370      default_value: the default value for the flag, if not provided.
371    Returns:
372      the value of the flag.
373    Raises:
374      RuntimeError: If supposedly deadcode is reached.
375    """
376    flag_int_value = default_value
377    found, flag_value = self.get_flag_value(wanted_flag_name)
378
379    if found:
380      try:
381        flag_int_value = int(flag_value)
382      except ValueError:
383        logging.warning('Cannot convert %s to int for flag %s' % (
384            flag_int_value, wanted_flag_name))
385    return flag_int_value
386
387  def get_flag_value(self, wanted_flag_name):
388    """Returns the value of a TensorTracer flags.
389
390    Args:
391      wanted_flag_name: the name of the flag we are looking for.
392
393    Returns:
394      A pair where the first element indicates if the flag is
395      found and the second element is the value of the flag.
396
397    Raises:
398      RuntimeError: If supposedly deadcode is reached.
399    """
400
401    tensor_tracer_flags = self._env.get(_FLAGS_ENV_VAR)
402    if not tensor_tracer_flags:
403      return False, None
404    pos = 0
405    while True:
406      match, has_value = TTParameters.match_next_flag(
407          tensor_tracer_flags, pos)
408      if not match:
409        return False, None
410      flag_name = match.group(1)
411      if has_value:
412        flag_value = match.group(2)
413      else:
414        flag_value = None
415      if flag_name == wanted_flag_name:
416        return True, flag_value
417      pos = match.end()
418    raise RuntimeError('Should not reach here.')
419
420  def _flag_value_to_re_list(self, flag_name):
421    """Converts list of strings to compiled RE."""
422
423    re_list = []
424    found, flag_value = self.get_flag_value(flag_name)
425    if not found or not flag_value:
426      return re_list
427    list_of_values = flag_value.split(',')
428    for v in list_of_values:
429      r = re.compile(v)
430      re_list.append(r)
431    return re_list
432
433  def is_flag_on(self, flag_name):
434    """Returns True if the given flag is on."""
435
436    found, flag_value = self.get_flag_value(flag_name)
437    if not found:
438      return False
439    if flag_value is None:
440      return True
441    # Depends on the flag value.
442    flag_value = flag_value.lower()
443    enabled = flag_value in ['1', 't', 'true', 'y', 'yes']
444    return enabled
445
446  def is_enabled(self):
447    """Returns True if TensorTracer is enabled."""
448
449    if self.is_flag_on(_FLAG_NAME_ENABLE):
450      logging.info('Tensor Tracer is enabled with flags %s.' %
451                   self._env.get(_FLAGS_ENV_VAR))
452      return True
453    else:
454      return False
455
456  def use_test_undeclared_outputs_dir(self):
457    """Decides the output directory of the report and trace files.
458
459    Args:
460       None.
461
462    Returns:
463       True if the output files should be written to the
464       test-undeclared-outputs-directory defined via an
465       env variable.
466    """
467
468    return self.is_flag_on(_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)
469