1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ======================================================================== 15"""Utilities to handle tensor tracer parameters.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21 22import os 23import os.path 24import re 25 26from tensorflow.python.ops import linalg_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.platform import tf_logging as logging 29 30TRACE_MODE_PART_TENSOR = 'part-tensor' 31TRACE_MODE_FULL_TENSOR = 'full-tensor' 32TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary' 33 34TRACE_MODE_NAN_INF = 'nan-inf' 35TRACE_MODE_NORM = 'norm' 36TRACE_MODE_MAX_ABS = 'max-abs' 37TRACE_MODE_SUMMARY = 'summary' 38# summary mode to collects a finite set of signatures for each traced tensor, 39# (such as norm, max, min, mean) and dumps it using tb summaries. 40 41# Full tensor mode dumps the whole tensor values for the traced tensors without 42# any processing on them; using tb summaries. 43 44_SUBMODE_BRIEF = 'brief' 45_SUBMODE_DETAILED = 'detailed' 46 47_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") 48_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') 49_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') 50_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') 51 52FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' 53FLAG_NAME_ENABLE = 'enable' 54FLAG_NAME_TRACE_MODE = 'trace_mode' 55FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar' 56FLAG_NAME_SUBMODE = 'submode' 57FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' 58FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' 59FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' 60FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' 61FLAG_NAME_TRACE_LEVEL = 'trace_level' 62FLAG_NAME_TRACE_DIR = 'trace_dir' 63FLAG_NAME_REPORT_FILE = 'report_file' 64FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' 65FLAG_NAME_OP_RANGE = 'op_range' 66# Folder to dump the pre (before tensor tracer updates) and post graphs (after 67# tensor tracer updates). 68FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' 69FLAG_NAME_SUMMARY_SIGNATURES = 'signatures' 70FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' 71FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache' 72FLAG_NAME_INSPECT_TRACE = 'inspect_trace' 73FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory' 74FLAG_FLUSH_SUMMARY = 'flush_summaries' 75 76_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') 77_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' 78 79_TT_DEFAULT_TRACE_LEVEL = 3 80_TT_PREFIX = 'tensor_tracer' 81 82_TT_NORM = 'norm' 83_TT_MAX = 'max' 84_TT_MAX_ABS = 'max-abs' 85_TT_MIN = 'min' 86_TT_MEAN = 'mean' 87_TT_VAR = 'var' 88_TT_SIZE = 'size' 89 90TT_SUMMARY_NORM = '%s_%s' % (_TT_PREFIX, _TT_NORM) 91TT_SUMMARY_MAX = '%s_%s' % (_TT_PREFIX, _TT_MAX) 92TT_SUMMARY_MAX_ABS = '%s_%s' % (_TT_PREFIX, _TT_MAX_ABS) 93TT_SUMMARY_MIN = '%s_%s' % (_TT_PREFIX, _TT_MIN) 94TT_SUMMARY_MEAN = '%s_%s' % (_TT_PREFIX, _TT_MEAN) 95TT_SUMMARY_VAR = '%s_%s' % (_TT_PREFIX, _TT_VAR) 96TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE) 97 98TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN, 99 TT_SUMMARY_MEAN, TT_SUMMARY_VAR, TT_SUMMARY_SIZE, 100 TT_SUMMARY_MAX_ABS) 101 102 103class TTParameters(object): 104 """A class that handles the parameters of Tensor Tracer.""" 105 106 def __init__(self, env=None): 107 if env: 108 self._env = env 109 else: 110 self._env = os.environ 111 self._validate_flag_names() 112 self.trace_mode = self._get_trace_mode() 113 self.submode = self._get_submode() 114 self.trace_dir = self._get_trace_dir() 115 self.report_file_path = self._get_report_filepath() 116 self.op_range = self._get_op_range() 117 self.excluded_opname_re_list = self._flag_value_to_re_list( 118 FLAG_NAME_EXCLUDED_OPNAMES) 119 self.excluded_optype_re_list = self._flag_value_to_re_list( 120 FLAG_NAME_EXCLUDED_OPTYPES) 121 122 self.included_opname_re_list = self._flag_value_to_re_list( 123 FLAG_NAME_INCLUDED_OPNAMES) 124 self.included_optype_re_list = self._flag_value_to_re_list( 125 FLAG_NAME_INCLUDED_OPTYPES) 126 127 self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS) 128 self.use_compact_trace = self.trace_mode in (TRACE_MODE_NAN_INF, 129 TRACE_MODE_NORM, 130 TRACE_MODE_MAX_ABS, 131 TRACE_MODE_SUMMARY) 132 self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR) 133 self.inspect_trace = self.is_flag_on(FLAG_NAME_INSPECT_TRACE) 134 self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR) 135 136 _, self.graph_dump_path = self.get_flag_value( 137 FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS) 138 self.trace_level = self._get_flag_int_value(FLAG_NAME_TRACE_LEVEL, 139 _TT_DEFAULT_TRACE_LEVEL) 140 self.summary_signatures = self._get_summary_signatures() 141 self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE) 142 self.flush_summaries_with_outside_compile = self.is_flag_on( 143 FLAG_FLUSH_SUMMARY) 144 self._check_flag_errors() 145 146 def _check_flag_errors(self): 147 if self.trace_mode in (TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY): 148 if not self.trace_dir: 149 raise ValueError('trace_dir must be explicitly provided in ' 150 'TENSOR_TRACER_FLAGS when summary mode is used.') 151 152 def _get_report_filepath(self): 153 """Sets the path of the output report file.""" 154 155 found, report_file_path = self.get_flag_value(FLAG_NAME_REPORT_FILE) 156 if found and report_file_path \ 157 and self.use_test_undeclared_outputs_dir(): 158 if os.path.isabs(report_file_path): 159 raise ValueError('If use_test_undeclared_outputs_dir is set,' 160 'report_file_path cannot be an absolute path (%s)' 161 %report_file_path) 162 outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 163 report_file_path = os.path.join(outputs_dir, report_file_path) 164 return report_file_path 165 166 def _get_op_range(self): 167 """Sets the index range of the Ops that we will consider tracing.""" 168 found, op_range = self.get_flag_value(FLAG_NAME_OP_RANGE) 169 if not found or not op_range: 170 op_range = (-1, -1) # this means including all ops. 171 return op_range 172 match = _OP_RANGE_PAT.match(op_range) 173 if not match: 174 op_range = (-1, -1) # this means including all ops. 175 return op_range 176 op_range = (int(match.group(1)), int(match.group(2))) 177 return op_range 178 179 def _get_trace_dir(self): 180 found, trace_dir = self.get_flag_value(FLAG_NAME_TRACE_DIR) 181 if found and trace_dir \ 182 and self.use_test_undeclared_outputs_dir(): 183 raise ValueError( 184 'Cannot not use --%s and --%s at the same time' % 185 (FLAG_NAME_TRACE_DIR, FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) 186 if self.use_test_undeclared_outputs_dir(): 187 trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 188 return trace_dir 189 190 def _get_trace_mode(self): 191 """Checks if the given trace mode is valid.""" 192 193 found, trace_mode = self.get_flag_value(FLAG_NAME_TRACE_MODE) 194 if not found or not trace_mode: 195 trace_mode = TRACE_MODE_NORM 196 valid_trace_modes = [ 197 TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR, 198 TRACE_MODE_NORM, TRACE_MODE_MAX_ABS, 199 TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY 200 ] 201 if trace_mode not in valid_trace_modes: 202 raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 203 'Valid trace modes are: %s'%(trace_mode, 204 valid_trace_modes)) 205 return trace_mode 206 207 def is_brief_mode(self): 208 return self.submode == _SUBMODE_BRIEF 209 210 def _get_submode(self): 211 """Checks if the given submode is valid.""" 212 213 found, submode = self.get_flag_value(FLAG_NAME_SUBMODE) 214 if not found or not submode: 215 submode = _SUBMODE_DETAILED 216 if not submode: 217 return 218 valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] 219 if submode not in valid_submodes: 220 raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' 221 'Valid submodes are: %s'%(submode, 222 valid_submodes)) 223 return submode 224 225 @staticmethod 226 def match_next_flag(flags, pos): 227 """Returns the match for the next TensorTracer flag. 228 229 Args: 230 flags: a string that contains the flags. 231 pos: where in flags to start the search. 232 233 Returns: 234 A pair where the first element is the regular-expression 235 match found and the second element indicates if the match 236 has a value. 237 """ 238 239 match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) 240 if match: 241 return match, True 242 match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) 243 if match: 244 return match, True 245 match = _FLAG_NO_QUOTE_PAT.match(flags, pos) 246 if match: 247 return match, True 248 match = _FLAG_NO_EQUAL_PAT.match(flags, pos) 249 if match: 250 # The flag is found but is not given a value. 251 return match, False 252 # The flag is not found. 253 return None, False 254 255 def _validate_flag_names(self): 256 """Validates if the TensorTrace flags passed are valid.""" 257 valid_flag_names = [ 258 FLAG_NAME_ENABLE, FLAG_NAME_TRACE_MODE, 259 FLAG_NAME_TRACE_SCALAR_OPS, 260 FLAG_NAME_SUBMODE, FLAG_NAME_EXCLUDED_OPNAMES, 261 FLAG_NAME_EXCLUDED_OPTYPES, FLAG_NAME_INCLUDED_OPNAMES, 262 FLAG_NAME_INCLUDED_OPTYPES, FLAG_NAME_TRACE_DIR, 263 FLAG_NAME_REPORT_FILE, 264 FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, 265 FLAG_NAME_OP_RANGE, 266 FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL, 267 FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE, 268 FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR, 269 FLAG_NAME_INSPECT_TRACE, FLAG_FLUSH_SUMMARY 270 ] 271 tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) 272 if not tensor_tracer_flags: 273 return 274 pos = 0 275 while True: 276 match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos) 277 if not match: 278 break 279 flag_name = match.group(1) 280 if flag_name not in valid_flag_names: 281 raise ValueError( 282 'The flag name "%s" passed via the environment variable "%s" ' 283 'is invalid. Valid flag names are:' 284 '\n%s' % (flag_name, FLAGS_ENV_VAR, valid_flag_names)) 285 pos = match.end() 286 287 def _get_summary_signatures(self): 288 """Verifies and returns the summary signatures. 289 290 Returns: 291 A dictionary of the signature identifiers {signature: index} that will be 292 computed when trace_mode is summary. 293 """ 294 signatures = self._flag_value_as_list(FLAG_NAME_SUMMARY_SIGNATURES) 295 296 tt_signatures = [] 297 for signature in signatures: 298 signature_with_prefix = '%s_%s' % (_TT_PREFIX, signature) 299 if signature in TT_SUMMARY_SIGNATURES: 300 tt_signatures.append(signature) 301 elif signature_with_prefix in TT_SUMMARY_SIGNATURES: 302 tt_signatures.append(signature_with_prefix) 303 else: 304 logging.warning('Unknown signature:%s. Supported signatures: %s' % ( 305 signature, TT_SUMMARY_SIGNATURES)) 306 if not tt_signatures: 307 # Default case collects norm and max only. 308 return {TT_SUMMARY_MAX_ABS: 0, TT_SUMMARY_NORM: 1} 309 else: 310 return {signature: idx for idx, signature in enumerate(tt_signatures)} 311 312 def get_signature_to_agg_fn_map(self): 313 """Returns a map that contains the aggregate function for each signature.""" 314 return {TRACE_MODE_NORM: linalg_ops.norm, 315 TRACE_MODE_MAX_ABS: math_ops.reduce_max, 316 TRACE_MODE_NAN_INF: math_ops.reduce_max, 317 TT_SUMMARY_NORM: linalg_ops.norm, 318 TT_SUMMARY_MAX: math_ops.reduce_max, 319 TT_SUMMARY_MAX_ABS: 320 lambda t, axis=0: math_ops.reduce_max(math_ops.abs(t), # pylint: disable=g-long-lambda 321 axis=axis), 322 TT_SUMMARY_MIN: math_ops.reduce_min, 323 TT_SUMMARY_MEAN: math_ops.reduce_mean, 324 TT_SUMMARY_VAR: math_ops.reduce_max, # Simply reduce max variance. 325 TT_SUMMARY_SIZE: math_ops.reduce_sum} 326 327 def _flag_value_as_list(self, wanted_flag_name): 328 """Returns the string list of a TensorTracer flag. 329 330 Args: 331 wanted_flag_name: the name of the flag we are looking for. 332 333 Returns: 334 The list value of the flag. 335 """ 336 string_value_list = [] 337 found, flag_value = self.get_flag_value(wanted_flag_name) 338 339 if found: 340 string_value_list = flag_value.split(',') 341 return string_value_list 342 343 def _flag_value_as_int_list(self, wanted_flag_name): 344 """Returns the integer list of a TensorTracer flag. 345 346 Args: 347 wanted_flag_name: the name of the flag we are looking for. 348 349 Returns: 350 the value of the flag. 351 Raises: 352 RuntimeError: If supposedly deadcode is reached. 353 """ 354 int_list = [] 355 found, flag_value = self.get_flag_value(wanted_flag_name) 356 357 if found and flag_value: 358 try: 359 integer_values = flag_value.split(',') 360 int_list = [int(int_val) for int_val in integer_values] 361 except ValueError: 362 logging.warning('Cannot convert %s to int for flag %s', int_list, 363 wanted_flag_name) 364 return int_list 365 366 def _get_flag_int_value(self, wanted_flag_name, default_value): 367 """Returns the int value of a TensorTracer flag. 368 369 Args: 370 wanted_flag_name: the name of the flag we are looking for. 371 default_value: the default value for the flag, if not provided. 372 Returns: 373 the value of the flag. 374 Raises: 375 RuntimeError: If supposedly deadcode is reached. 376 """ 377 flag_int_value = default_value 378 found, flag_value = self.get_flag_value(wanted_flag_name) 379 380 if found: 381 try: 382 flag_int_value = int(flag_value) 383 except ValueError: 384 logging.warning('Cannot convert %s to int for flag %s' % ( 385 flag_int_value, wanted_flag_name)) 386 return flag_int_value 387 388 def get_flag_value(self, wanted_flag_name): 389 """Returns the value of a TensorTracer flags. 390 391 Args: 392 wanted_flag_name: the name of the flag we are looking for. 393 394 Returns: 395 A pair where the first element indicates if the flag is 396 found and the second element is the value of the flag. 397 398 Raises: 399 RuntimeError: If supposedly deadcode is reached. 400 """ 401 402 tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) 403 if not tensor_tracer_flags: 404 return False, None 405 pos = 0 406 while True: 407 match, has_value = TTParameters.match_next_flag( 408 tensor_tracer_flags, pos) 409 if not match: 410 return False, None 411 flag_name = match.group(1) 412 if has_value: 413 flag_value = match.group(2) 414 else: 415 flag_value = None 416 if flag_name == wanted_flag_name: 417 return True, flag_value 418 pos = match.end() 419 raise RuntimeError('Should not reach here.') 420 421 def _flag_value_to_re_list(self, flag_name): 422 """Converts list of strings to compiled RE.""" 423 424 re_list = [] 425 found, flag_value = self.get_flag_value(flag_name) 426 if not found or not flag_value: 427 return re_list 428 list_of_values = flag_value.split(',') 429 for v in list_of_values: 430 r = re.compile(v) 431 re_list.append(r) 432 return re_list 433 434 def is_flag_on(self, flag_name): 435 """Returns True if the given flag is on.""" 436 437 found, flag_value = self.get_flag_value(flag_name) 438 if not found: 439 return False 440 if flag_value is None: 441 return True 442 # Depends on the flag value. 443 flag_value = flag_value.lower() 444 enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] 445 return enabled 446 447 def is_enabled(self): 448 """Returns True if TensorTracer is enabled.""" 449 450 if self.is_flag_on(FLAG_NAME_ENABLE): 451 logging.info('Tensor Tracer is enabled with flags %s.' % 452 self._env.get(FLAGS_ENV_VAR)) 453 return True 454 else: 455 return False 456 457 def use_test_undeclared_outputs_dir(self): 458 """Decides the output directory of the report and trace files. 459 460 Args: 461 None. 462 463 Returns: 464 True if the output files should be written to the 465 test-undeclared-outputs-directory defined via an 466 env variable. 467 """ 468 469 return self.is_flag_on(FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) 470