1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ======================================================================== 15"""Utilities to handle tensor tracer parameters.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22import os 23import os.path 24import re 25 26from tensorflow.python.ops import linalg_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.platform import tf_logging as logging 29 30TRACE_MODE_NAN_INF = 'nan-inf' 31TRACE_MODE_PART_TENSOR = 'part-tensor' 32TRACE_MODE_FULL_TENSOR = 'full-tensor' 33TRACE_MODE_FULL_IF_NAN = 'trace-back-if-nan' 34TRACE_MODE_NORM = 'norm' 35TRACE_MODE_MAX_ABS = 'max-abs' 36TRACE_MODE_SUMMARY = 'summary' 37# summary mode to collects a finite set of signatures for each traced tensor, 38# (such as norm, max, min, mean) and dumps it using tb summaries. 39TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary' 40# Full tensor mode dumps the whole tensor values for the traced tensors without 41# any processing on them; using tb summaries. 42_FLAG_NAME_TRACE_STACK_SIZE = 'trace_stack_size' 43_SUBMODE_BRIEF = 'brief' 44_SUBMODE_DETAILED = 'detailed' 45_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' 46_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") 47_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') 48_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') 49_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') 50_FLAG_NAME_ENABLE = 'enable' 51_FLAG_NAME_TRACE_MODE = 'trace_mode' 52_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace' 53_FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar' 54_FLAG_NAME_TRACE_BEFORE_OPS = 'trace_before_included_ops' 55_FLAG_NAME_TRACE_AFTER_OPS = 'trace_after_included_ops' 56_FLAG_NAME_SUBMODE = 'submode' 57_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' 58_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' 59_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' 60_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' 61_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' 62_FLAG_NAME_INCLUDED_CORES = 'included_cores' 63_FLAG_NAME_TRACE_LEVEL = 'trace_level' 64_FLAG_NAME_TRACE_DIR = 'trace_dir' 65_FLAG_NAME_REPORT_FILE = 'report_file' 66_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' 67_FLAG_NAME_OP_RANGE = 'op_range' 68# Folder to dump the pre (before tensor tracer updates) and post graphs (after 69# tensor tracer updates). 70_FLAG_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' 71_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') 72_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' 73_FLAG_SUMMARY_SIGNATURES = 'signatures' 74_FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' 75 76_TT_DEFAULT_TRACE_LEVEL = 3 77_TT_PREFIX = 'tensor_tracer' 78 79_TT_NORM = 'norm' 80_TT_MAX = 'max' 81_TT_MIN = 'min' 82_TT_MEAN = 'mean' 83_TT_VAR = 'var' 84_TT_SIZE = 'size' 85 86TT_SUMMARY_NORM = '%s_%s' % (_TT_PREFIX, _TT_NORM) 87TT_SUMMARY_MAX = '%s_%s' % (_TT_PREFIX, _TT_MAX) 88TT_SUMMARY_MIN = '%s_%s' % (_TT_PREFIX, _TT_MIN) 89TT_SUMMARY_MEAN = '%s_%s' % (_TT_PREFIX, _TT_MEAN) 90TT_SUMMARY_VAR = '%s_%s' % (_TT_PREFIX, _TT_VAR) 91TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE) 92 93TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN, 94 TT_SUMMARY_MEAN, TT_SUMMARY_VAR, TT_SUMMARY_SIZE) 95 96_TT_DEFAULT_TRACE_LEVEL = 3 97 98 99class TTParameters(object): 100 """A class that handles the parameters of Tensor Tracer.""" 101 102 def __init__(self, env=None): 103 if env: 104 self._env = env 105 else: 106 self._env = os.environ 107 self._validate_flag_names() 108 self.trace_mode = self._get_trace_mode() 109 self.submode = self._get_submode() 110 self.trace_dir = self._get_trace_dir() 111 self.report_file_path = self._get_report_filepath() 112 self.op_range = self._get_op_range() 113 self.excluded_opname_re_list = self._flag_value_to_re_list( 114 _FLAG_NAME_EXCLUDED_OPNAMES) 115 self.excluded_optype_re_list = self._flag_value_to_re_list( 116 _FLAG_NAME_EXCLUDED_OPTYPES) 117 118 self.included_opname_re_list = self._flag_value_to_re_list( 119 _FLAG_NAME_INCLUDED_OPNAMES) 120 self.included_optype_re_list = self._flag_value_to_re_list( 121 _FLAG_NAME_INCLUDED_OPTYPES) 122 123 self.is_conditional_trace = self._is_conditional_trace_mode() 124 self.trace_scalar_ops = self.is_flag_on(_FLAG_NAME_TRACE_SCALAR_OPS) 125 self.use_compact_trace = self.is_flag_on(_FLAG_NAME_USE_COMPACT_TRACE) 126 127 # _trace_ops_before_included and _trace_ops_after_included denotes to depth 128 # of tracing relative to the ops given in --included_opnames or 129 # --included_optypes 130 # For example, in the below graph 131 # op1 --> op2 --> op3 --> op4 --> op5 132 # If --included_opnames=op3 then only op3 will be traced. 133 # If also --trace_before_included_ops=2 (_trace_ops_before_included), then 134 # op1 and op2 will be traced as they are at most 2 hops apart from an 135 # included op. Similarly, if --trace_after_included_ops=2, then op4 and op5 136 # will also be traced. 137 self.trace_ops_before_included = self._get_flag_int_value( 138 _FLAG_NAME_TRACE_BEFORE_OPS, 0) 139 self.trace_ops_after_included = self._get_flag_int_value( 140 _FLAG_NAME_TRACE_AFTER_OPS, 0) 141 self.trace_stack_size = self._get_flag_int_value( 142 _FLAG_NAME_TRACE_STACK_SIZE, 1) 143 _, self.graph_dump_path = self.get_flag_value( 144 _FLAG_DUMP_BEFORE_AFTER_GRAPHS) 145 self.included_cores = self._flag_value_as_int_list( 146 _FLAG_NAME_INCLUDED_CORES) 147 self.include_less_interesting_ops = self.is_flag_on( 148 _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) 149 self.trace_level = self._get_flag_int_value( 150 _FLAG_NAME_TRACE_LEVEL, _TT_DEFAULT_TRACE_LEVEL) 151 self.summary_signatures = self._get_summary_signatures() 152 self.collect_summary_per_core = self.is_flag_on(_FLAG_NAME_SUMMARY_PER_CORE) 153 154 def _is_conditional_trace_mode(self): 155 return self.trace_mode == TRACE_MODE_FULL_IF_NAN 156 157 def _get_report_filepath(self): 158 """Sets the path of the output report file.""" 159 160 found, report_file_path = self.get_flag_value( 161 _FLAG_NAME_REPORT_FILE) 162 if found and report_file_path \ 163 and self.use_test_undeclared_outputs_dir(): 164 if os.path.isabs(report_file_path): 165 raise ValueError('If use_test_undeclared_outputs_dir is set,' 166 'report_file_path cannot be an absolute path (%s)' 167 %report_file_path) 168 outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 169 report_file_path = os.path.join(outputs_dir, report_file_path) 170 return report_file_path 171 172 def _get_op_range(self): 173 """Sets the index range of the Ops that we will consider tracing.""" 174 found, op_range = self.get_flag_value(_FLAG_NAME_OP_RANGE) 175 if not found or not op_range: 176 op_range = (-1, -1) # this means including all ops. 177 return op_range 178 match = _OP_RANGE_PAT.match(op_range) 179 if not match: 180 op_range = (-1, -1) # this means including all ops. 181 return op_range 182 op_range = (int(match.group(1)), int(match.group(2))) 183 return op_range 184 185 def _get_trace_dir(self): 186 found, trace_dir = self.get_flag_value(_FLAG_NAME_TRACE_DIR) 187 if found and trace_dir \ 188 and self.use_test_undeclared_outputs_dir(): 189 raise ValueError('Cannot not use --%s and --%s at the same time' 190 %(_FLAG_NAME_TRACE_DIR, 191 _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) 192 if self.use_test_undeclared_outputs_dir(): 193 trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 194 return trace_dir 195 196 def _get_trace_mode(self): 197 """Checks if the given trace mode is valid.""" 198 199 found, trace_mode = self.get_flag_value(_FLAG_NAME_TRACE_MODE) 200 if not found or not trace_mode: 201 trace_mode = TRACE_MODE_NORM 202 valid_trace_modes = [ 203 TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR, 204 TRACE_MODE_NORM, TRACE_MODE_MAX_ABS, TRACE_MODE_FULL_IF_NAN, 205 TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY 206 ] 207 if trace_mode not in valid_trace_modes: 208 raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 209 'Valid trace modes are: %s'%(trace_mode, 210 valid_trace_modes)) 211 return trace_mode 212 213 def is_brief_mode(self): 214 return self.submode == _SUBMODE_BRIEF 215 216 def _get_submode(self): 217 """Checks if the given submode is valid.""" 218 219 found, submode = self.get_flag_value(_FLAG_NAME_SUBMODE) 220 if not found or not submode: 221 submode = _SUBMODE_DETAILED 222 if not submode: 223 return 224 valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] 225 if submode not in valid_submodes: 226 raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' 227 'Valid submodes are: %s'%(submode, 228 valid_submodes)) 229 return submode 230 231 @staticmethod 232 def match_next_flag(flags, pos): 233 """Returns the match for the next TensorTracer flag. 234 235 Args: 236 flags: a string that contains the flags. 237 pos: where in flags to start the search. 238 239 Returns: 240 A pair where the first element is the regular-expression 241 match found and the second element indicates if the match 242 has a value. 243 """ 244 245 match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) 246 if match: 247 return match, True 248 match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) 249 if match: 250 return match, True 251 match = _FLAG_NO_QUOTE_PAT.match(flags, pos) 252 if match: 253 return match, True 254 match = _FLAG_NO_EQUAL_PAT.match(flags, pos) 255 if match: 256 # The flag is found but is not given a value. 257 return match, False 258 # The flag is not found. 259 return None, False 260 261 def _validate_flag_names(self): 262 """Validates if the TensorTrace flags passed are valid.""" 263 valid_flag_names = [ 264 _FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, _FLAG_NAME_USE_COMPACT_TRACE, 265 _FLAG_NAME_TRACE_SCALAR_OPS, _FLAG_NAME_TRACE_BEFORE_OPS, 266 _FLAG_NAME_TRACE_AFTER_OPS, _FLAG_NAME_TRACE_STACK_SIZE, 267 _FLAG_NAME_SUBMODE, _FLAG_NAME_EXCLUDED_OPNAMES, 268 _FLAG_NAME_EXCLUDED_OPTYPES, _FLAG_NAME_INCLUDED_OPNAMES, 269 _FLAG_NAME_INCLUDED_OPTYPES, _FLAG_NAME_TRACE_DIR, 270 _FLAG_NAME_INCLUDED_CORES, _FLAG_NAME_REPORT_FILE, 271 _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, 272 _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, _FLAG_NAME_OP_RANGE, 273 _FLAG_DUMP_BEFORE_AFTER_GRAPHS, _FLAG_NAME_TRACE_LEVEL, 274 _FLAG_SUMMARY_SIGNATURES, _FLAG_NAME_SUMMARY_PER_CORE 275 ] 276 tensor_tracer_flags = self._env.get(_FLAGS_ENV_VAR) 277 if not tensor_tracer_flags: 278 return 279 pos = 0 280 while True: 281 match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos) 282 if not match: 283 break 284 flag_name = match.group(1) 285 if flag_name not in valid_flag_names: 286 raise ValueError( 287 'The flag name "%s" passed via the environment variable "%s" ' 288 'is invalid. Valid flag names are:' 289 '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) 290 pos = match.end() 291 292 def _get_summary_signatures(self): 293 """Verifies and returns the summary signatures. 294 295 Returns: 296 A dictionary of the signature identifiers {signature: index} that will be 297 computed when trace_mode is summary. 298 """ 299 signatures = self._flag_value_as_list(_FLAG_SUMMARY_SIGNATURES) 300 301 tt_signatures = [] 302 for signature in signatures: 303 signature_with_prefix = '%s_%s' % (_TT_PREFIX, signature) 304 if signature in TT_SUMMARY_SIGNATURES: 305 tt_signatures.append(signature) 306 elif signature_with_prefix in TT_SUMMARY_SIGNATURES: 307 tt_signatures.append(signature_with_prefix) 308 else: 309 logging.warning('Unknown signature:%s. Supported signatures: %s' % ( 310 signature, TT_SUMMARY_SIGNATURES)) 311 if not tt_signatures: 312 # Default case collects norm and max only. 313 return {TT_SUMMARY_MAX: 0, TT_SUMMARY_NORM: 1} 314 else: 315 return {signature: idx for idx, signature in enumerate(tt_signatures)} 316 317 def get_signature_to_agg_fn_map(self): 318 """Returns a map that contains the aggragate function for each signature.""" 319 return {TT_SUMMARY_NORM: linalg_ops.norm, 320 TT_SUMMARY_MAX: math_ops.reduce_max, 321 TT_SUMMARY_MIN: math_ops.reduce_min, 322 TT_SUMMARY_MEAN: math_ops.reduce_mean, 323 TT_SUMMARY_VAR: math_ops.reduce_max, # Simply reduce max variance. 324 TT_SUMMARY_SIZE: math_ops.reduce_sum} 325 326 def _flag_value_as_list(self, wanted_flag_name): 327 """Returns the string list of a TensorTracer flag. 328 329 Args: 330 wanted_flag_name: the name of the flag we are looking for. 331 332 Returns: 333 The list value of the flag. 334 """ 335 string_value_list = [] 336 found, flag_value = self.get_flag_value(wanted_flag_name) 337 338 if found: 339 string_value_list = flag_value.split(',') 340 return string_value_list 341 342 def _flag_value_as_int_list(self, wanted_flag_name): 343 """Returns the integer list of a TensorTracer flag. 344 345 Args: 346 wanted_flag_name: the name of the flag we are looking for. 347 348 Returns: 349 the value of the flag. 350 Raises: 351 RuntimeError: If supposedly deadcode is reached. 352 """ 353 int_list = [] 354 found, flag_value = self.get_flag_value(wanted_flag_name) 355 356 if found: 357 try: 358 integer_values = flag_value.split(',') 359 int_list = [int(int_val) for int_val in integer_values] 360 except ValueError: 361 logging.warning('Cannot convert %s to int for flag %s', int_list, 362 wanted_flag_name) 363 return int_list 364 365 def _get_flag_int_value(self, wanted_flag_name, default_value): 366 """Returns the int value of a TensorTracer flag. 367 368 Args: 369 wanted_flag_name: the name of the flag we are looking for. 370 default_value: the default value for the flag, if not provided. 371 Returns: 372 the value of the flag. 373 Raises: 374 RuntimeError: If supposedly deadcode is reached. 375 """ 376 flag_int_value = default_value 377 found, flag_value = self.get_flag_value(wanted_flag_name) 378 379 if found: 380 try: 381 flag_int_value = int(flag_value) 382 except ValueError: 383 logging.warning('Cannot convert %s to int for flag %s' % ( 384 flag_int_value, wanted_flag_name)) 385 return flag_int_value 386 387 def get_flag_value(self, wanted_flag_name): 388 """Returns the value of a TensorTracer flags. 389 390 Args: 391 wanted_flag_name: the name of the flag we are looking for. 392 393 Returns: 394 A pair where the first element indicates if the flag is 395 found and the second element is the value of the flag. 396 397 Raises: 398 RuntimeError: If supposedly deadcode is reached. 399 """ 400 401 tensor_tracer_flags = self._env.get(_FLAGS_ENV_VAR) 402 if not tensor_tracer_flags: 403 return False, None 404 pos = 0 405 while True: 406 match, has_value = TTParameters.match_next_flag( 407 tensor_tracer_flags, pos) 408 if not match: 409 return False, None 410 flag_name = match.group(1) 411 if has_value: 412 flag_value = match.group(2) 413 else: 414 flag_value = None 415 if flag_name == wanted_flag_name: 416 return True, flag_value 417 pos = match.end() 418 raise RuntimeError('Should not reach here.') 419 420 def _flag_value_to_re_list(self, flag_name): 421 """Converts list of strings to compiled RE.""" 422 423 re_list = [] 424 found, flag_value = self.get_flag_value(flag_name) 425 if not found or not flag_value: 426 return re_list 427 list_of_values = flag_value.split(',') 428 for v in list_of_values: 429 r = re.compile(v) 430 re_list.append(r) 431 return re_list 432 433 def is_flag_on(self, flag_name): 434 """Returns True if the given flag is on.""" 435 436 found, flag_value = self.get_flag_value(flag_name) 437 if not found: 438 return False 439 if flag_value is None: 440 return True 441 # Depends on the flag value. 442 flag_value = flag_value.lower() 443 enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] 444 return enabled 445 446 def is_enabled(self): 447 """Returns True if TensorTracer is enabled.""" 448 449 if self.is_flag_on(_FLAG_NAME_ENABLE): 450 logging.info('Tensor Tracer is enabled with flags %s.' % 451 self._env.get(_FLAGS_ENV_VAR)) 452 return True 453 else: 454 return False 455 456 def use_test_undeclared_outputs_dir(self): 457 """Decides the output directory of the report and trace files. 458 459 Args: 460 None. 461 462 Returns: 463 True if the output files should be written to the 464 test-undeclared-outputs-directory defined via an 465 env variable. 466 """ 467 468 return self.is_flag_on(_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) 469