• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.test_util."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import random
24import threading
25import unittest
26import weakref
27
28from absl.testing import parameterized
29import numpy as np
30
31from google.protobuf import text_format
32
33from tensorflow.core.framework import graph_pb2
34from tensorflow.core.protobuf import meta_graph_pb2
35from tensorflow.python.compat import compat
36from tensorflow.python.eager import context
37from tensorflow.python.eager import def_function
38from tensorflow.python.framework import combinations
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import errors
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import test_ops  # pylint: disable=unused-import
44from tensorflow.python.framework import test_util
45from tensorflow.python.ops import control_flow_ops
46from tensorflow.python.ops import lookup_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import random_ops
49from tensorflow.python.ops import resource_variable_ops
50from tensorflow.python.ops import variable_scope
51from tensorflow.python.ops import variables
52from tensorflow.python.platform import googletest
53
54
55class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
56
57  def test_assert_ops_in_graph(self):
58    with ops.Graph().as_default():
59      constant_op.constant(["hello", "taffy"], name="hello")
60      test_util.assert_ops_in_graph({"hello": "Const"}, ops.get_default_graph())
61
62      self.assertRaises(ValueError, test_util.assert_ops_in_graph,
63                        {"bye": "Const"}, ops.get_default_graph())
64
65      self.assertRaises(ValueError, test_util.assert_ops_in_graph,
66                        {"hello": "Variable"}, ops.get_default_graph())
67
68  @test_util.run_deprecated_v1
69  def test_session_functions(self):
70    with self.test_session() as sess:
71      sess_ref = weakref.ref(sess)
72      with self.cached_session(graph=None, config=None) as sess2:
73        # We make sure that sess2 is sess.
74        assert sess2 is sess
75        # We make sure we raise an exception if we use cached_session with
76        # different values.
77        with self.assertRaises(ValueError):
78          with self.cached_session(graph=ops.Graph()) as sess2:
79            pass
80        with self.assertRaises(ValueError):
81          with self.cached_session(force_gpu=True) as sess2:
82            pass
83    # We make sure that test_session will cache the session even after the
84    # with scope.
85    assert not sess_ref()._closed
86    with self.session() as unique_sess:
87      unique_sess_ref = weakref.ref(unique_sess)
88      with self.session() as sess2:
89        assert sess2 is not unique_sess
90    # We make sure the session is closed when we leave the with statement.
91    assert unique_sess_ref()._closed
92
93  def test_assert_equal_graph_def(self):
94    with ops.Graph().as_default() as g:
95      def_empty = g.as_graph_def()
96      constant_op.constant(5, name="five")
97      constant_op.constant(7, name="seven")
98      def_57 = g.as_graph_def()
99    with ops.Graph().as_default() as g:
100      constant_op.constant(7, name="seven")
101      constant_op.constant(5, name="five")
102      def_75 = g.as_graph_def()
103    # Comparing strings is order dependent
104    self.assertNotEqual(str(def_57), str(def_75))
105    # assert_equal_graph_def doesn't care about order
106    test_util.assert_equal_graph_def(def_57, def_75)
107    # Compare two unequal graphs
108    with self.assertRaisesRegex(AssertionError,
109                                r"^Found unexpected node '{{node seven}}"):
110      test_util.assert_equal_graph_def(def_57, def_empty)
111
112  def test_assert_equal_graph_def_hash_table(self):
113    def get_graph_def():
114      with ops.Graph().as_default() as g:
115        x = constant_op.constant([2, 9], name="x")
116        keys = constant_op.constant([1, 2], name="keys")
117        values = constant_op.constant([3, 4], name="values")
118        default = constant_op.constant(-1, name="default")
119        table = lookup_ops.StaticHashTable(
120            lookup_ops.KeyValueTensorInitializer(keys, values), default)
121        _ = table.lookup(x)
122      return g.as_graph_def()
123    def_1 = get_graph_def()
124    def_2 = get_graph_def()
125    # The unique shared_name of each table makes the graph unequal.
126    with self.assertRaisesRegex(AssertionError, "hash_table_"):
127      test_util.assert_equal_graph_def(def_1, def_2,
128                                       hash_table_shared_name=False)
129    # That can be ignored. (NOTE: modifies GraphDefs in-place.)
130    test_util.assert_equal_graph_def(def_1, def_2,
131                                     hash_table_shared_name=True)
132
133  def testIsGoogleCudaEnabled(self):
134    # The test doesn't assert anything. It ensures the py wrapper
135    # function is generated correctly.
136    if test_util.IsGoogleCudaEnabled():
137      print("GoogleCuda is enabled")
138    else:
139      print("GoogleCuda is disabled")
140
141  def testIsMklEnabled(self):
142    # This test doesn't assert anything.
143    # It ensures the py wrapper function is generated correctly.
144    if test_util.IsMklEnabled():
145      print("MKL is enabled")
146    else:
147      print("MKL is disabled")
148
149  @test_util.run_in_graph_and_eager_modes
150  def testAssertProtoEqualsStr(self):
151
152    graph_str = "node { name: 'w1' op: 'params' }"
153    graph_def = graph_pb2.GraphDef()
154    text_format.Merge(graph_str, graph_def)
155
156    # test string based comparison
157    self.assertProtoEquals(graph_str, graph_def)
158
159    # test original comparison
160    self.assertProtoEquals(graph_def, graph_def)
161
162  @test_util.run_in_graph_and_eager_modes
163  def testAssertProtoEqualsAny(self):
164    # Test assertProtoEquals with a protobuf.Any field.
165    meta_graph_def_str = """
166    meta_info_def {
167      meta_graph_version: "outer"
168      any_info {
169        [type.googleapis.com/tensorflow.MetaGraphDef] {
170          meta_info_def {
171            meta_graph_version: "inner"
172          }
173        }
174      }
175    }
176    """
177    meta_graph_def_outer = meta_graph_pb2.MetaGraphDef()
178    meta_graph_def_outer.meta_info_def.meta_graph_version = "outer"
179    meta_graph_def_inner = meta_graph_pb2.MetaGraphDef()
180    meta_graph_def_inner.meta_info_def.meta_graph_version = "inner"
181    meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner)
182    self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer)
183    self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer)
184
185    # Check if the assertion failure message contains the content of
186    # the inner proto.
187    with self.assertRaisesRegex(AssertionError, r'meta_graph_version: "inner"'):
188      self.assertProtoEquals("", meta_graph_def_outer)
189
190  @test_util.run_in_graph_and_eager_modes
191  def testNDArrayNear(self):
192    a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
193    a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
194    a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])
195    self.assertTrue(self._NDArrayNear(a1, a2, 1e-5))
196    self.assertFalse(self._NDArrayNear(a1, a3, 1e-5))
197
198  @test_util.run_in_graph_and_eager_modes
199  def testCheckedThreadSucceeds(self):
200
201    def noop(ev):
202      ev.set()
203
204    event_arg = threading.Event()
205
206    self.assertFalse(event_arg.is_set())
207    t = self.checkedThread(target=noop, args=(event_arg,))
208    t.start()
209    t.join()
210    self.assertTrue(event_arg.is_set())
211
212  @test_util.run_in_graph_and_eager_modes
213  def testCheckedThreadFails(self):
214
215    def err_func():
216      return 1 // 0
217
218    t = self.checkedThread(target=err_func)
219    t.start()
220    with self.assertRaises(self.failureException) as fe:
221      t.join()
222    self.assertTrue("integer division or modulo by zero" in str(fe.exception))
223
224  @test_util.run_in_graph_and_eager_modes
225  def testCheckedThreadWithWrongAssertionFails(self):
226    x = 37
227
228    def err_func():
229      self.assertTrue(x < 10)
230
231    t = self.checkedThread(target=err_func)
232    t.start()
233    with self.assertRaises(self.failureException) as fe:
234      t.join()
235    self.assertTrue("False is not true" in str(fe.exception))
236
237  @test_util.run_in_graph_and_eager_modes
238  def testMultipleThreadsWithOneFailure(self):
239
240    def err_func(i):
241      self.assertTrue(i != 7)
242
243    threads = [
244        self.checkedThread(
245            target=err_func, args=(i,)) for i in range(10)
246    ]
247    for t in threads:
248      t.start()
249    for i, t in enumerate(threads):
250      if i == 7:
251        with self.assertRaises(self.failureException):
252          t.join()
253      else:
254        t.join()
255
256  def _WeMustGoDeeper(self, msg):
257    with self.assertRaisesOpError(msg):
258      with ops.Graph().as_default():
259        node_def = ops._NodeDef("IntOutput", "name")
260        node_def_orig = ops._NodeDef("IntOutput", "orig")
261        op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
262        op = ops.Operation(node_def, ops.get_default_graph(),
263                           original_op=op_orig)
264        raise errors.UnauthenticatedError(node_def, op, "true_err")
265
266  @test_util.run_in_graph_and_eager_modes
267  def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self):
268    with self.assertRaises(AssertionError):
269      self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for")
270
271    self._WeMustGoDeeper("true_err")
272    self._WeMustGoDeeper("name")
273    self._WeMustGoDeeper("orig")
274
275  @test_util.run_in_graph_and_eager_modes
276  def testAllCloseTensors(self):
277    a_raw_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
278    a = constant_op.constant(a_raw_data)
279    b = math_ops.add(1, constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
280    self.assertAllClose(a, b)
281    self.assertAllClose(a, a_raw_data)
282
283    a_dict = {"key": a}
284    b_dict = {"key": b}
285    self.assertAllClose(a_dict, b_dict)
286
287    x_list = [a, b]
288    y_list = [a_raw_data, b]
289    self.assertAllClose(x_list, y_list)
290
291  @test_util.run_in_graph_and_eager_modes
292  def testAllCloseScalars(self):
293    self.assertAllClose(7, 7 + 1e-8)
294    with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
295      self.assertAllClose(7, 7 + 1e-5)
296
297  @test_util.run_in_graph_and_eager_modes
298  def testAllCloseList(self):
299    with self.assertRaisesRegex(AssertionError, r"not close dif"):
300      self.assertAllClose([0], [1])
301
302  @test_util.run_in_graph_and_eager_modes
303  def testAllCloseDictToNonDict(self):
304    with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"):
305      self.assertAllClose(1, {"a": 1})
306    with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"):
307      self.assertAllClose({"a": 1}, 1)
308
309  @test_util.run_in_graph_and_eager_modes
310  def testAllCloseNamedtuples(self):
311    a = 7
312    b = (2., 3.)
313    c = np.ones((3, 2, 4)) * 7.
314    expected = {"a": a, "b": b, "c": c}
315    my_named_tuple = collections.namedtuple("MyNamedTuple", ["a", "b", "c"])
316
317    # Identity.
318    self.assertAllClose(expected, my_named_tuple(a=a, b=b, c=c))
319    self.assertAllClose(
320        my_named_tuple(a=a, b=b, c=c), my_named_tuple(a=a, b=b, c=c))
321
322  @test_util.run_in_graph_and_eager_modes
323  def testAllCloseDicts(self):
324    a = 7
325    b = (2., 3.)
326    c = np.ones((3, 2, 4)) * 7.
327    expected = {"a": a, "b": b, "c": c}
328
329    # Identity.
330    self.assertAllClose(expected, expected)
331    self.assertAllClose(expected, dict(expected))
332
333    # With each item removed.
334    for k in expected:
335      actual = dict(expected)
336      del actual[k]
337      with self.assertRaisesRegex(AssertionError, r"mismatched keys"):
338        self.assertAllClose(expected, actual)
339
340    # With each item changed.
341    with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
342      self.assertAllClose(expected, {"a": a + 1e-5, "b": b, "c": c})
343    with self.assertRaisesRegex(AssertionError, r"Shape mismatch"):
344      self.assertAllClose(expected, {"a": a, "b": b + (4.,), "c": c})
345    c_copy = np.array(c)
346    c_copy[1, 1, 1] += 1e-5
347    with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
348      self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy})
349
350  @test_util.run_in_graph_and_eager_modes
351  def testAllCloseListOfNamedtuples(self):
352    my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"])
353    l1 = [
354        my_named_tuple(x=np.array([[2.3, 2.5]]), y=np.array([[0.97, 0.96]])),
355        my_named_tuple(x=np.array([[3.3, 3.5]]), y=np.array([[0.98, 0.99]]))
356    ]
357    l2 = [
358        ([[2.3, 2.5]], [[0.97, 0.96]]),
359        ([[3.3, 3.5]], [[0.98, 0.99]]),
360    ]
361    self.assertAllClose(l1, l2)
362
363  @test_util.run_in_graph_and_eager_modes
364  def testAllCloseNestedStructure(self):
365    a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])}
366    self.assertAllClose(a, a)
367
368    b = copy.deepcopy(a)
369    self.assertAllClose(a, b)
370
371    # Test mismatched values
372    b["y"][1][0]["nested"]["n"] = 4.2
373    with self.assertRaisesRegex(AssertionError,
374                                r"\[y\]\[1\]\[0\]\[nested\]\[n\]"):
375      self.assertAllClose(a, b)
376
377  @test_util.run_in_graph_and_eager_modes
378  def testArrayNear(self):
379    a = [1, 2]
380    b = [1, 2, 5]
381    with self.assertRaises(AssertionError):
382      self.assertArrayNear(a, b, 0.001)
383    a = [1, 2]
384    b = [[1, 2], [3, 4]]
385    with self.assertRaises(TypeError):
386      self.assertArrayNear(a, b, 0.001)
387    a = [1, 2]
388    b = [1, 2]
389    self.assertArrayNear(a, b, 0.001)
390
391  @test_util.skip_if(True)  # b/117665998
392  def testForceGPU(self):
393    with self.assertRaises(errors.InvalidArgumentError):
394      with self.test_session(force_gpu=True):
395        # this relies on us not having a GPU implementation for assert, which
396        # seems sensible
397        x = constant_op.constant(True)
398        y = [15]
399        control_flow_ops.Assert(x, y).run()
400
401  @test_util.run_in_graph_and_eager_modes
402  def testAssertAllCloseAccordingToType(self):
403    # test plain int
404    self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8)
405
406    # test float64
407    self.assertAllCloseAccordingToType(
408        np.asarray([1e-8], dtype=np.float64),
409        np.asarray([2e-8], dtype=np.float64),
410        rtol=1e-8, atol=1e-8
411    )
412
413    self.assertAllCloseAccordingToType(
414        constant_op.constant([1e-8], dtype=dtypes.float64),
415        constant_op.constant([2e-8], dtype=dtypes.float64),
416        rtol=1e-8,
417        atol=1e-8)
418
419    with (self.assertRaises(AssertionError)):
420      self.assertAllCloseAccordingToType(
421          np.asarray([1e-7], dtype=np.float64),
422          np.asarray([2e-7], dtype=np.float64),
423          rtol=1e-8, atol=1e-8
424      )
425
426    # test float32
427    self.assertAllCloseAccordingToType(
428        np.asarray([1e-7], dtype=np.float32),
429        np.asarray([2e-7], dtype=np.float32),
430        rtol=1e-8, atol=1e-8,
431        float_rtol=1e-7, float_atol=1e-7
432    )
433
434    self.assertAllCloseAccordingToType(
435        constant_op.constant([1e-7], dtype=dtypes.float32),
436        constant_op.constant([2e-7], dtype=dtypes.float32),
437        rtol=1e-8,
438        atol=1e-8,
439        float_rtol=1e-7,
440        float_atol=1e-7)
441
442    with (self.assertRaises(AssertionError)):
443      self.assertAllCloseAccordingToType(
444          np.asarray([1e-6], dtype=np.float32),
445          np.asarray([2e-6], dtype=np.float32),
446          rtol=1e-8, atol=1e-8,
447          float_rtol=1e-7, float_atol=1e-7
448      )
449
450    # test float16
451    self.assertAllCloseAccordingToType(
452        np.asarray([1e-4], dtype=np.float16),
453        np.asarray([2e-4], dtype=np.float16),
454        rtol=1e-8, atol=1e-8,
455        float_rtol=1e-7, float_atol=1e-7,
456        half_rtol=1e-4, half_atol=1e-4
457    )
458
459    self.assertAllCloseAccordingToType(
460        constant_op.constant([1e-4], dtype=dtypes.float16),
461        constant_op.constant([2e-4], dtype=dtypes.float16),
462        rtol=1e-8,
463        atol=1e-8,
464        float_rtol=1e-7,
465        float_atol=1e-7,
466        half_rtol=1e-4,
467        half_atol=1e-4)
468
469    with (self.assertRaises(AssertionError)):
470      self.assertAllCloseAccordingToType(
471          np.asarray([1e-3], dtype=np.float16),
472          np.asarray([2e-3], dtype=np.float16),
473          rtol=1e-8, atol=1e-8,
474          float_rtol=1e-7, float_atol=1e-7,
475          half_rtol=1e-4, half_atol=1e-4
476      )
477
478  @test_util.run_in_graph_and_eager_modes
479  def testAssertAllEqual(self):
480    i = variables.Variable([100] * 3, dtype=dtypes.int32, name="i")
481    j = constant_op.constant([20] * 3, dtype=dtypes.int32, name="j")
482    k = math_ops.add(i, j, name="k")
483
484    self.evaluate(variables.global_variables_initializer())
485    self.assertAllEqual([100] * 3, i)
486    self.assertAllEqual([120] * 3, k)
487    self.assertAllEqual([20] * 3, j)
488
489    with self.assertRaisesRegex(AssertionError, r"not equal lhs"):
490      self.assertAllEqual([0] * 3, k)
491
492  @test_util.run_in_graph_and_eager_modes
493  def testAssertNotAllEqual(self):
494    i = variables.Variable([100], dtype=dtypes.int32, name="i")
495    j = constant_op.constant([20], dtype=dtypes.int32, name="j")
496    k = math_ops.add(i, j, name="k")
497
498    self.evaluate(variables.global_variables_initializer())
499    self.assertNotAllEqual([100] * 3, i)
500    self.assertNotAllEqual([120] * 3, k)
501    self.assertNotAllEqual([20] * 3, j)
502
503    with self.assertRaisesRegex(
504        AssertionError, r"two values are equal at all elements.*extra message"):
505      self.assertNotAllEqual([120], k, msg="extra message")
506
507  @test_util.run_in_graph_and_eager_modes
508  def testAssertNotAllClose(self):
509    # Test with arrays
510    self.assertNotAllClose([0.1], [0.2])
511    with self.assertRaises(AssertionError):
512      self.assertNotAllClose([-1.0, 2.0], [-1.0, 2.0])
513
514    # Test with tensors
515    x = constant_op.constant([1.0, 1.0], name="x")
516    y = math_ops.add(x, x)
517
518    self.assertAllClose([2.0, 2.0], y)
519    self.assertNotAllClose([0.9, 1.0], x)
520
521    with self.assertRaises(AssertionError):
522      self.assertNotAllClose([1.0, 1.0], x)
523
524  @test_util.run_in_graph_and_eager_modes
525  def testAssertNotAllCloseRTol(self):
526    # Test with arrays
527    with self.assertRaises(AssertionError):
528      self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], rtol=0.2)
529
530    # Test with tensors
531    x = constant_op.constant([1.0, 1.0], name="x")
532    y = math_ops.add(x, x)
533
534    self.assertAllClose([2.0, 2.0], y)
535
536    with self.assertRaises(AssertionError):
537      self.assertNotAllClose([0.9, 1.0], x, rtol=0.2)
538
539  @test_util.run_in_graph_and_eager_modes
540  def testAssertNotAllCloseATol(self):
541    # Test with arrays
542    with self.assertRaises(AssertionError):
543      self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], atol=0.2)
544
545    # Test with tensors
546    x = constant_op.constant([1.0, 1.0], name="x")
547    y = math_ops.add(x, x)
548
549    self.assertAllClose([2.0, 2.0], y)
550
551    with self.assertRaises(AssertionError):
552      self.assertNotAllClose([0.9, 1.0], x, atol=0.2)
553
554  @test_util.run_in_graph_and_eager_modes
555  def testAssertAllGreaterLess(self):
556    x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
557    y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
558    z = math_ops.add(x, y)
559
560    self.assertAllClose([110.0, 120.0, 130.0], z)
561
562    self.assertAllGreater(x, 95.0)
563    self.assertAllLess(x, 125.0)
564
565    with self.assertRaises(AssertionError):
566      self.assertAllGreater(x, 105.0)
567    with self.assertRaises(AssertionError):
568      self.assertAllGreater(x, 125.0)
569
570    with self.assertRaises(AssertionError):
571      self.assertAllLess(x, 115.0)
572    with self.assertRaises(AssertionError):
573      self.assertAllLess(x, 95.0)
574
575  @test_util.run_in_graph_and_eager_modes
576  def testAssertAllGreaterLessEqual(self):
577    x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
578    y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
579    z = math_ops.add(x, y)
580
581    self.assertAllEqual([110.0, 120.0, 130.0], z)
582
583    self.assertAllGreaterEqual(x, 95.0)
584    self.assertAllLessEqual(x, 125.0)
585
586    with self.assertRaises(AssertionError):
587      self.assertAllGreaterEqual(x, 105.0)
588    with self.assertRaises(AssertionError):
589      self.assertAllGreaterEqual(x, 125.0)
590
591    with self.assertRaises(AssertionError):
592      self.assertAllLessEqual(x, 115.0)
593    with self.assertRaises(AssertionError):
594      self.assertAllLessEqual(x, 95.0)
595
596  def testAssertAllInRangeWithNonNumericValuesFails(self):
597    s1 = constant_op.constant("Hello, ", name="s1")
598    c = constant_op.constant([1 + 2j, -3 + 5j], name="c")
599    b = constant_op.constant([False, True], name="b")
600
601    with self.assertRaises(AssertionError):
602      self.assertAllInRange(s1, 0.0, 1.0)
603    with self.assertRaises(AssertionError):
604      self.assertAllInRange(c, 0.0, 1.0)
605    with self.assertRaises(AssertionError):
606      self.assertAllInRange(b, 0, 1)
607
608  @test_util.run_in_graph_and_eager_modes
609  def testAssertAllInRange(self):
610    x = constant_op.constant([10.0, 15.0], name="x")
611    self.assertAllInRange(x, 10, 15)
612
613    with self.assertRaises(AssertionError):
614      self.assertAllInRange(x, 10, 15, open_lower_bound=True)
615    with self.assertRaises(AssertionError):
616      self.assertAllInRange(x, 10, 15, open_upper_bound=True)
617    with self.assertRaises(AssertionError):
618      self.assertAllInRange(
619          x, 10, 15, open_lower_bound=True, open_upper_bound=True)
620
621  @test_util.run_in_graph_and_eager_modes
622  def testAssertAllInRangeScalar(self):
623    x = constant_op.constant(10.0, name="x")
624    nan = constant_op.constant(np.nan, name="nan")
625    self.assertAllInRange(x, 5, 15)
626    with self.assertRaises(AssertionError):
627      self.assertAllInRange(nan, 5, 15)
628
629    with self.assertRaises(AssertionError):
630      self.assertAllInRange(x, 10, 15, open_lower_bound=True)
631    with self.assertRaises(AssertionError):
632      self.assertAllInRange(x, 1, 2)
633
634  @test_util.run_in_graph_and_eager_modes
635  def testAssertAllInRangeErrorMessageEllipses(self):
636    x_init = np.array([[10.0, 15.0]] * 12)
637    x = constant_op.constant(x_init, name="x")
638    with self.assertRaises(AssertionError):
639      self.assertAllInRange(x, 5, 10)
640
641  @test_util.run_in_graph_and_eager_modes
642  def testAssertAllInRangeDetectsNaNs(self):
643    x = constant_op.constant(
644        [[np.nan, 0.0], [np.nan, np.inf], [np.inf, np.nan]], name="x")
645    with self.assertRaises(AssertionError):
646      self.assertAllInRange(x, 0.0, 2.0)
647
648  @test_util.run_in_graph_and_eager_modes
649  def testAssertAllInRangeWithInfinities(self):
650    x = constant_op.constant([10.0, np.inf], name="x")
651    self.assertAllInRange(x, 10, np.inf)
652    with self.assertRaises(AssertionError):
653      self.assertAllInRange(x, 10, np.inf, open_upper_bound=True)
654
655  @test_util.run_in_graph_and_eager_modes
656  def testAssertAllInSet(self):
657    b = constant_op.constant([True, False], name="b")
658    x = constant_op.constant([13, 37], name="x")
659
660    self.assertAllInSet(b, [False, True])
661    self.assertAllInSet(b, (False, True))
662    self.assertAllInSet(b, {False, True})
663    self.assertAllInSet(x, [0, 13, 37, 42])
664    self.assertAllInSet(x, (0, 13, 37, 42))
665    self.assertAllInSet(x, {0, 13, 37, 42})
666
667    with self.assertRaises(AssertionError):
668      self.assertAllInSet(b, [False])
669    with self.assertRaises(AssertionError):
670      self.assertAllInSet(x, (42,))
671
672  def testRandomSeed(self):
673    # Call setUp again for WithCApi case (since it makes a new default graph
674    # after setup).
675    # TODO(skyewm): remove this when C API is permanently enabled.
676    with context.eager_mode():
677      self.setUp()
678      a = random.randint(1, 1000)
679      a_np_rand = np.random.rand(1)
680      a_rand = random_ops.random_normal([1])
681      # ensure that randomness in multiple testCases is deterministic.
682      self.setUp()
683      b = random.randint(1, 1000)
684      b_np_rand = np.random.rand(1)
685      b_rand = random_ops.random_normal([1])
686      self.assertEqual(a, b)
687      self.assertEqual(a_np_rand, b_np_rand)
688      self.assertAllEqual(a_rand, b_rand)
689
690  @test_util.run_in_graph_and_eager_modes
691  def test_callable_evaluate(self):
692    def model():
693      return resource_variable_ops.ResourceVariable(
694          name="same_name",
695          initial_value=1) + 1
696    with context.eager_mode():
697      self.assertEqual(2, self.evaluate(model))
698
699  @test_util.run_in_graph_and_eager_modes
700  def test_nested_tensors_evaluate(self):
701    expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}}
702    nested = {"a": constant_op.constant(1),
703              "b": constant_op.constant(2),
704              "nested": {"d": constant_op.constant(3),
705                         "e": constant_op.constant(4)}}
706
707    self.assertEqual(expected, self.evaluate(nested))
708
709  def test_run_in_graph_and_eager_modes(self):
710    l = []
711    def inc(self, with_brackets):
712      del self  # self argument is required by run_in_graph_and_eager_modes.
713      mode = "eager" if context.executing_eagerly() else "graph"
714      with_brackets = "with_brackets" if with_brackets else "without_brackets"
715      l.append((with_brackets, mode))
716
717    f = test_util.run_in_graph_and_eager_modes(inc)
718    f(self, with_brackets=False)
719    f = test_util.run_in_graph_and_eager_modes()(inc)
720    f(self, with_brackets=True)
721
722    self.assertEqual(len(l), 4)
723    self.assertEqual(set(l), {
724        ("with_brackets", "graph"),
725        ("with_brackets", "eager"),
726        ("without_brackets", "graph"),
727        ("without_brackets", "eager"),
728    })
729
730  def test_get_node_def_from_graph(self):
731    graph_def = graph_pb2.GraphDef()
732    node_foo = graph_def.node.add()
733    node_foo.name = "foo"
734    self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo)
735    self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
736
737  def test_run_in_eager_and_graph_modes_test_class(self):
738    msg = "`run_in_graph_and_eager_modes` only supports test methods.*"
739    with self.assertRaisesRegex(ValueError, msg):
740
741      @test_util.run_in_graph_and_eager_modes()
742      class Foo(object):
743        pass
744      del Foo  # Make pylint unused happy.
745
746  def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self):
747    modes = []
748    def _test(self):
749      if not context.executing_eagerly():
750        self.skipTest("Skipping in graph mode")
751      modes.append("eager" if context.executing_eagerly() else "graph")
752    test_util.run_in_graph_and_eager_modes(_test)(self)
753    self.assertEqual(modes, ["eager"])
754
755  def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self):
756    modes = []
757    def _test(self):
758      if context.executing_eagerly():
759        self.skipTest("Skipping in eager mode")
760      modes.append("eager" if context.executing_eagerly() else "graph")
761    test_util.run_in_graph_and_eager_modes(_test)(self)
762    self.assertEqual(modes, ["graph"])
763
764  def test_run_in_graph_and_eager_modes_setup_in_same_mode(self):
765    modes = []
766    mode_name = lambda: "eager" if context.executing_eagerly() else "graph"
767
768    class ExampleTest(test_util.TensorFlowTestCase):
769
770      def runTest(self):
771        pass
772
773      def setUp(self):
774        modes.append("setup_" + mode_name())
775
776      @test_util.run_in_graph_and_eager_modes
777      def testBody(self):
778        modes.append("run_" + mode_name())
779
780    e = ExampleTest()
781    e.setUp()
782    e.testBody()
783
784    self.assertEqual(modes[1:2], ["run_graph"])
785    self.assertEqual(modes[2:], ["setup_eager", "run_eager"])
786
787  @parameterized.named_parameters(dict(testcase_name="argument",
788                                       arg=True))
789  @test_util.run_in_graph_and_eager_modes
790  def test_run_in_graph_and_eager_works_with_parameterized_keyword(self, arg):
791    self.assertEqual(arg, True)
792
793  @combinations.generate(combinations.combine(arg=True))
794  @test_util.run_in_graph_and_eager_modes
795  def test_run_in_graph_and_eager_works_with_combinations(self, arg):
796    self.assertEqual(arg, True)
797
798  def test_build_as_function_and_v1_graph(self):
799
800    class GraphModeAndFunctionTest(parameterized.TestCase):
801
802      def __init__(inner_self):  # pylint: disable=no-self-argument
803        super(GraphModeAndFunctionTest, inner_self).__init__()
804        inner_self.graph_mode_tested = False
805        inner_self.inside_function_tested = False
806
807      def runTest(self):
808        del self
809
810      @test_util.build_as_function_and_v1_graph
811      def test_modes(inner_self):  # pylint: disable=no-self-argument
812        if ops.inside_function():
813          self.assertFalse(inner_self.inside_function_tested)
814          inner_self.inside_function_tested = True
815        else:
816          self.assertFalse(inner_self.graph_mode_tested)
817          inner_self.graph_mode_tested = True
818
819    test_object = GraphModeAndFunctionTest()
820    test_object.test_modes_v1_graph()
821    test_object.test_modes_function()
822    self.assertTrue(test_object.graph_mode_tested)
823    self.assertTrue(test_object.inside_function_tested)
824
825  def test_with_forward_compatibility_horizons(self):
826
827    tested_codepaths = set()
828    def some_function_with_forward_compat_behavior():
829      if compat.forward_compatible(2050, 1, 1):
830        tested_codepaths.add("future")
831      else:
832        tested_codepaths.add("present")
833
834    @test_util.with_forward_compatibility_horizons(None, [2051, 1, 1])
835    def some_test(self):
836      del self  # unused
837      some_function_with_forward_compat_behavior()
838
839    some_test(None)
840    self.assertEqual(tested_codepaths, set(["present", "future"]))
841
842
843class SkipTestTest(test_util.TensorFlowTestCase):
844
845  def _verify_test_in_set_up_or_tear_down(self):
846    with self.assertRaises(unittest.SkipTest):
847      with test_util.skip_if_error(self, ValueError,
848                                   ["foo bar", "test message"]):
849        raise ValueError("test message")
850    try:
851      with self.assertRaisesRegex(ValueError, "foo bar"):
852        with test_util.skip_if_error(self, ValueError, "test message"):
853          raise ValueError("foo bar")
854    except unittest.SkipTest:
855      raise RuntimeError("Test is not supposed to skip.")
856
857  def setUp(self):
858    super(SkipTestTest, self).setUp()
859    self._verify_test_in_set_up_or_tear_down()
860
861  def tearDown(self):
862    super(SkipTestTest, self).tearDown()
863    self._verify_test_in_set_up_or_tear_down()
864
865  def test_skip_if_error_should_skip(self):
866    with self.assertRaises(unittest.SkipTest):
867      with test_util.skip_if_error(self, ValueError, "test message"):
868        raise ValueError("test message")
869
870  def test_skip_if_error_should_skip_with_list(self):
871    with self.assertRaises(unittest.SkipTest):
872      with test_util.skip_if_error(self, ValueError,
873                                   ["foo bar", "test message"]):
874        raise ValueError("test message")
875
876  def test_skip_if_error_should_skip_without_expected_message(self):
877    with self.assertRaises(unittest.SkipTest):
878      with test_util.skip_if_error(self, ValueError):
879        raise ValueError("test message")
880
881  def test_skip_if_error_should_skip_without_error_message(self):
882    with self.assertRaises(unittest.SkipTest):
883      with test_util.skip_if_error(self, ValueError):
884        raise ValueError()
885
886  def test_skip_if_error_should_raise_message_mismatch(self):
887    try:
888      with self.assertRaisesRegex(ValueError, "foo bar"):
889        with test_util.skip_if_error(self, ValueError, "test message"):
890          raise ValueError("foo bar")
891    except unittest.SkipTest:
892      raise RuntimeError("Test is not supposed to skip.")
893
894  def test_skip_if_error_should_raise_no_message(self):
895    try:
896      with self.assertRaisesRegex(ValueError, ""):
897        with test_util.skip_if_error(self, ValueError, "test message"):
898          raise ValueError()
899    except unittest.SkipTest:
900      raise RuntimeError("Test is not supposed to skip.")
901
902
903# Its own test case to reproduce variable sharing issues which only pop up when
904# setUp() is overridden and super() is not called.
905class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase):
906
907  def setUp(self):
908    pass  # Intentionally does not call TensorFlowTestCase's super()
909
910  @test_util.run_in_graph_and_eager_modes
911  def test_no_variable_sharing(self):
912    variable_scope.get_variable(
913        name="step_size",
914        initializer=np.array(1e-5, np.float32),
915        use_resource=True,
916        trainable=False)
917
918
919class GarbageCollectionTest(test_util.TensorFlowTestCase):
920
921  def test_no_reference_cycle_decorator(self):
922
923    class ReferenceCycleTest(object):
924
925      def __init__(inner_self):  # pylint: disable=no-self-argument
926        inner_self.assertEqual = self.assertEqual  # pylint: disable=invalid-name
927
928      @test_util.assert_no_garbage_created
929      def test_has_cycle(self):
930        a = []
931        a.append(a)
932
933      @test_util.assert_no_garbage_created
934      def test_has_no_cycle(self):
935        pass
936
937    with self.assertRaises(AssertionError):
938      ReferenceCycleTest().test_has_cycle()
939
940    ReferenceCycleTest().test_has_no_cycle()
941
942  @test_util.run_in_graph_and_eager_modes
943  def test_no_leaked_tensor_decorator(self):
944
945    class LeakedTensorTest(object):
946
947      def __init__(inner_self):  # pylint: disable=no-self-argument
948        inner_self.assertEqual = self.assertEqual  # pylint: disable=invalid-name
949
950      @test_util.assert_no_new_tensors
951      def test_has_leak(self):
952        self.a = constant_op.constant([3.], name="leak")
953
954      @test_util.assert_no_new_tensors
955      def test_has_no_leak(self):
956        constant_op.constant([3.], name="no-leak")
957
958    with self.assertRaisesRegex(AssertionError, "Tensors not deallocated"):
959      LeakedTensorTest().test_has_leak()
960
961    LeakedTensorTest().test_has_no_leak()
962
963  def test_no_new_objects_decorator(self):
964
965    class LeakedObjectTest(unittest.TestCase):
966
967      def __init__(self, *args, **kwargs):
968        super(LeakedObjectTest, self).__init__(*args, **kwargs)
969        self.accumulation = []
970
971      @unittest.expectedFailure
972      @test_util.assert_no_new_pyobjects_executing_eagerly
973      def test_has_leak(self):
974        self.accumulation.append([1.])
975
976      @test_util.assert_no_new_pyobjects_executing_eagerly
977      def test_has_no_leak(self):
978        self.not_accumulating = [1.]
979
980    self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful())
981    self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful())
982
983
984class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,
985                                  parameterized.TestCase):
986  @parameterized.named_parameters(
987      [("_RunEagerly", True), ("_RunGraph", False)])
988  def test_run_functions_eagerly(self, run_eagerly):  # pylint: disable=g-wrong-blank-lines
989    results = []
990
991    @def_function.function
992    def add_two(x):
993      for _ in range(5):
994        x += 2
995        results.append(x)
996      return x
997
998    with test_util.run_functions_eagerly(run_eagerly):
999      add_two(constant_op.constant(2.))
1000      if context.executing_eagerly():
1001        if run_eagerly:
1002          self.assertTrue(isinstance(t, ops.EagerTensor) for t in results)
1003        else:
1004          self.assertTrue(isinstance(t, ops.Tensor) for t in results)
1005      else:
1006        self.assertTrue(isinstance(t, ops.Tensor) for t in results)
1007
1008
1009if __name__ == "__main__":
1010  googletest.main()
1011