• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Command parsing module for TensorFlow Debugger (tfdbg)."""
16import argparse
17import ast
18import re
19import sys
20
21
22_BRACKETS_PATTERN = re.compile(r"\[[^\]]*\]")
23_QUOTES_PATTERN = re.compile(r"(\"[^\"]*\"|\'[^\']*\')")
24_WHITESPACE_PATTERN = re.compile(r"\s+")
25
26_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?")
27
28
29class Interval(object):
30  """Represents an interval between a start and end value."""
31
32  def __init__(self, start, start_included, end, end_included):
33    self.start = start
34    self.start_included = start_included
35    self.end = end
36    self.end_included = end_included
37
38  def contains(self, value):
39    if value < self.start or value == self.start and not self.start_included:
40      return False
41    if value > self.end or value == self.end and not self.end_included:
42      return False
43    return True
44
45  def __eq__(self, other):
46    return (self.start == other.start and
47            self.start_included == other.start_included and
48            self.end == other.end and
49            self.end_included == other.end_included)
50
51
52def parse_command(command):
53  """Parse command string into a list of arguments.
54
55  - Disregards whitespace inside double quotes and brackets.
56  - Strips paired leading and trailing double quotes in arguments.
57  - Splits the command at whitespace.
58
59  Nested double quotes and brackets are not handled.
60
61  Args:
62    command: (str) Input command.
63
64  Returns:
65    (list of str) List of arguments.
66  """
67
68  command = command.strip()
69  if not command:
70    return []
71
72  brackets_intervals = [f.span() for f in _BRACKETS_PATTERN.finditer(command)]
73  quotes_intervals = [f.span() for f in _QUOTES_PATTERN.finditer(command)]
74  whitespaces_intervals = [
75      f.span() for f in _WHITESPACE_PATTERN.finditer(command)
76  ]
77
78  if not whitespaces_intervals:
79    return [command]
80
81  arguments = []
82  idx0 = 0
83  for start, end in whitespaces_intervals + [(len(command), None)]:
84    # Skip whitespace stretches enclosed in brackets or double quotes.
85
86    if not any(interval[0] < start < interval[1]
87               for interval in brackets_intervals + quotes_intervals):
88      argument = command[idx0:start]
89
90      # Strip leading and trailing double quote if they are paired.
91      if (argument.startswith("\"") and argument.endswith("\"") or
92          argument.startswith("'") and argument.endswith("'")):
93        argument = argument[1:-1]
94      arguments.append(argument)
95      idx0 = end
96
97  return arguments
98
99
100def extract_output_file_path(args):
101  """Extract output file path from command arguments.
102
103  Args:
104    args: (list of str) command arguments.
105
106  Returns:
107    (list of str) Command arguments with the output file path part stripped.
108    (str or None) Output file path (if any).
109
110  Raises:
111    SyntaxError: If there is no file path after the last ">" character.
112  """
113
114  if args and args[-1].endswith(">"):
115    raise SyntaxError("Redirect file path is empty")
116  elif args and args[-1].startswith(">"):
117    try:
118      _parse_interval(args[-1])
119      if len(args) > 1 and args[-2].startswith("-"):
120        output_file_path = None
121      else:
122        output_file_path = args[-1][1:]
123        args = args[:-1]
124    except ValueError:
125      output_file_path = args[-1][1:]
126      args = args[:-1]
127  elif len(args) > 1 and args[-2] == ">":
128    output_file_path = args[-1]
129    args = args[:-2]
130  elif args and args[-1].count(">") == 1:
131    gt_index = args[-1].index(">")
132    if gt_index > 0 and args[-1][gt_index - 1] == "=":
133      output_file_path = None
134    else:
135      output_file_path = args[-1][gt_index + 1:]
136      args[-1] = args[-1][:gt_index]
137  elif len(args) > 1 and args[-2].endswith(">"):
138    output_file_path = args[-1]
139    args = args[:-1]
140    args[-1] = args[-1][:-1]
141  else:
142    output_file_path = None
143
144  return args, output_file_path
145
146
147def parse_tensor_name_with_slicing(in_str):
148  """Parse tensor name, potentially suffixed by slicing string.
149
150  Args:
151    in_str: (str) Input name of the tensor, potentially followed by a slicing
152      string. E.g.: Without slicing string: "hidden/weights/Variable:0", with
153      slicing string: "hidden/weights/Variable:0[1, :]"
154
155  Returns:
156    (str) name of the tensor
157    (str) slicing string, if any. If no slicing string is present, return "".
158  """
159
160  if in_str.count("[") == 1 and in_str.endswith("]"):
161    tensor_name = in_str[:in_str.index("[")]
162    tensor_slicing = in_str[in_str.index("["):]
163  else:
164    tensor_name = in_str
165    tensor_slicing = ""
166
167  return tensor_name, tensor_slicing
168
169
170def validate_slicing_string(slicing_string):
171  """Validate a slicing string.
172
173  Check if the input string contains only brackets, digits, commas and
174  colons that are valid characters in numpy-style array slicing.
175
176  Args:
177    slicing_string: (str) Input slicing string to be validated.
178
179  Returns:
180    (bool) True if and only if the slicing string is valid.
181  """
182
183  return bool(re.search(r"^\[(\d|,|\s|:)+\]$", slicing_string))
184
185
186def _parse_slices(slicing_string):
187  """Construct a tuple of slices from the slicing string.
188
189  The string must be a valid slicing string.
190
191  Args:
192    slicing_string: (str) Input slicing string to be parsed.
193
194  Returns:
195    tuple(slice1, slice2, ...)
196
197  Raises:
198    ValueError: If tensor_slicing is not a valid numpy ndarray slicing str.
199  """
200  parsed = []
201  for slice_string in slicing_string[1:-1].split(","):
202    indices = slice_string.split(":")
203    if len(indices) == 1:
204      parsed.append(int(indices[0].strip()))
205    elif 2 <= len(indices) <= 3:
206      parsed.append(
207          slice(*[
208              int(index.strip()) if index.strip() else None for index in indices
209          ]))
210    else:
211      raise ValueError("Invalid tensor-slicing string.")
212  return tuple(parsed)
213
214
215def parse_indices(indices_string):
216  """Parse a string representing indices.
217
218  For example, if the input is "[1, 2, 3]", the return value will be a list of
219  indices: [1, 2, 3]
220
221  Args:
222    indices_string: (str) a string representing indices. Can optionally be
223      surrounded by a pair of brackets.
224
225  Returns:
226    (list of int): Parsed indices.
227  """
228
229  # Strip whitespace.
230  indices_string = re.sub(r"\s+", "", indices_string)
231
232  # Strip any brackets at the two ends.
233  if indices_string.startswith("[") and indices_string.endswith("]"):
234    indices_string = indices_string[1:-1]
235
236  return [int(element) for element in indices_string.split(",")]
237
238
239def parse_ranges(range_string):
240  """Parse a string representing numerical range(s).
241
242  Args:
243    range_string: (str) A string representing a numerical range or a list of
244      them. For example:
245        "[-1.0,1.0]", "[-inf, 0]", "[[-inf, -1.0], [1.0, inf]]"
246
247  Returns:
248    (list of list of float) A list of numerical ranges parsed from the input
249      string.
250
251  Raises:
252    ValueError: If the input doesn't represent a range or a list of ranges.
253  """
254
255  range_string = range_string.strip()
256  if not range_string:
257    return []
258
259  if "inf" in range_string:
260    range_string = re.sub(r"inf", repr(sys.float_info.max), range_string)
261
262  ranges = ast.literal_eval(range_string)
263  if isinstance(ranges, list) and not isinstance(ranges[0], list):
264    ranges = [ranges]
265
266  # Verify that ranges is a list of list of numbers.
267  for item in ranges:
268    if len(item) != 2:
269      raise ValueError("Incorrect number of elements in range")
270    elif not isinstance(item[0], (int, float)):
271      raise ValueError("Incorrect type in the 1st element of range: %s" %
272                       type(item[0]))
273    elif not isinstance(item[1], (int, float)):
274      raise ValueError("Incorrect type in the 2nd element of range: %s" %
275                       type(item[0]))
276
277  return ranges
278
279
280def parse_memory_interval(interval_str):
281  """Convert a human-readable memory interval to a tuple of start and end value.
282
283  Args:
284    interval_str: (`str`) A human-readable str representing an interval
285      (e.g., "[10kB, 20kB]", "<100M", ">100G"). Only the units "kB", "MB", "GB"
286      are supported. The "B character at the end of the input `str` may be
287      omitted.
288
289  Returns:
290    `Interval` object where start and end are in bytes.
291
292  Raises:
293    ValueError: if the input is not valid.
294  """
295  str_interval = _parse_interval(interval_str)
296  interval_start = 0
297  interval_end = float("inf")
298  if str_interval.start:
299    interval_start = parse_readable_size_str(str_interval.start)
300  if str_interval.end:
301    interval_end = parse_readable_size_str(str_interval.end)
302  if interval_start > interval_end:
303    raise ValueError(
304        "Invalid interval %s. Start of interval must be less than or equal "
305        "to end of interval." % interval_str)
306  return Interval(interval_start, str_interval.start_included,
307                  interval_end, str_interval.end_included)
308
309
310def parse_time_interval(interval_str):
311  """Convert a human-readable time interval to a tuple of start and end value.
312
313  Args:
314    interval_str: (`str`) A human-readable str representing an interval
315      (e.g., "[10us, 20us]", "<100s", ">100ms"). Supported time suffixes are
316      us, ms, s.
317
318  Returns:
319    `Interval` object where start and end are in microseconds.
320
321  Raises:
322    ValueError: if the input is not valid.
323  """
324  str_interval = _parse_interval(interval_str)
325  interval_start = 0
326  interval_end = float("inf")
327  if str_interval.start:
328    interval_start = parse_readable_time_str(str_interval.start)
329  if str_interval.end:
330    interval_end = parse_readable_time_str(str_interval.end)
331  if interval_start > interval_end:
332    raise ValueError(
333        "Invalid interval %s. Start must be before end of interval." %
334        interval_str)
335  return Interval(interval_start, str_interval.start_included,
336                  interval_end, str_interval.end_included)
337
338
339def _parse_interval(interval_str):
340  """Convert a human-readable interval to a tuple of start and end value.
341
342  Args:
343    interval_str: (`str`) A human-readable str representing an interval
344      (e.g., "[1M, 2M]", "<100k", ">100ms"). The items following the ">", "<",
345      ">=" and "<=" signs have to start with a number (e.g., 3.0, -2, .98).
346      The same requirement applies to the items in the parentheses or brackets.
347
348  Returns:
349    Interval object where start or end can be None
350    if the range is specified as "<N" or ">N" respectively.
351
352  Raises:
353    ValueError: if the input is not valid.
354  """
355  interval_str = interval_str.strip()
356  if interval_str.startswith("<="):
357    if _NUMBER_PATTERN.match(interval_str[2:].strip()):
358      return Interval(start=None, start_included=False,
359                      end=interval_str[2:].strip(), end_included=True)
360    else:
361      raise ValueError("Invalid value string after <= in '%s'" % interval_str)
362  if interval_str.startswith("<"):
363    if _NUMBER_PATTERN.match(interval_str[1:].strip()):
364      return Interval(start=None, start_included=False,
365                      end=interval_str[1:].strip(), end_included=False)
366    else:
367      raise ValueError("Invalid value string after < in '%s'" % interval_str)
368  if interval_str.startswith(">="):
369    if _NUMBER_PATTERN.match(interval_str[2:].strip()):
370      return Interval(start=interval_str[2:].strip(), start_included=True,
371                      end=None, end_included=False)
372    else:
373      raise ValueError("Invalid value string after >= in '%s'" % interval_str)
374  if interval_str.startswith(">"):
375    if _NUMBER_PATTERN.match(interval_str[1:].strip()):
376      return Interval(start=interval_str[1:].strip(), start_included=False,
377                      end=None, end_included=False)
378    else:
379      raise ValueError("Invalid value string after > in '%s'" % interval_str)
380
381  if (not interval_str.startswith(("[", "("))
382      or not interval_str.endswith(("]", ")"))):
383    raise ValueError(
384        "Invalid interval format: %s. Valid formats are: [min, max], "
385        "(min, max), <max, >min" % interval_str)
386  interval = interval_str[1:-1].split(",")
387  if len(interval) != 2:
388    raise ValueError(
389        "Incorrect interval format: %s. Interval should specify two values: "
390        "[min, max] or (min, max)." % interval_str)
391
392  start_item = interval[0].strip()
393  if not _NUMBER_PATTERN.match(start_item):
394    raise ValueError("Invalid first item in interval: '%s'" % start_item)
395  end_item = interval[1].strip()
396  if not _NUMBER_PATTERN.match(end_item):
397    raise ValueError("Invalid second item in interval: '%s'" % end_item)
398
399  return Interval(start=start_item,
400                  start_included=(interval_str[0] == "["),
401                  end=end_item,
402                  end_included=(interval_str[-1] == "]"))
403
404
405def parse_readable_size_str(size_str):
406  """Convert a human-readable str representation to number of bytes.
407
408  Only the units "kB", "MB", "GB" are supported. The "B character at the end
409  of the input `str` may be omitted.
410
411  Args:
412    size_str: (`str`) A human-readable str representing a number of bytes
413      (e.g., "0", "1023", "1.1kB", "24 MB", "23GB", "100 G".
414
415  Returns:
416    (`int`) The parsed number of bytes.
417
418  Raises:
419    ValueError: on failure to parse the input `size_str`.
420  """
421
422  size_str = size_str.strip()
423  if size_str.endswith("B"):
424    size_str = size_str[:-1]
425
426  if size_str.isdigit():
427    return int(size_str)
428  elif size_str.endswith("k"):
429    return int(float(size_str[:-1]) * 1024)
430  elif size_str.endswith("M"):
431    return int(float(size_str[:-1]) * 1048576)
432  elif size_str.endswith("G"):
433    return int(float(size_str[:-1]) * 1073741824)
434  else:
435    raise ValueError("Failed to parsed human-readable byte size str: \"%s\"" %
436                     size_str)
437
438
439def parse_readable_time_str(time_str):
440  """Parses a time string in the format N, Nus, Nms, Ns.
441
442  Args:
443    time_str: (`str`) string consisting of an integer time value optionally
444      followed by 'us', 'ms', or 's' suffix. If suffix is not specified,
445      value is assumed to be in microseconds. (e.g. 100us, 8ms, 5s, 100).
446
447  Returns:
448    Microseconds value.
449  """
450  def parse_positive_float(value_str):
451    value = float(value_str)
452    if value < 0:
453      raise ValueError(
454          "Invalid time %s. Time value must be positive." % value_str)
455    return value
456
457  time_str = time_str.strip()
458  if time_str.endswith("us"):
459    return int(parse_positive_float(time_str[:-2]))
460  elif time_str.endswith("ms"):
461    return int(parse_positive_float(time_str[:-2]) * 1e3)
462  elif time_str.endswith("s"):
463    return int(parse_positive_float(time_str[:-1]) * 1e6)
464  return int(parse_positive_float(time_str))
465
466
467def evaluate_tensor_slice(tensor, tensor_slicing):
468  """Call eval on the slicing of a tensor, with validation.
469
470  Args:
471    tensor: (numpy ndarray) The tensor value.
472    tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
473      None, no slicing will be performed on the tensor.
474
475  Returns:
476    (numpy ndarray) The sliced tensor.
477
478  Raises:
479    ValueError: If tensor_slicing is not a valid numpy ndarray slicing str.
480  """
481
482  _ = tensor
483
484  if not validate_slicing_string(tensor_slicing):
485    raise ValueError("Invalid tensor-slicing string.")
486
487  return tensor[_parse_slices(tensor_slicing)]
488
489
490def get_print_tensor_argparser(description):
491  """Get an ArgumentParser for a command that prints tensor values.
492
493  Examples of such commands include print_tensor and print_feed.
494
495  Args:
496    description: Description of the ArgumentParser.
497
498  Returns:
499    An instance of argparse.ArgumentParser.
500  """
501
502  ap = argparse.ArgumentParser(
503      description=description, usage=argparse.SUPPRESS)
504  ap.add_argument(
505      "tensor_name",
506      type=str,
507      help="Name of the tensor, followed by any slicing indices, "
508      "e.g., hidden1/Wx_plus_b/MatMul:0, "
509      "hidden1/Wx_plus_b/MatMul:0[1, :]")
510  ap.add_argument(
511      "-n",
512      "--number",
513      dest="number",
514      type=int,
515      default=-1,
516      help="0-based dump number for the specified tensor. "
517      "Required for tensor with multiple dumps.")
518  ap.add_argument(
519      "-r",
520      "--ranges",
521      dest="ranges",
522      type=str,
523      default="",
524      help="Numerical ranges to highlight tensor elements in. "
525      "Examples: -r 0,1e-8, -r [-0.1,0.1], "
526      "-r \"[[-inf, -0.1], [0.1, inf]]\"")
527  ap.add_argument(
528      "-a",
529      "--all",
530      dest="print_all",
531      action="store_true",
532      help="Print the tensor in its entirety, i.e., do not use ellipses.")
533  ap.add_argument(
534      "-s",
535      "--numeric_summary",
536      action="store_true",
537      help="Include summary for non-empty tensors of numeric (int*, float*, "
538      "complex*) and Boolean types.")
539  ap.add_argument(
540      "-w",
541      "--write_path",
542      type=str,
543      default="",
544      help="Path of the numpy file to write the tensor data to, using "
545      "numpy.save().")
546  return ap
547