• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.errors."""
16
17import gc
18import pickle
19import warnings
20
21from tensorflow.core.lib.core import error_codes_pb2
22from tensorflow.python.framework import _errors_test_helper
23from tensorflow.python.framework import c_api_util
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.lib.io import _pywrap_file_io
27from tensorflow.python.platform import test
28from tensorflow.python.util import compat
29
30
31class ErrorsTest(test.TestCase):
32
33  def _CountReferences(self, typeof):
34    """Count number of references to objects of type |typeof|."""
35    objs = gc.get_objects()
36    ref_count = 0
37    for o in objs:
38      try:
39        if isinstance(o, typeof):
40          ref_count += 1
41      # Certain versions of python keeps a weakref to deleted objects.
42      except ReferenceError:
43        pass
44    return ref_count
45
46  def testUniqueClassForEachErrorCode(self):
47    for error_code, exc_type in [
48        (errors.CANCELLED, errors_impl.CancelledError),
49        (errors.UNKNOWN, errors_impl.UnknownError),
50        (errors.INVALID_ARGUMENT, errors_impl.InvalidArgumentError),
51        (errors.DEADLINE_EXCEEDED, errors_impl.DeadlineExceededError),
52        (errors.NOT_FOUND, errors_impl.NotFoundError),
53        (errors.ALREADY_EXISTS, errors_impl.AlreadyExistsError),
54        (errors.PERMISSION_DENIED, errors_impl.PermissionDeniedError),
55        (errors.UNAUTHENTICATED, errors_impl.UnauthenticatedError),
56        (errors.RESOURCE_EXHAUSTED, errors_impl.ResourceExhaustedError),
57        (errors.FAILED_PRECONDITION, errors_impl.FailedPreconditionError),
58        (errors.ABORTED, errors_impl.AbortedError),
59        (errors.OUT_OF_RANGE, errors_impl.OutOfRangeError),
60        (errors.UNIMPLEMENTED, errors_impl.UnimplementedError),
61        (errors.INTERNAL, errors_impl.InternalError),
62        (errors.UNAVAILABLE, errors_impl.UnavailableError),
63        (errors.DATA_LOSS, errors_impl.DataLossError),
64    ]:
65      # pylint: disable=protected-access
66      self.assertTrue(
67          isinstance(
68              errors_impl._make_specific_exception(None, None, None,
69                                                   error_code), exc_type))
70      # error_code_from_exception_type and exception_type_from_error_code should
71      # be consistent with operation result.
72      self.assertEqual(error_code,
73                       errors_impl.error_code_from_exception_type(exc_type))
74      # pylint: enable=protected-access
75
76  def testKnownErrorClassForEachErrorCodeInProto(self):
77    for error_code in error_codes_pb2.Code.values():
78      # pylint: disable=line-too-long
79      if error_code in (
80          error_codes_pb2.OK, error_codes_pb2.
81          DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_
82      ):
83        continue
84      # pylint: enable=line-too-long
85      with warnings.catch_warnings(record=True) as w:
86        # pylint: disable=protected-access
87        exc = errors_impl._make_specific_exception(None, None, None, error_code)
88        # pylint: enable=protected-access
89      self.assertEqual(0, len(w))  # No warning is raised.
90      self.assertTrue(isinstance(exc, errors_impl.OpError))
91      self.assertTrue(errors_impl.OpError in exc.__class__.__bases__)
92
93  def testUnknownErrorCodeCausesWarning(self):
94    with warnings.catch_warnings(record=True) as w:
95      # pylint: disable=protected-access
96      exc = errors_impl._make_specific_exception(None, None, None, 37)
97      # pylint: enable=protected-access
98    self.assertEqual(1, len(w))
99    self.assertTrue("Unknown error code: 37" in str(w[0].message))
100    self.assertTrue(isinstance(exc, errors_impl.OpError))
101
102    with warnings.catch_warnings(record=True) as w:
103      # pylint: disable=protected-access
104      exc = errors_impl.error_code_from_exception_type("Unknown")
105      # pylint: enable=protected-access
106    self.assertEqual(1, len(w))
107    self.assertTrue("Unknown class exception" in str(w[0].message))
108    self.assertTrue(isinstance(exc, errors_impl.OpError))
109
110  def testStatusDoesNotLeak(self):
111    try:
112      _pywrap_file_io.DeleteFile(compat.as_bytes("/DOES_NOT_EXIST/"))
113    except:
114      pass
115    gc.collect()
116    self.assertEqual(0, self._CountReferences(c_api_util.ScopedTFStatus))
117
118  def testPickleable(self):
119    for error_code in [
120        errors.CANCELLED,
121        errors.UNKNOWN,
122        errors.INVALID_ARGUMENT,
123        errors.DEADLINE_EXCEEDED,
124        errors.NOT_FOUND,
125        errors.ALREADY_EXISTS,
126        errors.PERMISSION_DENIED,
127        errors.UNAUTHENTICATED,
128        errors.RESOURCE_EXHAUSTED,
129        errors.FAILED_PRECONDITION,
130        errors.ABORTED,
131        errors.OUT_OF_RANGE,
132        errors.UNIMPLEMENTED,
133        errors.INTERNAL,
134        errors.UNAVAILABLE,
135        errors.DATA_LOSS,
136    ]:
137      # pylint: disable=protected-access
138      exc = errors_impl._make_specific_exception(None, None, None, error_code)
139      # pylint: enable=protected-access
140      unpickled = pickle.loads(pickle.dumps(exc))
141      self.assertEqual(exc.node_def, unpickled.node_def)
142      self.assertEqual(exc.op, unpickled.op)
143      self.assertEqual(exc.message, unpickled.message)
144      self.assertEqual(exc.error_code, unpickled.error_code)
145
146  def testErrorPayloadsFromStatus(self):
147    for code, expected_exception in [
148        (1, errors.CancelledError),
149        (2, errors.UnknownError),
150        (3, errors.InvalidArgumentError),
151        (4, errors.DeadlineExceededError),
152        (5, errors.NotFoundError),
153        (6, errors.AlreadyExistsError),
154        (7, errors.PermissionDeniedError),
155        (16, errors.UnauthenticatedError),
156        (8, errors.ResourceExhaustedError),
157        (9, errors.FailedPreconditionError),
158        (10, errors.AbortedError),
159        (11, errors.OutOfRangeError),
160        (12, errors.UnimplementedError),
161        (13, errors.InternalError),
162        (14, errors.UnavailableError),
163        (15, errors.DataLossError),
164    ]:
165      with self.assertRaises(expected_exception) as error:
166        _errors_test_helper.TestRaiseFromStatus(code)
167      self.assertEqual(error.exception.experimental_payloads[b"key1"],
168                       b"value1")
169      self.assertEqual(error.exception.experimental_payloads[b"key2"],
170                       b"value2")
171
172  def testErrorPayloadsFromTFStatus(self):
173    for code, expected_exception in [
174        (1, errors.CancelledError),
175        (2, errors.UnknownError),
176        (3, errors.InvalidArgumentError),
177        (4, errors.DeadlineExceededError),
178        (5, errors.NotFoundError),
179        (6, errors.AlreadyExistsError),
180        (7, errors.PermissionDeniedError),
181        (16, errors.UnauthenticatedError),
182        (8, errors.ResourceExhaustedError),
183        (9, errors.FailedPreconditionError),
184        (10, errors.AbortedError),
185        (11, errors.OutOfRangeError),
186        (12, errors.UnimplementedError),
187        (13, errors.InternalError),
188        (14, errors.UnavailableError),
189        (15, errors.DataLossError),
190    ]:
191      with self.assertRaises(expected_exception) as error:
192        _errors_test_helper.TestRaiseFromTFStatus(code)
193      self.assertEqual(error.exception.experimental_payloads[b"key1"],
194                       b"value1")
195      self.assertEqual(error.exception.experimental_payloads[b"key2"],
196                       b"value2")
197
198  def testErrorPayloadsDefaultValue(self):
199    for exception_type in [
200        (errors.CancelledError),
201        (errors.UnknownError),
202        (errors.InvalidArgumentError),
203        (errors.DeadlineExceededError),
204        (errors.NotFoundError),
205        (errors.AlreadyExistsError),
206        (errors.PermissionDeniedError),
207        (errors.UnauthenticatedError),
208        (errors.ResourceExhaustedError),
209        (errors.FailedPreconditionError),
210        (errors.AbortedError),
211        (errors.OutOfRangeError),
212        (errors.UnimplementedError),
213        (errors.InternalError),
214        (errors.UnavailableError),
215        (errors.DataLossError),
216    ]:
217      e = exception_type(None, None, None)
218      self.assertEqual(type(e.experimental_payloads), dict)
219      self.assertEqual(len(e.experimental_payloads), 0)
220
221
222if __name__ == "__main__":
223  test.main()
224