1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.test_util.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import random 24import threading 25import unittest 26import weakref 27 28from absl.testing import parameterized 29import numpy as np 30 31from google.protobuf import text_format 32 33from tensorflow.core.framework import graph_pb2 34from tensorflow.core.protobuf import meta_graph_pb2 35from tensorflow.python.compat import compat 36from tensorflow.python.eager import context 37from tensorflow.python.eager import def_function 38from tensorflow.python.framework import combinations 39from tensorflow.python.framework import constant_op 40from tensorflow.python.framework import dtypes 41from tensorflow.python.framework import errors 42from tensorflow.python.framework import ops 43from tensorflow.python.framework import test_ops # pylint: disable=unused-import 44from tensorflow.python.framework import test_util 45from tensorflow.python.ops import control_flow_ops 46from tensorflow.python.ops import lookup_ops 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import random_ops 49from tensorflow.python.ops import resource_variable_ops 50from tensorflow.python.ops import variable_scope 51from tensorflow.python.ops import variables 52from tensorflow.python.platform import googletest 53 54 55class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): 56 57 def test_assert_ops_in_graph(self): 58 with ops.Graph().as_default(): 59 constant_op.constant(["hello", "taffy"], name="hello") 60 test_util.assert_ops_in_graph({"hello": "Const"}, ops.get_default_graph()) 61 62 self.assertRaises(ValueError, test_util.assert_ops_in_graph, 63 {"bye": "Const"}, ops.get_default_graph()) 64 65 self.assertRaises(ValueError, test_util.assert_ops_in_graph, 66 {"hello": "Variable"}, ops.get_default_graph()) 67 68 @test_util.run_deprecated_v1 69 def test_session_functions(self): 70 with self.test_session() as sess: 71 sess_ref = weakref.ref(sess) 72 with self.cached_session(graph=None, config=None) as sess2: 73 # We make sure that sess2 is sess. 74 assert sess2 is sess 75 # We make sure we raise an exception if we use cached_session with 76 # different values. 77 with self.assertRaises(ValueError): 78 with self.cached_session(graph=ops.Graph()) as sess2: 79 pass 80 with self.assertRaises(ValueError): 81 with self.cached_session(force_gpu=True) as sess2: 82 pass 83 # We make sure that test_session will cache the session even after the 84 # with scope. 85 assert not sess_ref()._closed 86 with self.session() as unique_sess: 87 unique_sess_ref = weakref.ref(unique_sess) 88 with self.session() as sess2: 89 assert sess2 is not unique_sess 90 # We make sure the session is closed when we leave the with statement. 91 assert unique_sess_ref()._closed 92 93 def test_assert_equal_graph_def(self): 94 with ops.Graph().as_default() as g: 95 def_empty = g.as_graph_def() 96 constant_op.constant(5, name="five") 97 constant_op.constant(7, name="seven") 98 def_57 = g.as_graph_def() 99 with ops.Graph().as_default() as g: 100 constant_op.constant(7, name="seven") 101 constant_op.constant(5, name="five") 102 def_75 = g.as_graph_def() 103 # Comparing strings is order dependent 104 self.assertNotEqual(str(def_57), str(def_75)) 105 # assert_equal_graph_def doesn't care about order 106 test_util.assert_equal_graph_def(def_57, def_75) 107 # Compare two unequal graphs 108 with self.assertRaisesRegex(AssertionError, 109 r"^Found unexpected node '{{node seven}}"): 110 test_util.assert_equal_graph_def(def_57, def_empty) 111 112 def test_assert_equal_graph_def_hash_table(self): 113 def get_graph_def(): 114 with ops.Graph().as_default() as g: 115 x = constant_op.constant([2, 9], name="x") 116 keys = constant_op.constant([1, 2], name="keys") 117 values = constant_op.constant([3, 4], name="values") 118 default = constant_op.constant(-1, name="default") 119 table = lookup_ops.StaticHashTable( 120 lookup_ops.KeyValueTensorInitializer(keys, values), default) 121 _ = table.lookup(x) 122 return g.as_graph_def() 123 def_1 = get_graph_def() 124 def_2 = get_graph_def() 125 # The unique shared_name of each table makes the graph unequal. 126 with self.assertRaisesRegex(AssertionError, "hash_table_"): 127 test_util.assert_equal_graph_def(def_1, def_2, 128 hash_table_shared_name=False) 129 # That can be ignored. (NOTE: modifies GraphDefs in-place.) 130 test_util.assert_equal_graph_def(def_1, def_2, 131 hash_table_shared_name=True) 132 133 def testIsGoogleCudaEnabled(self): 134 # The test doesn't assert anything. It ensures the py wrapper 135 # function is generated correctly. 136 if test_util.IsGoogleCudaEnabled(): 137 print("GoogleCuda is enabled") 138 else: 139 print("GoogleCuda is disabled") 140 141 def testIsMklEnabled(self): 142 # This test doesn't assert anything. 143 # It ensures the py wrapper function is generated correctly. 144 if test_util.IsMklEnabled(): 145 print("MKL is enabled") 146 else: 147 print("MKL is disabled") 148 149 @test_util.run_in_graph_and_eager_modes 150 def testAssertProtoEqualsStr(self): 151 152 graph_str = "node { name: 'w1' op: 'params' }" 153 graph_def = graph_pb2.GraphDef() 154 text_format.Merge(graph_str, graph_def) 155 156 # test string based comparison 157 self.assertProtoEquals(graph_str, graph_def) 158 159 # test original comparison 160 self.assertProtoEquals(graph_def, graph_def) 161 162 @test_util.run_in_graph_and_eager_modes 163 def testAssertProtoEqualsAny(self): 164 # Test assertProtoEquals with a protobuf.Any field. 165 meta_graph_def_str = """ 166 meta_info_def { 167 meta_graph_version: "outer" 168 any_info { 169 [type.googleapis.com/tensorflow.MetaGraphDef] { 170 meta_info_def { 171 meta_graph_version: "inner" 172 } 173 } 174 } 175 } 176 """ 177 meta_graph_def_outer = meta_graph_pb2.MetaGraphDef() 178 meta_graph_def_outer.meta_info_def.meta_graph_version = "outer" 179 meta_graph_def_inner = meta_graph_pb2.MetaGraphDef() 180 meta_graph_def_inner.meta_info_def.meta_graph_version = "inner" 181 meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner) 182 self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer) 183 self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer) 184 185 # Check if the assertion failure message contains the content of 186 # the inner proto. 187 with self.assertRaisesRegex(AssertionError, r'meta_graph_version: "inner"'): 188 self.assertProtoEquals("", meta_graph_def_outer) 189 190 @test_util.run_in_graph_and_eager_modes 191 def testNDArrayNear(self): 192 a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 193 a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 194 a3 = np.array([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]) 195 self.assertTrue(self._NDArrayNear(a1, a2, 1e-5)) 196 self.assertFalse(self._NDArrayNear(a1, a3, 1e-5)) 197 198 @test_util.run_in_graph_and_eager_modes 199 def testCheckedThreadSucceeds(self): 200 201 def noop(ev): 202 ev.set() 203 204 event_arg = threading.Event() 205 206 self.assertFalse(event_arg.is_set()) 207 t = self.checkedThread(target=noop, args=(event_arg,)) 208 t.start() 209 t.join() 210 self.assertTrue(event_arg.is_set()) 211 212 @test_util.run_in_graph_and_eager_modes 213 def testCheckedThreadFails(self): 214 215 def err_func(): 216 return 1 // 0 217 218 t = self.checkedThread(target=err_func) 219 t.start() 220 with self.assertRaises(self.failureException) as fe: 221 t.join() 222 self.assertTrue("integer division or modulo by zero" in str(fe.exception)) 223 224 @test_util.run_in_graph_and_eager_modes 225 def testCheckedThreadWithWrongAssertionFails(self): 226 x = 37 227 228 def err_func(): 229 self.assertTrue(x < 10) 230 231 t = self.checkedThread(target=err_func) 232 t.start() 233 with self.assertRaises(self.failureException) as fe: 234 t.join() 235 self.assertTrue("False is not true" in str(fe.exception)) 236 237 @test_util.run_in_graph_and_eager_modes 238 def testMultipleThreadsWithOneFailure(self): 239 240 def err_func(i): 241 self.assertTrue(i != 7) 242 243 threads = [ 244 self.checkedThread( 245 target=err_func, args=(i,)) for i in range(10) 246 ] 247 for t in threads: 248 t.start() 249 for i, t in enumerate(threads): 250 if i == 7: 251 with self.assertRaises(self.failureException): 252 t.join() 253 else: 254 t.join() 255 256 def _WeMustGoDeeper(self, msg): 257 with self.assertRaisesOpError(msg): 258 with ops.Graph().as_default(): 259 node_def = ops._NodeDef("IntOutput", "name") 260 node_def_orig = ops._NodeDef("IntOutput", "orig") 261 op_orig = ops.Operation(node_def_orig, ops.get_default_graph()) 262 op = ops.Operation(node_def, ops.get_default_graph(), 263 original_op=op_orig) 264 raise errors.UnauthenticatedError(node_def, op, "true_err") 265 266 @test_util.run_in_graph_and_eager_modes 267 def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self): 268 with self.assertRaises(AssertionError): 269 self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for") 270 271 self._WeMustGoDeeper("true_err") 272 self._WeMustGoDeeper("name") 273 self._WeMustGoDeeper("orig") 274 275 @test_util.run_in_graph_and_eager_modes 276 def testAllCloseTensors(self): 277 a_raw_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 278 a = constant_op.constant(a_raw_data) 279 b = math_ops.add(1, constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) 280 self.assertAllClose(a, b) 281 self.assertAllClose(a, a_raw_data) 282 283 a_dict = {"key": a} 284 b_dict = {"key": b} 285 self.assertAllClose(a_dict, b_dict) 286 287 x_list = [a, b] 288 y_list = [a_raw_data, b] 289 self.assertAllClose(x_list, y_list) 290 291 @test_util.run_in_graph_and_eager_modes 292 def testAllCloseScalars(self): 293 self.assertAllClose(7, 7 + 1e-8) 294 with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"): 295 self.assertAllClose(7, 7 + 1e-5) 296 297 @test_util.run_in_graph_and_eager_modes 298 def testAllCloseList(self): 299 with self.assertRaisesRegex(AssertionError, r"not close dif"): 300 self.assertAllClose([0], [1]) 301 302 @test_util.run_in_graph_and_eager_modes 303 def testAllCloseDictToNonDict(self): 304 with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"): 305 self.assertAllClose(1, {"a": 1}) 306 with self.assertRaisesRegex(ValueError, r"Can't compare dict to non-dict"): 307 self.assertAllClose({"a": 1}, 1) 308 309 @test_util.run_in_graph_and_eager_modes 310 def testAllCloseNamedtuples(self): 311 a = 7 312 b = (2., 3.) 313 c = np.ones((3, 2, 4)) * 7. 314 expected = {"a": a, "b": b, "c": c} 315 my_named_tuple = collections.namedtuple("MyNamedTuple", ["a", "b", "c"]) 316 317 # Identity. 318 self.assertAllClose(expected, my_named_tuple(a=a, b=b, c=c)) 319 self.assertAllClose( 320 my_named_tuple(a=a, b=b, c=c), my_named_tuple(a=a, b=b, c=c)) 321 322 @test_util.run_in_graph_and_eager_modes 323 def testAllCloseDicts(self): 324 a = 7 325 b = (2., 3.) 326 c = np.ones((3, 2, 4)) * 7. 327 expected = {"a": a, "b": b, "c": c} 328 329 # Identity. 330 self.assertAllClose(expected, expected) 331 self.assertAllClose(expected, dict(expected)) 332 333 # With each item removed. 334 for k in expected: 335 actual = dict(expected) 336 del actual[k] 337 with self.assertRaisesRegex(AssertionError, r"mismatched keys"): 338 self.assertAllClose(expected, actual) 339 340 # With each item changed. 341 with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"): 342 self.assertAllClose(expected, {"a": a + 1e-5, "b": b, "c": c}) 343 with self.assertRaisesRegex(AssertionError, r"Shape mismatch"): 344 self.assertAllClose(expected, {"a": a, "b": b + (4.,), "c": c}) 345 c_copy = np.array(c) 346 c_copy[1, 1, 1] += 1e-5 347 with self.assertRaisesRegex(AssertionError, r"Not equal to tolerance"): 348 self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy}) 349 350 @test_util.run_in_graph_and_eager_modes 351 def testAllCloseListOfNamedtuples(self): 352 my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"]) 353 l1 = [ 354 my_named_tuple(x=np.array([[2.3, 2.5]]), y=np.array([[0.97, 0.96]])), 355 my_named_tuple(x=np.array([[3.3, 3.5]]), y=np.array([[0.98, 0.99]])) 356 ] 357 l2 = [ 358 ([[2.3, 2.5]], [[0.97, 0.96]]), 359 ([[3.3, 3.5]], [[0.98, 0.99]]), 360 ] 361 self.assertAllClose(l1, l2) 362 363 @test_util.run_in_graph_and_eager_modes 364 def testAllCloseNestedStructure(self): 365 a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])} 366 self.assertAllClose(a, a) 367 368 b = copy.deepcopy(a) 369 self.assertAllClose(a, b) 370 371 # Test mismatched values 372 b["y"][1][0]["nested"]["n"] = 4.2 373 with self.assertRaisesRegex(AssertionError, 374 r"\[y\]\[1\]\[0\]\[nested\]\[n\]"): 375 self.assertAllClose(a, b) 376 377 @test_util.run_in_graph_and_eager_modes 378 def testArrayNear(self): 379 a = [1, 2] 380 b = [1, 2, 5] 381 with self.assertRaises(AssertionError): 382 self.assertArrayNear(a, b, 0.001) 383 a = [1, 2] 384 b = [[1, 2], [3, 4]] 385 with self.assertRaises(TypeError): 386 self.assertArrayNear(a, b, 0.001) 387 a = [1, 2] 388 b = [1, 2] 389 self.assertArrayNear(a, b, 0.001) 390 391 @test_util.skip_if(True) # b/117665998 392 def testForceGPU(self): 393 with self.assertRaises(errors.InvalidArgumentError): 394 with self.test_session(force_gpu=True): 395 # this relies on us not having a GPU implementation for assert, which 396 # seems sensible 397 x = constant_op.constant(True) 398 y = [15] 399 control_flow_ops.Assert(x, y).run() 400 401 @test_util.run_in_graph_and_eager_modes 402 def testAssertAllCloseAccordingToType(self): 403 # test plain int 404 self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8) 405 406 # test float64 407 self.assertAllCloseAccordingToType( 408 np.asarray([1e-8], dtype=np.float64), 409 np.asarray([2e-8], dtype=np.float64), 410 rtol=1e-8, atol=1e-8 411 ) 412 413 self.assertAllCloseAccordingToType( 414 constant_op.constant([1e-8], dtype=dtypes.float64), 415 constant_op.constant([2e-8], dtype=dtypes.float64), 416 rtol=1e-8, 417 atol=1e-8) 418 419 with (self.assertRaises(AssertionError)): 420 self.assertAllCloseAccordingToType( 421 np.asarray([1e-7], dtype=np.float64), 422 np.asarray([2e-7], dtype=np.float64), 423 rtol=1e-8, atol=1e-8 424 ) 425 426 # test float32 427 self.assertAllCloseAccordingToType( 428 np.asarray([1e-7], dtype=np.float32), 429 np.asarray([2e-7], dtype=np.float32), 430 rtol=1e-8, atol=1e-8, 431 float_rtol=1e-7, float_atol=1e-7 432 ) 433 434 self.assertAllCloseAccordingToType( 435 constant_op.constant([1e-7], dtype=dtypes.float32), 436 constant_op.constant([2e-7], dtype=dtypes.float32), 437 rtol=1e-8, 438 atol=1e-8, 439 float_rtol=1e-7, 440 float_atol=1e-7) 441 442 with (self.assertRaises(AssertionError)): 443 self.assertAllCloseAccordingToType( 444 np.asarray([1e-6], dtype=np.float32), 445 np.asarray([2e-6], dtype=np.float32), 446 rtol=1e-8, atol=1e-8, 447 float_rtol=1e-7, float_atol=1e-7 448 ) 449 450 # test float16 451 self.assertAllCloseAccordingToType( 452 np.asarray([1e-4], dtype=np.float16), 453 np.asarray([2e-4], dtype=np.float16), 454 rtol=1e-8, atol=1e-8, 455 float_rtol=1e-7, float_atol=1e-7, 456 half_rtol=1e-4, half_atol=1e-4 457 ) 458 459 self.assertAllCloseAccordingToType( 460 constant_op.constant([1e-4], dtype=dtypes.float16), 461 constant_op.constant([2e-4], dtype=dtypes.float16), 462 rtol=1e-8, 463 atol=1e-8, 464 float_rtol=1e-7, 465 float_atol=1e-7, 466 half_rtol=1e-4, 467 half_atol=1e-4) 468 469 with (self.assertRaises(AssertionError)): 470 self.assertAllCloseAccordingToType( 471 np.asarray([1e-3], dtype=np.float16), 472 np.asarray([2e-3], dtype=np.float16), 473 rtol=1e-8, atol=1e-8, 474 float_rtol=1e-7, float_atol=1e-7, 475 half_rtol=1e-4, half_atol=1e-4 476 ) 477 478 @test_util.run_in_graph_and_eager_modes 479 def testAssertAllEqual(self): 480 i = variables.Variable([100] * 3, dtype=dtypes.int32, name="i") 481 j = constant_op.constant([20] * 3, dtype=dtypes.int32, name="j") 482 k = math_ops.add(i, j, name="k") 483 484 self.evaluate(variables.global_variables_initializer()) 485 self.assertAllEqual([100] * 3, i) 486 self.assertAllEqual([120] * 3, k) 487 self.assertAllEqual([20] * 3, j) 488 489 with self.assertRaisesRegex(AssertionError, r"not equal lhs"): 490 self.assertAllEqual([0] * 3, k) 491 492 @test_util.run_in_graph_and_eager_modes 493 def testAssertNotAllEqual(self): 494 i = variables.Variable([100], dtype=dtypes.int32, name="i") 495 j = constant_op.constant([20], dtype=dtypes.int32, name="j") 496 k = math_ops.add(i, j, name="k") 497 498 self.evaluate(variables.global_variables_initializer()) 499 self.assertNotAllEqual([100] * 3, i) 500 self.assertNotAllEqual([120] * 3, k) 501 self.assertNotAllEqual([20] * 3, j) 502 503 with self.assertRaisesRegex( 504 AssertionError, r"two values are equal at all elements.*extra message"): 505 self.assertNotAllEqual([120], k, msg="extra message") 506 507 @test_util.run_in_graph_and_eager_modes 508 def testAssertNotAllClose(self): 509 # Test with arrays 510 self.assertNotAllClose([0.1], [0.2]) 511 with self.assertRaises(AssertionError): 512 self.assertNotAllClose([-1.0, 2.0], [-1.0, 2.0]) 513 514 # Test with tensors 515 x = constant_op.constant([1.0, 1.0], name="x") 516 y = math_ops.add(x, x) 517 518 self.assertAllClose([2.0, 2.0], y) 519 self.assertNotAllClose([0.9, 1.0], x) 520 521 with self.assertRaises(AssertionError): 522 self.assertNotAllClose([1.0, 1.0], x) 523 524 @test_util.run_in_graph_and_eager_modes 525 def testAssertNotAllCloseRTol(self): 526 # Test with arrays 527 with self.assertRaises(AssertionError): 528 self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], rtol=0.2) 529 530 # Test with tensors 531 x = constant_op.constant([1.0, 1.0], name="x") 532 y = math_ops.add(x, x) 533 534 self.assertAllClose([2.0, 2.0], y) 535 536 with self.assertRaises(AssertionError): 537 self.assertNotAllClose([0.9, 1.0], x, rtol=0.2) 538 539 @test_util.run_in_graph_and_eager_modes 540 def testAssertNotAllCloseATol(self): 541 # Test with arrays 542 with self.assertRaises(AssertionError): 543 self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], atol=0.2) 544 545 # Test with tensors 546 x = constant_op.constant([1.0, 1.0], name="x") 547 y = math_ops.add(x, x) 548 549 self.assertAllClose([2.0, 2.0], y) 550 551 with self.assertRaises(AssertionError): 552 self.assertNotAllClose([0.9, 1.0], x, atol=0.2) 553 554 @test_util.run_in_graph_and_eager_modes 555 def testAssertAllGreaterLess(self): 556 x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32) 557 y = constant_op.constant([10.0] * 3, dtype=dtypes.float32) 558 z = math_ops.add(x, y) 559 560 self.assertAllClose([110.0, 120.0, 130.0], z) 561 562 self.assertAllGreater(x, 95.0) 563 self.assertAllLess(x, 125.0) 564 565 with self.assertRaises(AssertionError): 566 self.assertAllGreater(x, 105.0) 567 with self.assertRaises(AssertionError): 568 self.assertAllGreater(x, 125.0) 569 570 with self.assertRaises(AssertionError): 571 self.assertAllLess(x, 115.0) 572 with self.assertRaises(AssertionError): 573 self.assertAllLess(x, 95.0) 574 575 @test_util.run_in_graph_and_eager_modes 576 def testAssertAllGreaterLessEqual(self): 577 x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32) 578 y = constant_op.constant([10.0] * 3, dtype=dtypes.float32) 579 z = math_ops.add(x, y) 580 581 self.assertAllEqual([110.0, 120.0, 130.0], z) 582 583 self.assertAllGreaterEqual(x, 95.0) 584 self.assertAllLessEqual(x, 125.0) 585 586 with self.assertRaises(AssertionError): 587 self.assertAllGreaterEqual(x, 105.0) 588 with self.assertRaises(AssertionError): 589 self.assertAllGreaterEqual(x, 125.0) 590 591 with self.assertRaises(AssertionError): 592 self.assertAllLessEqual(x, 115.0) 593 with self.assertRaises(AssertionError): 594 self.assertAllLessEqual(x, 95.0) 595 596 def testAssertAllInRangeWithNonNumericValuesFails(self): 597 s1 = constant_op.constant("Hello, ", name="s1") 598 c = constant_op.constant([1 + 2j, -3 + 5j], name="c") 599 b = constant_op.constant([False, True], name="b") 600 601 with self.assertRaises(AssertionError): 602 self.assertAllInRange(s1, 0.0, 1.0) 603 with self.assertRaises(AssertionError): 604 self.assertAllInRange(c, 0.0, 1.0) 605 with self.assertRaises(AssertionError): 606 self.assertAllInRange(b, 0, 1) 607 608 @test_util.run_in_graph_and_eager_modes 609 def testAssertAllInRange(self): 610 x = constant_op.constant([10.0, 15.0], name="x") 611 self.assertAllInRange(x, 10, 15) 612 613 with self.assertRaises(AssertionError): 614 self.assertAllInRange(x, 10, 15, open_lower_bound=True) 615 with self.assertRaises(AssertionError): 616 self.assertAllInRange(x, 10, 15, open_upper_bound=True) 617 with self.assertRaises(AssertionError): 618 self.assertAllInRange( 619 x, 10, 15, open_lower_bound=True, open_upper_bound=True) 620 621 @test_util.run_in_graph_and_eager_modes 622 def testAssertAllInRangeScalar(self): 623 x = constant_op.constant(10.0, name="x") 624 nan = constant_op.constant(np.nan, name="nan") 625 self.assertAllInRange(x, 5, 15) 626 with self.assertRaises(AssertionError): 627 self.assertAllInRange(nan, 5, 15) 628 629 with self.assertRaises(AssertionError): 630 self.assertAllInRange(x, 10, 15, open_lower_bound=True) 631 with self.assertRaises(AssertionError): 632 self.assertAllInRange(x, 1, 2) 633 634 @test_util.run_in_graph_and_eager_modes 635 def testAssertAllInRangeErrorMessageEllipses(self): 636 x_init = np.array([[10.0, 15.0]] * 12) 637 x = constant_op.constant(x_init, name="x") 638 with self.assertRaises(AssertionError): 639 self.assertAllInRange(x, 5, 10) 640 641 @test_util.run_in_graph_and_eager_modes 642 def testAssertAllInRangeDetectsNaNs(self): 643 x = constant_op.constant( 644 [[np.nan, 0.0], [np.nan, np.inf], [np.inf, np.nan]], name="x") 645 with self.assertRaises(AssertionError): 646 self.assertAllInRange(x, 0.0, 2.0) 647 648 @test_util.run_in_graph_and_eager_modes 649 def testAssertAllInRangeWithInfinities(self): 650 x = constant_op.constant([10.0, np.inf], name="x") 651 self.assertAllInRange(x, 10, np.inf) 652 with self.assertRaises(AssertionError): 653 self.assertAllInRange(x, 10, np.inf, open_upper_bound=True) 654 655 @test_util.run_in_graph_and_eager_modes 656 def testAssertAllInSet(self): 657 b = constant_op.constant([True, False], name="b") 658 x = constant_op.constant([13, 37], name="x") 659 660 self.assertAllInSet(b, [False, True]) 661 self.assertAllInSet(b, (False, True)) 662 self.assertAllInSet(b, {False, True}) 663 self.assertAllInSet(x, [0, 13, 37, 42]) 664 self.assertAllInSet(x, (0, 13, 37, 42)) 665 self.assertAllInSet(x, {0, 13, 37, 42}) 666 667 with self.assertRaises(AssertionError): 668 self.assertAllInSet(b, [False]) 669 with self.assertRaises(AssertionError): 670 self.assertAllInSet(x, (42,)) 671 672 def testRandomSeed(self): 673 # Call setUp again for WithCApi case (since it makes a new default graph 674 # after setup). 675 # TODO(skyewm): remove this when C API is permanently enabled. 676 with context.eager_mode(): 677 self.setUp() 678 a = random.randint(1, 1000) 679 a_np_rand = np.random.rand(1) 680 a_rand = random_ops.random_normal([1]) 681 # ensure that randomness in multiple testCases is deterministic. 682 self.setUp() 683 b = random.randint(1, 1000) 684 b_np_rand = np.random.rand(1) 685 b_rand = random_ops.random_normal([1]) 686 self.assertEqual(a, b) 687 self.assertEqual(a_np_rand, b_np_rand) 688 self.assertAllEqual(a_rand, b_rand) 689 690 @test_util.run_in_graph_and_eager_modes 691 def test_callable_evaluate(self): 692 def model(): 693 return resource_variable_ops.ResourceVariable( 694 name="same_name", 695 initial_value=1) + 1 696 with context.eager_mode(): 697 self.assertEqual(2, self.evaluate(model)) 698 699 @test_util.run_in_graph_and_eager_modes 700 def test_nested_tensors_evaluate(self): 701 expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}} 702 nested = {"a": constant_op.constant(1), 703 "b": constant_op.constant(2), 704 "nested": {"d": constant_op.constant(3), 705 "e": constant_op.constant(4)}} 706 707 self.assertEqual(expected, self.evaluate(nested)) 708 709 def test_run_in_graph_and_eager_modes(self): 710 l = [] 711 def inc(self, with_brackets): 712 del self # self argument is required by run_in_graph_and_eager_modes. 713 mode = "eager" if context.executing_eagerly() else "graph" 714 with_brackets = "with_brackets" if with_brackets else "without_brackets" 715 l.append((with_brackets, mode)) 716 717 f = test_util.run_in_graph_and_eager_modes(inc) 718 f(self, with_brackets=False) 719 f = test_util.run_in_graph_and_eager_modes()(inc) 720 f(self, with_brackets=True) 721 722 self.assertEqual(len(l), 4) 723 self.assertEqual(set(l), { 724 ("with_brackets", "graph"), 725 ("with_brackets", "eager"), 726 ("without_brackets", "graph"), 727 ("without_brackets", "eager"), 728 }) 729 730 def test_get_node_def_from_graph(self): 731 graph_def = graph_pb2.GraphDef() 732 node_foo = graph_def.node.add() 733 node_foo.name = "foo" 734 self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo) 735 self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def)) 736 737 def test_run_in_eager_and_graph_modes_test_class(self): 738 msg = "`run_in_graph_and_eager_modes` only supports test methods.*" 739 with self.assertRaisesRegex(ValueError, msg): 740 741 @test_util.run_in_graph_and_eager_modes() 742 class Foo(object): 743 pass 744 del Foo # Make pylint unused happy. 745 746 def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self): 747 modes = [] 748 def _test(self): 749 if not context.executing_eagerly(): 750 self.skipTest("Skipping in graph mode") 751 modes.append("eager" if context.executing_eagerly() else "graph") 752 test_util.run_in_graph_and_eager_modes(_test)(self) 753 self.assertEqual(modes, ["eager"]) 754 755 def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self): 756 modes = [] 757 def _test(self): 758 if context.executing_eagerly(): 759 self.skipTest("Skipping in eager mode") 760 modes.append("eager" if context.executing_eagerly() else "graph") 761 test_util.run_in_graph_and_eager_modes(_test)(self) 762 self.assertEqual(modes, ["graph"]) 763 764 def test_run_in_graph_and_eager_modes_setup_in_same_mode(self): 765 modes = [] 766 mode_name = lambda: "eager" if context.executing_eagerly() else "graph" 767 768 class ExampleTest(test_util.TensorFlowTestCase): 769 770 def runTest(self): 771 pass 772 773 def setUp(self): 774 modes.append("setup_" + mode_name()) 775 776 @test_util.run_in_graph_and_eager_modes 777 def testBody(self): 778 modes.append("run_" + mode_name()) 779 780 e = ExampleTest() 781 e.setUp() 782 e.testBody() 783 784 self.assertEqual(modes[1:2], ["run_graph"]) 785 self.assertEqual(modes[2:], ["setup_eager", "run_eager"]) 786 787 @parameterized.named_parameters(dict(testcase_name="argument", 788 arg=True)) 789 @test_util.run_in_graph_and_eager_modes 790 def test_run_in_graph_and_eager_works_with_parameterized_keyword(self, arg): 791 self.assertEqual(arg, True) 792 793 @combinations.generate(combinations.combine(arg=True)) 794 @test_util.run_in_graph_and_eager_modes 795 def test_run_in_graph_and_eager_works_with_combinations(self, arg): 796 self.assertEqual(arg, True) 797 798 def test_build_as_function_and_v1_graph(self): 799 800 class GraphModeAndFunctionTest(parameterized.TestCase): 801 802 def __init__(inner_self): # pylint: disable=no-self-argument 803 super(GraphModeAndFunctionTest, inner_self).__init__() 804 inner_self.graph_mode_tested = False 805 inner_self.inside_function_tested = False 806 807 def runTest(self): 808 del self 809 810 @test_util.build_as_function_and_v1_graph 811 def test_modes(inner_self): # pylint: disable=no-self-argument 812 if ops.inside_function(): 813 self.assertFalse(inner_self.inside_function_tested) 814 inner_self.inside_function_tested = True 815 else: 816 self.assertFalse(inner_self.graph_mode_tested) 817 inner_self.graph_mode_tested = True 818 819 test_object = GraphModeAndFunctionTest() 820 test_object.test_modes_v1_graph() 821 test_object.test_modes_function() 822 self.assertTrue(test_object.graph_mode_tested) 823 self.assertTrue(test_object.inside_function_tested) 824 825 def test_with_forward_compatibility_horizons(self): 826 827 tested_codepaths = set() 828 def some_function_with_forward_compat_behavior(): 829 if compat.forward_compatible(2050, 1, 1): 830 tested_codepaths.add("future") 831 else: 832 tested_codepaths.add("present") 833 834 @test_util.with_forward_compatibility_horizons(None, [2051, 1, 1]) 835 def some_test(self): 836 del self # unused 837 some_function_with_forward_compat_behavior() 838 839 some_test(None) 840 self.assertEqual(tested_codepaths, set(["present", "future"])) 841 842 843class SkipTestTest(test_util.TensorFlowTestCase): 844 845 def _verify_test_in_set_up_or_tear_down(self): 846 with self.assertRaises(unittest.SkipTest): 847 with test_util.skip_if_error(self, ValueError, 848 ["foo bar", "test message"]): 849 raise ValueError("test message") 850 try: 851 with self.assertRaisesRegex(ValueError, "foo bar"): 852 with test_util.skip_if_error(self, ValueError, "test message"): 853 raise ValueError("foo bar") 854 except unittest.SkipTest: 855 raise RuntimeError("Test is not supposed to skip.") 856 857 def setUp(self): 858 super(SkipTestTest, self).setUp() 859 self._verify_test_in_set_up_or_tear_down() 860 861 def tearDown(self): 862 super(SkipTestTest, self).tearDown() 863 self._verify_test_in_set_up_or_tear_down() 864 865 def test_skip_if_error_should_skip(self): 866 with self.assertRaises(unittest.SkipTest): 867 with test_util.skip_if_error(self, ValueError, "test message"): 868 raise ValueError("test message") 869 870 def test_skip_if_error_should_skip_with_list(self): 871 with self.assertRaises(unittest.SkipTest): 872 with test_util.skip_if_error(self, ValueError, 873 ["foo bar", "test message"]): 874 raise ValueError("test message") 875 876 def test_skip_if_error_should_skip_without_expected_message(self): 877 with self.assertRaises(unittest.SkipTest): 878 with test_util.skip_if_error(self, ValueError): 879 raise ValueError("test message") 880 881 def test_skip_if_error_should_skip_without_error_message(self): 882 with self.assertRaises(unittest.SkipTest): 883 with test_util.skip_if_error(self, ValueError): 884 raise ValueError() 885 886 def test_skip_if_error_should_raise_message_mismatch(self): 887 try: 888 with self.assertRaisesRegex(ValueError, "foo bar"): 889 with test_util.skip_if_error(self, ValueError, "test message"): 890 raise ValueError("foo bar") 891 except unittest.SkipTest: 892 raise RuntimeError("Test is not supposed to skip.") 893 894 def test_skip_if_error_should_raise_no_message(self): 895 try: 896 with self.assertRaisesRegex(ValueError, ""): 897 with test_util.skip_if_error(self, ValueError, "test message"): 898 raise ValueError() 899 except unittest.SkipTest: 900 raise RuntimeError("Test is not supposed to skip.") 901 902 903# Its own test case to reproduce variable sharing issues which only pop up when 904# setUp() is overridden and super() is not called. 905class GraphAndEagerNoVariableSharing(test_util.TensorFlowTestCase): 906 907 def setUp(self): 908 pass # Intentionally does not call TensorFlowTestCase's super() 909 910 @test_util.run_in_graph_and_eager_modes 911 def test_no_variable_sharing(self): 912 variable_scope.get_variable( 913 name="step_size", 914 initializer=np.array(1e-5, np.float32), 915 use_resource=True, 916 trainable=False) 917 918 919class GarbageCollectionTest(test_util.TensorFlowTestCase): 920 921 def test_no_reference_cycle_decorator(self): 922 923 class ReferenceCycleTest(object): 924 925 def __init__(inner_self): # pylint: disable=no-self-argument 926 inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name 927 928 @test_util.assert_no_garbage_created 929 def test_has_cycle(self): 930 a = [] 931 a.append(a) 932 933 @test_util.assert_no_garbage_created 934 def test_has_no_cycle(self): 935 pass 936 937 with self.assertRaises(AssertionError): 938 ReferenceCycleTest().test_has_cycle() 939 940 ReferenceCycleTest().test_has_no_cycle() 941 942 @test_util.run_in_graph_and_eager_modes 943 def test_no_leaked_tensor_decorator(self): 944 945 class LeakedTensorTest(object): 946 947 def __init__(inner_self): # pylint: disable=no-self-argument 948 inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name 949 950 @test_util.assert_no_new_tensors 951 def test_has_leak(self): 952 self.a = constant_op.constant([3.], name="leak") 953 954 @test_util.assert_no_new_tensors 955 def test_has_no_leak(self): 956 constant_op.constant([3.], name="no-leak") 957 958 with self.assertRaisesRegex(AssertionError, "Tensors not deallocated"): 959 LeakedTensorTest().test_has_leak() 960 961 LeakedTensorTest().test_has_no_leak() 962 963 def test_no_new_objects_decorator(self): 964 965 class LeakedObjectTest(unittest.TestCase): 966 967 def __init__(self, *args, **kwargs): 968 super(LeakedObjectTest, self).__init__(*args, **kwargs) 969 self.accumulation = [] 970 971 @unittest.expectedFailure 972 @test_util.assert_no_new_pyobjects_executing_eagerly 973 def test_has_leak(self): 974 self.accumulation.append([1.]) 975 976 @test_util.assert_no_new_pyobjects_executing_eagerly 977 def test_has_no_leak(self): 978 self.not_accumulating = [1.] 979 980 self.assertTrue(LeakedObjectTest("test_has_leak").run().wasSuccessful()) 981 self.assertTrue(LeakedObjectTest("test_has_no_leak").run().wasSuccessful()) 982 983 984class RunFunctionsEagerlyInV2Test(test_util.TensorFlowTestCase, 985 parameterized.TestCase): 986 @parameterized.named_parameters( 987 [("_RunEagerly", True), ("_RunGraph", False)]) 988 def test_run_functions_eagerly(self, run_eagerly): # pylint: disable=g-wrong-blank-lines 989 results = [] 990 991 @def_function.function 992 def add_two(x): 993 for _ in range(5): 994 x += 2 995 results.append(x) 996 return x 997 998 with test_util.run_functions_eagerly(run_eagerly): 999 add_two(constant_op.constant(2.)) 1000 if context.executing_eagerly(): 1001 if run_eagerly: 1002 self.assertTrue(isinstance(t, ops.EagerTensor) for t in results) 1003 else: 1004 self.assertTrue(isinstance(t, ops.Tensor) for t in results) 1005 else: 1006 self.assertTrue(isinstance(t, ops.Tensor) for t in results) 1007 1008 1009if __name__ == "__main__": 1010 googletest.main() 1011