• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Function for interpolating formatted errors from the TensorFlow runtime.
16
17Exposes the function `interpolate` to interpolate messages with tags of the form
18{{type name}}.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections
26import itertools
27import os
28import re
29import site
30import traceback
31
32import six
33
34from tensorflow.core.protobuf import graph_debug_info_pb2
35
36_NAME_REGEX = r"[A-Za-z0-9_.][A-Za-z0-9_.\-/]*?"
37_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX)
38_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
39_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
40
41_ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
42
43
44# Remove the last three path components from this module's file (i.e.
45# python/framework/error_interpolation.py) so that we have an absolute path
46# prefix to the root of the installation.
47_FRAMEWORK_COMMON_PREFIX = os.path.dirname(
48    os.path.dirname(os.path.dirname(__file__)))
49
50# Sub-directories under the common prefix that are considered part of the
51# framework.
52# Note that keras code lives outside of tensorflow directory, we need to walk
53# up the directory tree and find it.
54_FRAMEWORK_PATH_PREFIXES = [
55    os.path.join(_FRAMEWORK_COMMON_PREFIX, "python") + os.sep,
56    os.path.join(_FRAMEWORK_COMMON_PREFIX, "contrib") + os.sep,
57    os.path.join(os.path.dirname(_FRAMEWORK_COMMON_PREFIX),
58                 "py", "keras") + os.sep,
59]
60
61# Patterns of filename patterns that should be considered internal to
62# the TensorFlow framework.
63_FRAMEWORK_FILENAME_PATTERNS = [
64    re.compile(r"<embedded"),
65]
66
67# This is for OSS keras, since the package is load from local python env,
68# but we don't know exactly where it is installed. Matching to keyword
69# "keras".
70try:
71  _FRAMEWORK_PATH_PREFIXES.extend([
72      os.path.join(package_path, "keras") + os.sep
73      for package_path in site.getsitepackages() + [site.getusersitepackages()]
74  ])
75except AttributeError:
76  # if site.getsitepackages is not available somehow, we just use the "keras" as
77  # the keyword to do the match.
78  _FRAMEWORK_FILENAME_PATTERNS.append(re.compile(r"keras"))
79
80# Patterns of filename patterns that should be considered external to
81# TensorFlow regardless of framework prefix match.
82_EXTERNAL_FILENAME_PATTERNS = [
83    # Explicitly treat test frames as not part of the framework.
84    re.compile(r"_test\.py$"),
85]
86
87
88def parse_message(message):
89  """Parses the message.
90
91  Splits the message into separators and tags. Tags are named tuples
92  representing the string {{type name}} and they are separated by
93  separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are
94  two tags and three separators. The separators are the numeric characters.
95
96  Args:
97    message: String to parse
98
99  Returns:
100    (list of separator strings, list of _ParseTags).
101
102    For example, if message is "123{{node Foo}}456" then this function
103    returns (["123", "456"], [_ParseTag("node", "Foo")])
104  """
105  seps = []
106  tags = []
107  pos = 0
108  while pos < len(message):
109    match = re.match(_INTERPOLATION_PATTERN, message[pos:])
110    if match:
111      seps.append(match.group(1))
112      tags.append(_ParseTag(match.group(3), match.group(4)))
113      pos += match.end()
114    else:
115      break
116  seps.append(message[pos:])
117  return seps, tags
118
119
120def _compute_device_summary_from_list(name, device_assignment_list, prefix=""):
121  """Return a summary of an op's device function stack.
122
123  Args:
124    name: The name of the op.
125    device_assignment_list: The op._device_assignments list.
126    prefix:  An optional string prefix used before each line of the multi-
127        line string returned by this function.
128
129  Returns:
130    A multi-line string similar to:
131        Device assignments active during op 'foo' creation:
132          with tf.device(/cpu:0): <test_1.py:27>
133          with tf.device(some_func<foo.py, 123>): <test_2.py:38>
134    The first line will have no padding to its left by default.  Subsequent
135    lines will have two spaces of left-padding.  Use the prefix argument
136    to increase indentation.
137  """
138  if not device_assignment_list:
139    message = "No device assignments were active during op '%s' creation."
140    message %= name
141    return prefix + message
142
143  str_list = []
144  str_list.append(
145      "%sDevice assignments active during op '%s' creation:" % (prefix, name))
146
147  for traceable_obj in device_assignment_list:
148    location_summary = "<{file}:{line}>".format(
149        file=traceable_obj.filename, line=traceable_obj.lineno)
150    subs = {
151        "prefix": prefix,
152        "indent": "  ",
153        "dev_name": traceable_obj.obj,
154        "loc": location_summary,
155    }
156    str_list.append(
157        "{prefix}{indent}with tf.device({dev_name}): {loc}".format(**subs))
158
159  return "\n".join(str_list)
160
161
162def _compute_device_assignment_summary_from_op(op, prefix=""):
163  # pylint: disable=protected-access
164  return _compute_device_summary_from_list(op.name, op._device_assignments,
165                                           prefix)
166  # pylint: enable=protected-access
167
168
169def _compute_colocation_summary_from_dict(name, colocation_dict, prefix=""):
170  """Return a summary of an op's colocation stack.
171
172  Args:
173    name: The op name.
174    colocation_dict: The op._colocation_dict.
175    prefix:  An optional string prefix used before each line of the multi-
176        line string returned by this function.
177
178  Returns:
179    A multi-line string similar to:
180        Node-device colocations active during op creation:
181          with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
182          with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>
183    The first line will have no padding to its left by default.  Subsequent
184    lines will have two spaces of left-padding.  Use the prefix argument
185    to increase indentation.
186  """
187  if not colocation_dict:
188    message = "No node-device colocations were active during op '%s' creation."
189    message %= name
190    return prefix + message
191
192  str_list = []
193  str_list.append("%sNode-device colocations active during op '%s' creation:" %
194                  (prefix, name))
195
196  for coloc_name, location in colocation_dict.items():
197    location_summary = "<{file}:{line}>".format(
198        file=location.filename, line=location.lineno)
199    subs = {
200        "prefix": prefix,
201        "indent": "  ",
202        "name": coloc_name,
203        "loc": location_summary,
204    }
205    str_list.append(
206        "{prefix}{indent}with tf.colocate_with({name}): {loc}".format(**subs))
207
208  return "\n".join(str_list)
209
210
211def _compute_colocation_summary_from_op(op, prefix=""):
212  """Fetch colocation file, line, and nesting and return a summary string."""
213  # pylint: disable=protected-access
214  return _compute_colocation_summary_from_dict(op.name, op._colocation_dict,
215                                               prefix)
216  # pylint: enable=protected-access
217
218
219def _is_framework_filename(filename):
220  """Returns whether a filename should be considered a part of the framework.
221
222  A file is part of the framework if it does not match a pattern in
223  _EXTERNAL_FILENAME_PATTERNS and it either matches a pattern in
224  _FRAMEWORK_FILENAME_PATTERNS or starts with a _FRAMEWORK_PATH_PREFIXES prefix.
225
226  Args:
227    filename: A filename string.
228
229  Returns:
230    Whether the filename should be considered to be internal to the
231    TensorFlow framework for the purposes of reporting errors.
232  """
233  for pattern in _EXTERNAL_FILENAME_PATTERNS:
234    if pattern.search(filename):
235      return False
236  for pattern in _FRAMEWORK_FILENAME_PATTERNS:
237    if pattern.search(filename):
238      return True
239  for prefix in _FRAMEWORK_PATH_PREFIXES:
240    if filename.startswith(prefix):
241      return True
242  return False
243
244
245def _find_index_of_defining_frame(tb):
246  """Return index in op.traceback with first 'useful' frame.
247
248  This method reads through the stack stored in op.traceback looking for the
249  innermost frame which (hopefully) belongs to the caller.  It accomplishes this
250  by rejecting frames deemed to be part of the TensorFlow framework (by
251  pattern matching the filename).
252
253  Args:
254    tb: A list of traceback frames (as from Operation.traceback).
255
256  Returns:
257    Integer index into op.traceback where the first non-TF file was found
258    (innermost to outermost), or 0 (for the outermost stack frame) if all files
259    came from TensorFlow.
260  """
261  # Index 0 of traceback is the outermost frame.
262  size = len(tb)
263  filenames = [frame.filename for frame in tb]
264  # We process the filenames from the innermost frame to outermost.
265  for idx, filename in enumerate(reversed(filenames)):
266    is_framework = _is_framework_filename(filename)
267    if not is_framework:
268      # Consider this to be the defining frame.
269      return size - idx - 1
270  return 0
271
272
273# TODO(feyu): follow up with users of this function (saved model)
274# to see what 'useful' means and whether we can obliviate this.
275def _compute_useful_frames(tb, num):
276  """Return a list of frames, which form a 'useful' stack.
277
278  Starting from the defining frame to the outermost one, this method computes
279  the contiguous portion of the 'useful' stack trace and returns the selected
280  frames.
281
282  Args:
283    tb: A list of traceback frames (as from Operation.traceback).
284    num: total number of frames to return.
285
286  Returns:
287    A list of frames.
288  """
289  defining_frame_index = _find_index_of_defining_frame(tb)
290  # The stack trace is collected from two lines before the defining frame in the
291  # model file to the outermost with `num` frames at most. These two extra lines
292  # are included from the TensorFlow library to give the context which node is
293  # defined.
294  innermost_excluded = min(defining_frame_index + 2 + 1, len(tb))
295  outermost_included = max(innermost_excluded - num, 0)
296  return tb[outermost_included:innermost_excluded]
297
298
299def create_graph_debug_info_def(func_named_operations):
300  """Construct and returns a `GraphDebugInfo` protocol buffer.
301
302  Args:
303    func_named_operations: An iterable of (func_name, op.Operation) tuples
304      where the Operation instances have a _traceback members. The func_name
305      should be the empty string for operations in the top-level Graph.
306
307  Returns:
308    GraphDebugInfo protocol buffer.
309
310  Raises:
311    TypeError: If the arguments are not of the correct proto buffer type.
312  """
313  # Creates an empty GraphDebugInfoDef proto.
314  graph_debug_info_def = graph_debug_info_pb2.GraphDebugInfo()
315
316  # Gets the file names and line numbers for the exported node names. Also
317  # collects the unique file names.
318  all_file_names = set()
319  node_to_trace = {}
320  for func_name, op in func_named_operations:
321    try:
322      op_traceback = op.traceback
323    except AttributeError:
324      # Some ops synthesized on as part of function or control flow definition
325      # do not have tracebacks.
326      continue
327
328    # Gets the stack trace of the operation and then the file location.
329    node_name = op.name + "@" + func_name
330    node_to_trace[node_name] = _compute_useful_frames(op_traceback, 10)
331    for frame in node_to_trace[node_name]:
332      all_file_names.add(frame.filename)
333
334  # Sets the `files` field in the GraphDebugInfo proto
335  graph_debug_info_def.files.extend(all_file_names)
336
337  # Builds a mapping between file names and index of the `files` field, so we
338  # only store the indexes for the nodes in the GraphDebugInfo.
339  file_to_index = dict(
340      [(y, x) for x, y in enumerate(graph_debug_info_def.files)])
341
342  # Creates the FileLineCol proto for each node and sets the value in the
343  # GraphDebugInfo proto. We only store the file name index for each node to
344  # save the storage space.
345  for node_name, frames in node_to_trace.items():
346    trace_def = graph_debug_info_def.traces[node_name]
347    for frame in reversed(frames):
348      trace_def.file_line_cols.add(
349          file_index=file_to_index[frame.filename],
350          line=frame.lineno)
351
352  return graph_debug_info_def
353
354
355def _compute_field_dict(op):
356  r"""Return a dictionary mapping interpolation tokens to values.
357
358  Args:
359    op: op.Operation object having a _traceback member.
360
361  Returns:
362    A dictionary mapping string tokens to string values.  The keys are shown
363    below along with example values.
364    {
365      "file": "tool_utils.py",
366      "lineno": "124",
367      "line": "  source code line",
368      "defined_at": " (defined at tool_utils.py:124)",
369      "colocations":
370          '''Node-device colocations active during op creation:
371               with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
372               with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>'''
373      "devices":
374          '''Device assignments active during op 'foo' creation:
375               with tf.device(/cpu:0): <test_1.py:27>
376               with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
377      "devs_and_colocs": A concatenation of colocations and devices, e.g.
378          '''Node-device colocations active during op creation:
379               with tf.compat.v1.colocate_with(test_node_1): <test_1.py:27>
380               with tf.compat.v1.colocate_with(test_node_2): <test_2.py:38>'''
381             Device assignments active during op 'foo' creation:
382               with tf.device(/cpu:0): <test_1.py:27>
383               with tf.device(some_func<foo.py, 123>): <test_2.py:38>'''
384    }
385  """
386  colocation_summary = _compute_colocation_summary_from_op(op)
387  device_summary = _compute_device_assignment_summary_from_op(op)
388  combined_summary = "\n".join([colocation_summary, device_summary])
389
390  # Optional traceback info.
391  try:
392    tb = op.traceback
393  except AttributeError:
394    # Some ops synthesized on as part of function or control flow definition
395    # do not have tracebacks.
396    filename = "<unknown>"
397    lineno = 0
398    defined_at = " (defined at <unknown>)"
399    definition_traceback = ""
400    line = ""
401  else:
402    frame = tb.last_user_frame()
403    filename = frame.filename
404    definition_traceback = traceback.format_list(tb.get_user_frames())
405
406    lineno = frame.lineno
407    line = frame.line
408    defined_at = " (defined at %s:%d)" % (filename, lineno)
409
410  field_dict = {
411      "colocations": colocation_summary,
412      "devices": device_summary,
413      "devs_and_colocs": combined_summary,
414      "defined_at": defined_at,
415      "file": filename,
416      "lineno": lineno,
417      "line": line,
418      "definition_traceback": definition_traceback,
419  }
420  return field_dict
421
422
423def _get_input_ops_for_op(op, graph):
424  """Gets the input ops for op.
425
426  We do a best effort and may not always find all input Ops.
427
428  Args:
429    op: The op node.
430    graph: The graph containing the node.
431
432  Returns:
433    A list of (ind_inp, op_inp).
434    ind_inp: index in the input list.
435    op_inp: op_inp, the input Op at ind_inp in the input list.
436  """
437  inputs = []
438  for ind_inp, name in enumerate(op.node_def.input):
439    if name.startswith("^"):
440      name = name[1:]
441    try:
442      tensor = graph.get_tensor_by_name(name)
443      op_inp = tensor.op
444    except (KeyError, ValueError):
445      try:
446        op_inp = graph.get_operation_by_name(name)
447      except KeyError:
448        continue
449    inputs.append((ind_inp, op_inp))
450
451  return inputs
452
453
454def _build_error_message(op, input_ops):
455  """Returns the formatted error message for the given op.
456
457  Args:
458    op: The node.
459    input_ops: The input nodes to the 'op' node
460
461  Returns:
462    The formatted error message for the given op. The error message also
463    includes the information about the input sources for the given op.
464  """
465  field_dict = _compute_field_dict(op)
466  msg = "node %s\n%s\n" % (op.name, field_dict["defined_at"])
467  input_debug_info = []
468  # This stores the line numbers that we have already printed.
469  done = set()
470  done.add(field_dict["defined_at"])
471  for ind_inp, op_inp in input_ops:
472    field_dict_inp = _compute_field_dict(op_inp)
473    if field_dict_inp["defined_at"] not in done:
474      input_debug_info.append(
475          "In[%d] %s%s" % (ind_inp, op_inp.name, field_dict_inp["defined_at"]))
476      done.add(field_dict_inp["defined_at"])
477    else:
478      input_debug_info.append("In[%d] %s:" % (ind_inp, op_inp.name))
479
480  end_msg = ""
481  if input_debug_info:
482    end_msg += ("\nInput Source operations connected to node %s:\n") % (op.name)
483    end_msg += "\t\n".join(input_debug_info)
484
485  end_msg += "\n\nOperation defined at: (most recent call last)\n"
486
487  definition_traceback = "\n".join(field_dict["definition_traceback"])
488  # Adds a prefix to differentiate from a Python Interpreter traceback.
489  end_msg += "\n".join([">>> " + s for s in definition_traceback.split("\n")])
490
491  return msg, end_msg
492
493
494def interpolate(error_message, graph):
495  """Interpolates an error message.
496
497  The error message can contain tags of the form `{{type name}}` which will be
498  replaced. For example: "{{node <name>}}" would get expanded to:
499  "node <name>(defined at <path>)".
500
501  Args:
502    error_message: A string to interpolate.
503    graph: ops.Graph object containing all nodes referenced in the error
504        message.
505
506  Returns:
507    The string with tags of the form {{type name}} interpolated.
508  """
509  seps, tags = parse_message(error_message)
510  subs = []
511  end_msg = collections.defaultdict(list)
512  tagged_ops = []
513  all_ops = []
514
515  for t in tags:
516    try:
517      op = graph.get_operation_by_name(t.name)
518    except KeyError:
519      op = None
520    if op is None:
521      tagged_ops.append((None, None))
522    else:
523      op_inps = _get_input_ops_for_op(op, graph)
524      tagged_ops.append((op, op_inps))
525      for _, op_inp in op_inps:
526        all_ops.append(op_inp)
527
528  for tag, (op, op_inps), in zip(tags, tagged_ops):
529    msg = "{{%s %s}}" % (tag.type, tag.name)
530    if op is not None:
531      if tag.type == "node":
532        msg, source_msg = _build_error_message(op, op_inps)
533        if source_msg:
534          end_msg["source_nodes"].append(source_msg)
535      elif tag.type == "colocation_node":
536        field_dict = _compute_field_dict(op)
537        msg = "node %s%s placed on device %s " % (
538            op.name, field_dict["defined_at"], field_dict["devices"])
539        end_msg["colocations"].append(field_dict["devs_and_colocs"])
540    if tag.type == "function_node":
541      msg = ""
542    subs.append(msg)
543
544  if "source_nodes" in end_msg:
545    subs.append("\n\nErrors may have originated from an input operation.")
546    subs.append("\n".join(end_msg["source_nodes"]))
547    end_msg.pop("source_nodes", None)
548  for k, messages in end_msg.items():
549    subs.append("Additional information about %s:" % k)
550    subs.append("\n".join(messages))
551
552  return "".join(
553      itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
554