1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Command parsing module for TensorFlow Debugger (tfdbg).""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import argparse 21import ast 22import re 23import sys 24 25 26_BRACKETS_PATTERN = re.compile(r"\[[^\]]*\]") 27_QUOTES_PATTERN = re.compile(r"(\"[^\"]*\"|\'[^\']*\')") 28_WHITESPACE_PATTERN = re.compile(r"\s+") 29 30_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?") 31 32 33class Interval(object): 34 """Represents an interval between a start and end value.""" 35 36 def __init__(self, start, start_included, end, end_included): 37 self.start = start 38 self.start_included = start_included 39 self.end = end 40 self.end_included = end_included 41 42 def contains(self, value): 43 if value < self.start or value == self.start and not self.start_included: 44 return False 45 if value > self.end or value == self.end and not self.end_included: 46 return False 47 return True 48 49 def __eq__(self, other): 50 return (self.start == other.start and 51 self.start_included == other.start_included and 52 self.end == other.end and 53 self.end_included == other.end_included) 54 55 56def parse_command(command): 57 """Parse command string into a list of arguments. 58 59 - Disregards whitespace inside double quotes and brackets. 60 - Strips paired leading and trailing double quotes in arguments. 61 - Splits the command at whitespace. 62 63 Nested double quotes and brackets are not handled. 64 65 Args: 66 command: (str) Input command. 67 68 Returns: 69 (list of str) List of arguments. 70 """ 71 72 command = command.strip() 73 if not command: 74 return [] 75 76 brackets_intervals = [f.span() for f in _BRACKETS_PATTERN.finditer(command)] 77 quotes_intervals = [f.span() for f in _QUOTES_PATTERN.finditer(command)] 78 whitespaces_intervals = [ 79 f.span() for f in _WHITESPACE_PATTERN.finditer(command) 80 ] 81 82 if not whitespaces_intervals: 83 return [command] 84 85 arguments = [] 86 idx0 = 0 87 for start, end in whitespaces_intervals + [(len(command), None)]: 88 # Skip whitespace stretches enclosed in brackets or double quotes. 89 90 if not any(interval[0] < start < interval[1] 91 for interval in brackets_intervals + quotes_intervals): 92 argument = command[idx0:start] 93 94 # Strip leading and trailing double quote if they are paired. 95 if (argument.startswith("\"") and argument.endswith("\"") or 96 argument.startswith("'") and argument.endswith("'")): 97 argument = argument[1:-1] 98 arguments.append(argument) 99 idx0 = end 100 101 return arguments 102 103 104def extract_output_file_path(args): 105 """Extract output file path from command arguments. 106 107 Args: 108 args: (list of str) command arguments. 109 110 Returns: 111 (list of str) Command arguments with the output file path part stripped. 112 (str or None) Output file path (if any). 113 114 Raises: 115 SyntaxError: If there is no file path after the last ">" character. 116 """ 117 118 if args and args[-1].endswith(">"): 119 raise SyntaxError("Redirect file path is empty") 120 elif args and args[-1].startswith(">"): 121 try: 122 _parse_interval(args[-1]) 123 if len(args) > 1 and args[-2].startswith("-"): 124 output_file_path = None 125 else: 126 output_file_path = args[-1][1:] 127 args = args[:-1] 128 except ValueError: 129 output_file_path = args[-1][1:] 130 args = args[:-1] 131 elif len(args) > 1 and args[-2] == ">": 132 output_file_path = args[-1] 133 args = args[:-2] 134 elif args and args[-1].count(">") == 1: 135 gt_index = args[-1].index(">") 136 if gt_index > 0 and args[-1][gt_index - 1] == "=": 137 output_file_path = None 138 else: 139 output_file_path = args[-1][gt_index + 1:] 140 args[-1] = args[-1][:gt_index] 141 elif len(args) > 1 and args[-2].endswith(">"): 142 output_file_path = args[-1] 143 args = args[:-1] 144 args[-1] = args[-1][:-1] 145 else: 146 output_file_path = None 147 148 return args, output_file_path 149 150 151def parse_tensor_name_with_slicing(in_str): 152 """Parse tensor name, potentially suffixed by slicing string. 153 154 Args: 155 in_str: (str) Input name of the tensor, potentially followed by a slicing 156 string. E.g.: Without slicing string: "hidden/weights/Variable:0", with 157 slicing string: "hidden/weights/Variable:0[1, :]" 158 159 Returns: 160 (str) name of the tensor 161 (str) slicing string, if any. If no slicing string is present, return "". 162 """ 163 164 if in_str.count("[") == 1 and in_str.endswith("]"): 165 tensor_name = in_str[:in_str.index("[")] 166 tensor_slicing = in_str[in_str.index("["):] 167 else: 168 tensor_name = in_str 169 tensor_slicing = "" 170 171 return tensor_name, tensor_slicing 172 173 174def validate_slicing_string(slicing_string): 175 """Validate a slicing string. 176 177 Check if the input string contains only brackets, digits, commas and 178 colons that are valid characters in numpy-style array slicing. 179 180 Args: 181 slicing_string: (str) Input slicing string to be validated. 182 183 Returns: 184 (bool) True if and only if the slicing string is valid. 185 """ 186 187 return bool(re.search(r"^\[(\d|,|\s|:)+\]$", slicing_string)) 188 189 190def _parse_slices(slicing_string): 191 """Construct a tuple of slices from the slicing string. 192 193 The string must be a valid slicing string. 194 195 Args: 196 slicing_string: (str) Input slicing string to be parsed. 197 198 Returns: 199 tuple(slice1, slice2, ...) 200 201 Raises: 202 ValueError: If tensor_slicing is not a valid numpy ndarray slicing str. 203 """ 204 parsed = [] 205 for slice_string in slicing_string[1:-1].split(","): 206 indices = slice_string.split(":") 207 if len(indices) == 1: 208 parsed.append(int(indices[0].strip())) 209 elif 2 <= len(indices) <= 3: 210 parsed.append( 211 slice(*[ 212 int(index.strip()) if index.strip() else None for index in indices 213 ])) 214 else: 215 raise ValueError("Invalid tensor-slicing string.") 216 return tuple(parsed) 217 218 219def parse_indices(indices_string): 220 """Parse a string representing indices. 221 222 For example, if the input is "[1, 2, 3]", the return value will be a list of 223 indices: [1, 2, 3] 224 225 Args: 226 indices_string: (str) a string representing indices. Can optionally be 227 surrounded by a pair of brackets. 228 229 Returns: 230 (list of int): Parsed indices. 231 """ 232 233 # Strip whitespace. 234 indices_string = re.sub(r"\s+", "", indices_string) 235 236 # Strip any brackets at the two ends. 237 if indices_string.startswith("[") and indices_string.endswith("]"): 238 indices_string = indices_string[1:-1] 239 240 return [int(element) for element in indices_string.split(",")] 241 242 243def parse_ranges(range_string): 244 """Parse a string representing numerical range(s). 245 246 Args: 247 range_string: (str) A string representing a numerical range or a list of 248 them. For example: 249 "[-1.0,1.0]", "[-inf, 0]", "[[-inf, -1.0], [1.0, inf]]" 250 251 Returns: 252 (list of list of float) A list of numerical ranges parsed from the input 253 string. 254 255 Raises: 256 ValueError: If the input doesn't represent a range or a list of ranges. 257 """ 258 259 range_string = range_string.strip() 260 if not range_string: 261 return [] 262 263 if "inf" in range_string: 264 range_string = re.sub(r"inf", repr(sys.float_info.max), range_string) 265 266 ranges = ast.literal_eval(range_string) 267 if isinstance(ranges, list) and not isinstance(ranges[0], list): 268 ranges = [ranges] 269 270 # Verify that ranges is a list of list of numbers. 271 for item in ranges: 272 if len(item) != 2: 273 raise ValueError("Incorrect number of elements in range") 274 elif not isinstance(item[0], (int, float)): 275 raise ValueError("Incorrect type in the 1st element of range: %s" % 276 type(item[0])) 277 elif not isinstance(item[1], (int, float)): 278 raise ValueError("Incorrect type in the 2nd element of range: %s" % 279 type(item[0])) 280 281 return ranges 282 283 284def parse_memory_interval(interval_str): 285 """Convert a human-readable memory interval to a tuple of start and end value. 286 287 Args: 288 interval_str: (`str`) A human-readable str representing an interval 289 (e.g., "[10kB, 20kB]", "<100M", ">100G"). Only the units "kB", "MB", "GB" 290 are supported. The "B character at the end of the input `str` may be 291 omitted. 292 293 Returns: 294 `Interval` object where start and end are in bytes. 295 296 Raises: 297 ValueError: if the input is not valid. 298 """ 299 str_interval = _parse_interval(interval_str) 300 interval_start = 0 301 interval_end = float("inf") 302 if str_interval.start: 303 interval_start = parse_readable_size_str(str_interval.start) 304 if str_interval.end: 305 interval_end = parse_readable_size_str(str_interval.end) 306 if interval_start > interval_end: 307 raise ValueError( 308 "Invalid interval %s. Start of interval must be less than or equal " 309 "to end of interval." % interval_str) 310 return Interval(interval_start, str_interval.start_included, 311 interval_end, str_interval.end_included) 312 313 314def parse_time_interval(interval_str): 315 """Convert a human-readable time interval to a tuple of start and end value. 316 317 Args: 318 interval_str: (`str`) A human-readable str representing an interval 319 (e.g., "[10us, 20us]", "<100s", ">100ms"). Supported time suffixes are 320 us, ms, s. 321 322 Returns: 323 `Interval` object where start and end are in microseconds. 324 325 Raises: 326 ValueError: if the input is not valid. 327 """ 328 str_interval = _parse_interval(interval_str) 329 interval_start = 0 330 interval_end = float("inf") 331 if str_interval.start: 332 interval_start = parse_readable_time_str(str_interval.start) 333 if str_interval.end: 334 interval_end = parse_readable_time_str(str_interval.end) 335 if interval_start > interval_end: 336 raise ValueError( 337 "Invalid interval %s. Start must be before end of interval." % 338 interval_str) 339 return Interval(interval_start, str_interval.start_included, 340 interval_end, str_interval.end_included) 341 342 343def _parse_interval(interval_str): 344 """Convert a human-readable interval to a tuple of start and end value. 345 346 Args: 347 interval_str: (`str`) A human-readable str representing an interval 348 (e.g., "[1M, 2M]", "<100k", ">100ms"). The items following the ">", "<", 349 ">=" and "<=" signs have to start with a number (e.g., 3.0, -2, .98). 350 The same requirement applies to the items in the parentheses or brackets. 351 352 Returns: 353 Interval object where start or end can be None 354 if the range is specified as "<N" or ">N" respectively. 355 356 Raises: 357 ValueError: if the input is not valid. 358 """ 359 interval_str = interval_str.strip() 360 if interval_str.startswith("<="): 361 if _NUMBER_PATTERN.match(interval_str[2:].strip()): 362 return Interval(start=None, start_included=False, 363 end=interval_str[2:].strip(), end_included=True) 364 else: 365 raise ValueError("Invalid value string after <= in '%s'" % interval_str) 366 if interval_str.startswith("<"): 367 if _NUMBER_PATTERN.match(interval_str[1:].strip()): 368 return Interval(start=None, start_included=False, 369 end=interval_str[1:].strip(), end_included=False) 370 else: 371 raise ValueError("Invalid value string after < in '%s'" % interval_str) 372 if interval_str.startswith(">="): 373 if _NUMBER_PATTERN.match(interval_str[2:].strip()): 374 return Interval(start=interval_str[2:].strip(), start_included=True, 375 end=None, end_included=False) 376 else: 377 raise ValueError("Invalid value string after >= in '%s'" % interval_str) 378 if interval_str.startswith(">"): 379 if _NUMBER_PATTERN.match(interval_str[1:].strip()): 380 return Interval(start=interval_str[1:].strip(), start_included=False, 381 end=None, end_included=False) 382 else: 383 raise ValueError("Invalid value string after > in '%s'" % interval_str) 384 385 if (not interval_str.startswith(("[", "(")) 386 or not interval_str.endswith(("]", ")"))): 387 raise ValueError( 388 "Invalid interval format: %s. Valid formats are: [min, max], " 389 "(min, max), <max, >min" % interval_str) 390 interval = interval_str[1:-1].split(",") 391 if len(interval) != 2: 392 raise ValueError( 393 "Incorrect interval format: %s. Interval should specify two values: " 394 "[min, max] or (min, max)." % interval_str) 395 396 start_item = interval[0].strip() 397 if not _NUMBER_PATTERN.match(start_item): 398 raise ValueError("Invalid first item in interval: '%s'" % start_item) 399 end_item = interval[1].strip() 400 if not _NUMBER_PATTERN.match(end_item): 401 raise ValueError("Invalid second item in interval: '%s'" % end_item) 402 403 return Interval(start=start_item, 404 start_included=(interval_str[0] == "["), 405 end=end_item, 406 end_included=(interval_str[-1] == "]")) 407 408 409def parse_readable_size_str(size_str): 410 """Convert a human-readable str representation to number of bytes. 411 412 Only the units "kB", "MB", "GB" are supported. The "B character at the end 413 of the input `str` may be omitted. 414 415 Args: 416 size_str: (`str`) A human-readable str representing a number of bytes 417 (e.g., "0", "1023", "1.1kB", "24 MB", "23GB", "100 G". 418 419 Returns: 420 (`int`) The parsed number of bytes. 421 422 Raises: 423 ValueError: on failure to parse the input `size_str`. 424 """ 425 426 size_str = size_str.strip() 427 if size_str.endswith("B"): 428 size_str = size_str[:-1] 429 430 if size_str.isdigit(): 431 return int(size_str) 432 elif size_str.endswith("k"): 433 return int(float(size_str[:-1]) * 1024) 434 elif size_str.endswith("M"): 435 return int(float(size_str[:-1]) * 1048576) 436 elif size_str.endswith("G"): 437 return int(float(size_str[:-1]) * 1073741824) 438 else: 439 raise ValueError("Failed to parsed human-readable byte size str: \"%s\"" % 440 size_str) 441 442 443def parse_readable_time_str(time_str): 444 """Parses a time string in the format N, Nus, Nms, Ns. 445 446 Args: 447 time_str: (`str`) string consisting of an integer time value optionally 448 followed by 'us', 'ms', or 's' suffix. If suffix is not specified, 449 value is assumed to be in microseconds. (e.g. 100us, 8ms, 5s, 100). 450 451 Returns: 452 Microseconds value. 453 """ 454 def parse_positive_float(value_str): 455 value = float(value_str) 456 if value < 0: 457 raise ValueError( 458 "Invalid time %s. Time value must be positive." % value_str) 459 return value 460 461 time_str = time_str.strip() 462 if time_str.endswith("us"): 463 return int(parse_positive_float(time_str[:-2])) 464 elif time_str.endswith("ms"): 465 return int(parse_positive_float(time_str[:-2]) * 1e3) 466 elif time_str.endswith("s"): 467 return int(parse_positive_float(time_str[:-1]) * 1e6) 468 return int(parse_positive_float(time_str)) 469 470 471def evaluate_tensor_slice(tensor, tensor_slicing): 472 """Call eval on the slicing of a tensor, with validation. 473 474 Args: 475 tensor: (numpy ndarray) The tensor value. 476 tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If 477 None, no slicing will be performed on the tensor. 478 479 Returns: 480 (numpy ndarray) The sliced tensor. 481 482 Raises: 483 ValueError: If tensor_slicing is not a valid numpy ndarray slicing str. 484 """ 485 486 _ = tensor 487 488 if not validate_slicing_string(tensor_slicing): 489 raise ValueError("Invalid tensor-slicing string.") 490 491 return tensor[_parse_slices(tensor_slicing)] 492 493 494def get_print_tensor_argparser(description): 495 """Get an ArgumentParser for a command that prints tensor values. 496 497 Examples of such commands include print_tensor and print_feed. 498 499 Args: 500 description: Description of the ArgumentParser. 501 502 Returns: 503 An instance of argparse.ArgumentParser. 504 """ 505 506 ap = argparse.ArgumentParser( 507 description=description, usage=argparse.SUPPRESS) 508 ap.add_argument( 509 "tensor_name", 510 type=str, 511 help="Name of the tensor, followed by any slicing indices, " 512 "e.g., hidden1/Wx_plus_b/MatMul:0, " 513 "hidden1/Wx_plus_b/MatMul:0[1, :]") 514 ap.add_argument( 515 "-n", 516 "--number", 517 dest="number", 518 type=int, 519 default=-1, 520 help="0-based dump number for the specified tensor. " 521 "Required for tensor with multiple dumps.") 522 ap.add_argument( 523 "-r", 524 "--ranges", 525 dest="ranges", 526 type=str, 527 default="", 528 help="Numerical ranges to highlight tensor elements in. " 529 "Examples: -r 0,1e-8, -r [-0.1,0.1], " 530 "-r \"[[-inf, -0.1], [0.1, inf]]\"") 531 ap.add_argument( 532 "-a", 533 "--all", 534 dest="print_all", 535 action="store_true", 536 help="Print the tensor in its entirety, i.e., do not use ellipses.") 537 ap.add_argument( 538 "-s", 539 "--numeric_summary", 540 action="store_true", 541 help="Include summary for non-empty tensors of numeric (int*, float*, " 542 "complex*) and Boolean types.") 543 ap.add_argument( 544 "-w", 545 "--write_path", 546 type=str, 547 default="", 548 help="Path of the numpy file to write the tensor data to, using " 549 "numpy.save().") 550 return ap 551