• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.errors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23import re
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import error_interpolation
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import traceable_stack
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import test
31
32# A mock for ``tf_stack.FrameSummary``.
33FrameSummary = collections.namedtuple(
34    "StackFrame", ["filename", "lineno", "name", "line"])
35
36
37def _make_frame_with_filename(op, idx, filename):
38  """Return a copy of an existing stack frame with a new filename."""
39  frame = op._traceback[idx]
40  return FrameSummary(
41      filename,
42      frame.lineno,
43      frame.name,
44      frame.line)
45
46
47def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
48                                    num_inner_tf_frames):
49  """Replace op._traceback with a new traceback using special filenames."""
50  tf_filename = error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "%d.py"
51  user_filename = os.path.join("%d", "my_favorite_file.py")
52
53  num_requested_frames = num_user_frames + num_inner_tf_frames
54  num_actual_frames = len(op._traceback)
55  num_outer_frames = num_actual_frames - num_requested_frames
56  assert num_requested_frames <= num_actual_frames, "Too few real frames."
57
58  # The op's traceback has outermost frame at index 0.
59  stack = []
60  for idx in range(0, num_outer_frames):
61    stack.append(op._traceback[idx])
62  for idx in range(len(stack), len(stack) + num_user_frames):
63    stack.append(_make_frame_with_filename(op, idx, user_filename % idx))
64  for idx in range(len(stack), len(stack) + num_inner_tf_frames):
65    stack.append(_make_frame_with_filename(op, idx, tf_filename % idx))
66  op._traceback = stack
67
68
69class ComputeDeviceSummaryFromOpTest(test.TestCase):
70
71  def testCorrectFormatWithActiveDeviceAssignments(self):
72    assignments = []
73    assignments.append(
74        traceable_stack.TraceableObject(
75            "/cpu:0", filename="hope.py", lineno=24))
76    assignments.append(
77        traceable_stack.TraceableObject(
78            "/gpu:2", filename="please.py", lineno=42))
79
80    summary = error_interpolation._compute_device_summary_from_list(
81        "nodename", assignments, prefix="  ")
82
83    self.assertIn("nodename", summary)
84    self.assertIn("tf.device(/cpu:0)", summary)
85    self.assertIn("<hope.py:24>", summary)
86    self.assertIn("tf.device(/gpu:2)", summary)
87    self.assertIn("<please.py:42>", summary)
88
89  def testCorrectFormatWhenNoColocationsWereActive(self):
90    device_assignment_list = []
91    summary = error_interpolation._compute_device_summary_from_list(
92        "nodename", device_assignment_list, prefix="  ")
93    self.assertIn("nodename", summary)
94    self.assertIn("No device assignments", summary)
95
96
97class ComputeColocationSummaryFromOpTest(test.TestCase):
98
99  def testCorrectFormatWithActiveColocations(self):
100    t_obj_1 = traceable_stack.TraceableObject(
101        None, filename="test_1.py", lineno=27)
102    t_obj_2 = traceable_stack.TraceableObject(
103        None, filename="test_2.py", lineno=38)
104    colocation_dict = {
105        "test_node_1": t_obj_1,
106        "test_node_2": t_obj_2,
107    }
108    summary = error_interpolation._compute_colocation_summary_from_dict(
109        "node_name", colocation_dict, prefix="  ")
110    self.assertIn("node_name", summary)
111    self.assertIn("colocate_with(test_node_1)", summary)
112    self.assertIn("<test_1.py:27>", summary)
113    self.assertIn("colocate_with(test_node_2)", summary)
114    self.assertIn("<test_2.py:38>", summary)
115
116  def testCorrectFormatWhenNoColocationsWereActive(self):
117    colocation_dict = {}
118    summary = error_interpolation._compute_colocation_summary_from_dict(
119        "node_name", colocation_dict, prefix="  ")
120    self.assertIn("node_name", summary)
121    self.assertIn("No node-device colocations", summary)
122
123
124# Note that the create_graph_debug_info_def needs to run on graph mode ops,
125# so it is excluded from eager tests. Even when used in eager mode, it is
126# via FunctionGraphs, and directly verifying in graph mode is the narrowest
127# way to unit test the functionality.
128class CreateGraphDebugInfoDefTest(test.TestCase):
129
130  def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index):
131    self.assertIn(key, graph_debug_info.traces)
132    stack_trace = graph_debug_info.traces[key]
133    found_flc = None
134    for flc in stack_trace.file_line_cols:
135      if flc.file_index == file_index:
136        found_flc = flc
137        break
138    self.assertIsNotNone(found_flc,
139                         "Could not find a stack trace entry for file")
140    return found_flc
141
142  def testStackTraceExtraction(self):
143    # This test is verifying stack trace information added in graph mode, so
144    # only makes sense in graph mode.
145    with ops.Graph().as_default():
146      # Since the create_graph_debug_info_def() function does not actually
147      # do anything special with functions except name mangling, just verify
148      # it with a loose op and manually provided function name.
149      # The following ops *must* be on consecutive lines (it will be verified
150      # in the resulting trace).
151      # pyformat: disable
152      global_op = constant_op.constant(0, name="Global").op
153      op1 = constant_op.constant(1, name="One").op
154      op2 = constant_op.constant(2, name="Two").op
155      non_traceback_op = constant_op.constant(3, name="NonTraceback").op
156      # Ensure op without traceback does not fail
157      del non_traceback_op._traceback
158      # pyformat: enable
159
160      export_ops = [("", global_op), ("func1", op1), ("func2", op2),
161                    ("func2", non_traceback_op)]
162      graph_debug_info = error_interpolation.create_graph_debug_info_def(
163          export_ops)
164      this_file_index = -1
165      for file_index, file_name in enumerate(graph_debug_info.files):
166        if "{}error_interpolation_test.py".format(os.sep) in file_name:
167          this_file_index = file_index
168      self.assertGreaterEqual(
169          this_file_index, 0,
170          "Could not find this file in trace:" + repr(graph_debug_info))
171
172      # Verify the traces exist for each op.
173      global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@",
174                                                   this_file_index)
175      op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1",
176                                                this_file_index)
177      op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2",
178                                                this_file_index)
179
180      global_line = global_flc.line
181      self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line")
182      self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line")
183
184
185class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
186
187  def testFindIndexOfDefiningFrameForOp(self):
188    with ops.Graph().as_default():
189      local_op = constant_op.constant(42).op
190      user_filename = "hope.py"
191      _modify_op_stack_with_filenames(
192          local_op,
193          num_user_frames=3,
194          user_filename=user_filename,
195          num_inner_tf_frames=5)
196      idx = error_interpolation._find_index_of_defining_frame(
197          local_op._traceback)
198      # Expected frame is 6th from the end because there are 5 inner frames with
199      # TF filenames.
200      expected_frame = len(local_op._traceback) - 6
201      self.assertEqual(expected_frame, idx)
202
203  def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self):
204    with ops.Graph().as_default():
205      local_op = constant_op.constant(43).op
206      # Truncate stack to known length.
207      local_op._traceback = local_op._traceback[:7]
208      # Ensure all frames look like TF frames.
209      _modify_op_stack_with_filenames(
210          local_op,
211          num_user_frames=0,
212          user_filename="user_file.py",
213          num_inner_tf_frames=7)
214      idx = error_interpolation._find_index_of_defining_frame(
215          local_op._traceback)
216      self.assertEqual(0, idx)
217
218  def testNothingToDo(self):
219    with ops.Graph().as_default():
220      constant_op.constant(1, name="One")
221      normal_string = "This is just a normal string"
222      interpolated_string = error_interpolation.interpolate(
223          normal_string, ops.get_default_graph())
224      self.assertEqual(interpolated_string, normal_string)
225
226  def testOneTagWithAFakeNameResultsInPlaceholders(self):
227    with ops.Graph().as_default():
228      one_tag_string = "{{node MinusOne}}"
229      interpolated_string = error_interpolation.interpolate(
230          one_tag_string, ops.get_default_graph())
231      self.assertEqual(one_tag_string, interpolated_string)
232
233  def testTwoTagsNoSeps(self):
234    with ops.Graph().as_default():
235      constant_op.constant(1, name="One")
236      constant_op.constant(2, name="Two")
237      constant_op.constant(3, name="Three")
238      two_tags_no_seps = "{{node One}}{{node Three}}"
239      interpolated_string = error_interpolation.interpolate(
240          two_tags_no_seps, ops.get_default_graph())
241      self.assertRegex(
242          interpolated_string, r"error_interpolation_test\.py:[0-9]+."
243          r"*error_interpolation_test\.py:[0-9]+")
244
245  def testTwoTagsWithSeps(self):
246    with ops.Graph().as_default():
247      constant_op.constant(1, name="One")
248      constant_op.constant(2, name="Two")
249      constant_op.constant(3, name="Three")
250      two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
251      interpolated_string = error_interpolation.interpolate(
252          two_tags_with_seps, ops.get_default_graph())
253      expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
254                        r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
255      self.assertRegex(interpolated_string, expected_regex)
256
257  def testNewLine(self):
258    with ops.Graph().as_default():
259      constant_op.constant(1, name="One")
260      constant_op.constant(2, name="Two")
261      newline = "\n\n{{node One}}"
262      interpolated_string = error_interpolation.interpolate(
263          newline, ops.get_default_graph())
264      self.assertRegex(interpolated_string,
265                       r"error_interpolation_test\.py:[0-9]+.*")
266
267
268class InputNodesTest(test.TestCase):
269
270  def testNoInputs(self):
271    with ops.Graph().as_default():
272      one = constant_op.constant(1, name="One")
273      two = constant_op.constant(2, name="Two")
274      _ = math_ops.add(one, two, name="Three")
275      two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;"
276      interpolated_string = error_interpolation.interpolate(
277          two_tags_with_seps, ops.get_default_graph())
278      expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
279                        r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
280      self.assertRegex(interpolated_string, expected_regex)
281
282  def testBasicInputs(self):
283    with ops.Graph().as_default():
284      one = constant_op.constant(1, name="One")
285      two = constant_op.constant(2, name="Two")
286      _ = math_ops.add(one, two, name="Three")
287      tag = ";;;{{node Three}};;;"
288      interpolated_string = error_interpolation.interpolate(
289          tag, ops.get_default_graph())
290      expected_regex = re.compile(
291          r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
292          r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL)
293      self.assertRegex(interpolated_string, expected_regex)
294
295
296class InterpolateDeviceSummaryTest(test.TestCase):
297
298  def _fancy_device_function(self, unused_op):
299    return "/cpu:*"
300
301  def testNodeZeroHasNoDeviceSummaryInfo(self):
302    with ops.Graph().as_default():
303      self.zero = constant_op.constant([0.0], name="zero")
304      message = "{{colocation_node zero}}"
305      result = error_interpolation.interpolate(message, ops.get_default_graph())
306      self.assertIn("No device assignments were active", result)
307
308  def testNodeOneHasExactlyOneInterpolatedDevice(self):
309    with ops.Graph().as_default():
310      with ops.device("/cpu"):
311        self.one = constant_op.constant([1.0], name="one")
312      message = "{{colocation_node one}}"
313      result = error_interpolation.interpolate(message, ops.get_default_graph())
314      self.assertEqual(2, result.count("tf.device(/cpu)"))
315
316  def testNodeTwoHasTwoInterpolatedDevice(self):
317    with ops.Graph().as_default():
318      with ops.device("/cpu"):
319        with ops.device("/cpu:0"):
320          self.two = constant_op.constant([2.0], name="two")
321      message = "{{colocation_node two}}"
322      result = error_interpolation.interpolate(message, ops.get_default_graph())
323      self.assertEqual(2, result.count("tf.device(/cpu)"))
324      self.assertEqual(2, result.count("tf.device(/cpu:0)"))
325
326  def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
327    with ops.Graph().as_default():
328      with ops.device(self._fancy_device_function):
329        self.three = constant_op.constant(3.0, name="three")
330      message = "{{colocation_node three}}"
331      result = error_interpolation.interpolate(message, ops.get_default_graph())
332      num_devices = result.count("tf.device")
333      self.assertEqual(2, num_devices)
334      name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
335      expected_re = r"with tf.device\(.*%s\)" % name_re
336      self.assertRegex(result, expected_re)
337
338
339class InterpolateColocationSummaryTest(test.TestCase):
340
341  def _set_up_graph(self):
342    # Add nodes to the graph for retrieval by name later.
343    node_one = constant_op.constant(1, name="One")
344    node_two = constant_op.constant(2, name="Two")
345
346    # node_three has one colocation group, obviously.
347    with ops.colocate_with(node_one):
348      node_three = constant_op.constant(3, name="Three_with_one")
349
350    # node_four has one colocation group even though three is (transitively)
351    # colocated with one.
352    with ops.colocate_with(node_three):
353      constant_op.constant(4, name="Four_with_three")
354
355    # node_five has two colocation groups because one and two are not colocated.
356    with ops.colocate_with(node_two):
357      with ops.colocate_with(node_one):
358        constant_op.constant(5, name="Five_with_one_with_two")
359
360  def testNodeThreeHasColocationInterpolation(self):
361    with ops.Graph().as_default():
362      self._set_up_graph()
363      message = "{{colocation_node Three_with_one}}"
364      result = error_interpolation.interpolate(message, ops.get_default_graph())
365      self.assertIn("colocate_with(One)", result)
366
367  def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
368    with ops.Graph().as_default():
369      self._set_up_graph()
370      message = "{{colocation_node Four_with_three}}"
371      result = error_interpolation.interpolate(message, ops.get_default_graph())
372      self.assertIn("colocate_with(Three_with_one)", result)
373      self.assertNotIn(
374          "One", result,
375          "Node One should not appear in Four_with_three's summary:\n%s" %
376          result)
377
378  def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
379    with ops.Graph().as_default():
380      self._set_up_graph()
381      message = "{{colocation_node Five_with_one_with_two}}"
382      result = error_interpolation.interpolate(message, ops.get_default_graph())
383      self.assertIn("colocate_with(One)", result)
384      self.assertIn("colocate_with(Two)", result)
385
386  def testColocationInterpolationForNodeLackingColocation(self):
387    with ops.Graph().as_default():
388      self._set_up_graph()
389      message = "{{colocation_node One}}"
390      result = error_interpolation.interpolate(message, ops.get_default_graph())
391      self.assertIn("No node-device colocations", result)
392      self.assertNotIn("Two", result)
393
394
395class IsFrameworkFilenameTest(test.TestCase):
396
397  def testAllowsUnitTests(self):
398    self.assertFalse(
399        error_interpolation._is_framework_filename(
400            error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "foobar_test.py"))
401
402  def testFrameworkPythonFile(self):
403    self.assertTrue(
404        error_interpolation._is_framework_filename(
405            error_interpolation.__file__))
406
407  def testEmbedded(self):
408    self.assertTrue(
409        error_interpolation._is_framework_filename(
410            "<embedded stdlib>/context_lib.py"))
411
412
413if __name__ == "__main__":
414  test.main()
415