• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.errors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import re
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import error_interpolation
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.framework import traceable_stack
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import test
31from tensorflow.python.util import tf_stack
32
33
34def _make_frame_with_filename(op, idx, filename):
35  """Return a copy of an existing stack frame with a new filename."""
36  stack_frame = list(op._traceback[idx])
37  stack_frame[tf_stack.TB_FILENAME] = filename
38  return tuple(stack_frame)
39
40
41def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
42                                    num_inner_tf_frames):
43  """Replace op._traceback with a new traceback using special filenames."""
44  tf_filename = "%d" + error_interpolation._BAD_FILE_SUBSTRINGS[0]
45  user_filename = os.path.join("%d", "my_favorite_file.py")
46
47  num_requested_frames = num_user_frames + num_inner_tf_frames
48  num_actual_frames = len(op._traceback)
49  num_outer_frames = num_actual_frames - num_requested_frames
50  assert num_requested_frames <= num_actual_frames, "Too few real frames."
51
52  # The op's traceback has outermost frame at index 0.
53  stack = []
54  for idx in range(0, num_outer_frames):
55    stack.append(op._traceback[idx])
56  for idx in range(len(stack), len(stack) + num_user_frames):
57    stack.append(_make_frame_with_filename(op, idx, user_filename % idx))
58  for idx in range(len(stack), len(stack) + num_inner_tf_frames):
59    stack.append(_make_frame_with_filename(op, idx, tf_filename % idx))
60  op._traceback = stack
61
62
63class ComputeDeviceSummaryFromOpTest(test.TestCase):
64
65  def testCorrectFormatWithActiveDeviceAssignments(self):
66    assignments = []
67    assignments.append(
68        traceable_stack.TraceableObject(
69            "/cpu:0", filename="hope.py", lineno=24))
70    assignments.append(
71        traceable_stack.TraceableObject(
72            "/gpu:2", filename="please.py", lineno=42))
73
74    summary = error_interpolation._compute_device_summary_from_list(
75        "nodename", assignments, prefix="  ")
76
77    self.assertIn("nodename", summary)
78    self.assertIn("tf.device(/cpu:0)", summary)
79    self.assertIn("<hope.py:24>", summary)
80    self.assertIn("tf.device(/gpu:2)", summary)
81    self.assertIn("<please.py:42>", summary)
82
83  def testCorrectFormatWhenNoColocationsWereActive(self):
84    device_assignment_list = []
85    summary = error_interpolation._compute_device_summary_from_list(
86        "nodename", device_assignment_list, prefix="  ")
87    self.assertIn("nodename", summary)
88    self.assertIn("No device assignments", summary)
89
90
91class ComputeColocationSummaryFromOpTest(test.TestCase):
92
93  def testCorrectFormatWithActiveColocations(self):
94    t_obj_1 = traceable_stack.TraceableObject(
95        None, filename="test_1.py", lineno=27)
96    t_obj_2 = traceable_stack.TraceableObject(
97        None, filename="test_2.py", lineno=38)
98    colocation_dict = {
99        "test_node_1": t_obj_1,
100        "test_node_2": t_obj_2,
101    }
102    summary = error_interpolation._compute_colocation_summary_from_dict(
103        "node_name", colocation_dict, prefix="  ")
104    self.assertIn("node_name", summary)
105    self.assertIn("colocate_with(test_node_1)", summary)
106    self.assertIn("<test_1.py:27>", summary)
107    self.assertIn("colocate_with(test_node_2)", summary)
108    self.assertIn("<test_2.py:38>", summary)
109
110  def testCorrectFormatWhenNoColocationsWereActive(self):
111    colocation_dict = {}
112    summary = error_interpolation._compute_colocation_summary_from_dict(
113        "node_name", colocation_dict, prefix="  ")
114    self.assertIn("node_name", summary)
115    self.assertIn("No node-device colocations", summary)
116
117
118@test_util.run_deprecated_v1
119class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
120
121  def setUp(self):
122    ops.reset_default_graph()
123    # Add nodes to the graph for retrieval by name later.
124    constant_op.constant(1, name="One")
125    constant_op.constant(2, name="Two")
126    three = constant_op.constant(3, name="Three")
127    self.graph = three.graph
128
129    # Change the list of bad file substrings so that constant_op.py is chosen
130    # as the defining stack frame for constant_op.constant ops.
131    self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS
132    error_interpolation._BAD_FILE_SUBSTRINGS = [
133        "%sops.py" % os.sep,
134        "%sutil" % os.sep,
135    ]
136
137  def tearDown(self):
138    error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings
139
140  def testFindIndexOfDefiningFrameForOp(self):
141    local_op = constant_op.constant(42).op
142    user_filename = "hope.py"
143    _modify_op_stack_with_filenames(
144        local_op,
145        num_user_frames=3,
146        user_filename=user_filename,
147        num_inner_tf_frames=5)
148    idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
149    # Expected frame is 6th from the end because there are 5 inner frames witih
150    # TF filenames.
151    expected_frame = len(local_op._traceback) - 6
152    self.assertEqual(expected_frame, idx)
153
154  def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self):
155    local_op = constant_op.constant(43).op
156    # Truncate stack to known length.
157    local_op._traceback = local_op._traceback[:7]
158    # Ensure all frames look like TF frames.
159    _modify_op_stack_with_filenames(
160        local_op,
161        num_user_frames=0,
162        user_filename="user_file.py",
163        num_inner_tf_frames=7)
164    idx = error_interpolation._find_index_of_defining_frame_for_op(local_op)
165    self.assertEqual(0, idx)
166
167  def testNothingToDo(self):
168    normal_string = "This is just a normal string"
169    interpolated_string = error_interpolation.interpolate(
170        normal_string, self.graph)
171    self.assertEqual(interpolated_string, normal_string)
172
173  def testOneTagWithAFakeNameResultsInPlaceholders(self):
174    one_tag_string = "{{node MinusOne}}"
175    interpolated_string = error_interpolation.interpolate(
176        one_tag_string, self.graph)
177    self.assertEqual(one_tag_string, interpolated_string)
178
179  def testTwoTagsNoSeps(self):
180    two_tags_no_seps = "{{node One}}{{node Three}}"
181    interpolated_string = error_interpolation.interpolate(
182        two_tags_no_seps, self.graph)
183    self.assertRegexpMatches(interpolated_string,
184                             "constant_op.py:[0-9]+.*constant_op.py:[0-9]+")
185
186  def testTwoTagsWithSeps(self):
187    two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
188    interpolated_string = error_interpolation.interpolate(
189        two_tags_with_seps, self.graph)
190    expected_regex = (
191        r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
192    self.assertRegexpMatches(interpolated_string, expected_regex)
193
194  def testNewLine(self):
195    newline = "\n\n{{node One}}"
196    interpolated_string = error_interpolation.interpolate(newline, self.graph)
197    self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
198
199
200@test_util.run_deprecated_v1
201class InputNodesTest(test.TestCase):
202
203  def setUp(self):
204    # Add nodes to the graph for retrieval by name later.
205    one = constant_op.constant(1, name="One")
206    two = constant_op.constant(2, name="Two")
207    three = math_ops.add(one, two, name="Three")
208    self.graph = three.graph
209
210    # Change the list of bad file substrings so that constant_op.py is chosen
211    # as the defining stack frame for constant_op.constant ops.
212    self.old_bad_strings = error_interpolation._BAD_FILE_SUBSTRINGS
213    error_interpolation._BAD_FILE_SUBSTRINGS = [
214        "%sops.py" % os.sep,
215        "%sutil" % os.sep,
216    ]
217
218  def tearDown(self):
219    error_interpolation._BAD_FILE_SUBSTRINGS = self.old_bad_strings
220
221  def testNoInputs(self):
222    two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;"
223    interpolated_string = error_interpolation.interpolate(
224        two_tags_with_seps, self.graph)
225    expected_regex = (
226        r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
227    self.assertRegexpMatches(interpolated_string, expected_regex)
228
229  def testBasicInputs(self):
230    tag = ";;;{{node Three}};;;"
231    interpolated_string = error_interpolation.interpolate(tag, self.graph)
232    expected_regex = re.compile(
233        r"^;;;.*op_def_library.py:[0-9]+\) ;;;.*Input.*constant_op.py:[0-9]+\)",
234        re.DOTALL)
235    self.assertRegexpMatches(interpolated_string, expected_regex)
236
237
238@test_util.run_deprecated_v1
239class InterpolateDeviceSummaryTest(test.TestCase):
240
241  def _fancy_device_function(self, unused_op):
242    return "/cpu:*"
243
244  def setUp(self):
245    ops.reset_default_graph()
246    self.zero = constant_op.constant([0.0], name="zero")
247    with ops.device("/cpu"):
248      self.one = constant_op.constant([1.0], name="one")
249      with ops.device("/cpu:0"):
250        self.two = constant_op.constant([2.0], name="two")
251    with ops.device(self._fancy_device_function):
252      self.three = constant_op.constant(3.0, name="three")
253
254    self.graph = self.three.graph
255
256  def testNodeZeroHasNoDeviceSummaryInfo(self):
257    message = "{{colocation_node zero}}"
258    result = error_interpolation.interpolate(message, self.graph)
259    self.assertIn("No device assignments were active", result)
260
261  def testNodeOneHasExactlyOneInterpolatedDevice(self):
262    message = "{{colocation_node one}}"
263    result = error_interpolation.interpolate(message, self.graph)
264    self.assertEqual(2, result.count("tf.device(/cpu)"))
265
266  def testNodeTwoHasTwoInterpolatedDevice(self):
267    message = "{{colocation_node two}}"
268    result = error_interpolation.interpolate(message, self.graph)
269    self.assertEqual(2, result.count("tf.device(/cpu)"))
270    self.assertEqual(2, result.count("tf.device(/cpu:0)"))
271
272  def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
273    message = "{{colocation_node three}}"
274    result = error_interpolation.interpolate(message, self.graph)
275    num_devices = result.count("tf.device")
276    self.assertEqual(2, num_devices)
277    name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
278    expected_re = r"with tf.device\(.*%s\)" % name_re
279    self.assertRegexpMatches(result, expected_re)
280
281
282@test_util.run_deprecated_v1
283class InterpolateColocationSummaryTest(test.TestCase):
284
285  def setUp(self):
286    ops.reset_default_graph()
287    # Add nodes to the graph for retrieval by name later.
288    node_one = constant_op.constant(1, name="One")
289    node_two = constant_op.constant(2, name="Two")
290
291    # node_three has one colocation group, obviously.
292    with ops.colocate_with(node_one):
293      node_three = constant_op.constant(3, name="Three_with_one")
294
295    # node_four has one colocation group even though three is (transitively)
296    # colocated with one.
297    with ops.colocate_with(node_three):
298      constant_op.constant(4, name="Four_with_three")
299
300    # node_five has two colocation groups because one and two are not colocated.
301    with ops.colocate_with(node_two):
302      with ops.colocate_with(node_one):
303        constant_op.constant(5, name="Five_with_one_with_two")
304
305    self.graph = node_three.graph
306
307  def testNodeThreeHasColocationInterpolation(self):
308    message = "{{colocation_node Three_with_one}}"
309    result = error_interpolation.interpolate(message, self.graph)
310    self.assertIn("colocate_with(One)", result)
311
312  def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
313    message = "{{colocation_node Four_with_three}}"
314    result = error_interpolation.interpolate(message, self.graph)
315    self.assertIn("colocate_with(Three_with_one)", result)
316    self.assertNotIn(
317        "One", result,
318        "Node One should not appear in Four_with_three's summary:\n%s" % result)
319
320  def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
321    message = "{{colocation_node Five_with_one_with_two}}"
322    result = error_interpolation.interpolate(message, self.graph)
323    self.assertIn("colocate_with(One)", result)
324    self.assertIn("colocate_with(Two)", result)
325
326  def testColocationInterpolationForNodeLackingColocation(self):
327    message = "{{colocation_node One}}"
328    result = error_interpolation.interpolate(message, self.graph)
329    self.assertIn("No node-device colocations", result)
330    self.assertNotIn("Two", result)
331
332
333if __name__ == "__main__":
334  test.main()
335