1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.framework.errors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import os 23import re 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import error_interpolation 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import traceable_stack 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import test 31 32# A mock for ``tf_stack.FrameSummary``. 33FrameSummary = collections.namedtuple( 34 "StackFrame", ["filename", "lineno", "name", "line"]) 35 36 37def _make_frame_with_filename(op, idx, filename): 38 """Return a copy of an existing stack frame with a new filename.""" 39 frame = op._traceback[idx] 40 return FrameSummary( 41 filename, 42 frame.lineno, 43 frame.name, 44 frame.line) 45 46 47def _modify_op_stack_with_filenames(op, num_user_frames, user_filename, 48 num_inner_tf_frames): 49 """Replace op._traceback with a new traceback using special filenames.""" 50 tf_filename = error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "%d.py" 51 user_filename = os.path.join("%d", "my_favorite_file.py") 52 53 num_requested_frames = num_user_frames + num_inner_tf_frames 54 num_actual_frames = len(op._traceback) 55 num_outer_frames = num_actual_frames - num_requested_frames 56 assert num_requested_frames <= num_actual_frames, "Too few real frames." 57 58 # The op's traceback has outermost frame at index 0. 59 stack = [] 60 for idx in range(0, num_outer_frames): 61 stack.append(op._traceback[idx]) 62 for idx in range(len(stack), len(stack) + num_user_frames): 63 stack.append(_make_frame_with_filename(op, idx, user_filename % idx)) 64 for idx in range(len(stack), len(stack) + num_inner_tf_frames): 65 stack.append(_make_frame_with_filename(op, idx, tf_filename % idx)) 66 op._traceback = stack 67 68 69class ComputeDeviceSummaryFromOpTest(test.TestCase): 70 71 def testCorrectFormatWithActiveDeviceAssignments(self): 72 assignments = [] 73 assignments.append( 74 traceable_stack.TraceableObject( 75 "/cpu:0", filename="hope.py", lineno=24)) 76 assignments.append( 77 traceable_stack.TraceableObject( 78 "/gpu:2", filename="please.py", lineno=42)) 79 80 summary = error_interpolation._compute_device_summary_from_list( 81 "nodename", assignments, prefix=" ") 82 83 self.assertIn("nodename", summary) 84 self.assertIn("tf.device(/cpu:0)", summary) 85 self.assertIn("<hope.py:24>", summary) 86 self.assertIn("tf.device(/gpu:2)", summary) 87 self.assertIn("<please.py:42>", summary) 88 89 def testCorrectFormatWhenNoColocationsWereActive(self): 90 device_assignment_list = [] 91 summary = error_interpolation._compute_device_summary_from_list( 92 "nodename", device_assignment_list, prefix=" ") 93 self.assertIn("nodename", summary) 94 self.assertIn("No device assignments", summary) 95 96 97class ComputeColocationSummaryFromOpTest(test.TestCase): 98 99 def testCorrectFormatWithActiveColocations(self): 100 t_obj_1 = traceable_stack.TraceableObject( 101 None, filename="test_1.py", lineno=27) 102 t_obj_2 = traceable_stack.TraceableObject( 103 None, filename="test_2.py", lineno=38) 104 colocation_dict = { 105 "test_node_1": t_obj_1, 106 "test_node_2": t_obj_2, 107 } 108 summary = error_interpolation._compute_colocation_summary_from_dict( 109 "node_name", colocation_dict, prefix=" ") 110 self.assertIn("node_name", summary) 111 self.assertIn("colocate_with(test_node_1)", summary) 112 self.assertIn("<test_1.py:27>", summary) 113 self.assertIn("colocate_with(test_node_2)", summary) 114 self.assertIn("<test_2.py:38>", summary) 115 116 def testCorrectFormatWhenNoColocationsWereActive(self): 117 colocation_dict = {} 118 summary = error_interpolation._compute_colocation_summary_from_dict( 119 "node_name", colocation_dict, prefix=" ") 120 self.assertIn("node_name", summary) 121 self.assertIn("No node-device colocations", summary) 122 123 124# Note that the create_graph_debug_info_def needs to run on graph mode ops, 125# so it is excluded from eager tests. Even when used in eager mode, it is 126# via FunctionGraphs, and directly verifying in graph mode is the narrowest 127# way to unit test the functionality. 128class CreateGraphDebugInfoDefTest(test.TestCase): 129 130 def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index): 131 self.assertIn(key, graph_debug_info.traces) 132 stack_trace = graph_debug_info.traces[key] 133 found_flc = None 134 for flc in stack_trace.file_line_cols: 135 if flc.file_index == file_index: 136 found_flc = flc 137 break 138 self.assertIsNotNone(found_flc, 139 "Could not find a stack trace entry for file") 140 return found_flc 141 142 def testStackTraceExtraction(self): 143 # This test is verifying stack trace information added in graph mode, so 144 # only makes sense in graph mode. 145 with ops.Graph().as_default(): 146 # Since the create_graph_debug_info_def() function does not actually 147 # do anything special with functions except name mangling, just verify 148 # it with a loose op and manually provided function name. 149 # The following ops *must* be on consecutive lines (it will be verified 150 # in the resulting trace). 151 # pyformat: disable 152 global_op = constant_op.constant(0, name="Global").op 153 op1 = constant_op.constant(1, name="One").op 154 op2 = constant_op.constant(2, name="Two").op 155 non_traceback_op = constant_op.constant(3, name="NonTraceback").op 156 # Ensure op without traceback does not fail 157 del non_traceback_op._traceback 158 # pyformat: enable 159 160 export_ops = [("", global_op), ("func1", op1), ("func2", op2), 161 ("func2", non_traceback_op)] 162 graph_debug_info = error_interpolation.create_graph_debug_info_def( 163 export_ops) 164 this_file_index = -1 165 for file_index, file_name in enumerate(graph_debug_info.files): 166 if "{}error_interpolation_test.py".format(os.sep) in file_name: 167 this_file_index = file_index 168 self.assertGreaterEqual( 169 this_file_index, 0, 170 "Could not find this file in trace:" + repr(graph_debug_info)) 171 172 # Verify the traces exist for each op. 173 global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@", 174 this_file_index) 175 op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1", 176 this_file_index) 177 op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2", 178 this_file_index) 179 180 global_line = global_flc.line 181 self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line") 182 self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line") 183 184 185class InterpolateFilenamesAndLineNumbersTest(test.TestCase): 186 187 def testFindIndexOfDefiningFrameForOp(self): 188 with ops.Graph().as_default(): 189 local_op = constant_op.constant(42).op 190 user_filename = "hope.py" 191 _modify_op_stack_with_filenames( 192 local_op, 193 num_user_frames=3, 194 user_filename=user_filename, 195 num_inner_tf_frames=5) 196 idx = error_interpolation._find_index_of_defining_frame( 197 local_op._traceback) 198 # Expected frame is 6th from the end because there are 5 inner frames with 199 # TF filenames. 200 expected_frame = len(local_op._traceback) - 6 201 self.assertEqual(expected_frame, idx) 202 203 def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self): 204 with ops.Graph().as_default(): 205 local_op = constant_op.constant(43).op 206 # Truncate stack to known length. 207 local_op._traceback = local_op._traceback[:7] 208 # Ensure all frames look like TF frames. 209 _modify_op_stack_with_filenames( 210 local_op, 211 num_user_frames=0, 212 user_filename="user_file.py", 213 num_inner_tf_frames=7) 214 idx = error_interpolation._find_index_of_defining_frame( 215 local_op._traceback) 216 self.assertEqual(0, idx) 217 218 def testNothingToDo(self): 219 with ops.Graph().as_default(): 220 constant_op.constant(1, name="One") 221 normal_string = "This is just a normal string" 222 interpolated_string = error_interpolation.interpolate( 223 normal_string, ops.get_default_graph()) 224 self.assertEqual(interpolated_string, normal_string) 225 226 def testOneTagWithAFakeNameResultsInPlaceholders(self): 227 with ops.Graph().as_default(): 228 one_tag_string = "{{node MinusOne}}" 229 interpolated_string = error_interpolation.interpolate( 230 one_tag_string, ops.get_default_graph()) 231 self.assertEqual(one_tag_string, interpolated_string) 232 233 def testTwoTagsNoSeps(self): 234 with ops.Graph().as_default(): 235 constant_op.constant(1, name="One") 236 constant_op.constant(2, name="Two") 237 constant_op.constant(3, name="Three") 238 two_tags_no_seps = "{{node One}}{{node Three}}" 239 interpolated_string = error_interpolation.interpolate( 240 two_tags_no_seps, ops.get_default_graph()) 241 self.assertRegex( 242 interpolated_string, r"error_interpolation_test\.py:[0-9]+." 243 r"*error_interpolation_test\.py:[0-9]+") 244 245 def testTwoTagsWithSeps(self): 246 with ops.Graph().as_default(): 247 constant_op.constant(1, name="One") 248 constant_op.constant(2, name="Two") 249 constant_op.constant(3, name="Three") 250 two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;" 251 interpolated_string = error_interpolation.interpolate( 252 two_tags_with_seps, ops.get_default_graph()) 253 expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) " 254 r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$") 255 self.assertRegex(interpolated_string, expected_regex) 256 257 def testNewLine(self): 258 with ops.Graph().as_default(): 259 constant_op.constant(1, name="One") 260 constant_op.constant(2, name="Two") 261 newline = "\n\n{{node One}}" 262 interpolated_string = error_interpolation.interpolate( 263 newline, ops.get_default_graph()) 264 self.assertRegex(interpolated_string, 265 r"error_interpolation_test\.py:[0-9]+.*") 266 267 268class InputNodesTest(test.TestCase): 269 270 def testNoInputs(self): 271 with ops.Graph().as_default(): 272 one = constant_op.constant(1, name="One") 273 two = constant_op.constant(2, name="Two") 274 _ = math_ops.add(one, two, name="Three") 275 two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;" 276 interpolated_string = error_interpolation.interpolate( 277 two_tags_with_seps, ops.get_default_graph()) 278 expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) " 279 r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$") 280 self.assertRegex(interpolated_string, expected_regex) 281 282 def testBasicInputs(self): 283 with ops.Graph().as_default(): 284 one = constant_op.constant(1, name="One") 285 two = constant_op.constant(2, name="Two") 286 _ = math_ops.add(one, two, name="Three") 287 tag = ";;;{{node Three}};;;" 288 interpolated_string = error_interpolation.interpolate( 289 tag, ops.get_default_graph()) 290 expected_regex = re.compile( 291 r"^;;;.*error_interpolation_test\.py:[0-9]+\) " 292 r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL) 293 self.assertRegex(interpolated_string, expected_regex) 294 295 296class InterpolateDeviceSummaryTest(test.TestCase): 297 298 def _fancy_device_function(self, unused_op): 299 return "/cpu:*" 300 301 def testNodeZeroHasNoDeviceSummaryInfo(self): 302 with ops.Graph().as_default(): 303 self.zero = constant_op.constant([0.0], name="zero") 304 message = "{{colocation_node zero}}" 305 result = error_interpolation.interpolate(message, ops.get_default_graph()) 306 self.assertIn("No device assignments were active", result) 307 308 def testNodeOneHasExactlyOneInterpolatedDevice(self): 309 with ops.Graph().as_default(): 310 with ops.device("/cpu"): 311 self.one = constant_op.constant([1.0], name="one") 312 message = "{{colocation_node one}}" 313 result = error_interpolation.interpolate(message, ops.get_default_graph()) 314 self.assertEqual(2, result.count("tf.device(/cpu)")) 315 316 def testNodeTwoHasTwoInterpolatedDevice(self): 317 with ops.Graph().as_default(): 318 with ops.device("/cpu"): 319 with ops.device("/cpu:0"): 320 self.two = constant_op.constant([2.0], name="two") 321 message = "{{colocation_node two}}" 322 result = error_interpolation.interpolate(message, ops.get_default_graph()) 323 self.assertEqual(2, result.count("tf.device(/cpu)")) 324 self.assertEqual(2, result.count("tf.device(/cpu:0)")) 325 326 def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self): 327 with ops.Graph().as_default(): 328 with ops.device(self._fancy_device_function): 329 self.three = constant_op.constant(3.0, name="three") 330 message = "{{colocation_node three}}" 331 result = error_interpolation.interpolate(message, ops.get_default_graph()) 332 num_devices = result.count("tf.device") 333 self.assertEqual(2, num_devices) 334 name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>" 335 expected_re = r"with tf.device\(.*%s\)" % name_re 336 self.assertRegex(result, expected_re) 337 338 339class InterpolateColocationSummaryTest(test.TestCase): 340 341 def _set_up_graph(self): 342 # Add nodes to the graph for retrieval by name later. 343 node_one = constant_op.constant(1, name="One") 344 node_two = constant_op.constant(2, name="Two") 345 346 # node_three has one colocation group, obviously. 347 with ops.colocate_with(node_one): 348 node_three = constant_op.constant(3, name="Three_with_one") 349 350 # node_four has one colocation group even though three is (transitively) 351 # colocated with one. 352 with ops.colocate_with(node_three): 353 constant_op.constant(4, name="Four_with_three") 354 355 # node_five has two colocation groups because one and two are not colocated. 356 with ops.colocate_with(node_two): 357 with ops.colocate_with(node_one): 358 constant_op.constant(5, name="Five_with_one_with_two") 359 360 def testNodeThreeHasColocationInterpolation(self): 361 with ops.Graph().as_default(): 362 self._set_up_graph() 363 message = "{{colocation_node Three_with_one}}" 364 result = error_interpolation.interpolate(message, ops.get_default_graph()) 365 self.assertIn("colocate_with(One)", result) 366 367 def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): 368 with ops.Graph().as_default(): 369 self._set_up_graph() 370 message = "{{colocation_node Four_with_three}}" 371 result = error_interpolation.interpolate(message, ops.get_default_graph()) 372 self.assertIn("colocate_with(Three_with_one)", result) 373 self.assertNotIn( 374 "One", result, 375 "Node One should not appear in Four_with_three's summary:\n%s" % 376 result) 377 378 def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): 379 with ops.Graph().as_default(): 380 self._set_up_graph() 381 message = "{{colocation_node Five_with_one_with_two}}" 382 result = error_interpolation.interpolate(message, ops.get_default_graph()) 383 self.assertIn("colocate_with(One)", result) 384 self.assertIn("colocate_with(Two)", result) 385 386 def testColocationInterpolationForNodeLackingColocation(self): 387 with ops.Graph().as_default(): 388 self._set_up_graph() 389 message = "{{colocation_node One}}" 390 result = error_interpolation.interpolate(message, ops.get_default_graph()) 391 self.assertIn("No node-device colocations", result) 392 self.assertNotIn("Two", result) 393 394 395class IsFrameworkFilenameTest(test.TestCase): 396 397 def testAllowsUnitTests(self): 398 self.assertFalse( 399 error_interpolation._is_framework_filename( 400 error_interpolation._FRAMEWORK_PATH_PREFIXES[0] + "foobar_test.py")) 401 402 def testFrameworkPythonFile(self): 403 self.assertTrue( 404 error_interpolation._is_framework_filename( 405 error_interpolation.__file__)) 406 407 def testEmbedded(self): 408 self.assertTrue( 409 error_interpolation._is_framework_filename( 410 "<embedded stdlib>/context_lib.py")) 411 412 413if __name__ == "__main__": 414 test.main() 415