• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.test_util."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import random
24import threading
25import unittest
26import weakref
27
28from absl.testing import parameterized
29import numpy as np
30
31from google.protobuf import text_format
32
33from tensorflow.core.framework import graph_pb2
34from tensorflow.core.protobuf import meta_graph_pb2
35from tensorflow.python import pywrap_sanitizers
36from tensorflow.python.compat import compat
37from tensorflow.python.eager import context
38from tensorflow.python.eager import def_function
39from tensorflow.python.framework import combinations
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import errors
43from tensorflow.python.framework import ops
44from tensorflow.python.framework import random_seed
45from tensorflow.python.framework import test_ops  # pylint: disable=unused-import
46from tensorflow.python.framework import test_util
47from tensorflow.python.ops import control_flow_ops
48from tensorflow.python.ops import lookup_ops
49from tensorflow.python.ops import math_ops
50from tensorflow.python.ops import random_ops
51from tensorflow.python.ops import resource_variable_ops
52from tensorflow.python.ops import variable_scope
53from tensorflow.python.ops import variables
54from tensorflow.python.platform import googletest
55
56
57class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
58
59  def test_assert_ops_in_graph(self):
60    with ops.Graph().as_default():
61      constant_op.constant(["hello", "taffy"], name="hello")
62      test_util.assert_ops_in_graph({"hello": "Const"}, ops.get_default_graph())
63
64      self.assertRaises(ValueError, test_util.assert_ops_in_graph,
65                        {"bye": "Const"}, ops.get_default_graph())
66
67      self.assertRaises(ValueError, test_util.assert_ops_in_graph,
68                        {"hello": "Variable"}, ops.get_default_graph())
69
70  @test_util.run_deprecated_v1
71  def test_session_functions(self):
72    with self.test_session() as sess:
73      sess_ref = weakref.ref(sess)
74      with self.cached_session(graph=None, config=None) as sess2:
75        # We make sure that sess2 is sess.
76        assert sess2 is sess
77        # We make sure we raise an exception if we use cached_session with
78        # different values.
79        with self.assertRaises(ValueError):
80          with self.cached_session(graph=ops.Graph()) as sess2:
81            pass
82        with self.assertRaises(ValueError):
83          with self.cached_session(force_gpu=True) as sess2:
84            pass
85    # We make sure that test_session will cache the session even after the
86    # with scope.
87    assert not sess_ref()._closed
88    with self.session() as unique_sess:
89      unique_sess_ref = weakref.ref(unique_sess)
90      with self.session() as sess2:
91        assert sess2 is not unique_sess
92    # We make sure the session is closed when we leave the with statement.
93    assert unique_sess_ref()._closed
94
95  def test_assert_equal_graph_def(self):
96    with ops.Graph().as_default() as g:
97      def_empty = g.as_graph_def()
98      constant_op.constant(5, name="five")
99      constant_op.constant(7, name="seven")
100      def_57 = g.as_graph_def()
101    with ops.Graph().as_default() as g:
102      constant_op.constant(7, name="seven")
103      constant_op.constant(5, name="five")
104      def_75 = g.as_graph_def()
105    # Comparing strings is order dependent
106    self.assertNotEqual(str(def_57), str(def_75))
107    # assert_equal_graph_def doesn't care about order
108    test_util.assert_equal_graph_def(def_57, def_75)
109    # Compare two unequal graphs
110    with self.assertRaisesRegex(AssertionError,
111                                r"^Found unexpected node '{{node seven}}"):
112      test_util.assert_equal_graph_def(def_57, def_empty)
113
114  def test_assert_equal_graph_def_hash_table(self):
115    def get_graph_def():
116      with ops.Graph().as_default() as g:
117        x = constant_op.constant([2, 9], name="x")
118        keys = constant_op.constant([1, 2], name="keys")
119        values = constant_op.constant([3, 4], name="values")
120        default = constant_op.constant(-1, name="default")
121        table = lookup_ops.StaticHashTable(
122            lookup_ops.KeyValueTensorInitializer(keys, values), default)
123        _ = table.lookup(x)
124      return g.as_graph_def()
125    def_1 = get_graph_def()
126    def_2 = get_graph_def()
127    # The unique shared_name of each table makes the graph unequal.
128    with self.assertRaisesRegex(AssertionError, "hash_table_"):
129      test_util.assert_equal_graph_def(def_1, def_2,
130                                       hash_table_shared_name=False)
131    # That can be ignored. (NOTE: modifies GraphDefs in-place.)
132    test_util.assert_equal_graph_def(def_1, def_2,
133                                     hash_table_shared_name=True)
134
135  def testIsGoogleCudaEnabled(self):
136    # The test doesn't assert anything. It ensures the py wrapper
137    # function is generated correctly.
138    if test_util.IsGoogleCudaEnabled():
139      print("GoogleCuda is enabled")
140    else:
141      print("GoogleCuda is disabled")
142
143  def testIsMklEnabled(self):
144    # This test doesn't assert anything.
145    # It ensures the py wrapper function is generated correctly.
146    if test_util.IsMklEnabled():
147      print("MKL is enabled")
148    else:
149      print("MKL is disabled")
150
151  @test_util.disable_asan("Skip test if ASAN is enabled.")
152  def testDisableAsan(self):
153    self.assertFalse(pywrap_sanitizers.is_asan_enabled())
154
155  @test_util.disable_msan("Skip test if MSAN is enabled.")
156  def testDisableMsan(self):
157    self.assertFalse(pywrap_sanitizers.is_msan_enabled())
158
159  @test_util.disable_tsan("Skip test if TSAN is enabled.")
160  def testDisableTsan(self):
161    self.assertFalse(pywrap_sanitizers.is_tsan_enabled())
162
163  @test_util.disable_ubsan("Skip test if UBSAN is enabled.")
164  def testDisableUbsan(self):
165    self.assertFalse(pywrap_sanitizers.is_ubsan_enabled())
166
167  @test_util.run_in_graph_and_eager_modes
168  def testAssertProtoEqualsStr(self):
169
170    graph_str = "node { name: 'w1' op: 'params' }"
171    graph_def = graph_pb2.GraphDef()
172    text_format.Merge(graph_str, graph_def)
173
174    # test string based comparison
175    self.assertProtoEquals(graph_str, graph_def)
176
177    # test original comparison
178    self.assertProtoEquals(graph_def, graph_def)
179
180  @test_util.run_in_graph_and_eager_modes
181  def testAssertProtoEqualsAny(self):
182    # Test assertProtoEquals with a protobuf.Any field.
183    meta_graph_def_str = """
184    meta_info_def {
185      meta_graph_version: "outer"
186      any_info {
187        [type.googleapis.com/tensorflow.MetaGraphDef] {
188          meta_info_def {
189            meta_graph_version: "inner"
190          }
191        }
192      }
193    }
194    """
195    meta_graph_def_outer = meta_graph_pb2.MetaGraphDef()
196    meta_graph_def_outer.meta_info_def.meta_graph_version = "outer"
197    meta_graph_def_inner = meta_graph_pb2.MetaGraphDef()
198    meta_graph_def_inner.meta_info_def.meta_graph_version = "inner"
199    meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner)
200    self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer)
201    self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer)
202
203    # Check if the assertion failure message contains the content of
204    # the inner proto.
205    with self.assertRaisesRegex(AssertionError, r'meta_graph_version: "inner"'):
206      self.assertProtoEquals("", meta_graph_def_outer)
207
208  @test_util.run_in_graph_and_eager_modes
209  def testNDArrayNear(self):
210    a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
211    a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
212    a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]])
213    self.assertTrue(self._NDArrayNear(a1, a2, 1e-5))
214    self.assertFalse(self._NDArrayNear(a1, a3, 1e-5))
215
216  @test_util.run_in_graph_and_eager_modes
217  def testCheckedThreadSucceeds(self):
218
219    def noop(ev):
220      ev.set()
221
222    event_arg = threading.Event()
223
224    self.assertFalse(event_arg.is_set())
225    t = self.checkedThread(target=noop, args=(event_arg,))
226    t.start()
227    t.join()
228    self.assertTrue(event_arg.is_set())
229
230  @test_util.run_in_graph_and_eager_modes
231  def testCheckedThreadFails(self):
232
233    def err_func():
234      return 1 // 0
235
236    t = self.checkedThread(target=err_func)
237    t.start()
238    with self.assertRaises(self.failureException) as fe:
239      t.join()
240    self.assertTrue("integer division or modulo by zero" in str(fe.exception))
241
242  @test_util.run_in_graph_and_eager_modes
243  def testCheckedThreadWithWrongAssertionFails(self):
244    x = 37
245
246    def err_func():
247      self.assertTrue(x < 10)
248
249    t = self.checkedThread(target=err_func)
250    t.start()
251    with self.assertRaises(self.failureException) as fe:
252      t.join()
253    self.assertTrue("False is not true" in str(fe.exception))
254
255  @test_util.run_in_graph_and_eager_modes
256  def testMultipleThreadsWithOneFailure(self):
257
258    def err_func(i):
259      self.assertTrue(i != 7)
260
261    threads = [
262        self.checkedThread(
263            target=err_func, args=(i,)) for i in range(10)
264    ]
265    for t in threads:
266      t.start()
267    for i, t in enumerate(threads):
268      if i == 7:
269        with self.assertRaises(self.failureException):
270          t.join()
271      else:
272        t.join()
273
274  def _WeMustGoDeeper(self, msg):
275    with self.assertRaisesOpError(msg):
276      with ops.Graph().as_default():
277        node_def = ops._NodeDef("IntOutput", "name")
278        node_def_orig = ops._NodeDef("IntOutput", "orig")
279        op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
280        op = ops.Operation(node_def, ops.get_default_graph(),
281                           original_op=op_orig)
282        raise errors.UnauthenticatedError(node_def, op, "true_err")
283
284  @test_util.run_in_graph_and_eager_modes
285  def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self):
286    with self.assertRaises(AssertionError):
287      self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for")
288
289    self._WeMustGoDeeper("true_err")
290    self._WeMustGoDeeper("name")
291    self._WeMustGoDeeper("orig")
292
293  @test_util.run_in_graph_and_eager_modes
294  def testAllCloseTensors(self):
295    a_raw_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
296    a = constant_op.constant(a_raw_data)
297    b = math_ops.add(1, constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
298    self.assertAllClose(a, b)
299    self.assertAllClose(a, a_raw_data)
300
301    a_dict = {"key": a}
302    b_dict = {"key": b}
303    self.assertAllClose(a_dict, b_dict)
304
305    x_list = [a, b]
306    y_list = [a_raw_data, b]
307    self.assertAllClose(x_list, y_list)
308
309  @test_util.run_in_graph_and_eager_modes
310  def testAllCloseScalars(self):
311    self.assertAllClose(7, 7 + 1e-8)
312    with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
313      self.assertAllClose(7, 7 + 1e-5)
314
315  @test_util.run_in_graph_and_eager_modes
316  def testAllCloseList(self):
317    with self.assertRaisesRegex(AssertionError, r"not close dif"):
318      self.assertAllClose([0], [1])
319
320  @test_util.run_in_graph_and_eager_modes
321  def testAllCloseDictToNonDict(self):
322    with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"):
323      self.assertAllClose(1, {"a": 1})
324    with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"):
325      self.assertAllClose({"a": 1}, 1)
326
327  @test_util.run_in_graph_and_eager_modes
328  def testAllCloseNamedtuples(self):
329    a = 7
330    b = (2., 3.)
331    c = np.ones((3, 2, 4)) * 7.
332    expected = {"a": a, "b": b, "c": c}
333    my_named_tuple = collections.namedtuple("MyNamedTuple", ["a", "b", "c"])
334
335    # Identity.
336    self.assertAllClose(expected, my_named_tuple(a=a, b=b, c=c))
337    self.assertAllClose(
338        my_named_tuple(a=a, b=b, c=c), my_named_tuple(a=a, b=b, c=c))
339
340  @test_util.run_in_graph_and_eager_modes
341  def testAllCloseDicts(self):
342    a = 7
343    b = (2., 3.)
344    c = np.ones((3, 2, 4)) * 7.
345    expected = {"a": a, "b": b, "c": c}
346
347    # Identity.
348    self.assertAllClose(expected, expected)
349    self.assertAllClose(expected, dict(expected))
350
351    # With each item removed.
352    for k in expected:
353      actual = dict(expected)
354      del actual[k]
355      with self.assertRaisesRegex(AssertionError, r"mismatched keys"):
356        self.assertAllClose(expected, actual)
357
358    # With each item changed.
359    with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
360      self.assertAllClose(expected, {"a": a + 1e-5, "b": b, "c": c})
361    with self.assertRaisesRegex(AssertionError, r"Shape mismatch"):
362      self.assertAllClose(expected, {"a": a, "b": b + (4.,), "c": c})
363    c_copy = np.array(c)
364    c_copy[1, 1, 1] += 1e-5
365    with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"):
366      self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy})
367
368  @test_util.run_in_graph_and_eager_modes
369  def testAllCloseListOfNamedtuples(self):
370    my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"])
371    l1 = [
372        my_named_tuple(x=np.array([[2.3, 2.5]]), y=np.array([[0.97, 0.96]])),
373        my_named_tuple(x=np.array([[3.3, 3.5]]), y=np.array([[0.98, 0.99]]))
374    ]
375    l2 = [
376        ([[2.3, 2.5]], [[0.97, 0.96]]),
377        ([[3.3, 3.5]], [[0.98, 0.99]]),
378    ]
379    self.assertAllClose(l1, l2)
380
381  @test_util.run_in_graph_and_eager_modes
382  def testAllCloseNestedStructure(self):
383    a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])}
384    self.assertAllClose(a, a)
385
386    b = copy.deepcopy(a)
387    self.assertAllClose(a, b)
388
389    # Test mismatched values
390    b["y"][1][0]["nested"]["n"] = 4.2
391    with self.assertRaisesRegex(AssertionError,
392                                r"\[y\]\[1\]\[0\]\[nested\]\[n\]"):
393      self.assertAllClose(a, b)
394
395  @test_util.run_in_graph_and_eager_modes
396  def testArrayNear(self):
397    a = [1, 2]
398    b = [1, 2, 5]
399    with self.assertRaises(AssertionError):
400      self.assertArrayNear(a, b, 0.001)
401    a = [1, 2]
402    b = [[1, 2], [3, 4]]
403    with self.assertRaises(TypeError):
404      self.assertArrayNear(a, b, 0.001)
405    a = [1, 2]
406    b = [1, 2]
407    self.assertArrayNear(a, b, 0.001)
408
409  @test_util.skip_if(True)  # b/117665998
410  def testForceGPU(self):
411    with self.assertRaises(errors.InvalidArgumentError):
412      with self.test_session(force_gpu=True):
413        # this relies on us not having a GPU implementation for assert, which
414        # seems sensible
415        x = constant_op.constant(True)
416        y = [15]
417        control_flow_ops.Assert(x, y).run()
418
419  @test_util.run_in_graph_and_eager_modes
420  def testAssertAllCloseAccordingToType(self):
421    # test plain int
422    self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8)
423
424    # test float64
425    self.assertAllCloseAccordingToType(
426        np.asarray([1e-8], dtype=np.float64),
427        np.asarray([2e-8], dtype=np.float64),
428        rtol=1e-8, atol=1e-8
429    )
430
431    self.assertAllCloseAccordingToType(
432        constant_op.constant([1e-8], dtype=dtypes.float64),
433        constant_op.constant([2e-8], dtype=dtypes.float64),
434        rtol=1e-8,
435        atol=1e-8)
436
437    with (self.assertRaises(AssertionError)):
438      self.assertAllCloseAccordingToType(
439          np.asarray([1e-7], dtype=np.float64),
440          np.asarray([2e-7], dtype=np.float64),
441          rtol=1e-8, atol=1e-8
442      )
443
444    # test float32
445    self.assertAllCloseAccordingToType(
446        np.asarray([1e-7], dtype=np.float32),
447        np.asarray([2e-7], dtype=np.float32),
448        rtol=1e-8, atol=1e-8,
449        float_rtol=1e-7, float_atol=1e-7
450    )
451
452    self.assertAllCloseAccordingToType(
453        constant_op.constant([1e-7], dtype=dtypes.float32),
454        constant_op.constant([2e-7], dtype=dtypes.float32),
455        rtol=1e-8,
456        atol=1e-8,
457        float_rtol=1e-7,
458        float_atol=1e-7)
459
460    with (self.assertRaises(AssertionError)):
461      self.assertAllCloseAccordingToType(
462          np.asarray([1e-6], dtype=np.float32),
463          np.asarray([2e-6], dtype=np.float32),
464          rtol=1e-8, atol=1e-8,
465          float_rtol=1e-7, float_atol=1e-7
466      )
467
468    # test float16
469    self.assertAllCloseAccordingToType(
470        np.asarray([1e-4], dtype=np.float16),
471        np.asarray([2e-4], dtype=np.float16),
472        rtol=1e-8, atol=1e-8,
473        float_rtol=1e-7, float_atol=1e-7,
474        half_rtol=1e-4, half_atol=1e-4
475    )
476
477    self.assertAllCloseAccordingToType(
478        constant_op.constant([1e-4], dtype=dtypes.float16),
479        constant_op.constant([2e-4], dtype=dtypes.float16),
480        rtol=1e-8,
481        atol=1e-8,
482        float_rtol=1e-7,
483        float_atol=1e-7,
484        half_rtol=1e-4,
485        half_atol=1e-4)
486
487    with (self.assertRaises(AssertionError)):
488      self.assertAllCloseAccordingToType(
489          np.asarray([1e-3], dtype=np.float16),
490          np.asarray([2e-3], dtype=np.float16),
491          rtol=1e-8, atol=1e-8,
492          float_rtol=1e-7, float_atol=1e-7,
493          half_rtol=1e-4, half_atol=1e-4
494      )
495
496  @test_util.run_in_graph_and_eager_modes
497  def testAssertAllEqual(self):
498    i = variables.Variable([100] * 3, dtype=dtypes.int32, name="i")
499    j = constant_op.constant([20] * 3, dtype=dtypes.int32, name="j")
500    k = math_ops.add(i, j, name="k")
501
502    self.evaluate(variables.global_variables_initializer())
503    self.assertAllEqual([100] * 3, i)
504    self.assertAllEqual([120] * 3, k)
505    self.assertAllEqual([20] * 3, j)
506
507    with self.assertRaisesRegex(AssertionError, r"not equal lhs"):
508      self.assertAllEqual([0] * 3, k)
509
510  @test_util.run_in_graph_and_eager_modes
511  def testAssertNotAllEqual(self):
512    i = variables.Variable([100], dtype=dtypes.int32, name="i")
513    j = constant_op.constant([20], dtype=dtypes.int32, name="j")
514    k = math_ops.add(i, j, name="k")
515
516    self.evaluate(variables.global_variables_initializer())
517    self.assertNotAllEqual([100] * 3, i)
518    self.assertNotAllEqual([120] * 3, k)
519    self.assertNotAllEqual([20] * 3, j)
520
521    with self.assertRaisesRegex(
522        AssertionError, r"two values are equal at all elements.*extra message"):
523      self.assertNotAllEqual([120], k, msg="extra message")
524
525  @test_util.run_in_graph_and_eager_modes
526  def testAssertNotAllClose(self):
527    # Test with arrays
528    self.assertNotAllClose([0.1], [0.2])
529    with self.assertRaises(AssertionError):
530      self.assertNotAllClose([-1.0, 2.0], [-1.0, 2.0])
531
532    # Test with tensors
533    x = constant_op.constant([1.0, 1.0], name="x")
534    y = math_ops.add(x, x)
535
536    self.assertAllClose([2.0, 2.0], y)
537    self.assertNotAllClose([0.9, 1.0], x)
538
539    with self.assertRaises(AssertionError):
540      self.assertNotAllClose([1.0, 1.0], x)
541
542  @test_util.run_in_graph_and_eager_modes
543  def testAssertNotAllCloseRTol(self):
544    # Test with arrays
545    with self.assertRaises(AssertionError):
546      self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], rtol=0.2)
547
548    # Test with tensors
549    x = constant_op.constant([1.0, 1.0], name="x")
550    y = math_ops.add(x, x)
551
552    self.assertAllClose([2.0, 2.0], y)
553
554    with self.assertRaises(AssertionError):
555      self.assertNotAllClose([0.9, 1.0], x, rtol=0.2)
556
557  @test_util.run_in_graph_and_eager_modes
558  def testAssertNotAllCloseATol(self):
559    # Test with arrays
560    with self.assertRaises(AssertionError):
561      self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], atol=0.2)
562
563    # Test with tensors
564    x = constant_op.constant([1.0, 1.0], name="x")
565    y = math_ops.add(x, x)
566
567    self.assertAllClose([2.0, 2.0], y)
568
569    with self.assertRaises(AssertionError):
570      self.assertNotAllClose([0.9, 1.0], x, atol=0.2)
571
572  @test_util.run_in_graph_and_eager_modes
573  def testAssertAllGreaterLess(self):
574    x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
575    y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
576    z = math_ops.add(x, y)
577
578    self.assertAllClose([110.0, 120.0, 130.0], z)
579
580    self.assertAllGreater(x, 95.0)
581    self.assertAllLess(x, 125.0)
582
583    with self.assertRaises(AssertionError):
584      self.assertAllGreater(x, 105.0)
585    with self.assertRaises(AssertionError):
586      self.assertAllGreater(x, 125.0)
587
588    with self.assertRaises(AssertionError):
589      self.assertAllLess(x, 115.0)
590    with self.assertRaises(AssertionError):
591      self.assertAllLess(x, 95.0)
592
593  @test_util.run_in_graph_and_eager_modes
594  def testAssertAllGreaterLessEqual(self):
595    x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
596    y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
597    z = math_ops.add(x, y)
598
599    self.assertAllEqual([110.0, 120.0, 130.0], z)
600
601    self.assertAllGreaterEqual(x, 95.0)
602    self.assertAllLessEqual(x, 125.0)
603
604    with self.assertRaises(AssertionError):
605      self.assertAllGreaterEqual(x, 105.0)
606    with self.assertRaises(AssertionError):
607      self.assertAllGreaterEqual(x, 125.0)
608
609    with self.assertRaises(AssertionError):
610      self.assertAllLessEqual(x, 115.0)
611    with self.assertRaises(AssertionError):
612      self.assertAllLessEqual(x, 95.0)
613
614  def testAssertAllInRangeWithNonNumericValuesFails(self):
615    s1 = constant_op.constant("Hello, ", name="s1")
616    c = constant_op.constant([1 + 2j, -3 + 5j], name="c")
617    b = constant_op.constant([False, True], name="b")
618
619    with self.assertRaises(AssertionError):
620      self.assertAllInRange(s1, 0.0, 1.0)
621    with self.assertRaises(AssertionError):
622      self.assertAllInRange(c, 0.0, 1.0)
623    with self.assertRaises(AssertionError):
624      self.assertAllInRange(b, 0, 1)
625
626  @test_util.run_in_graph_and_eager_modes
627  def testAssertAllInRange(self):
628    x = constant_op.constant([10.0, 15.0], name="x")
629    self.assertAllInRange(x, 10, 15)
630
631    with self.assertRaises(AssertionError):
632      self.assertAllInRange(x, 10, 15, open_lower_bound=True)
633    with self.assertRaises(AssertionError):
634      self.assertAllInRange(x, 10, 15, open_upper_bound=True)
635    with self.assertRaises(AssertionError):
636      self.assertAllInRange(
637          x, 10, 15, open_lower_bound=True, open_upper_bound=True)
638
639  @test_util.run_in_graph_and_eager_modes
640  def testAssertAllInRangeScalar(self):
641    x = constant_op.constant(10.0, name="x")
642    nan = constant_op.constant(np.nan, name="nan")
643    self.assertAllInRange(x, 5, 15)
644    with self.assertRaises(AssertionError):
645      self.assertAllInRange(nan, 5, 15)
646
647    with self.assertRaises(AssertionError):
648      self.assertAllInRange(x, 10, 15, open_lower_bound=True)
649    with self.assertRaises(AssertionError):
650      self.assertAllInRange(x, 1, 2)
651
652  @test_util.run_in_graph_and_eager_modes
653  def testAssertAllInRangeErrorMessageEllipses(self):
654    x_init = np.array([[10.0, 15.0]] * 12)
655    x = constant_op.constant(x_init, name="x")
656    with self.assertRaises(AssertionError):
657      self.assertAllInRange(x, 5, 10)
658
659  @test_util.run_in_graph_and_eager_modes
660  def testAssertAllInRangeDetectsNaNs(self):
661    x = constant_op.constant(
662        [[np.nan, 0.0], [np.nan, np.inf], [np.inf, np.nan]], name="x")
663    with self.assertRaises(AssertionError):
664      self.assertAllInRange(x, 0.0, 2.0)
665
666  @test_util.run_in_graph_and_eager_modes
667  def testAssertAllInRangeWithInfinities(self):
668    x = constant_op.constant([10.0, np.inf], name="x")
669    self.assertAllInRange(x, 10, np.inf)
670    with self.assertRaises(AssertionError):
671      self.assertAllInRange(x, 10, np.inf, open_upper_bound=True)
672
673  @test_util.run_in_graph_and_eager_modes
674  def testAssertAllInSet(self):
675    b = constant_op.constant([True, False], name="b")
676    x = constant_op.constant([13, 37], name="x")
677
678    self.assertAllInSet(b, [False, True])
679    self.assertAllInSet(b, (False, True))
680    self.assertAllInSet(b, {False, True})
681    self.assertAllInSet(x, [0, 13, 37, 42])
682    self.assertAllInSet(x, (0, 13, 37, 42))
683    self.assertAllInSet(x, {0, 13, 37, 42})
684
685    with self.assertRaises(AssertionError):
686      self.assertAllInSet(b, [False])
687    with self.assertRaises(AssertionError):
688      self.assertAllInSet(x, (42,))
689
690  def testRandomSeed(self):
691    # Call setUp again for WithCApi case (since it makes a new default graph
692    # after setup).
693    # TODO(skyewm): remove this when C API is permanently enabled.
694    with context.eager_mode():
695      self.setUp()
696      a = random.randint(1, 1000)
697      a_np_rand = np.random.rand(1)
698      a_rand = random_ops.random_normal([1])
699      # ensure that randomness in multiple testCases is deterministic.
700      self.setUp()
701      b = random.randint(1, 1000)
702      b_np_rand = np.random.rand(1)
703      b_rand = random_ops.random_normal([1])
704      self.assertEqual(a, b)
705      self.assertEqual(a_np_rand, b_np_rand)
706      self.assertAllEqual(a_rand, b_rand)
707
708  @test_util.run_in_graph_and_eager_modes
709  def test_callable_evaluate(self):
710    def model():
711      return resource_variable_ops.ResourceVariable(
712          name="same_name",
713          initial_value=1) + 1
714    with context.eager_mode():
715      self.assertEqual(2, self.evaluate(model))
716
717  @test_util.run_in_graph_and_eager_modes
718  def test_nested_tensors_evaluate(self):
719    expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}}
720    nested = {"a": constant_op.constant(1),
721              "b": constant_op.constant(2),
722              "nested": {"d": constant_op.constant(3),
723                         "e": constant_op.constant(4)}}
724
725    self.assertEqual(expected, self.evaluate(nested))
726
727  def test_run_in_graph_and_eager_modes(self):
728    l = []
729    def inc(self, with_brackets):
730      del self  # self argument is required by run_in_graph_and_eager_modes.
731      mode = "eager" if context.executing_eagerly() else "graph"
732      with_brackets = "with_brackets" if with_brackets else "without_brackets"
733      l.append((with_brackets, mode))
734
735    f = test_util.run_in_graph_and_eager_modes(inc)
736    f(self, with_brackets=False)
737    f = test_util.run_in_graph_and_eager_modes()(inc)  # pylint: disable=assignment-from-no-return
738    f(self, with_brackets=True)
739
740    self.assertEqual(len(l), 4)
741    self.assertEqual(set(l), {
742        ("with_brackets", "graph"),
743        ("with_brackets", "eager"),
744        ("without_brackets", "graph"),
745        ("without_brackets", "eager"),
746    })
747
748  def test_get_node_def_from_graph(self):
749    graph_def = graph_pb2.GraphDef()
750    node_foo = graph_def.node.add()
751    node_foo.name = "foo"
752    self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo)
753    self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
754
755  def test_run_in_eager_and_graph_modes_test_class(self):
756    msg = "`run_in_graph_and_eager_modes` only supports test methods.*"
757    with self.assertRaisesRegex(ValueError, msg):
758
759      @test_util.run_in_graph_and_eager_modes()
760      class Foo(object):
761        pass
762      del Foo  # Make pylint unused happy.
763
764  def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self):
765    modes = []
766    def _test(self):
767      if not context.executing_eagerly():
768        self.skipTest("Skipping in graph mode")
769      modes.append("eager" if context.executing_eagerly() else "graph")
770    test_util.run_in_graph_and_eager_modes(_test)(self)
771    self.assertEqual(modes, ["eager"])
772
773  def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self):
774    modes = []
775    def _test(self):
776      if context.executing_eagerly():
777        self.skipTest("Skipping in eager mode")
778      modes.append("eager" if context.executing_eagerly() else "graph")
779    test_util.run_in_graph_and_eager_modes(_test)(self)
780    self.assertEqual(modes, ["graph"])
781
782  def test_run_in_graph_and_eager_modes_setup_in_same_mode(self):
783    modes = []
784    mode_name = lambda: "eager" if context.executing_eagerly() else "graph"
785
786    class ExampleTest(test_util.TensorFlowTestCase):
787
788      def runTest(self):
789        pass
790
791      def setUp(self):
792        modes.append("setup_" + mode_name())
793
794      @test_util.run_in_graph_and_eager_modes
795      def testBody(self):
796        modes.append("run_" + mode_name())
797
798    e = ExampleTest()
799    e.setUp()
800    e.testBody()
801
802    self.assertEqual(modes[1:2], ["run_graph"])
803    self.assertEqual(modes[2:], ["setup_eager", "run_eager"])
804
805  @parameterized.named_parameters(dict(testcase_name="argument",
806                                       arg=True))
807  @test_util.run_in_graph_and_eager_modes
808  def test_run_in_graph_and_eager_works_with_parameterized_keyword(self, arg):
809    self.assertEqual(arg, True)
810
811  @combinations.generate(combinations.combine(arg=True))
812  @test_util.run_in_graph_and_eager_modes
813  def test_run_in_graph_and_eager_works_with_combinations(self, arg):
814    self.assertEqual(arg, True)
815
816  def test_build_as_function_and_v1_graph(self):
817
818    class GraphModeAndFunctionTest(parameterized.TestCase):
819
820      def __init__(inner_self):  # pylint: disable=no-self-argument
821        super(GraphModeAndFunctionTest, inner_self).__init__()
822        inner_self.graph_mode_tested = False
823        inner_self.inside_function_tested = False
824
825      def runTest(self):
826        del self
827
828      @test_util.build_as_function_and_v1_graph
829      def test_modes(inner_self):  # pylint: disable=no-self-argument
830        if ops.inside_function():
831          self.assertFalse(inner_self.inside_function_tested)
832          inner_self.inside_function_tested = True
833        else:
834          self.assertFalse(inner_self.graph_mode_tested)
835          inner_self.graph_mode_tested = True
836
837    test_object = GraphModeAndFunctionTest()
838    test_object.test_modes_v1_graph()
839    test_object.test_modes_function()
840    self.assertTrue(test_object.graph_mode_tested)
841    self.assertTrue(test_object.inside_function_tested)
842
843  @test_util.run_in_graph_and_eager_modes
844  def test_consistent_random_seed_in_assert_all_equal(self):
845    random_seed.set_seed(1066)
846    index = random_ops.random_shuffle([0, 1, 2, 3, 4], seed=2021)
847    # This failed when `a` and `b` were evaluated in separate sessions.
848    self.assertAllEqual(index, index)
849
850  def test_with_forward_compatibility_horizons(self):
851
852    tested_codepaths = set()
853    def some_function_with_forward_compat_behavior():
854      if compat.forward_compatible(2050, 1, 1):
855        tested_codepaths.add("future")
856      else:
857        tested_codepaths.add("present")
858
859    @test_util.with_forward_compatibility_horizons(None, [2051, 1, 1])
860    def some_test(self):
861      del self  # unused
862      some_function_with_forward_compat_behavior()
863
864    some_test(None)
865    self.assertEqual(tested_codepaths, set(["present", "future"]))
866
867
868class SkipTestTest(test_util.TensorFlowTestCase):
869
870  def _verify_test_in_set_up_or_tear_down(self):
871    with self.assertRaises(unittest.SkipTest):
872      with test_util.skip_if_error(self, ValueError,
873                                   ["foo bar", "test message"]):
874        raise ValueError("test message")
875    try:
876      with self.assertRaisesRegex(ValueError, "foo bar"):
877        with test_util.skip_if_error(self, ValueError, "test message"):
878          raise ValueError("foo bar")
879    except unittest.SkipTest:
880      raise RuntimeError("Test is not supposed to skip.")
881
882  def setUp(self):
883    super(SkipTestTest, self).setUp()
884    self._verify_test_in_set_up_or_tear_down()
885
886  def tearDown(self):
887    super(SkipTestTest, self).tearDown()
888    self._verify_test_in_set_up_or_tear_down()
889
890  def test_skip_if_error_should_skip(self):
891    with self.assertRaises(unittest.SkipTest):
892      with test_util.skip_if_error(self, ValueError, "test message"):
893        raise ValueError("test message")
894
895  def test_skip_if_error_should_skip_with_list(self):
896    with self.assertRaises(unittest.SkipTest):
897      with test_util.skip_if_error(self, ValueError,
898                                   ["foo bar", "test message"]):
899        raise ValueError("test message")
900
901  def test_skip_if_error_should_skip_without_expected_message(self):
902    with self.assertRaises(unittest.SkipTest):
903      with test_util.skip_if_error(self, ValueError):
904        raise ValueError("test message")
905
906  def test_skip_if_error_should_skip_without_error_message(self):
907    with self.assertRaises(unittest.SkipTest):
908      with test_util.skip_if_error(self, ValueError):
909        raise ValueError()
910
911  def test_skip_if_error_should_raise_message_mismatch(self):
912    try:
913      with self.assertRaisesRegex(ValueError, "foo bar"):
914        with test_util.skip_if_error(self, ValueError, "test message"):
915          raise ValueError("foo bar")
916    except unittest.SkipTest:
917      raise RuntimeError("Test is not supposed to skip.")
918
919  def test_skip_if_error_should_raise_no_message(self):
920    try:
921      with self.assertRaisesRegex(ValueError, ""):
922        with test_util.skip_if_error(self, ValueError, "test message"):
923          raise ValueError()
924    except unittest.SkipTest:
925      raise RuntimeError("Test is not supposed to skip.")
926
927
928# Its own test case to reproduce variable sharing issues which only pop up when
929# setUp() is overridden and super() is not called.
930class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase):
931
932  def setUp(self):
933    pass  # Intentionally does not call TensorFlowTestCase's super()
934
935  @test_util.run_in_graph_and_eager_modes
936  def test_no_variable_sharing(self):
937    variable_scope.get_variable(
938        name="step_size",
939        initializer=np.array(1e-5, np.float32),
940        use_resource=True,
941        trainable=False)
942
943
944class GarbageCollectionTest(test_util.TensorFlowTestCase):
945
946  def test_no_reference_cycle_decorator(self):
947
948    class ReferenceCycleTest(object):
949
950      def __init__(inner_self):  # pylint: disable=no-self-argument
951        inner_self.assertEqual = self.assertEqual  # pylint: disable=invalid-name
952
953      @test_util.assert_no_garbage_created
954      def test_has_cycle(self):
955        a = []
956        a.append(a)
957
958      @test_util.assert_no_garbage_created
959      def test_has_no_cycle(self):
960        pass
961
962    with self.assertRaises(AssertionError):
963      ReferenceCycleTest().test_has_cycle()
964
965    ReferenceCycleTest().test_has_no_cycle()
966
967  @test_util.run_in_graph_and_eager_modes
968  def test_no_leaked_tensor_decorator(self):
969
970    class LeakedTensorTest(object):
971
972      def __init__(inner_self):  # pylint: disable=no-self-argument
973        inner_self.assertEqual = self.assertEqual  # pylint: disable=invalid-name
974
975      @test_util.assert_no_new_tensors
976      def test_has_leak(self):
977        self.a = constant_op.constant([3.], name="leak")
978
979      @test_util.assert_no_new_tensors
980      def test_has_no_leak(self):
981        constant_op.constant([3.], name="no-leak")
982
983    with self.assertRaisesRegex(AssertionError, "Tensors not deallocated"):
984      LeakedTensorTest().test_has_leak()
985
986    LeakedTensorTest().test_has_no_leak()
987
988  def test_no_new_objects_decorator(self):
989
990    class LeakedObjectTest(unittest.TestCase):
991
992      def __init__(self, *args, **kwargs):
993        super(LeakedObjectTest, self).__init__(*args, **kwargs)
994        self.accumulation = []
995
996      @unittest.expectedFailure
997      @test_util.assert_no_new_pyobjects_executing_eagerly
998      def test_has_leak(self):
999        self.accumulation.append([1.])
1000
1001      @test_util.assert_no_new_pyobjects_executing_eagerly
1002      def test_has_no_leak(self):
1003        self.not_accumulating = [1.]
1004
1005    self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful())
1006    self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful())
1007
1008
1009class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase,
1010                                  parameterized.TestCase):
1011  @parameterized.named_parameters(
1012      [("_RunEagerly", True), ("_RunGraph", False)])
1013  def test_run_functions_eagerly(self, run_eagerly):  # pylint: disable=g-wrong-blank-lines
1014    results = []
1015
1016    @def_function.function
1017    def add_two(x):
1018      for _ in range(5):
1019        x += 2
1020        results.append(x)
1021      return x
1022
1023    with test_util.run_functions_eagerly(run_eagerly):
1024      add_two(constant_op.constant(2.))
1025      if context.executing_eagerly():
1026        if run_eagerly:
1027          self.assertTrue(isinstance(t, ops.EagerTensor) for t in results)
1028        else:
1029          self.assertTrue(isinstance(t, ops.Tensor) for t in results)
1030      else:
1031        self.assertTrue(isinstance(t, ops.Tensor) for t in results)
1032
1033
1034if __name__ == "__main__":
1035  googletest.main()
1036