1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.framework.errors.""" 16 17import gc 18import pickle 19import warnings 20 21from tensorflow.core.lib.core import error_codes_pb2 22from tensorflow.python.framework import _errors_test_helper 23from tensorflow.python.framework import c_api_util 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.lib.io import _pywrap_file_io 27from tensorflow.python.platform import test 28from tensorflow.python.util import compat 29 30 31class ErrorsTest(test.TestCase): 32 33 def _CountReferences(self, typeof): 34 """Count number of references to objects of type |typeof|.""" 35 objs = gc.get_objects() 36 ref_count = 0 37 for o in objs: 38 try: 39 if isinstance(o, typeof): 40 ref_count += 1 41 # Certain versions of python keeps a weakref to deleted objects. 42 except ReferenceError: 43 pass 44 return ref_count 45 46 def testUniqueClassForEachErrorCode(self): 47 for error_code, exc_type in [ 48 (errors.CANCELLED, errors_impl.CancelledError), 49 (errors.UNKNOWN, errors_impl.UnknownError), 50 (errors.INVALID_ARGUMENT, errors_impl.InvalidArgumentError), 51 (errors.DEADLINE_EXCEEDED, errors_impl.DeadlineExceededError), 52 (errors.NOT_FOUND, errors_impl.NotFoundError), 53 (errors.ALREADY_EXISTS, errors_impl.AlreadyExistsError), 54 (errors.PERMISSION_DENIED, errors_impl.PermissionDeniedError), 55 (errors.UNAUTHENTICATED, errors_impl.UnauthenticatedError), 56 (errors.RESOURCE_EXHAUSTED, errors_impl.ResourceExhaustedError), 57 (errors.FAILED_PRECONDITION, errors_impl.FailedPreconditionError), 58 (errors.ABORTED, errors_impl.AbortedError), 59 (errors.OUT_OF_RANGE, errors_impl.OutOfRangeError), 60 (errors.UNIMPLEMENTED, errors_impl.UnimplementedError), 61 (errors.INTERNAL, errors_impl.InternalError), 62 (errors.UNAVAILABLE, errors_impl.UnavailableError), 63 (errors.DATA_LOSS, errors_impl.DataLossError), 64 ]: 65 # pylint: disable=protected-access 66 self.assertTrue( 67 isinstance( 68 errors_impl._make_specific_exception(None, None, None, 69 error_code), exc_type)) 70 # error_code_from_exception_type and exception_type_from_error_code should 71 # be consistent with operation result. 72 self.assertEqual(error_code, 73 errors_impl.error_code_from_exception_type(exc_type)) 74 # pylint: enable=protected-access 75 76 def testKnownErrorClassForEachErrorCodeInProto(self): 77 for error_code in error_codes_pb2.Code.values(): 78 # pylint: disable=line-too-long 79 if error_code in ( 80 error_codes_pb2.OK, error_codes_pb2. 81 DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ 82 ): 83 continue 84 # pylint: enable=line-too-long 85 with warnings.catch_warnings(record=True) as w: 86 # pylint: disable=protected-access 87 exc = errors_impl._make_specific_exception(None, None, None, error_code) 88 # pylint: enable=protected-access 89 self.assertEqual(0, len(w)) # No warning is raised. 90 self.assertTrue(isinstance(exc, errors_impl.OpError)) 91 self.assertTrue(errors_impl.OpError in exc.__class__.__bases__) 92 93 def testUnknownErrorCodeCausesWarning(self): 94 with warnings.catch_warnings(record=True) as w: 95 # pylint: disable=protected-access 96 exc = errors_impl._make_specific_exception(None, None, None, 37) 97 # pylint: enable=protected-access 98 self.assertEqual(1, len(w)) 99 self.assertTrue("Unknown error code: 37" in str(w[0].message)) 100 self.assertTrue(isinstance(exc, errors_impl.OpError)) 101 102 with warnings.catch_warnings(record=True) as w: 103 # pylint: disable=protected-access 104 exc = errors_impl.error_code_from_exception_type("Unknown") 105 # pylint: enable=protected-access 106 self.assertEqual(1, len(w)) 107 self.assertTrue("Unknown class exception" in str(w[0].message)) 108 self.assertTrue(isinstance(exc, errors_impl.OpError)) 109 110 def testStatusDoesNotLeak(self): 111 try: 112 _pywrap_file_io.DeleteFile(compat.as_bytes("/DOES_NOT_EXIST/")) 113 except: 114 pass 115 gc.collect() 116 self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus)) 117 118 def testPickleable(self): 119 for error_code in [ 120 errors.CANCELLED, 121 errors.UNKNOWN, 122 errors.INVALID_ARGUMENT, 123 errors.DEADLINE_EXCEEDED, 124 errors.NOT_FOUND, 125 errors.ALREADY_EXISTS, 126 errors.PERMISSION_DENIED, 127 errors.UNAUTHENTICATED, 128 errors.RESOURCE_EXHAUSTED, 129 errors.FAILED_PRECONDITION, 130 errors.ABORTED, 131 errors.OUT_OF_RANGE, 132 errors.UNIMPLEMENTED, 133 errors.INTERNAL, 134 errors.UNAVAILABLE, 135 errors.DATA_LOSS, 136 ]: 137 # pylint: disable=protected-access 138 exc = errors_impl._make_specific_exception(None, None, None, error_code) 139 # pylint: enable=protected-access 140 unpickled = pickle.loads(pickle.dumps(exc)) 141 self.assertEqual(exc.node_def, unpickled.node_def) 142 self.assertEqual(exc.op, unpickled.op) 143 self.assertEqual(exc.message, unpickled.message) 144 self.assertEqual(exc.error_code, unpickled.error_code) 145 146 def testErrorPayloadsFromStatus(self): 147 for code, expected_exception in [ 148 (1, errors.CancelledError), 149 (2, errors.UnknownError), 150 (3, errors.InvalidArgumentError), 151 (4, errors.DeadlineExceededError), 152 (5, errors.NotFoundError), 153 (6, errors.AlreadyExistsError), 154 (7, errors.PermissionDeniedError), 155 (16, errors.UnauthenticatedError), 156 (8, errors.ResourceExhaustedError), 157 (9, errors.FailedPreconditionError), 158 (10, errors.AbortedError), 159 (11, errors.OutOfRangeError), 160 (12, errors.UnimplementedError), 161 (13, errors.InternalError), 162 (14, errors.UnavailableError), 163 (15, errors.DataLossError), 164 ]: 165 with self.assertRaises(expected_exception) as error: 166 _errors_test_helper.TestRaiseFromStatus(code) 167 self.assertEqual(error.exception.experimental_payloads[b"key1"], 168 b"value1") 169 self.assertEqual(error.exception.experimental_payloads[b"key2"], 170 b"value2") 171 172 def testErrorPayloadsFromTFStatus(self): 173 for code, expected_exception in [ 174 (1, errors.CancelledError), 175 (2, errors.UnknownError), 176 (3, errors.InvalidArgumentError), 177 (4, errors.DeadlineExceededError), 178 (5, errors.NotFoundError), 179 (6, errors.AlreadyExistsError), 180 (7, errors.PermissionDeniedError), 181 (16, errors.UnauthenticatedError), 182 (8, errors.ResourceExhaustedError), 183 (9, errors.FailedPreconditionError), 184 (10, errors.AbortedError), 185 (11, errors.OutOfRangeError), 186 (12, errors.UnimplementedError), 187 (13, errors.InternalError), 188 (14, errors.UnavailableError), 189 (15, errors.DataLossError), 190 ]: 191 with self.assertRaises(expected_exception) as error: 192 _errors_test_helper.TestRaiseFromTFStatus(code) 193 self.assertEqual(error.exception.experimental_payloads[b"key1"], 194 b"value1") 195 self.assertEqual(error.exception.experimental_payloads[b"key2"], 196 b"value2") 197 198 def testErrorPayloadsDefaultValue(self): 199 for exception_type in [ 200 (errors.CancelledError), 201 (errors.UnknownError), 202 (errors.InvalidArgumentError), 203 (errors.DeadlineExceededError), 204 (errors.NotFoundError), 205 (errors.AlreadyExistsError), 206 (errors.PermissionDeniedError), 207 (errors.UnauthenticatedError), 208 (errors.ResourceExhaustedError), 209 (errors.FailedPreconditionError), 210 (errors.AbortedError), 211 (errors.OutOfRangeError), 212 (errors.UnimplementedError), 213 (errors.InternalError), 214 (errors.UnavailableError), 215 (errors.DataLossError), 216 ]: 217 e = exception_type(None, None, None) 218 self.assertEqual(type(e.experimental_payloads), dict) 219 self.assertEqual(len(e.experimental_payloads), 0) 220 221 222if __name__ == "__main__": 223 test.main() 224