1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.test_util.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import random 24import threading 25import unittest 26import weakref 27 28from absl.testing import parameterized 29import numpy as np 30 31from google.protobuf import text_format 32 33from tensorflow.core.framework import graph_pb2 34from tensorflow.core.protobuf import meta_graph_pb2 35from tensorflow.python import pywrap_sanitizers 36from tensorflow.python.compat import compat 37from tensorflow.python.eager import context 38from tensorflow.python.eager import def_function 39from tensorflow.python.framework import combinations 40from tensorflow.python.framework import constant_op 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import errors 43from tensorflow.python.framework import ops 44from tensorflow.python.framework import random_seed 45from tensorflow.python.framework import test_ops # pylint: disable=unused-import 46from tensorflow.python.framework import test_util 47from tensorflow.python.ops import control_flow_ops 48from tensorflow.python.ops import lookup_ops 49from tensorflow.python.ops import math_ops 50from tensorflow.python.ops import random_ops 51from tensorflow.python.ops import resource_variable_ops 52from tensorflow.python.ops import variable_scope 53from tensorflow.python.ops import variables 54from tensorflow.python.platform import googletest 55 56 57class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): 58 59 def test_assert_ops_in_graph(self): 60 with ops.Graph().as_default(): 61 constant_op.constant(["hello", "taffy"], name="hello") 62 test_util.assert_ops_in_graph({"hello": "Const"}, ops.get_default_graph()) 63 64 self.assertRaises(ValueError, test_util.assert_ops_in_graph, 65 {"bye": "Const"}, ops.get_default_graph()) 66 67 self.assertRaises(ValueError, test_util.assert_ops_in_graph, 68 {"hello": "Variable"}, ops.get_default_graph()) 69 70 @test_util.run_deprecated_v1 71 def test_session_functions(self): 72 with self.test_session() as sess: 73 sess_ref = weakref.ref(sess) 74 with self.cached_session(graph=None, config=None) as sess2: 75 # We make sure that sess2 is sess. 76 assert sess2 is sess 77 # We make sure we raise an exception if we use cached_session with 78 # different values. 79 with self.assertRaises(ValueError): 80 with self.cached_session(graph=ops.Graph()) as sess2: 81 pass 82 with self.assertRaises(ValueError): 83 with self.cached_session(force_gpu=True) as sess2: 84 pass 85 # We make sure that test_session will cache the session even after the 86 # with scope. 87 assert not sess_ref()._closed 88 with self.session() as unique_sess: 89 unique_sess_ref = weakref.ref(unique_sess) 90 with self.session() as sess2: 91 assert sess2 is not unique_sess 92 # We make sure the session is closed when we leave the with statement. 93 assert unique_sess_ref()._closed 94 95 def test_assert_equal_graph_def(self): 96 with ops.Graph().as_default() as g: 97 def_empty = g.as_graph_def() 98 constant_op.constant(5, name="five") 99 constant_op.constant(7, name="seven") 100 def_57 = g.as_graph_def() 101 with ops.Graph().as_default() as g: 102 constant_op.constant(7, name="seven") 103 constant_op.constant(5, name="five") 104 def_75 = g.as_graph_def() 105 # Comparing strings is order dependent 106 self.assertNotEqual(str(def_57), str(def_75)) 107 # assert_equal_graph_def doesn't care about order 108 test_util.assert_equal_graph_def(def_57, def_75) 109 # Compare two unequal graphs 110 with self.assertRaisesRegex(AssertionError, 111 r"^Found unexpected node '{{node seven}}"): 112 test_util.assert_equal_graph_def(def_57, def_empty) 113 114 def test_assert_equal_graph_def_hash_table(self): 115 def get_graph_def(): 116 with ops.Graph().as_default() as g: 117 x = constant_op.constant([2, 9], name="x") 118 keys = constant_op.constant([1, 2], name="keys") 119 values = constant_op.constant([3, 4], name="values") 120 default = constant_op.constant(-1, name="default") 121 table = lookup_ops.StaticHashTable( 122 lookup_ops.KeyValueTensorInitializer(keys, values), default) 123 _ = table.lookup(x) 124 return g.as_graph_def() 125 def_1 = get_graph_def() 126 def_2 = get_graph_def() 127 # The unique shared_name of each table makes the graph unequal. 128 with self.assertRaisesRegex(AssertionError, "hash_table_"): 129 test_util.assert_equal_graph_def(def_1, def_2, 130 hash_table_shared_name=False) 131 # That can be ignored. (NOTE: modifies GraphDefs in-place.) 132 test_util.assert_equal_graph_def(def_1, def_2, 133 hash_table_shared_name=True) 134 135 def testIsGoogleCudaEnabled(self): 136 # The test doesn't assert anything. It ensures the py wrapper 137 # function is generated correctly. 138 if test_util.IsGoogleCudaEnabled(): 139 print("GoogleCuda is enabled") 140 else: 141 print("GoogleCuda is disabled") 142 143 def testIsMklEnabled(self): 144 # This test doesn't assert anything. 145 # It ensures the py wrapper function is generated correctly. 146 if test_util.IsMklEnabled(): 147 print("MKL is enabled") 148 else: 149 print("MKL is disabled") 150 151 @test_util.disable_asan("Skip test if ASAN is enabled.") 152 def testDisableAsan(self): 153 self.assertFalse(pywrap_sanitizers.is_asan_enabled()) 154 155 @test_util.disable_msan("Skip test if MSAN is enabled.") 156 def testDisableMsan(self): 157 self.assertFalse(pywrap_sanitizers.is_msan_enabled()) 158 159 @test_util.disable_tsan("Skip test if TSAN is enabled.") 160 def testDisableTsan(self): 161 self.assertFalse(pywrap_sanitizers.is_tsan_enabled()) 162 163 @test_util.disable_ubsan("Skip test if UBSAN is enabled.") 164 def testDisableUbsan(self): 165 self.assertFalse(pywrap_sanitizers.is_ubsan_enabled()) 166 167 @test_util.run_in_graph_and_eager_modes 168 def testAssertProtoEqualsStr(self): 169 170 graph_str = "node { name: 'w1' op: 'params' }" 171 graph_def = graph_pb2.GraphDef() 172 text_format.Merge(graph_str, graph_def) 173 174 # test string based comparison 175 self.assertProtoEquals(graph_str, graph_def) 176 177 # test original comparison 178 self.assertProtoEquals(graph_def, graph_def) 179 180 @test_util.run_in_graph_and_eager_modes 181 def testAssertProtoEqualsAny(self): 182 # Test assertProtoEquals with a protobuf.Any field. 183 meta_graph_def_str = """ 184 meta_info_def { 185 meta_graph_version: "outer" 186 any_info { 187 [type.googleapis.com/tensorflow.MetaGraphDef] { 188 meta_info_def { 189 meta_graph_version: "inner" 190 } 191 } 192 } 193 } 194 """ 195 meta_graph_def_outer = meta_graph_pb2.MetaGraphDef() 196 meta_graph_def_outer.meta_info_def.meta_graph_version = "outer" 197 meta_graph_def_inner = meta_graph_pb2.MetaGraphDef() 198 meta_graph_def_inner.meta_info_def.meta_graph_version = "inner" 199 meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner) 200 self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer) 201 self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer) 202 203 # Check if the assertion failure message contains the content of 204 # the inner proto. 205 with self.assertRaisesRegex(AssertionError, r'meta_graph_version: "inner"'): 206 self.assertProtoEquals("", meta_graph_def_outer) 207 208 @test_util.run_in_graph_and_eager_modes 209 def testNDArrayNear(self): 210 a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 211 a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 212 a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]) 213 self.assertTrue(self._NDArrayNear(a1, a2, 1e-5)) 214 self.assertFalse(self._NDArrayNear(a1, a3, 1e-5)) 215 216 @test_util.run_in_graph_and_eager_modes 217 def testCheckedThreadSucceeds(self): 218 219 def noop(ev): 220 ev.set() 221 222 event_arg = threading.Event() 223 224 self.assertFalse(event_arg.is_set()) 225 t = self.checkedThread(target=noop, args=(event_arg,)) 226 t.start() 227 t.join() 228 self.assertTrue(event_arg.is_set()) 229 230 @test_util.run_in_graph_and_eager_modes 231 def testCheckedThreadFails(self): 232 233 def err_func(): 234 return 1 // 0 235 236 t = self.checkedThread(target=err_func) 237 t.start() 238 with self.assertRaises(self.failureException) as fe: 239 t.join() 240 self.assertTrue("integer division or modulo by zero" in str(fe.exception)) 241 242 @test_util.run_in_graph_and_eager_modes 243 def testCheckedThreadWithWrongAssertionFails(self): 244 x = 37 245 246 def err_func(): 247 self.assertTrue(x < 10) 248 249 t = self.checkedThread(target=err_func) 250 t.start() 251 with self.assertRaises(self.failureException) as fe: 252 t.join() 253 self.assertTrue("False is not true" in str(fe.exception)) 254 255 @test_util.run_in_graph_and_eager_modes 256 def testMultipleThreadsWithOneFailure(self): 257 258 def err_func(i): 259 self.assertTrue(i != 7) 260 261 threads = [ 262 self.checkedThread( 263 target=err_func, args=(i,)) for i in range(10) 264 ] 265 for t in threads: 266 t.start() 267 for i, t in enumerate(threads): 268 if i == 7: 269 with self.assertRaises(self.failureException): 270 t.join() 271 else: 272 t.join() 273 274 def _WeMustGoDeeper(self, msg): 275 with self.assertRaisesOpError(msg): 276 with ops.Graph().as_default(): 277 node_def = ops._NodeDef("IntOutput", "name") 278 node_def_orig = ops._NodeDef("IntOutput", "orig") 279 op_orig = ops.Operation(node_def_orig, ops.get_default_graph()) 280 op = ops.Operation(node_def, ops.get_default_graph(), 281 original_op=op_orig) 282 raise errors.UnauthenticatedError(node_def, op, "true_err") 283 284 @test_util.run_in_graph_and_eager_modes 285 def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self): 286 with self.assertRaises(AssertionError): 287 self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for") 288 289 self._WeMustGoDeeper("true_err") 290 self._WeMustGoDeeper("name") 291 self._WeMustGoDeeper("orig") 292 293 @test_util.run_in_graph_and_eager_modes 294 def testAllCloseTensors(self): 295 a_raw_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 296 a = constant_op.constant(a_raw_data) 297 b = math_ops.add(1, constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) 298 self.assertAllClose(a, b) 299 self.assertAllClose(a, a_raw_data) 300 301 a_dict = {"key": a} 302 b_dict = {"key": b} 303 self.assertAllClose(a_dict, b_dict) 304 305 x_list = [a, b] 306 y_list = [a_raw_data, b] 307 self.assertAllClose(x_list, y_list) 308 309 @test_util.run_in_graph_and_eager_modes 310 def testAllCloseScalars(self): 311 self.assertAllClose(7, 7 + 1e-8) 312 with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"): 313 self.assertAllClose(7, 7 + 1e-5) 314 315 @test_util.run_in_graph_and_eager_modes 316 def testAllCloseList(self): 317 with self.assertRaisesRegex(AssertionError, r"not close dif"): 318 self.assertAllClose([0], [1]) 319 320 @test_util.run_in_graph_and_eager_modes 321 def testAllCloseDictToNonDict(self): 322 with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"): 323 self.assertAllClose(1, {"a": 1}) 324 with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"): 325 self.assertAllClose({"a": 1}, 1) 326 327 @test_util.run_in_graph_and_eager_modes 328 def testAllCloseNamedtuples(self): 329 a = 7 330 b = (2., 3.) 331 c = np.ones((3, 2, 4)) * 7. 332 expected = {"a": a, "b": b, "c": c} 333 my_named_tuple = collections.namedtuple("MyNamedTuple", ["a", "b", "c"]) 334 335 # Identity. 336 self.assertAllClose(expected, my_named_tuple(a=a, b=b, c=c)) 337 self.assertAllClose( 338 my_named_tuple(a=a, b=b, c=c), my_named_tuple(a=a, b=b, c=c)) 339 340 @test_util.run_in_graph_and_eager_modes 341 def testAllCloseDicts(self): 342 a = 7 343 b = (2., 3.) 344 c = np.ones((3, 2, 4)) * 7. 345 expected = {"a": a, "b": b, "c": c} 346 347 # Identity. 348 self.assertAllClose(expected, expected) 349 self.assertAllClose(expected, dict(expected)) 350 351 # With each item removed. 352 for k in expected: 353 actual = dict(expected) 354 del actual[k] 355 with self.assertRaisesRegex(AssertionError, r"mismatched keys"): 356 self.assertAllClose(expected, actual) 357 358 # With each item changed. 359 with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"): 360 self.assertAllClose(expected, {"a": a + 1e-5, "b": b, "c": c}) 361 with self.assertRaisesRegex(AssertionError, r"Shape mismatch"): 362 self.assertAllClose(expected, {"a": a, "b": b + (4.,), "c": c}) 363 c_copy = np.array(c) 364 c_copy[1, 1, 1] += 1e-5 365 with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"): 366 self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy}) 367 368 @test_util.run_in_graph_and_eager_modes 369 def testAllCloseListOfNamedtuples(self): 370 my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"]) 371 l1 = [ 372 my_named_tuple(x=np.array([[2.3, 2.5]]), y=np.array([[0.97, 0.96]])), 373 my_named_tuple(x=np.array([[3.3, 3.5]]), y=np.array([[0.98, 0.99]])) 374 ] 375 l2 = [ 376 ([[2.3, 2.5]], [[0.97, 0.96]]), 377 ([[3.3, 3.5]], [[0.98, 0.99]]), 378 ] 379 self.assertAllClose(l1, l2) 380 381 @test_util.run_in_graph_and_eager_modes 382 def testAllCloseNestedStructure(self): 383 a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])} 384 self.assertAllClose(a, a) 385 386 b = copy.deepcopy(a) 387 self.assertAllClose(a, b) 388 389 # Test mismatched values 390 b["y"][1][0]["nested"]["n"] = 4.2 391 with self.assertRaisesRegex(AssertionError, 392 r"\[y\]\[1\]\[0\]\[nested\]\[n\]"): 393 self.assertAllClose(a, b) 394 395 @test_util.run_in_graph_and_eager_modes 396 def testArrayNear(self): 397 a = [1, 2] 398 b = [1, 2, 5] 399 with self.assertRaises(AssertionError): 400 self.assertArrayNear(a, b, 0.001) 401 a = [1, 2] 402 b = [[1, 2], [3, 4]] 403 with self.assertRaises(TypeError): 404 self.assertArrayNear(a, b, 0.001) 405 a = [1, 2] 406 b = [1, 2] 407 self.assertArrayNear(a, b, 0.001) 408 409 @test_util.skip_if(True) # b/117665998 410 def testForceGPU(self): 411 with self.assertRaises(errors.InvalidArgumentError): 412 with self.test_session(force_gpu=True): 413 # this relies on us not having a GPU implementation for assert, which 414 # seems sensible 415 x = constant_op.constant(True) 416 y = [15] 417 control_flow_ops.Assert(x, y).run() 418 419 @test_util.run_in_graph_and_eager_modes 420 def testAssertAllCloseAccordingToType(self): 421 # test plain int 422 self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8) 423 424 # test float64 425 self.assertAllCloseAccordingToType( 426 np.asarray([1e-8], dtype=np.float64), 427 np.asarray([2e-8], dtype=np.float64), 428 rtol=1e-8, atol=1e-8 429 ) 430 431 self.assertAllCloseAccordingToType( 432 constant_op.constant([1e-8], dtype=dtypes.float64), 433 constant_op.constant([2e-8], dtype=dtypes.float64), 434 rtol=1e-8, 435 atol=1e-8) 436 437 with (self.assertRaises(AssertionError)): 438 self.assertAllCloseAccordingToType( 439 np.asarray([1e-7], dtype=np.float64), 440 np.asarray([2e-7], dtype=np.float64), 441 rtol=1e-8, atol=1e-8 442 ) 443 444 # test float32 445 self.assertAllCloseAccordingToType( 446 np.asarray([1e-7], dtype=np.float32), 447 np.asarray([2e-7], dtype=np.float32), 448 rtol=1e-8, atol=1e-8, 449 float_rtol=1e-7, float_atol=1e-7 450 ) 451 452 self.assertAllCloseAccordingToType( 453 constant_op.constant([1e-7], dtype=dtypes.float32), 454 constant_op.constant([2e-7], dtype=dtypes.float32), 455 rtol=1e-8, 456 atol=1e-8, 457 float_rtol=1e-7, 458 float_atol=1e-7) 459 460 with (self.assertRaises(AssertionError)): 461 self.assertAllCloseAccordingToType( 462 np.asarray([1e-6], dtype=np.float32), 463 np.asarray([2e-6], dtype=np.float32), 464 rtol=1e-8, atol=1e-8, 465 float_rtol=1e-7, float_atol=1e-7 466 ) 467 468 # test float16 469 self.assertAllCloseAccordingToType( 470 np.asarray([1e-4], dtype=np.float16), 471 np.asarray([2e-4], dtype=np.float16), 472 rtol=1e-8, atol=1e-8, 473 float_rtol=1e-7, float_atol=1e-7, 474 half_rtol=1e-4, half_atol=1e-4 475 ) 476 477 self.assertAllCloseAccordingToType( 478 constant_op.constant([1e-4], dtype=dtypes.float16), 479 constant_op.constant([2e-4], dtype=dtypes.float16), 480 rtol=1e-8, 481 atol=1e-8, 482 float_rtol=1e-7, 483 float_atol=1e-7, 484 half_rtol=1e-4, 485 half_atol=1e-4) 486 487 with (self.assertRaises(AssertionError)): 488 self.assertAllCloseAccordingToType( 489 np.asarray([1e-3], dtype=np.float16), 490 np.asarray([2e-3], dtype=np.float16), 491 rtol=1e-8, atol=1e-8, 492 float_rtol=1e-7, float_atol=1e-7, 493 half_rtol=1e-4, half_atol=1e-4 494 ) 495 496 @test_util.run_in_graph_and_eager_modes 497 def testAssertAllEqual(self): 498 i = variables.Variable([100] * 3, dtype=dtypes.int32, name="i") 499 j = constant_op.constant([20] * 3, dtype=dtypes.int32, name="j") 500 k = math_ops.add(i, j, name="k") 501 502 self.evaluate(variables.global_variables_initializer()) 503 self.assertAllEqual([100] * 3, i) 504 self.assertAllEqual([120] * 3, k) 505 self.assertAllEqual([20] * 3, j) 506 507 with self.assertRaisesRegex(AssertionError, r"not equal lhs"): 508 self.assertAllEqual([0] * 3, k) 509 510 @test_util.run_in_graph_and_eager_modes 511 def testAssertNotAllEqual(self): 512 i = variables.Variable([100], dtype=dtypes.int32, name="i") 513 j = constant_op.constant([20], dtype=dtypes.int32, name="j") 514 k = math_ops.add(i, j, name="k") 515 516 self.evaluate(variables.global_variables_initializer()) 517 self.assertNotAllEqual([100] * 3, i) 518 self.assertNotAllEqual([120] * 3, k) 519 self.assertNotAllEqual([20] * 3, j) 520 521 with self.assertRaisesRegex( 522 AssertionError, r"two values are equal at all elements.*extra message"): 523 self.assertNotAllEqual([120], k, msg="extra message") 524 525 @test_util.run_in_graph_and_eager_modes 526 def testAssertNotAllClose(self): 527 # Test with arrays 528 self.assertNotAllClose([0.1], [0.2]) 529 with self.assertRaises(AssertionError): 530 self.assertNotAllClose([-1.0, 2.0], [-1.0, 2.0]) 531 532 # Test with tensors 533 x = constant_op.constant([1.0, 1.0], name="x") 534 y = math_ops.add(x, x) 535 536 self.assertAllClose([2.0, 2.0], y) 537 self.assertNotAllClose([0.9, 1.0], x) 538 539 with self.assertRaises(AssertionError): 540 self.assertNotAllClose([1.0, 1.0], x) 541 542 @test_util.run_in_graph_and_eager_modes 543 def testAssertNotAllCloseRTol(self): 544 # Test with arrays 545 with self.assertRaises(AssertionError): 546 self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], rtol=0.2) 547 548 # Test with tensors 549 x = constant_op.constant([1.0, 1.0], name="x") 550 y = math_ops.add(x, x) 551 552 self.assertAllClose([2.0, 2.0], y) 553 554 with self.assertRaises(AssertionError): 555 self.assertNotAllClose([0.9, 1.0], x, rtol=0.2) 556 557 @test_util.run_in_graph_and_eager_modes 558 def testAssertNotAllCloseATol(self): 559 # Test with arrays 560 with self.assertRaises(AssertionError): 561 self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], atol=0.2) 562 563 # Test with tensors 564 x = constant_op.constant([1.0, 1.0], name="x") 565 y = math_ops.add(x, x) 566 567 self.assertAllClose([2.0, 2.0], y) 568 569 with self.assertRaises(AssertionError): 570 self.assertNotAllClose([0.9, 1.0], x, atol=0.2) 571 572 @test_util.run_in_graph_and_eager_modes 573 def testAssertAllGreaterLess(self): 574 x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32) 575 y = constant_op.constant([10.0] * 3, dtype=dtypes.float32) 576 z = math_ops.add(x, y) 577 578 self.assertAllClose([110.0, 120.0, 130.0], z) 579 580 self.assertAllGreater(x, 95.0) 581 self.assertAllLess(x, 125.0) 582 583 with self.assertRaises(AssertionError): 584 self.assertAllGreater(x, 105.0) 585 with self.assertRaises(AssertionError): 586 self.assertAllGreater(x, 125.0) 587 588 with self.assertRaises(AssertionError): 589 self.assertAllLess(x, 115.0) 590 with self.assertRaises(AssertionError): 591 self.assertAllLess(x, 95.0) 592 593 @test_util.run_in_graph_and_eager_modes 594 def testAssertAllGreaterLessEqual(self): 595 x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32) 596 y = constant_op.constant([10.0] * 3, dtype=dtypes.float32) 597 z = math_ops.add(x, y) 598 599 self.assertAllEqual([110.0, 120.0, 130.0], z) 600 601 self.assertAllGreaterEqual(x, 95.0) 602 self.assertAllLessEqual(x, 125.0) 603 604 with self.assertRaises(AssertionError): 605 self.assertAllGreaterEqual(x, 105.0) 606 with self.assertRaises(AssertionError): 607 self.assertAllGreaterEqual(x, 125.0) 608 609 with self.assertRaises(AssertionError): 610 self.assertAllLessEqual(x, 115.0) 611 with self.assertRaises(AssertionError): 612 self.assertAllLessEqual(x, 95.0) 613 614 def testAssertAllInRangeWithNonNumericValuesFails(self): 615 s1 = constant_op.constant("Hello, ", name="s1") 616 c = constant_op.constant([1 + 2j, -3 + 5j], name="c") 617 b = constant_op.constant([False, True], name="b") 618 619 with self.assertRaises(AssertionError): 620 self.assertAllInRange(s1, 0.0, 1.0) 621 with self.assertRaises(AssertionError): 622 self.assertAllInRange(c, 0.0, 1.0) 623 with self.assertRaises(AssertionError): 624 self.assertAllInRange(b, 0, 1) 625 626 @test_util.run_in_graph_and_eager_modes 627 def testAssertAllInRange(self): 628 x = constant_op.constant([10.0, 15.0], name="x") 629 self.assertAllInRange(x, 10, 15) 630 631 with self.assertRaises(AssertionError): 632 self.assertAllInRange(x, 10, 15, open_lower_bound=True) 633 with self.assertRaises(AssertionError): 634 self.assertAllInRange(x, 10, 15, open_upper_bound=True) 635 with self.assertRaises(AssertionError): 636 self.assertAllInRange( 637 x, 10, 15, open_lower_bound=True, open_upper_bound=True) 638 639 @test_util.run_in_graph_and_eager_modes 640 def testAssertAllInRangeScalar(self): 641 x = constant_op.constant(10.0, name="x") 642 nan = constant_op.constant(np.nan, name="nan") 643 self.assertAllInRange(x, 5, 15) 644 with self.assertRaises(AssertionError): 645 self.assertAllInRange(nan, 5, 15) 646 647 with self.assertRaises(AssertionError): 648 self.assertAllInRange(x, 10, 15, open_lower_bound=True) 649 with self.assertRaises(AssertionError): 650 self.assertAllInRange(x, 1, 2) 651 652 @test_util.run_in_graph_and_eager_modes 653 def testAssertAllInRangeErrorMessageEllipses(self): 654 x_init = np.array([[10.0, 15.0]] * 12) 655 x = constant_op.constant(x_init, name="x") 656 with self.assertRaises(AssertionError): 657 self.assertAllInRange(x, 5, 10) 658 659 @test_util.run_in_graph_and_eager_modes 660 def testAssertAllInRangeDetectsNaNs(self): 661 x = constant_op.constant( 662 [[np.nan, 0.0], [np.nan, np.inf], [np.inf, np.nan]], name="x") 663 with self.assertRaises(AssertionError): 664 self.assertAllInRange(x, 0.0, 2.0) 665 666 @test_util.run_in_graph_and_eager_modes 667 def testAssertAllInRangeWithInfinities(self): 668 x = constant_op.constant([10.0, np.inf], name="x") 669 self.assertAllInRange(x, 10, np.inf) 670 with self.assertRaises(AssertionError): 671 self.assertAllInRange(x, 10, np.inf, open_upper_bound=True) 672 673 @test_util.run_in_graph_and_eager_modes 674 def testAssertAllInSet(self): 675 b = constant_op.constant([True, False], name="b") 676 x = constant_op.constant([13, 37], name="x") 677 678 self.assertAllInSet(b, [False, True]) 679 self.assertAllInSet(b, (False, True)) 680 self.assertAllInSet(b, {False, True}) 681 self.assertAllInSet(x, [0, 13, 37, 42]) 682 self.assertAllInSet(x, (0, 13, 37, 42)) 683 self.assertAllInSet(x, {0, 13, 37, 42}) 684 685 with self.assertRaises(AssertionError): 686 self.assertAllInSet(b, [False]) 687 with self.assertRaises(AssertionError): 688 self.assertAllInSet(x, (42,)) 689 690 def testRandomSeed(self): 691 # Call setUp again for WithCApi case (since it makes a new default graph 692 # after setup). 693 # TODO(skyewm): remove this when C API is permanently enabled. 694 with context.eager_mode(): 695 self.setUp() 696 a = random.randint(1, 1000) 697 a_np_rand = np.random.rand(1) 698 a_rand = random_ops.random_normal([1]) 699 # ensure that randomness in multiple testCases is deterministic. 700 self.setUp() 701 b = random.randint(1, 1000) 702 b_np_rand = np.random.rand(1) 703 b_rand = random_ops.random_normal([1]) 704 self.assertEqual(a, b) 705 self.assertEqual(a_np_rand, b_np_rand) 706 self.assertAllEqual(a_rand, b_rand) 707 708 @test_util.run_in_graph_and_eager_modes 709 def test_callable_evaluate(self): 710 def model(): 711 return resource_variable_ops.ResourceVariable( 712 name="same_name", 713 initial_value=1) + 1 714 with context.eager_mode(): 715 self.assertEqual(2, self.evaluate(model)) 716 717 @test_util.run_in_graph_and_eager_modes 718 def test_nested_tensors_evaluate(self): 719 expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}} 720 nested = {"a": constant_op.constant(1), 721 "b": constant_op.constant(2), 722 "nested": {"d": constant_op.constant(3), 723 "e": constant_op.constant(4)}} 724 725 self.assertEqual(expected, self.evaluate(nested)) 726 727 def test_run_in_graph_and_eager_modes(self): 728 l = [] 729 def inc(self, with_brackets): 730 del self # self argument is required by run_in_graph_and_eager_modes. 731 mode = "eager" if context.executing_eagerly() else "graph" 732 with_brackets = "with_brackets" if with_brackets else "without_brackets" 733 l.append((with_brackets, mode)) 734 735 f = test_util.run_in_graph_and_eager_modes(inc) 736 f(self, with_brackets=False) 737 f = test_util.run_in_graph_and_eager_modes()(inc) # pylint: disable=assignment-from-no-return 738 f(self, with_brackets=True) 739 740 self.assertEqual(len(l), 4) 741 self.assertEqual(set(l), { 742 ("with_brackets", "graph"), 743 ("with_brackets", "eager"), 744 ("without_brackets", "graph"), 745 ("without_brackets", "eager"), 746 }) 747 748 def test_get_node_def_from_graph(self): 749 graph_def = graph_pb2.GraphDef() 750 node_foo = graph_def.node.add() 751 node_foo.name = "foo" 752 self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo) 753 self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def)) 754 755 def test_run_in_eager_and_graph_modes_test_class(self): 756 msg = "`run_in_graph_and_eager_modes` only supports test methods.*" 757 with self.assertRaisesRegex(ValueError, msg): 758 759 @test_util.run_in_graph_and_eager_modes() 760 class Foo(object): 761 pass 762 del Foo # Make pylint unused happy. 763 764 def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self): 765 modes = [] 766 def _test(self): 767 if not context.executing_eagerly(): 768 self.skipTest("Skipping in graph mode") 769 modes.append("eager" if context.executing_eagerly() else "graph") 770 test_util.run_in_graph_and_eager_modes(_test)(self) 771 self.assertEqual(modes, ["eager"]) 772 773 def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self): 774 modes = [] 775 def _test(self): 776 if context.executing_eagerly(): 777 self.skipTest("Skipping in eager mode") 778 modes.append("eager" if context.executing_eagerly() else "graph") 779 test_util.run_in_graph_and_eager_modes(_test)(self) 780 self.assertEqual(modes, ["graph"]) 781 782 def test_run_in_graph_and_eager_modes_setup_in_same_mode(self): 783 modes = [] 784 mode_name = lambda: "eager" if context.executing_eagerly() else "graph" 785 786 class ExampleTest(test_util.TensorFlowTestCase): 787 788 def runTest(self): 789 pass 790 791 def setUp(self): 792 modes.append("setup_" + mode_name()) 793 794 @test_util.run_in_graph_and_eager_modes 795 def testBody(self): 796 modes.append("run_" + mode_name()) 797 798 e = ExampleTest() 799 e.setUp() 800 e.testBody() 801 802 self.assertEqual(modes[1:2], ["run_graph"]) 803 self.assertEqual(modes[2:], ["setup_eager", "run_eager"]) 804 805 @parameterized.named_parameters(dict(testcase_name="argument", 806 arg=True)) 807 @test_util.run_in_graph_and_eager_modes 808 def test_run_in_graph_and_eager_works_with_parameterized_keyword(self, arg): 809 self.assertEqual(arg, True) 810 811 @combinations.generate(combinations.combine(arg=True)) 812 @test_util.run_in_graph_and_eager_modes 813 def test_run_in_graph_and_eager_works_with_combinations(self, arg): 814 self.assertEqual(arg, True) 815 816 def test_build_as_function_and_v1_graph(self): 817 818 class GraphModeAndFunctionTest(parameterized.TestCase): 819 820 def __init__(inner_self): # pylint: disable=no-self-argument 821 super(GraphModeAndFunctionTest, inner_self).__init__() 822 inner_self.graph_mode_tested = False 823 inner_self.inside_function_tested = False 824 825 def runTest(self): 826 del self 827 828 @test_util.build_as_function_and_v1_graph 829 def test_modes(inner_self): # pylint: disable=no-self-argument 830 if ops.inside_function(): 831 self.assertFalse(inner_self.inside_function_tested) 832 inner_self.inside_function_tested = True 833 else: 834 self.assertFalse(inner_self.graph_mode_tested) 835 inner_self.graph_mode_tested = True 836 837 test_object = GraphModeAndFunctionTest() 838 test_object.test_modes_v1_graph() 839 test_object.test_modes_function() 840 self.assertTrue(test_object.graph_mode_tested) 841 self.assertTrue(test_object.inside_function_tested) 842 843 @test_util.run_in_graph_and_eager_modes 844 def test_consistent_random_seed_in_assert_all_equal(self): 845 random_seed.set_seed(1066) 846 index = random_ops.random_shuffle([0, 1, 2, 3, 4], seed=2021) 847 # This failed when `a` and `b` were evaluated in separate sessions. 848 self.assertAllEqual(index, index) 849 850 def test_with_forward_compatibility_horizons(self): 851 852 tested_codepaths = set() 853 def some_function_with_forward_compat_behavior(): 854 if compat.forward_compatible(2050, 1, 1): 855 tested_codepaths.add("future") 856 else: 857 tested_codepaths.add("present") 858 859 @test_util.with_forward_compatibility_horizons(None, [2051, 1, 1]) 860 def some_test(self): 861 del self # unused 862 some_function_with_forward_compat_behavior() 863 864 some_test(None) 865 self.assertEqual(tested_codepaths, set(["present", "future"])) 866 867 868class SkipTestTest(test_util.TensorFlowTestCase): 869 870 def _verify_test_in_set_up_or_tear_down(self): 871 with self.assertRaises(unittest.SkipTest): 872 with test_util.skip_if_error(self, ValueError, 873 ["foo bar", "test message"]): 874 raise ValueError("test message") 875 try: 876 with self.assertRaisesRegex(ValueError, "foo bar"): 877 with test_util.skip_if_error(self, ValueError, "test message"): 878 raise ValueError("foo bar") 879 except unittest.SkipTest: 880 raise RuntimeError("Test is not supposed to skip.") 881 882 def setUp(self): 883 super(SkipTestTest, self).setUp() 884 self._verify_test_in_set_up_or_tear_down() 885 886 def tearDown(self): 887 super(SkipTestTest, self).tearDown() 888 self._verify_test_in_set_up_or_tear_down() 889 890 def test_skip_if_error_should_skip(self): 891 with self.assertRaises(unittest.SkipTest): 892 with test_util.skip_if_error(self, ValueError, "test message"): 893 raise ValueError("test message") 894 895 def test_skip_if_error_should_skip_with_list(self): 896 with self.assertRaises(unittest.SkipTest): 897 with test_util.skip_if_error(self, ValueError, 898 ["foo bar", "test message"]): 899 raise ValueError("test message") 900 901 def test_skip_if_error_should_skip_without_expected_message(self): 902 with self.assertRaises(unittest.SkipTest): 903 with test_util.skip_if_error(self, ValueError): 904 raise ValueError("test message") 905 906 def test_skip_if_error_should_skip_without_error_message(self): 907 with self.assertRaises(unittest.SkipTest): 908 with test_util.skip_if_error(self, ValueError): 909 raise ValueError() 910 911 def test_skip_if_error_should_raise_message_mismatch(self): 912 try: 913 with self.assertRaisesRegex(ValueError, "foo bar"): 914 with test_util.skip_if_error(self, ValueError, "test message"): 915 raise ValueError("foo bar") 916 except unittest.SkipTest: 917 raise RuntimeError("Test is not supposed to skip.") 918 919 def test_skip_if_error_should_raise_no_message(self): 920 try: 921 with self.assertRaisesRegex(ValueError, ""): 922 with test_util.skip_if_error(self, ValueError, "test message"): 923 raise ValueError() 924 except unittest.SkipTest: 925 raise RuntimeError("Test is not supposed to skip.") 926 927 928# Its own test case to reproduce variable sharing issues which only pop up when 929# setUp() is overridden and super() is not called. 930class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase): 931 932 def setUp(self): 933 pass # Intentionally does not call TensorFlowTestCase's super() 934 935 @test_util.run_in_graph_and_eager_modes 936 def test_no_variable_sharing(self): 937 variable_scope.get_variable( 938 name="step_size", 939 initializer=np.array(1e-5, np.float32), 940 use_resource=True, 941 trainable=False) 942 943 944class GarbageCollectionTest(test_util.TensorFlowTestCase): 945 946 def test_no_reference_cycle_decorator(self): 947 948 class ReferenceCycleTest(object): 949 950 def __init__(inner_self): # pylint: disable=no-self-argument 951 inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name 952 953 @test_util.assert_no_garbage_created 954 def test_has_cycle(self): 955 a = [] 956 a.append(a) 957 958 @test_util.assert_no_garbage_created 959 def test_has_no_cycle(self): 960 pass 961 962 with self.assertRaises(AssertionError): 963 ReferenceCycleTest().test_has_cycle() 964 965 ReferenceCycleTest().test_has_no_cycle() 966 967 @test_util.run_in_graph_and_eager_modes 968 def test_no_leaked_tensor_decorator(self): 969 970 class LeakedTensorTest(object): 971 972 def __init__(inner_self): # pylint: disable=no-self-argument 973 inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name 974 975 @test_util.assert_no_new_tensors 976 def test_has_leak(self): 977 self.a = constant_op.constant([3.], name="leak") 978 979 @test_util.assert_no_new_tensors 980 def test_has_no_leak(self): 981 constant_op.constant([3.], name="no-leak") 982 983 with self.assertRaisesRegex(AssertionError, "Tensors not deallocated"): 984 LeakedTensorTest().test_has_leak() 985 986 LeakedTensorTest().test_has_no_leak() 987 988 def test_no_new_objects_decorator(self): 989 990 class LeakedObjectTest(unittest.TestCase): 991 992 def __init__(self, *args, **kwargs): 993 super(LeakedObjectTest, self).__init__(*args, **kwargs) 994 self.accumulation = [] 995 996 @unittest.expectedFailure 997 @test_util.assert_no_new_pyobjects_executing_eagerly 998 def test_has_leak(self): 999 self.accumulation.append([1.]) 1000 1001 @test_util.assert_no_new_pyobjects_executing_eagerly 1002 def test_has_no_leak(self): 1003 self.not_accumulating = [1.] 1004 1005 self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful()) 1006 self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful()) 1007 1008 1009class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase, 1010 parameterized.TestCase): 1011 @parameterized.named_parameters( 1012 [("_RunEagerly", True), ("_RunGraph", False)]) 1013 def test_run_functions_eagerly(self, run_eagerly): # pylint: disable=g-wrong-blank-lines 1014 results = [] 1015 1016 @def_function.function 1017 def add_two(x): 1018 for _ in range(5): 1019 x += 2 1020 results.append(x) 1021 return x 1022 1023 with test_util.run_functions_eagerly(run_eagerly): 1024 add_two(constant_op.constant(2.)) 1025 if context.executing_eagerly(): 1026 if run_eagerly: 1027 self.assertTrue(isinstance(t, ops.EagerTensor) for t in results) 1028 else: 1029 self.assertTrue(isinstance(t, ops.Tensor) for t in results) 1030 else: 1031 self.assertTrue(isinstance(t, ops.Tensor) for t in results) 1032 1033 1034if __name__ == "__main__": 1035 googletest.main() 1036