1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.framework.ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import gc 23import numpy as np 24import os 25import threading 26import weakref 27 28from tensorflow.core.framework import attr_value_pb2 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.python.autograph.core import ag_ctx 31from tensorflow.python.client import session 32from tensorflow.python.eager import backprop 33from tensorflow.python.eager import context 34from tensorflow.python.eager import def_function 35from tensorflow.python.eager import function as eager_function 36from tensorflow.python.eager import wrap_function 37from tensorflow.python.framework import config 38from tensorflow.python.framework import composite_tensor 39from tensorflow.python.framework import constant_op 40from tensorflow.python.framework import device as pydev 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import errors 43from tensorflow.python.framework import function 44from tensorflow.python.framework import indexed_slices 45from tensorflow.python.framework import ops 46from tensorflow.python.framework import sparse_tensor 47from tensorflow.python.framework import tensor_shape 48from tensorflow.python.framework import tensor_spec 49from tensorflow.python.framework import tensor_util 50from tensorflow.python.framework import test_ops 51from tensorflow.python.framework import test_util 52from tensorflow.python.framework import type_spec 53from tensorflow.python.framework import versions 54from tensorflow.python.ops import array_ops 55from tensorflow.python.ops import control_flow_ops 56from tensorflow.python.ops import math_ops 57from tensorflow.python.ops import resource_variable_ops 58from tensorflow.python.ops import resources 59from tensorflow.python.ops import special_math_ops 60from tensorflow.python.ops import variable_scope 61from tensorflow.python.ops import variables 62import tensorflow.python.ops.gradients # pylint: disable=unused-import 63from tensorflow.python.platform import googletest 64from tensorflow.python.util import compat 65 66 67class ResourceTest(test_util.TensorFlowTestCase): 68 69 @test_util.run_deprecated_v1 70 def testBuildGraph(self): 71 with self.cached_session(): 72 pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") 73 test_ops.resource_create_op(pt).run() 74 75 @test_util.run_deprecated_v1 76 def testInitialize(self): 77 with self.cached_session(): 78 handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") 79 resources.register_resource( 80 handle=handle, 81 create_op=test_ops.resource_create_op(handle), 82 is_initialized_op=test_ops.resource_initialized_op(handle)) 83 self.assertEqual( 84 len( 85 resources.report_uninitialized_resources( 86 resources.shared_resources()).eval()), 1) 87 resources.initialize_resources(resources.shared_resources()).run() 88 self.assertEqual( 89 len( 90 resources.report_uninitialized_resources( 91 resources.shared_resources()).eval()), 0) 92 93 94class TensorAndShapeTest(test_util.TensorFlowTestCase): 95 96 def testShape(self): 97 op = ops.Operation( 98 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 99 t = op.outputs[0] 100 self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) 101 t.set_shape([1, 2, 3]) 102 self.assertEqual([1, 2, 3], t.get_shape()) 103 104 def testIterable(self): 105 if not context.executing_eagerly(): 106 self.skipTest("Eager-mode test") 107 op = ops.Operation( 108 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 109 t = op.outputs[0] 110 with self.assertRaisesRegex(TypeError, "Cannot iterate"): 111 iter(t) 112 113 def testIterableGraph(self): 114 if context.executing_eagerly(): 115 self.skipTest("Graph-mode test") 116 117 op = ops.Operation( 118 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 119 t = op.outputs[0] 120 with self.assertRaisesRegex(TypeError, "iterating.*not allowed in Graph"): 121 next(iter(t)) 122 with self.assertRaisesRegex(TypeError, "iterating.*AutoGraph did convert"): 123 with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED): 124 next(iter(t)) 125 with self.assertRaisesRegex(TypeError, "iterating.*AutoGraph is disabled"): 126 with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED): 127 next(iter(t)) 128 129 def testImplicitBool(self): 130 op = ops.Operation( 131 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.bool]) 132 t = op.outputs[0] 133 with self.assertRaisesRegex(TypeError, 134 "using.*as a.*bool.*not allowed in Graph"): 135 bool(t) 136 with self.assertRaisesRegex(TypeError, 137 "using.*as a.*bool.*AutoGraph did convert"): 138 with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED): 139 bool(t) 140 with self.assertRaisesRegex(TypeError, 141 "using.*as a.*bool.*AutoGraph is disabled"): 142 with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED): 143 bool(t) 144 145 def testAddShape(self): 146 with self.cached_session(): 147 a = array_ops.zeros([2, 3]) 148 b = array_ops.ones([1, 3]) 149 c = a + b 150 self.assertEqual([2, 3], c.shape) 151 152 @test_util.run_deprecated_v1 153 def testUnknownDim(self): 154 with self.cached_session(): 155 a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) 156 b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) 157 c = a + b 158 self.assertEqual([2, None, 3], c.shape.as_list()) 159 160 @test_util.run_deprecated_v1 161 def testUnknownShape(self): 162 with self.cached_session(): 163 a = array_ops.placeholder(dtype=dtypes.float32, shape=None) 164 b = array_ops.ones([1, 3]) 165 c = a + b 166 self.assertEqual(tensor_shape.unknown_shape(), c.shape) 167 168 @test_util.run_deprecated_v1 169 def testScalarShape(self): 170 with self.cached_session(): 171 a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 172 b = array_ops.ones([]) 173 c = a + b 174 self.assertEqual(tensor_shape.TensorShape([]), c.shape) 175 176 @test_util.run_deprecated_v1 177 def testShapeFunctionError(self): 178 with self.cached_session(): 179 a = array_ops.ones([1, 2, 3]) 180 b = array_ops.ones([4, 5, 6]) 181 with self.assertRaisesRegex( 182 ValueError, r"Dimensions must be equal, but are 2 and 5 for .*add" 183 r".*Add(V2)?.* with input shapes: \[1,2,3\], \[4,5,6\]."): 184 _ = a + b 185 186 def testNumpyArray(self): 187 with ops.Graph().as_default(): 188 x = array_ops.ones((3, 4), name="test_ones") 189 190 with self.assertRaisesRegex(NotImplementedError, 191 r"Cannot convert a symbolic.+test_ones"): 192 np.array(x) 193 194 with self.assertRaisesRegex(TypeError, "not well defined.+test_ones"): 195 len(x) 196 197 # EagerTensors should still behave as numpy arrays. 198 with context.eager_mode(): 199 x = array_ops.ones((3, 4)) 200 201 self.assertAllEqual(x, np.ones((3, 4))) 202 self.assertAllEqual(np.array(x), np.ones((3, 4))) 203 self.assertEqual(len(x), 3) 204 205 def testConstructor(self): 206 a = array_ops.ones([]) 207 for name in ["T", "astype", "ravel", "transpose", "reshape", "clip", "size", 208 "tolist", "data"]: 209 with self.assertRaisesRegex( 210 AttributeError, r"If you are looking for numpy-related methods"): 211 getattr(a, name) 212 with self.assertRaisesRegex( 213 AttributeError, r"object has no attribute"): 214 a.foo_bar() 215 216 def testRef(self): 217 x1 = constant_op.constant(3) 218 x2 = x1 219 y = constant_op.constant(3) 220 z = constant_op.constant([6, 10]) 221 w = variables.Variable(5) 222 223 self.assertEqual(x1.ref(), x1.ref()) 224 self.assertEqual(x2.ref(), x2.ref()) 225 self.assertEqual(x1.ref(), x2.ref()) 226 self.assertEqual(y.ref(), y.ref()) 227 self.assertEqual(z.ref(), z.ref()) 228 self.assertEqual(w.ref(), w.ref()) 229 230 self.assertNotEqual(x1.ref(), y.ref()) 231 self.assertNotEqual(x1.ref(), z.ref()) 232 self.assertNotEqual(x1.ref(), w.ref()) 233 self.assertNotEqual(y.ref(), z.ref()) 234 self.assertNotEqual(y.ref(), w.ref()) 235 self.assertNotEqual(z.ref(), w.ref()) 236 237 def testRefDeref(self): 238 x1 = constant_op.constant(3) 239 x2 = x1 240 y = constant_op.constant(3) 241 z = constant_op.constant([6, 10]) 242 w = variables.Variable(5) 243 244 self.assertIs(x1, x1.ref().deref()) 245 self.assertIs(x2, x2.ref().deref()) 246 self.assertIs(x1, x2.ref().deref()) 247 self.assertIs(x2, x1.ref().deref()) 248 self.assertIs(y, y.ref().deref()) 249 self.assertIs(z, z.ref().deref()) 250 251 self.assertIsNot(x1, y.ref().deref()) 252 self.assertIsNot(x1, z.ref().deref()) 253 self.assertIsNot(x1, w.ref().deref()) 254 self.assertIsNot(y, z.ref().deref()) 255 self.assertIsNot(y, w.ref().deref()) 256 self.assertIsNot(z, w.ref().deref()) 257 258 def testRefInSet(self): 259 x1 = constant_op.constant(3) 260 x2 = x1 261 y = constant_op.constant(3) 262 z = constant_op.constant([6, 10]) 263 w = variables.Variable(5) 264 265 self.assertEqual(x1.ref(), x2.ref()) 266 267 tensor_set = { 268 x1.ref(), 269 x2.ref(), 270 y.ref(), 271 z.ref(), 272 w.ref(), 273 } 274 275 self.assertEqual(len(tensor_set), 4) 276 self.assertIn(x1.ref(), tensor_set) 277 self.assertIn(x2.ref(), tensor_set) 278 self.assertIn(y.ref(), tensor_set) 279 self.assertIn(z.ref(), tensor_set) 280 self.assertIn(w.ref(), tensor_set) 281 282 def testRefInDict(self): 283 x1 = constant_op.constant(3) 284 x2 = x1 285 y = constant_op.constant(3) 286 z = constant_op.constant([6, 10]) 287 w = variables.Variable(5) 288 289 self.assertEqual(x1.ref(), x2.ref()) 290 291 tensor_dict = { 292 x1.ref(): "x1", 293 y.ref(): "y", 294 z.ref(): "z", 295 w.ref(): "w", 296 } 297 298 self.assertEqual(len(tensor_dict), 4) 299 300 # Overwriting x1 301 tensor_dict[x2.ref()] = "x2" 302 self.assertEqual(len(tensor_dict), 4) 303 304 self.assertEqual(tensor_dict[x1.ref()], "x2") 305 self.assertEqual(tensor_dict[x2.ref()], "x2") 306 self.assertEqual(tensor_dict[y.ref()], "y") 307 self.assertEqual(tensor_dict[z.ref()], "z") 308 self.assertEqual(tensor_dict[w.ref()], "w") 309 310 def testTensorRefStrong(self): 311 x = constant_op.constant(1.) 312 x_ref = x.ref() 313 del x 314 self.assertIsNotNone(x_ref.deref()) 315 316 def testVariableRefStrong(self): 317 x = variables.Variable(1.) 318 x_ref = x.ref() 319 del x 320 self.assertIsNotNone(x_ref.deref()) 321 322 @test_util.run_in_graph_and_eager_modes 323 def testBitwiseAndNumeric(self): 324 x = constant_op.constant([0, 1, 3]) 325 y = constant_op.constant([1, 1, 1]) 326 327 z = x & y 328 329 self.assertAllEqual(z, [0, 1, 1]) 330 331 @test_util.run_in_graph_and_eager_modes 332 def testBitwiseAndBool(self): 333 x = constant_op.constant([False, False, True, True]) 334 y = constant_op.constant([False, True, False, True]) 335 336 z = x & y 337 338 self.assertAllEqual(z, [False, False, False, True]) 339 340 @test_util.run_in_graph_and_eager_modes 341 def testBitwiseAndErrors(self): 342 x_int = constant_op.constant(0) 343 x_bool = constant_op.constant(True) 344 345 if context.executing_eagerly(): # :( 346 expected_errtype = errors.InvalidArgumentError 347 else: 348 expected_errtype = TypeError 349 350 with self.assertRaises(expected_errtype): 351 _ = x_int & x_bool 352 with self.assertRaises(expected_errtype): 353 _ = x_int & constant_op.constant("a") 354 355 with self.assertRaises(expected_errtype): 356 _ = x_bool & x_int 357 with self.assertRaises(expected_errtype): 358 _ = x_bool & constant_op.constant("a") 359 360 with self.assertRaises(expected_errtype): 361 _ = constant_op.constant("a") & constant_op.constant("b") 362 363 @test_util.run_in_graph_and_eager_modes 364 def testBitwiseOrNumeric(self): 365 x = constant_op.constant([0, 1, 2]) 366 y = constant_op.constant([1, 1, 1]) 367 368 z = x | y 369 370 self.assertAllEqual(z, [1, 1, 3]) 371 372 @test_util.run_in_graph_and_eager_modes 373 def testBitwiseOrBool(self): 374 x = constant_op.constant([False, False, True, True]) 375 y = constant_op.constant([False, True, False, True]) 376 377 z = x | y 378 379 self.assertAllEqual(z, [False, True, True, True]) 380 381 @test_util.run_in_graph_and_eager_modes 382 def testBitwiseOrErrors(self): 383 x_int = constant_op.constant(0) 384 x_bool = constant_op.constant(True) 385 386 if context.executing_eagerly(): # :( 387 expected_errtype = errors.InvalidArgumentError 388 else: 389 expected_errtype = TypeError 390 391 with self.assertRaises(expected_errtype): 392 _ = x_int | x_bool 393 with self.assertRaises(expected_errtype): 394 _ = x_int | constant_op.constant("a") 395 396 with self.assertRaises(expected_errtype): 397 _ = x_bool | x_int 398 with self.assertRaises(expected_errtype): 399 _ = x_bool | constant_op.constant("a") 400 401 with self.assertRaises(expected_errtype): 402 _ = constant_op.constant("a") | constant_op.constant("b") 403 404 @test_util.run_in_graph_and_eager_modes 405 def testBitwiseXorNumeric(self): 406 x = constant_op.constant([0, 1, 3]) 407 y = constant_op.constant([1, 1, 1]) 408 409 z = x ^ y 410 411 self.assertAllEqual(z, [1, 0, 2]) 412 413 @test_util.run_in_graph_and_eager_modes 414 def testBitwiseXorBool(self): 415 x = constant_op.constant([False, False, True, True]) 416 y = constant_op.constant([False, True, False, True]) 417 418 z = x ^ y 419 420 self.assertAllEqual(z, [False, True, True, False]) 421 422 @test_util.run_in_graph_and_eager_modes 423 def testBitwiseXorErrors(self): 424 x_int = constant_op.constant(0) 425 x_bool = constant_op.constant(True) 426 427 if context.executing_eagerly(): # :( 428 expected_errtype = errors.InvalidArgumentError 429 else: 430 expected_errtype = TypeError 431 432 with self.assertRaises(expected_errtype): 433 _ = x_int ^ x_bool 434 with self.assertRaises(expected_errtype): 435 _ = x_int ^ constant_op.constant("a") 436 437 with self.assertRaises(expected_errtype): 438 _ = x_bool ^ x_int 439 with self.assertRaises(expected_errtype): 440 _ = x_bool ^ constant_op.constant("a") 441 442 with self.assertRaises(expected_errtype): 443 _ = constant_op.constant("a") ^ constant_op.constant("b") 444 445 @test_util.run_in_graph_and_eager_modes 446 def testBitwiseNotNumeric(self): 447 x = constant_op.constant([0, dtypes.int32.min, 1]) 448 449 y = ~x 450 451 self.assertAllEqual(y, [-1, dtypes.int32.max, -2]) 452 453 @test_util.run_in_graph_and_eager_modes 454 def testBitwiseNotBool(self): 455 x = constant_op.constant([False, True]) 456 457 y = ~x 458 459 self.assertAllEqual(y, [True, False]) 460 461 @test_util.run_in_graph_and_eager_modes 462 def testBitwiseNotErrors(self): 463 if context.executing_eagerly(): # :( 464 expected_errtype = errors.InvalidArgumentError 465 else: 466 expected_errtype = TypeError 467 468 with self.assertRaises(expected_errtype): 469 _ = ~constant_op.constant("a") 470 471 472@test_util.run_all_in_graph_and_eager_modes 473class IndexedSlicesTest(test_util.TensorFlowTestCase): 474 475 def testToTensor(self): 476 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 477 indices = constant_op.constant([0, 2]) 478 x = ops.IndexedSlices(values, indices) 479 with self.assertRaises(ValueError): 480 tensor = ops.convert_to_tensor(x, name="tensor") 481 self.assertEqual(tensor_shape.TensorShape(None), x.shape) 482 483 dense_shape = constant_op.constant([3, 2]) 484 y = ops.IndexedSlices(values, indices, dense_shape) 485 tensor = ops.convert_to_tensor(y, name="tensor") 486 self.assertAllEqual(tensor.shape, y.shape) 487 self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]]) 488 489 @test_util.run_gpu_only 490 def testEagerCopy(self): 491 with context.eager_mode(): 492 var = variables.Variable([[0.0], [0.0], [0.0], [0.0]], name="tensor") 493 with backprop.GradientTape() as tape: 494 a = array_ops.gather(array_ops.gather(var, [0, 1]), [0, 1]) 495 b = array_ops.gather(array_ops.gather(var, [2, 3]), [0, 1]) 496 r = special_math_ops.einsum("ij,ij->i", a, b) 497 g = tape.gradient(r, [var])[0] 498 values = g.values if isinstance(g, ops.IndexedSlices) else g 499 self.assertAllEqual(values.get_shape(), [4, 1]) 500 501 def testNegation(self): 502 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 503 indices = constant_op.constant([0, 2]) 504 x = -ops.IndexedSlices(values, indices) 505 self.assertAllEqual(x.values, [[-2, -3], [-5, -7]]) 506 self.assertAllEqual(x.indices, [0, 2]) 507 508 def testScalarMul(self): 509 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 510 indices = constant_op.constant([0, 2]) 511 x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices)) 512 self.assertAllEqual(x.values, [[-4, -6], [-10, -14]]) 513 self.assertAllEqual(x.indices, [0, 2]) 514 515 516@test_util.run_all_in_graph_and_eager_modes 517class IndexedSlicesSpecTest(test_util.TensorFlowTestCase, 518 parameterized.TestCase): 519 520 def assertAllTensorsEqual(self, list1, list2): 521 self.assertLen(list1, len(list2)) 522 for (t1, t2) in zip(list1, list2): 523 self.assertAllEqual(t1, t2) 524 525 def testConstruction(self): 526 spec1 = indexed_slices.IndexedSlicesSpec() 527 self.assertEqual(spec1._shape.rank, None) 528 self.assertEqual(spec1._values_dtype, dtypes.float32) 529 self.assertEqual(spec1._indices_dtype, dtypes.int64) 530 self.assertEqual(spec1._dense_shape_dtype, None) 531 self.assertEqual(spec1._indices_shape.as_list(), [None]) 532 533 spec2 = indexed_slices.IndexedSlicesSpec([None, None], dtypes.string, 534 dtypes.int32, dtypes.int64, [10]) 535 self.assertEqual(spec2._shape.as_list(), [None, None]) 536 self.assertEqual(spec2._values_dtype, dtypes.string) 537 self.assertEqual(spec2._indices_dtype, dtypes.int32) 538 self.assertEqual(spec2._dense_shape_dtype, dtypes.int64) 539 self.assertEqual(spec2._indices_shape.as_list(), [10]) 540 541 def testValueType(self): 542 spec1 = indexed_slices.IndexedSlicesSpec() 543 self.assertEqual(spec1.value_type, ops.IndexedSlices) 544 545 @parameterized.parameters([ 546 (indexed_slices.IndexedSlicesSpec(), 547 (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None, 548 tensor_shape.TensorShape([None]))), 549 (indexed_slices.IndexedSlicesSpec(shape=[5, None, None]), 550 (tensor_shape.TensorShape([5, None, None]), dtypes.float32, 551 dtypes.int64, None, tensor_shape.TensorShape([None]))), 552 (indexed_slices.IndexedSlicesSpec( 553 dtype=dtypes.int32, dense_shape_dtype=dtypes.int64), 554 (tensor_shape.TensorShape(None), dtypes.int32, dtypes.int64, 555 dtypes.int64, tensor_shape.TensorShape([None]))), 556 (indexed_slices.IndexedSlicesSpec(indices_shape=[100]), 557 (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None, 558 tensor_shape.TensorShape([100]))), 559 ]) # pyformat: disable 560 def testSerialize(self, spec, expected): 561 serialization = spec._serialize() 562 # TensorShape has an unconventional definition of equality, so we can't use 563 # assertEqual directly here. But repr() is deterministic and lossless for 564 # the expected values, so we can use that instead. 565 self.assertEqual(repr(serialization), repr(expected)) 566 567 @parameterized.parameters([ 568 (indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), ( 569 tensor_spec.TensorSpec(None, dtypes.string), 570 tensor_spec.TensorSpec([None], dtypes.int64), 571 )), 572 (indexed_slices.IndexedSlicesSpec( 573 dtype=dtypes.string, dense_shape_dtype=dtypes.int32), ( 574 tensor_spec.TensorSpec(None, dtypes.string), 575 tensor_spec.TensorSpec([None], dtypes.int64), 576 tensor_spec.TensorSpec([None], dtypes.int32), 577 )), 578 (indexed_slices.IndexedSlicesSpec( 579 shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), ( 580 tensor_spec.TensorSpec([None, 10, 15], dtypes.float32), 581 tensor_spec.TensorSpec([None], dtypes.int64), 582 tensor_spec.TensorSpec([3], dtypes.int32), 583 )), 584 (indexed_slices.IndexedSlicesSpec( 585 shape=[5, 10, 15], dense_shape_dtype=dtypes.int32, 586 indices_shape=[20]), ( 587 tensor_spec.TensorSpec([20, 10, 15], dtypes.float32), 588 tensor_spec.TensorSpec([20], dtypes.int64), 589 tensor_spec.TensorSpec([3], dtypes.int32), 590 )), 591 ]) 592 def testComponentSpecs(self, spec, expected): 593 self.assertEqual(spec._component_specs, expected) 594 595 @parameterized.parameters([ 596 { 597 "spec": indexed_slices.IndexedSlicesSpec(), 598 "values": [3.0, 5.0], 599 "indices": [5, 10] 600 }, 601 { 602 "spec": 603 indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32), 604 "values": [3.0, 5.0], 605 "indices": [5, 10], 606 "dense_shape": [100] 607 }, 608 ]) 609 def testToFromComponents(self, spec, indices, values, dense_shape=None): 610 x = ops.IndexedSlices(indices, values, dense_shape) 611 actual_components = spec._to_components(x) 612 if dense_shape is None: 613 self.assertAllTensorsEqual(actual_components, [indices, values]) 614 else: 615 self.assertAllTensorsEqual(actual_components, 616 [indices, values, dense_shape]) 617 st_reconstructed = spec._from_components(actual_components) 618 self.assertAllEqual(x.indices, st_reconstructed.indices) 619 self.assertAllEqual(x.values, st_reconstructed.values) 620 if dense_shape is None: 621 self.assertIs(st_reconstructed.dense_shape, None) 622 else: 623 self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape) 624 625 @test_util.run_v1_only("IndexedSlicesValue is deprecated in v2") 626 def testFromNumpyComponents(self): 627 indices = np.array([3, 8]) 628 values = np.array([1.0, 9.0]) 629 dense_shape = np.array([100]) 630 631 spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32) 632 st1 = spec1._from_components((values, indices, dense_shape)) 633 self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue) 634 self.assertAllEqual(st1.indices, indices) 635 self.assertAllEqual(st1.values, values) 636 self.assertAllEqual(st1.dense_shape, dense_shape) 637 638 spec2 = indexed_slices.IndexedSlicesSpec() 639 st2 = spec2._from_components((values, indices)) 640 self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue) 641 self.assertAllEqual(st2.indices, indices) 642 self.assertAllEqual(st2.values, values) 643 self.assertIs(st2.dense_shape, None) 644 645 646class NodeDefConstructorTest(test_util.TensorFlowTestCase): 647 648 def testNoArgs(self): 649 nodedef = ops._NodeDef("None", "bar") 650 self.assertProtoEquals("op: 'None' name: 'bar'", nodedef) 651 652 653def _apply_op(g, *args, **kwargs): 654 op = g.create_op(*args, **kwargs) 655 if len(op.outputs) == 1: 656 return op.outputs[0] 657 else: 658 return op.outputs 659 660 661class OperationTest(test_util.TensorFlowTestCase): 662 663 @test_util.run_deprecated_v1 664 def testNoInputs(self): 665 op = test_ops.float_output_string_output(name="myop").a.op 666 self.assertEqual(2, len(op.values())) 667 self.assertEqual(0, len(op.inputs)) 668 self.assertEqual("myop", op.name) 669 670 float_t, label_str_t = op.values() 671 self.assertEqual(dtypes.float32, float_t.dtype) 672 self.assertEqual(op, float_t.op) 673 self.assertEqual(0, float_t._value_index) 674 self.assertEqual(0, len(float_t.consumers())) 675 self.assertEqual("myop", float_t._as_node_def_input()) 676 677 self.assertEqual(dtypes.string, label_str_t.dtype) 678 self.assertEqual(op, label_str_t.op) 679 self.assertEqual(1, label_str_t._value_index) 680 self.assertEqual(0, len(label_str_t.consumers())) 681 self.assertEqual("myop:1", label_str_t._as_node_def_input()) 682 683 self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'", 684 op.node_def) 685 686 @test_util.run_deprecated_v1 687 def testNoOutputs(self): 688 op1 = test_ops.float_output(name="myop1").op 689 float_t, = op1.values() 690 op2 = test_ops.float_input(float_t, name="myop2") 691 self.assertEqual(0, len(op2.values())) 692 self.assertEqual(1, len(op2.inputs)) 693 self.assertIs(float_t, op2.inputs[0]) 694 695 self.assertEqual(1, len(float_t.consumers())) 696 self.assertEqual(op2, float_t.consumers()[0]) 697 698 self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def) 699 self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'", 700 op2.node_def) 701 702 @test_util.run_deprecated_v1 703 def testInputsAndOutputs(self): 704 op1 = test_ops.float_output(name="myop1").op 705 self.assertEqual(1, len(op1.values())) 706 float1_t, = op1.values() 707 708 op2 = test_ops.float_output_string_output(name="myop2").a.op 709 self.assertEqual(2, len(op2.values())) 710 float2_t, label2_str_t = op2.values() 711 712 # Note that we consume label2_str_t twice here. 713 op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op 714 self.assertEqual(2, len(op3.values())) 715 716 self.assertEqual(1, len(float1_t.consumers())) 717 self.assertEqual(op3, float1_t.consumers()[0]) 718 719 self.assertEqual(0, len(float2_t.consumers())) 720 721 self.assertEqual(2, len(label2_str_t.consumers())) 722 self.assertEqual(op3, label2_str_t.consumers()[0]) 723 self.assertEqual(op3, label2_str_t.consumers()[1]) 724 725 self.assertProtoEquals(""" 726 op:'Foo2' name:'myop3' 727 input:'myop1' input:'myop2:1' input:'myop2:1' 728 """, op3.node_def) 729 730 def testDeviceObject(self): 731 op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], []) 732 op._set_device("/job:goo/device:GPU:0") 733 self.assertProtoEquals( 734 "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) 735 op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], []) 736 op._set_device( 737 pydev.DeviceSpec( 738 job="muu", device_type="CPU", device_index=0)) 739 self.assertProtoEquals( 740 "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def) 741 742 def testReferenceInput(self): 743 g = ops.Graph() 744 op1 = ops.Operation( 745 ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], 746 [dtypes.float32_ref, dtypes.float32]) 747 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) 748 self.assertEqual([], list(op1.inputs)) 749 ref_t, nonref_t = op1.values() 750 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 751 op2 = ops.Operation( 752 ops._NodeDef("RefInputFloatInput", "op2"), 753 g, [ref_t, nonref_t], [], 754 input_types=[dtypes.float32_ref, dtypes.float32]) 755 self.assertProtoEquals( 756 "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", 757 op2.node_def) 758 self.assertEqual([ref_t, nonref_t], list(op2.inputs)) 759 op3 = ops.Operation( 760 ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) 761 self.assertProtoEquals( 762 "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", 763 op3.node_def) 764 765 def testInvalidNames(self): 766 g = ops.Graph() 767 with self.assertRaises(ValueError): 768 ops.Operation(ops._NodeDef("op", ""), g) 769 with self.assertRaises(ValueError): 770 ops.Operation(ops._NodeDef("op", "_invalid"), g) 771 with self.assertRaises(ValueError): 772 ops.Operation(ops._NodeDef("op", "-invalid"), g) 773 with self.assertRaises(ValueError): 774 ops.Operation(ops._NodeDef("op", "/invalid"), g) 775 with self.assertRaises(ValueError): 776 ops.Operation(ops._NodeDef("op", "invalid:0"), g) 777 778 @test_util.run_deprecated_v1 779 def testNoShapeFunction(self): 780 op = test_ops.a() 781 self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) 782 783 @test_util.run_in_graph_and_eager_modes 784 def testConvertToTensorNestedArray(self): 785 values = [[2], [3], [5], [7]] 786 tensor = ops.convert_to_tensor(values) 787 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 788 self.assertAllEqual(values, self.evaluate(tensor)) 789 790 def testShapeTuple(self): 791 with self.cached_session(): 792 c = constant_op.constant(1) 793 self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access 794 795 def testConvertToTensorEager(self): 796 with context.eager_mode(): 797 t = constant_op.constant(1) 798 self.assertTrue(isinstance(t, ops.EagerTensor)) 799 converted = ops.convert_to_tensor(t) 800 self.assertTrue(isinstance(converted, ops.EagerTensor)) 801 converted = ops.convert_to_tensor(1) 802 self.assertTrue(isinstance(converted, ops.EagerTensor)) 803 804 @test_util.run_in_graph_and_eager_modes 805 def testConvertToTensorNestedTuple(self): 806 values = ((2,), (3,), (5,), (7,)) 807 tensor = ops.convert_to_tensor(values) 808 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 809 self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values))) 810 811 @test_util.run_in_graph_and_eager_modes 812 def testConvertToTensorNestedTensors(self): 813 values = ((2,), (3,), (5,), (7,)) 814 tensor = ops.convert_to_tensor( 815 [constant_op.constant(row) for row in values]) 816 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 817 self.assertAllEqual(values, self.evaluate(tensor)) 818 tensor = ops.convert_to_tensor( 819 [[constant_op.constant(v) for v in row] for row in values]) 820 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 821 self.assertAllEqual(values, self.evaluate(tensor)) 822 823 @test_util.run_in_graph_and_eager_modes 824 def testConvertToTensorNestedMix(self): 825 values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) 826 tensor = ops.convert_to_tensor(values) 827 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 828 self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor)) 829 830 @test_util.run_in_graph_and_eager_modes 831 def testConvertToTensorPreferred(self): 832 values = [2, 3, 5, 7] 833 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) 834 self.assertEqual(dtypes.float32, tensor.dtype) 835 836 # Convert empty tensor to anything. 837 values = [] 838 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) 839 self.assertEqual(dtypes.int64, tensor.dtype) 840 841 # The preferred dtype is a type error and will convert to 842 # float32 instead. 843 values = [1.23] 844 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) 845 self.assertEqual(dtypes.float32, tensor.dtype) 846 847 @test_util.run_in_graph_and_eager_modes 848 def testConvertToInvalidTensorType(self): 849 with self.assertRaises(TypeError): 850 # Forcing an invalid dtype should fail with a type error. 851 values = [1.23] 852 ops.convert_to_tensor(values, dtype=dtypes.int64) 853 854 @test_util.run_in_graph_and_eager_modes 855 def testConvertToLongLongTensorType(self): 856 tensor = ops.convert_to_tensor( 857 # Get a numpy array of dtype NPY_LONGLONG 858 np.prod(constant_op.constant([1])._shape_tuple()), 859 dtype=dtypes.int64) 860 self.assertEqual(dtypes.int64, tensor.dtype) 861 862 @test_util.run_in_graph_and_eager_modes 863 def testConvertToTensorFromInvalidTensor(self): 864 tensor = constant_op.constant(42.0, dtype=dtypes.float32) 865 with self.assertRaises(ValueError): 866 ops.convert_to_tensor(tensor, dtype=dtypes.int32) 867 868 @test_util.run_in_graph_and_eager_modes 869 def testConvertToTensorProtocol(self): 870 class TensorCompatible: 871 872 def __tf_tensor__(self, dtype=None, name=None): 873 return constant_op.constant((1, 2, 3), dtype=dtype, name=name) 874 875 tc = TensorCompatible() 876 877 tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32) 878 self.assertEqual(tensor.dtype, dtypes.int32) 879 self.assertAllEqual((1, 2, 3), self.evaluate(tensor)) 880 881 @test_util.run_deprecated_v1 882 def testNoConvert(self): 883 # Operation cannot be converted to Tensor. 884 op = control_flow_ops.no_op() 885 with self.assertRaisesRegex(TypeError, 886 "can't convert Operation '.+' to Tensor"): 887 ops.convert_to_tensor(op) 888 889 def testStr(self): 890 node_def = ops._NodeDef("None", "op1") 891 op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32]) 892 self.assertEqual(str(node_def), str(op)) 893 894 def testRepr(self): 895 op = ops.Operation( 896 ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32]) 897 self.assertEqual("<tf.Operation 'op1' type=None>", repr(op)) 898 899 @test_util.run_deprecated_v1 900 def testGetAttr(self): 901 op = test_ops.default_attrs() 902 self.assertEqual(op.get_attr("string_val"), b"abc") 903 self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) 904 self.assertEqual(op.get_attr("int_val"), 123) 905 self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) 906 self.assertEqual(op.get_attr("float_val"), 10.0) 907 self.assertEqual(op.get_attr("float_list_val"), [10.0]) 908 self.assertEqual(op.get_attr("bool_val"), True) 909 self.assertEqual(op.get_attr("bool_list_val"), [True, False]) 910 self.assertEqual(op.get_attr("shape_val"), 911 tensor_shape.as_shape([2, 1]).as_proto()) 912 self.assertEqual(op.get_attr("shape_list_val"), 913 [tensor_shape.as_shape([]).as_proto(), 914 tensor_shape.as_shape([1]).as_proto()]) 915 self.assertEqual(op.get_attr("tensor_val"), 916 tensor_util.make_tensor_proto(1, dtypes.int32)) 917 self.assertEqual(op.get_attr("tensor_list_val"), 918 [tensor_util.make_tensor_proto(1, dtypes.int32)]) 919 920 type_val = op.get_attr("type_val") 921 # First check that type_val is a DType, because the assertEqual will work 922 # no matter what since DType overrides __eq__ 923 self.assertIsInstance(type_val, dtypes.DType) 924 self.assertEqual(type_val, dtypes.int32) 925 926 type_list_val = op.get_attr("type_list_val") 927 self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) 928 self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) 929 930 @function.Defun(dtypes.float32, func_name="MyFunc") 931 def func(x): 932 return x 933 934 op = test_ops.func_attr(func) 935 self.assertEqual(op.get_attr("f"), 936 attr_value_pb2.NameAttrList(name="MyFunc")) 937 938 # Try fetching missing attr 939 with self.assertRaisesRegex( 940 ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."): 941 op.get_attr("FakeAttr") 942 943 # TODO(b/65162920): remove this test when users who are directly mutating the 944 # node_def have been updated to proper usage. 945 @test_util.run_deprecated_v1 946 def testSetAttr(self): 947 op = test_ops.int_attr().op 948 op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) 949 # TODO(skyewm): add node_def check 950 self.assertEqual(op.get_attr("foo"), 2) 951 952 # TODO(nolivia): test all error cases 953 def testAddControlInput(self): 954 with ops.Graph().as_default(): 955 x = constant_op.constant(1).op 956 y = constant_op.constant(2).op 957 z = constant_op.constant(3).op 958 z._add_control_input(x) # pylint: disable=protected-access 959 self.assertEqual(z.control_inputs, [x]) 960 z._add_control_input(x) # pylint: disable=protected-access 961 self.assertEqual(z.control_inputs, [x]) 962 z._add_control_inputs([x, y, y]) # pylint: disable=protected-access 963 self.assertEqual(z.control_inputs, [x, y]) 964 self.assertEqual(x._control_outputs, [z]) 965 966 @test_util.run_deprecated_v1 967 def testRemoveAllControlInputs(self): 968 a = constant_op.constant(1) 969 with ops.control_dependencies([a]): 970 b = constant_op.constant(2) 971 c = constant_op.constant(3) 972 d = constant_op.constant(4) 973 e = constant_op.constant(5) 974 with ops.control_dependencies([a, c]): 975 f = d + e 976 977 self.assertEqual(a.op.control_inputs, []) 978 self.assertEqual(b.op.control_inputs, [a.op]) 979 self.assertEqual(f.op.control_inputs, [a.op, c.op]) 980 981 a.op._remove_all_control_inputs() # pylint: disable=protected-access 982 self.assertEqual(a.op.control_inputs, []) 983 984 b.op._remove_all_control_inputs() # pylint: disable=protected-access 985 self.assertEqual(b.op.control_inputs, []) 986 987 f.op._remove_all_control_inputs() # pylint: disable=protected-access 988 self.assertEqual(f.op.control_inputs, []) 989 self.assertEqual(list(f.op.inputs), [d, e]) 990 991 @test_util.run_deprecated_v1 992 def testControlInputCycle(self): 993 graph = ops.Graph() 994 with graph.as_default(): 995 z = constant_op.constant(0) 996 x = constant_op.constant(1) 997 y = constant_op.constant(2) 998 y.op._add_control_input(z.op) # pylint: disable=protected-access 999 y.op._add_control_input(x.op) # pylint: disable=protected-access 1000 x.op._add_control_input(y.op) # pylint: disable=protected-access 1001 with self.session(graph=graph) as sess: 1002 with self.assertRaisesRegex( 1003 errors.InvalidArgumentError, 1004 "Graph is invalid, contains a cycle with 2 nodes"): 1005 self.evaluate(x) 1006 1007 def testUpdateInput(self): 1008 g = ops.Graph() 1009 with g.as_default(): 1010 x = constant_op.constant(1) 1011 y = constant_op.constant(2) 1012 z = x + y 1013 1014 z.op._update_input(0, y) # pylint: disable=protected-access 1015 self.assertEqual(list(z.op.inputs), [y, y]) 1016 self.assertEqual(x.consumers(), []) 1017 self.assertEqual(y.consumers(), [z.op, z.op]) 1018 with session.Session(graph=g) as sess: 1019 self.assertEqual(self.evaluate(z), 4) 1020 1021 z.op._update_input(0, x) # pylint: disable=protected-access 1022 self.assertEqual(list(z.op.inputs), [x, y]) 1023 self.assertEqual(x.consumers(), [z.op]) 1024 self.assertEqual(y.consumers(), [z.op]) 1025 with session.Session(graph=g) as sess: 1026 self.assertEqual(self.evaluate(z), 3) 1027 1028 z.op._update_input(1, y) # pylint: disable=protected-access 1029 self.assertEqual(list(z.op.inputs), [x, y]) 1030 self.assertEqual(x.consumers(), [z.op]) 1031 self.assertEqual(y.consumers(), [z.op]) 1032 with session.Session(graph=g) as sess: 1033 self.assertEqual(self.evaluate(z), 3) 1034 1035 def testUpdateInputGraphError(self): 1036 g_0 = ops.Graph() 1037 g_1 = ops.Graph() 1038 with g_0.as_default(): 1039 x = constant_op.constant(1) 1040 with g_1.as_default(): 1041 y = constant_op.constant(2) 1042 z = y * 2 1043 with self.assertRaisesRegex(ValueError, "must be from the same graph"): 1044 z.op._update_input(0, x) # pylint: disable=protected-access 1045 1046 def testUpdateInputTypeError(self): 1047 g = ops.Graph() 1048 with g.as_default(): 1049 w = constant_op.constant(0) 1050 x = constant_op.constant("") 1051 y = constant_op.constant(1) 1052 z = y + w 1053 z.op._update_input(0, x) # pylint: disable=protected-access 1054 with session.Session(graph=g) as sess: 1055 with self.assertRaisesRegex( 1056 errors.InvalidArgumentError, 1057 "Input 0 of node add was passed string from Const_1:0 incompatible " 1058 "with expected int32"): 1059 self.evaluate(z) 1060 1061 def testUpdateInputShapeError(self): 1062 g = ops.Graph() 1063 with g.as_default(): 1064 w = constant_op.constant(2, shape=[3, 1]) 1065 x = constant_op.constant(0, shape=[3, 1]) 1066 y = constant_op.constant(1, shape=[2, 2]) 1067 z = w + x 1068 with self.assertRaisesRegex( 1069 errors.InvalidArgumentError, 1070 r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): 1071 z.op._update_input(0, y) # pylint: disable=protected-access 1072 1073 def testUpdateInputOutOfRange(self): 1074 g = ops.Graph() 1075 with g.as_default(): 1076 x = constant_op.constant(1) 1077 with self.assertRaisesRegex( 1078 errors.OutOfRangeError, 1079 r"Cannot update edge. Input index \[1\] is greater than the number of " 1080 r"total inputs \[0\]."): 1081 x.op._update_input(1, x) # pylint: disable=protected-access 1082 1083 @test_util.enable_control_flow_v2 1084 @test_util.run_v1_only("b/120545219") 1085 def testAddWhileInput(self): 1086 1087 @eager_function.defun 1088 def test(): 1089 output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1, 1090 [1]) 1091 while_op = output.op 1092 self.assertEqual(while_op.type, "StatelessWhile") 1093 orig_num_inputs = len(while_op.inputs) 1094 1095 # Make sure we can handle the while op having a control input. 1096 while_op._add_control_input(constant_op.constant(0).op) 1097 1098 new_input1 = constant_op.constant(1.0) 1099 new_input2 = constant_op.constant(True) 1100 1101 # Clear output shapes to bypass shape checking. 1102 while_op._set_shape_list_attr("output_shapes", []) 1103 while_op._set_type_list_attr("T", [t.dtype for t in while_op.inputs] + 1104 [new_input1.dtype, new_input2.dtype]) 1105 1106 while_op._add_while_inputs([new_input1, new_input2]) 1107 # Can't add an edge beyond what's specified by "T" 1108 with self.assertRaises(errors.OutOfRangeError): 1109 while_op._add_while_inputs([new_input2]) 1110 self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert 1111 1112 test() 1113 1114 @test_util.run_deprecated_v1 1115 def testOpDef(self): 1116 x = constant_op.constant(0) 1117 y = constant_op.constant(1) 1118 z = x + y 1119 1120 self.assertEqual(x.op.op_def.name, "Const") 1121 self.assertEqual(len(x.op.op_def.input_arg), 0) 1122 self.assertEqual(len(x.op.op_def.output_arg), 1) 1123 1124 self.assertRegex(z.op.op_def.name, "Add(V2)?") 1125 self.assertEqual(len(z.op.op_def.input_arg), 2) 1126 self.assertEqual(len(z.op.op_def.output_arg), 1) 1127 1128 def testInputFromDifferentGraphError(self): 1129 g_0 = ops.Graph() 1130 g_1 = ops.Graph() 1131 with g_0.as_default(): 1132 x = constant_op.constant(1) 1133 with g_1.as_default(): 1134 y = constant_op.constant(2) 1135 with self.assertRaisesRegex(ValueError, "must be from the same graph"): 1136 y * x # pylint: disable=pointless-statement 1137 1138 def testInputsAreImmutable(self): 1139 g = ops.Graph() 1140 with g.as_default(): 1141 x = test_ops.int_output() 1142 op = test_ops.int_input_int_output(x, name="myop").op 1143 with self.assertRaisesRegex(AttributeError, 1144 "'tuple' object has no attribute 'append'"): 1145 op.inputs.append(None) 1146 1147 1148class CreateOpTest(test_util.TensorFlowTestCase): 1149 1150 def testNodeDefArgs(self): 1151 g = ops.Graph() 1152 op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 1153 with g.device("/device:GPU:0"): 1154 op2 = g.create_op( 1155 "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None, 1156 name="myop2") 1157 op3 = g.create_op( 1158 "Foo3", 1159 [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]], 1160 [dtypes.float32, dtypes.int32], 1161 None, 1162 name="myop3") 1163 self.assertDeviceEqual(None, op1.device) 1164 self.assertDeviceEqual("/device:GPU:0", op2.device) 1165 self.assertDeviceEqual(None, op3.device) 1166 self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def) 1167 self.assertProtoEquals( 1168 "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'", 1169 op2.node_def) 1170 self.assertProtoEquals( 1171 "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'", 1172 op3.node_def) 1173 1174 def testReferenceInput(self): 1175 g = ops.Graph() 1176 op1 = g.create_op( 1177 "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], 1178 name="op1") 1179 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) 1180 ref_t, nonref_t = op1.values() 1181 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 1182 op2 = g.create_op( 1183 "RefInputFloatInput", [ref_t, nonref_t], [], 1184 input_types=[dtypes.float32_ref, dtypes.float32], 1185 name="op2") 1186 self.assertProtoEquals( 1187 "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", 1188 op2.node_def) 1189 op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3") 1190 self.assertProtoEquals( 1191 "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", 1192 op3.node_def) 1193 1194 def testFinalized(self): 1195 g = ops.Graph() 1196 g.finalize() 1197 with self.assertRaises(RuntimeError): 1198 g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 1199 1200 # Test unfinalize. 1201 g._unsafe_unfinalize() 1202 g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 1203 1204 1205# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation 1206# method. Arguably we should only test the public APIs that depend on this 1207# method. However, this logic is complex and tricky, and it can be difficult to 1208# ascertain if we have adequate coverage (e.g. a graph may run successfully if 1209# the control flow context isn't set properly, but a more complicated use case 1210# that might not be obvious to test will fail). Thus we instead explicitly test 1211# the low-level behavior. 1212class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): 1213 1214 @test_util.run_deprecated_v1 1215 def testBasic(self): 1216 g = ops.Graph() 1217 with g.as_default(): 1218 x = test_ops.int_output() 1219 c_op = ops._create_c_op( 1220 g, ops._NodeDef("IntInputIntOutput", "myop"), [x], []) 1221 op = g._create_op_from_tf_operation(c_op) 1222 1223 self.assertEqual(op.name, "myop") 1224 self.assertEqual(op.type, "IntInputIntOutput") 1225 self.assertEqual(len(op.outputs), 1) 1226 self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape()) 1227 self.assertEqual(list(op.inputs), [x]) 1228 self.assertEqual(op.control_inputs, []) 1229 self.assertEqual(op.graph, g) 1230 self.assertEqual(x.consumers(), [op]) 1231 self.assertIsNotNone(op.traceback) 1232 self.assertEqual(g.get_operation_by_name("myop"), op) 1233 self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0]) 1234 1235 def testShape(self): 1236 g = ops.Graph() 1237 with g.as_default(): 1238 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 1239 c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], []) 1240 op = g._create_op_from_tf_operation(c_op) 1241 1242 self.assertEqual(op.name, "myop") 1243 self.assertEqual(op.type, "Identity") 1244 self.assertEqual(len(op.outputs), 1) 1245 self.assertEqual(op.outputs[0].shape, tensor_shape.TensorShape([2, 3])) 1246 1247 def testUniqueName(self): 1248 g = ops.Graph() 1249 with g.as_default(): 1250 c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], []) 1251 c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], []) 1252 op = g._create_op_from_tf_operation(c_op) 1253 op2 = g._create_op_from_tf_operation(c_op2) 1254 1255 # Create ops with same names as op1 and op2. We expect the new names to be 1256 # uniquified. 1257 op3 = test_ops.int_output(name="myop").op 1258 op4 = test_ops.int_output(name="myop_1").op 1259 1260 self.assertEqual(op.name, "myop") 1261 self.assertEqual(op2.name, "myop_1") 1262 self.assertEqual(op3.name, "myop_2") 1263 self.assertEqual(op4.name, "myop_1_1") 1264 1265 @test_util.run_v1_only("b/120545219") 1266 def testCond(self): 1267 g = ops.Graph() 1268 with g.as_default(): 1269 x = test_ops.int_output() 1270 1271 def true_fn(): 1272 ops._create_c_op(ops.get_default_graph(), 1273 ops._NodeDef("IntInput", "cond/myop"), [x], []) 1274 new_ops = g._add_new_tf_operations() 1275 self.assertEqual(len(new_ops), 1) 1276 return x 1277 1278 control_flow_ops.cond(x < 10, true_fn, lambda: x) 1279 1280 op = g.get_operation_by_name("cond/myop") 1281 self.assertIsNotNone(op) 1282 self.assertEqual(op.name, "cond/myop") 1283 self.assertEqual(op.type, "IntInput") 1284 self.assertEqual(op.outputs, []) 1285 op_input = op.inputs[0].op 1286 self.assertEqual(op_input.type, "Switch") 1287 self.assertEqual(op_input.inputs[0], x) 1288 self.assertEqual(op.graph, g) 1289 # pylint: disable=protected-access 1290 self.assertIsNotNone(op._get_control_flow_context()) 1291 self.assertEqual(op._get_control_flow_context().name, 1292 "cond/cond_text") 1293 # pylint: enable=protected-access 1294 1295 @test_util.run_v1_only("b/120545219") 1296 def testWhileLoop(self): 1297 g = ops.Graph() 1298 with g.as_default(): 1299 x = test_ops.int_output() 1300 1301 def body(i): 1302 ops._create_c_op(ops.get_default_graph(), 1303 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 1304 new_ops = g._add_new_tf_operations() 1305 self.assertEqual(len(new_ops), 1) 1306 return i 1307 1308 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 1309 1310 op = g.get_operation_by_name("myloop/myop") 1311 self.assertIsNotNone(op) 1312 self.assertEqual(op.name, "myloop/myop") 1313 self.assertEqual(op.type, "IntInput") 1314 self.assertEqual(op.outputs, []) 1315 op_input = op.inputs[0].op 1316 self.assertEqual(op_input.type, "Enter") 1317 self.assertEqual(list(op_input.inputs), [x]) 1318 self.assertEqual(op.graph, g) 1319 # pylint: disable=protected-access 1320 self.assertIsNotNone(op._get_control_flow_context()) 1321 self.assertEqual(op._get_control_flow_context().name, 1322 "myloop/while_context") 1323 # pylint: enable=protected-access 1324 1325 @test_util.run_v1_only("b/120545219") 1326 def testWhileLoopWithInternalControlDep(self): 1327 g = ops.Graph() 1328 with g.as_default(): 1329 x = test_ops.int_output() 1330 1331 def body(i): 1332 c = constant_op.constant(1.0, name="c") 1333 ops._create_c_op(ops.get_default_graph(), 1334 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 1335 with ops.control_dependencies([c]): 1336 new_ops = g._add_new_tf_operations() 1337 self.assertEqual(len(new_ops), 1) 1338 return i 1339 1340 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 1341 1342 op = g.get_operation_by_name("myloop/myop") 1343 self.assertIsNotNone(op) 1344 c = g.get_operation_by_name("myloop/c") 1345 self.assertIsNotNone(c) 1346 # Internal control dep is preserved 1347 self.assertEqual(op.control_inputs, [c]) 1348 1349 @test_util.run_v1_only("b/120545219") 1350 def testWhileLoopWithExternalControlDep(self): 1351 g = ops.Graph() 1352 with g.as_default(): 1353 x = test_ops.int_output() 1354 c = constant_op.constant(1.0) 1355 1356 def body(i): 1357 ops._create_c_op(ops.get_default_graph(), 1358 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 1359 with ops.control_dependencies([c]): 1360 new_ops = g._add_new_tf_operations() 1361 self.assertEqual(len(new_ops), 1) 1362 return i 1363 1364 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 1365 1366 op = g.get_operation_by_name("myloop/myop") 1367 self.assertIsNotNone(op) 1368 # External control dep is removed and replaced with internal control dep 1369 self.assertNotEqual(op.control_inputs[0], c.op) 1370 self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) 1371 1372 1373class ApplyOpTest(test_util.TensorFlowTestCase): 1374 1375 def testNodeDefArgs(self): 1376 g = ops.Graph() 1377 t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") 1378 with g.device("/device:GPU:0"): 1379 t2 = _apply_op( 1380 g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2") 1381 t3 = _apply_op( 1382 g, 1383 "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], 1384 name="myop3") 1385 self.assertTrue(isinstance(t1, ops.Tensor)) 1386 self.assertTrue(isinstance(t2, list)) 1387 self.assertTrue(isinstance(t3, list)) 1388 self.assertTrue(isinstance(t3[0], ops.Tensor)) 1389 self.assertEqual("myop1", t1._as_node_def_input()) 1390 self.assertEqual("myop2", t2[0]._as_node_def_input()) 1391 self.assertEqual("myop2:1", t2[1]._as_node_def_input()) 1392 self.assertEqual("myop3", t3[0]._as_node_def_input()) 1393 # Validate that we got the right ops as well 1394 self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def) 1395 self.assertProtoEquals( 1396 "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'", 1397 t2[0].op.node_def) 1398 self.assertProtoEquals( 1399 "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'", 1400 t3[0].op.node_def) 1401 1402 def testReferenceInput(self): 1403 g = ops.Graph() 1404 ref_t, nonref_t = _apply_op( 1405 g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], 1406 name="op1") 1407 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", 1408 ref_t.op.node_def) 1409 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 1410 out_2 = _apply_op( 1411 g, 1412 "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32], 1413 input_types=[dtypes.float32_ref, dtypes.float32], 1414 name="op2") 1415 self.assertProtoEquals( 1416 "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'", 1417 out_2.op.node_def) 1418 out_3 = _apply_op( 1419 g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32], 1420 name="op3") 1421 self.assertProtoEquals( 1422 "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'", 1423 out_3.op.node_def) 1424 1425 1426class NameStackTest(test_util.TensorFlowTestCase): 1427 1428 def testBasics(self): 1429 g = ops.Graph() 1430 self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) 1431 self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) 1432 self.assertEqual("foo", g.unique_name("foo")) 1433 self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False)) 1434 self.assertEqual("foo_1", g.unique_name("foo")) 1435 self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False)) 1436 self.assertEqual("foo_2", g.unique_name("foo")) 1437 self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False)) 1438 self.assertEqual("foo_1_1", g.unique_name("foo_1")) 1439 self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False)) 1440 self.assertEqual("foo_1_2", g.unique_name("foo_1")) 1441 self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False)) 1442 self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2")) 1443 with g.name_scope("bar"): 1444 self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False)) 1445 self.assertEqual("bar/foo", g.unique_name("foo")) 1446 self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False)) 1447 self.assertEqual("bar/foo_1", g.unique_name("foo")) 1448 with g.name_scope(None): 1449 self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False)) 1450 self.assertEqual("foo_3", g.unique_name("foo")) 1451 with g.name_scope("baz"): 1452 self.assertEqual( 1453 "bar/baz/foo", g.unique_name( 1454 "foo", mark_as_used=False)) 1455 self.assertEqual("bar/baz/foo", g.unique_name("foo")) 1456 self.assertEqual( 1457 "bar/baz/foo_1", g.unique_name( 1458 "foo", mark_as_used=False)) 1459 self.assertEqual("bar/baz/foo_1", g.unique_name("foo")) 1460 with g.name_scope("baz"): 1461 self.assertEqual( 1462 "bar/baz_1/foo", g.unique_name( 1463 "foo", mark_as_used=False)) 1464 self.assertEqual("bar/baz_1/foo", g.unique_name("foo")) 1465 self.assertEqual( 1466 "bar/baz_1/foo_1", g.unique_name( 1467 "foo", mark_as_used=False)) 1468 self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo")) 1469 with g.name_scope("quux"): 1470 self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False)) 1471 self.assertEqual("quux/foo", g.unique_name("foo")) 1472 with g.name_scope("bar"): 1473 with g.name_scope("baz"): 1474 self.assertEqual( 1475 "bar_1/baz/foo", g.unique_name( 1476 "foo", mark_as_used=False)) 1477 self.assertEqual("bar_1/baz/foo", g.unique_name("foo")) 1478 self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False)) 1479 self.assertEqual("foo_4", g.unique_name("foo")) 1480 self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False)) 1481 self.assertEqual("bar_2", g.unique_name("bar")) 1482 1483 def testBackslashAndDashRegex(self): 1484 # GitHub issue 39019, all should pass 1485 g = ops.Graph() 1486 with g.name_scope("n_CatCntc-campaign\\c_campaign"): 1487 pass 1488 with g.name_scope("foo"): 1489 with g.name_scope("n_CatCntc-campaign\\c_campaign"): 1490 pass 1491 with g.name_scope("n_CatCntc-campaign\\c_campaign"): 1492 with g.name_scope("foo"): 1493 pass 1494 1495 @test_util.run_deprecated_v1 1496 def testNameAndVariableScope(self): 1497 with self.cached_session() as sess: 1498 with sess.graph.name_scope("l0"): 1499 with variable_scope.variable_scope("l1"): 1500 with sess.graph.name_scope("l1") as scope: 1501 self.assertEqual("l0/l1/l1/", scope) 1502 self.assertEqual( 1503 "l0/l1/l1/foo", 1504 sess.graph.unique_name( 1505 "foo", mark_as_used=False)) 1506 self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo")) 1507 with sess.graph.name_scope("l2") as scope: 1508 self.assertEqual("l0/l1/l2/", scope) 1509 self.assertEqual( 1510 "l0/l1/l2/foo", 1511 sess.graph.unique_name( 1512 "foo", mark_as_used=False)) 1513 self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo")) 1514 1515 def testOutOfOrderUniqueName(self): 1516 g = ops.Graph() 1517 self.assertEqual("foo_2", g.unique_name("foo_2")) 1518 self.assertEqual("foo", g.unique_name("foo")) 1519 self.assertEqual("foo_1", g.unique_name("foo")) 1520 self.assertEqual("foo_3", g.unique_name("foo")) 1521 1522 def testUniqueNameCaseInsensitivity(self): 1523 g = ops.Graph() 1524 self.assertEqual("foo", g.unique_name("foo")) 1525 self.assertEqual("Foo_1", g.unique_name("Foo")) 1526 with g.name_scope("bar"): 1527 self.assertEqual("bar/foo", g.unique_name("foo")) 1528 with g.name_scope("Bar"): 1529 self.assertEqual("Bar_1/foo", g.unique_name("foo")) 1530 1531 def testInvalidNameRaisesError(self): 1532 g = ops.Graph() 1533 with g.name_scope(""): # Should not raise 1534 pass 1535 with g.name_scope("foo/"): # Should not raise 1536 with g.name_scope("_bar"): # Should not raise 1537 pass 1538 with self.assertRaises(ValueError): 1539 with g.name_scope("foo:0"): 1540 pass 1541 with self.assertRaises(ValueError): 1542 with g.name_scope("_bar"): 1543 pass 1544 1545 1546class NameTest(test_util.TensorFlowTestCase): 1547 1548 def testGenerateName(self): 1549 g = ops.Graph() 1550 op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) 1551 self.assertEqual("TwoFloatOutputs", op0.name) 1552 self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name) 1553 self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name) 1554 1555 op1 = g.create_op("FloatOutput", [], [dtypes.float32]) 1556 self.assertEqual("FloatOutput", op1.name) 1557 self.assertEqual("FloatOutput:0", op1.outputs[0].name) 1558 1559 op2 = g.create_op("FloatOutput", [], [dtypes.float32]) 1560 self.assertEqual("FloatOutput_1", op2.name) 1561 self.assertEqual("FloatOutput_1:0", op2.outputs[0].name) 1562 1563 op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op") 1564 self.assertEqual("my_op", op3.name) 1565 self.assertEqual("my_op:0", op3.outputs[0].name) 1566 1567 def testNameScope(self): 1568 g = ops.Graph() 1569 1570 with g.name_scope("foo") as foo: 1571 self.assertEqual("foo/", foo) 1572 with g.name_scope("foo2") as foo2: 1573 self.assertEqual("foo/foo2/", foo2) 1574 with g.name_scope(None) as empty1: 1575 self.assertEqual("", empty1) 1576 with g.name_scope("foo3") as foo3: 1577 self.assertEqual("foo3/", foo3) 1578 with g.name_scope("") as empty2: 1579 self.assertEqual("", empty2) 1580 1581 self.assertEqual("FloatOutput", 1582 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1583 with g.name_scope("bar") as scope: 1584 self.assertEqual("bar/FloatOutput", 1585 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1586 self.assertEqual("bar/FloatOutput_1", 1587 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1588 # If you use the value from "with .. as", that values is used as-is. 1589 self.assertEqual( 1590 "bar", g.create_op( 1591 "FloatOutput", [], [dtypes.float32], name=scope).name) 1592 with g.name_scope("baz") as scope: 1593 with g.name_scope("quux"): 1594 self.assertEqual("baz/quux/FloatOutput", 1595 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1596 # If you use the value from the enclosing "with .. as", nothing is pushed. 1597 with g.name_scope(scope): 1598 self.assertEqual("baz/FloatOutput", 1599 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1600 self.assertEqual( 1601 "baz", g.create_op( 1602 "FloatOutput", [], [dtypes.float32], name=scope).name) 1603 self.assertEqual( 1604 "trailing", 1605 g.create_op( 1606 "FloatOutput", [], [dtypes.float32], name="trailing/").name) 1607 with g.name_scope("bar"): 1608 self.assertEqual("bar_1/FloatOutput", 1609 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1610 with g.name_scope("bar/"): 1611 self.assertEqual("bar/FloatOutput_2", 1612 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1613 1614 1615class DeviceTest(test_util.TensorFlowTestCase): 1616 1617 def testNoDevice(self): 1618 g = ops.Graph() 1619 op = g.create_op("FloatOutput", [], [dtypes.float32]) 1620 self.assertDeviceEqual(None, op.device) 1621 gd = g.as_graph_def() 1622 self.assertProtoEqualsVersion(""" 1623 node { name: "FloatOutput" op: "FloatOutput" } 1624 """, gd) 1625 1626 def testEagerBackingDevice(self): 1627 with context.eager_mode(): 1628 with ops.device("/device:CPU:0"): 1629 t = constant_op.constant(1.0) 1630 self.assertRegex(t.device, "/device:CPU:0") 1631 self.assertRegex(t.backing_device, "/device:CPU:0") 1632 1633 def testDevicePartialString(self): 1634 g = ops.Graph() 1635 with g.device("/job:worker/replica:2"): 1636 g.create_op("FloatOutput", [], [dtypes.float32]) 1637 gd = g.as_graph_def() 1638 self.assertProtoEqualsVersion(""" 1639 node { name: "FloatOutput" op: "FloatOutput" 1640 device: "/job:worker/replica:2" } 1641 """, gd) 1642 1643 def testDeviceFull(self): 1644 g = ops.Graph() 1645 with g.device( 1646 pydev.DeviceSpec( 1647 job="worker", replica=2, task=0, device_type="CPU", 1648 device_index=3)): 1649 g.create_op("FloatOutput", [], [dtypes.float32]) 1650 gd = g.as_graph_def() 1651 self.assertProtoEqualsVersion(""" 1652 node { name: "FloatOutput" op: "FloatOutput" 1653 device: "/job:worker/replica:2/task:0/device:CPU:3" } 1654 """, gd) 1655 1656 def testNesting(self): 1657 g = ops.Graph() 1658 with g.device("/job:worker/replica:2"): 1659 g.create_op("FloatOutput", [], [dtypes.float32]) 1660 with g.device("/job:worker/replica:3/task:0"): 1661 g.create_op("FloatOutput", [], [dtypes.float32]) 1662 g.create_op("FloatOutput", [], [dtypes.float32]) 1663 gd = g.as_graph_def() 1664 self.assertProtoEqualsVersion(""" 1665 node { name: "FloatOutput" op: "FloatOutput" 1666 device: "/job:worker/replica:2" } 1667 node { name: "FloatOutput_1" op: "FloatOutput" 1668 device: "/job:worker/replica:3/task:0" } 1669 node { name: "FloatOutput_2" op: "FloatOutput" 1670 device: "/job:worker/replica:2" } 1671 """, gd) 1672 1673 def testNestingString(self): 1674 g = ops.Graph() 1675 with g.device("/job:worker/replica:2"): 1676 g.create_op("FloatOutput", [], [dtypes.float32]) 1677 with g.device("/job:worker/replica:3/task:0"): 1678 g.create_op("FloatOutput", [], [dtypes.float32]) 1679 g.create_op("FloatOutput", [], [dtypes.float32]) 1680 gd = g.as_graph_def() 1681 self.assertProtoEqualsVersion(""" 1682 node { name: "FloatOutput" op: "FloatOutput" 1683 device: "/job:worker/replica:2" } 1684 node { name: "FloatOutput_1" op: "FloatOutput" 1685 device: "/job:worker/replica:3/task:0" } 1686 node { name: "FloatOutput_2" op: "FloatOutput" 1687 device: "/job:worker/replica:2" } 1688 """, gd) 1689 1690 def testNestingOverrideGpuCpu(self): 1691 g = ops.Graph() 1692 with g.device("/job:worker/replica:2/device:CPU:1"): 1693 g.create_op("FloatOutput", [], [dtypes.float32]) 1694 with g.device("/job:worker/replica:2/device:GPU:2"): 1695 g.create_op("FloatOutput", [], [dtypes.float32]) 1696 g.create_op("FloatOutput", [], [dtypes.float32]) 1697 gd = g.as_graph_def() 1698 self.assertProtoEqualsVersion(""" 1699 node { name: "FloatOutput" op: "FloatOutput" 1700 device: "/job:worker/replica:2/device:CPU:1" } 1701 node { name: "FloatOutput_1" op: "FloatOutput" 1702 device: "/job:worker/replica:2/device:GPU:2" } 1703 node { name: "FloatOutput_2" op: "FloatOutput" 1704 device: "/job:worker/replica:2/device:CPU:1" } 1705 """, gd) 1706 1707 def testNestingWithMergeDeviceFunction(self): 1708 g = ops.Graph() 1709 1710 with g.device(pydev.merge_device("/device:GPU:0")): 1711 g.create_op("FloatOutput", [], [dtypes.float32]) 1712 with g.device(pydev.merge_device("/job:worker")): 1713 g.create_op("FloatOutput", [], [dtypes.float32]) 1714 with g.device(pydev.merge_device("/device:CPU:0")): 1715 g.create_op("FloatOutput", [], [dtypes.float32]) 1716 with g.device(pydev.merge_device("/job:ps")): 1717 g.create_op("FloatOutput", [], [dtypes.float32]) 1718 with g.device(pydev.merge_device(None)): 1719 g.create_op("FloatOutput", [], [dtypes.float32]) 1720 1721 gd = g.as_graph_def() 1722 self.assertProtoEqualsVersion(""" 1723 node { name: "FloatOutput" op: "FloatOutput" 1724 device: "/device:GPU:0" } 1725 node { name: "FloatOutput_1" op: "FloatOutput" 1726 device: "/job:worker/device:GPU:0" } 1727 node { name: "FloatOutput_2" op: "FloatOutput" 1728 device: "/job:worker/device:CPU:0" } 1729 node { name: "FloatOutput_3" op: "FloatOutput" 1730 device: "/job:ps/device:CPU:0" } 1731 node { name: "FloatOutput_4" op: "FloatOutput" 1732 device: "/job:ps/device:CPU:0" } 1733 """, gd) 1734 1735 def testNestingWithDeviceStrings(self): 1736 g = ops.Graph() 1737 1738 with g.device("/device:GPU:0"): 1739 g.create_op("FloatOutput", [], [dtypes.float32]) 1740 with g.device("/job:worker"): 1741 g.create_op("FloatOutput", [], [dtypes.float32]) 1742 with g.device("/device:CPU:0"): 1743 g.create_op("FloatOutput", [], [dtypes.float32]) 1744 with g.device("/job:ps"): 1745 g.create_op("FloatOutput", [], [dtypes.float32]) 1746 with g.device(""): 1747 g.create_op("FloatOutput", [], [dtypes.float32]) 1748 1749 gd = g.as_graph_def() 1750 self.assertProtoEqualsVersion(""" 1751 node { name: "FloatOutput" op: "FloatOutput" 1752 device: "/device:GPU:0" } 1753 node { name: "FloatOutput_1" op: "FloatOutput" 1754 device: "/job:worker/device:GPU:0" } 1755 node { name: "FloatOutput_2" op: "FloatOutput" 1756 device: "/job:worker/device:CPU:0" } 1757 node { name: "FloatOutput_3" op: "FloatOutput" 1758 device: "/job:ps/device:CPU:0" } 1759 node { name: "FloatOutput_4" op: "FloatOutput" 1760 device: "/job:ps/device:CPU:0" } 1761 """, gd) 1762 1763 def testNestingWithDeviceStringWildcard(self): 1764 g = ops.Graph() 1765 1766 with g.device("/device:GPU:7"): 1767 g.create_op("FloatOutput", [], [dtypes.float32]) 1768 with g.device("/device:GPU:*"): 1769 g.create_op("FloatOutput", [], [dtypes.float32]) 1770 1771 with g.device("/device:CPU:*"): 1772 g.create_op("FloatOutput", [], [dtypes.float32]) 1773 with g.device("/device:CPU:5"): 1774 g.create_op("FloatOutput", [], [dtypes.float32]) 1775 1776 gd = g.as_graph_def() 1777 self.assertProtoEqualsVersion(""" 1778 node { name: "FloatOutput" op: "FloatOutput" 1779 device: "/device:GPU:7" } 1780 node { name: "FloatOutput_1" op: "FloatOutput" 1781 device: "/device:GPU:7" } 1782 node { name: "FloatOutput_2" op: "FloatOutput" 1783 device: "/device:CPU:*" } 1784 node { name: "FloatOutput_3" op: "FloatOutput" 1785 device: "/device:CPU:5" } 1786 """, gd) 1787 1788 def testNestingErrorGraph(self): 1789 g = ops.Graph() 1790 scope = g.device("/device:GPU:8") 1791 scope.__enter__() 1792 with g.device("/device:GPU:9"): 1793 with self.assertRaises(RuntimeError): 1794 scope.__exit__(None, None, None) 1795 1796 def testNestingErrorEager(self): 1797 with context.eager_mode(): 1798 scope = ops.device("/device:CPU:0") 1799 scope.__enter__() 1800 with ops.device(None): 1801 with self.assertRaises(RuntimeError): 1802 scope.__exit__(None, None, None) 1803 1804 def testNoneClearsDefault(self): 1805 g = ops.Graph() 1806 with g.device("/job:worker/replica:2/device:CPU:1"): 1807 g.create_op("FloatOutput", [], [dtypes.float32]) 1808 with g.device(None): 1809 g.create_op("FloatOutput", [], [dtypes.float32]) 1810 g.create_op("FloatOutput", [], [dtypes.float32]) 1811 gd = g.as_graph_def() 1812 self.assertProtoEqualsVersion(""" 1813 node { name: "FloatOutput" op: "FloatOutput" 1814 device: "/job:worker/replica:2/device:CPU:1" } 1815 node { name: "FloatOutput_1" op: "FloatOutput" } 1816 node { name: "FloatOutput_2" op: "FloatOutput" 1817 device: "/job:worker/replica:2/device:CPU:1" } 1818 """, gd) 1819 1820 def testNoneIgnoresOuterDeviceFunction(self): 1821 g = ops.Graph() 1822 with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"): 1823 g.create_op("FloatOutput", [], [dtypes.float32]) 1824 with g.device(None): 1825 g.create_op("FloatOutput", [], [dtypes.float32]) 1826 g.create_op("FloatOutput", [], [dtypes.float32]) 1827 gd = g.as_graph_def() 1828 self.assertProtoEqualsVersion(""" 1829 node { name: "FloatOutput" op: "FloatOutput" 1830 device: "/job:worker/replica:2/device:CPU:1" } 1831 node { name: "FloatOutput_1" op: "FloatOutput" } 1832 node { name: "FloatOutput_2" op: "FloatOutput" 1833 device: "/job:worker/replica:2/device:CPU:1" } 1834 """, gd) 1835 1836 def _overwritingDeviceFunction(self, unused_op): 1837 # This device function unconditionally overwrites the device of ops. 1838 # 1839 # NOTE(mrry): Writing device functions like this is not 1840 # recommended. Instead, in most cases you should use 1841 # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the 1842 # argument to `tf.device()` and the device component will be merged in. 1843 return "/job:overwrite" 1844 1845 def testOverwritingBehavior(self): 1846 g = ops.Graph() 1847 with g.device(self._overwritingDeviceFunction): 1848 g.create_op("FloatOutput", [], [dtypes.float32]) 1849 with g.device("/job:ps"): # Will be overwritten. 1850 g.create_op("FloatOutput", [], [dtypes.float32]) 1851 with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. 1852 g.create_op("FloatOutput", [], [dtypes.float32]) 1853 with g.device(None): # Disables overwriting device function 1854 with g.device("/job:ps"): 1855 g.create_op("FloatOutput", [], [dtypes.float32]) 1856 with g.device(None): # Disables overwriting device function 1857 with g.device(pydev.merge_device("/job:ps")): 1858 g.create_op("FloatOutput", [], [dtypes.float32]) 1859 gd = g.as_graph_def() 1860 self.assertProtoEqualsVersion(""" 1861 node { name: "FloatOutput" op: "FloatOutput" 1862 device: "/job:overwrite" } 1863 node { name: "FloatOutput_1" op: "FloatOutput" 1864 device: "/job:overwrite" } 1865 node { name: "FloatOutput_2" op: "FloatOutput" 1866 device: "/job:overwrite" } 1867 node { name: "FloatOutput_3" op: "FloatOutput" 1868 device: "/job:ps" } 1869 node { name: "FloatOutput_4" op: "FloatOutput" 1870 device: "/job:ps" } 1871 """, gd) 1872 1873 1874class MultithreadedGraphStateTest(test_util.TensorFlowTestCase): 1875 1876 class TestThread(threading.Thread): 1877 1878 def __init__(self, graph, replica_id): 1879 super(MultithreadedGraphStateTest.TestThread, self).__init__() 1880 self._graph = graph 1881 self._replica_id = replica_id 1882 # This thread sets this event when it mutated the graph. The caller can 1883 # wait for that. 1884 self.has_mutated_graph = threading.Event() 1885 # This thread waits for when it should continue. The caller can set this 1886 # event. 1887 self.should_continue = threading.Event() 1888 1889 def run(self): 1890 # Mutate a graph's stack, then set `has_mutated_graph`, then wait for 1891 # `should_continue`, then add an op to the graph affected by the graph's 1892 # stack. 1893 raise NotImplementedError("must be implemented in descendants") 1894 1895 def testDeviceFunctionStack(self): 1896 1897 class DeviceSettingThread(self.TestThread): 1898 1899 def run(self): 1900 with g.device("/job:worker/replica:{}".format(self._replica_id)): 1901 self.has_mutated_graph.set() 1902 self.should_continue.wait() 1903 self.should_continue.clear() 1904 g.create_op( 1905 "FloatOutput", [], [dtypes.float32], 1906 name="FloatOutput_{}".format(self._replica_id)) 1907 1908 g = ops.Graph() 1909 # If `switch_to_thread` isn't called, then device placement of the ops 1910 # below is not deterministic. 1911 g.switch_to_thread_local() 1912 threads = [DeviceSettingThread(g, i) for i in range(3)] 1913 for t in threads: 1914 t.start() 1915 t.has_mutated_graph.wait() 1916 t.has_mutated_graph.clear() 1917 for t in threads: 1918 t.should_continue.set() 1919 t.join() 1920 1921 gd = g.as_graph_def() 1922 self.assertProtoEqualsVersion(""" 1923 node { name: "FloatOutput_0" op: "FloatOutput" 1924 device: "/job:worker/replica:0" } 1925 node { name: "FloatOutput_1" op: "FloatOutput" 1926 device: "/job:worker/replica:1" } 1927 node { name: "FloatOutput_2" op: "FloatOutput" 1928 device: "/job:worker/replica:2" } 1929 """, gd) 1930 1931 def testColocateWith(self): 1932 1933 class ColocatingThread(self.TestThread): 1934 1935 def __init__(self, graph, replica_id, op_to_colocate_with): 1936 super(ColocatingThread, self).__init__(graph, replica_id) 1937 self._op_to_colocate_with = op_to_colocate_with 1938 1939 def run(self): 1940 with g.colocate_with(self._op_to_colocate_with): 1941 self.has_mutated_graph.set() 1942 self.should_continue.wait() 1943 self.should_continue.clear() 1944 g.create_op( 1945 "FloatOutput", [], [dtypes.float32], 1946 name="FloatOutput_{}".format(self._replica_id)) 1947 1948 g = ops.Graph() 1949 ops_to_colocate_with = [] 1950 for i in range(3): 1951 with g.device("/job:worker/replica:{}".format(i)): 1952 ops_to_colocate_with.append( 1953 g.create_op( 1954 "FloatOutput", [], [dtypes.float32], 1955 name="ColocateWithMe_{}".format(i))) 1956 1957 # If `switch_to_thread` isn't called, then `device` and `attr` values for 1958 # the ops below are not deterministic. 1959 g.switch_to_thread_local() 1960 threads = [ 1961 ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3) 1962 ] 1963 for t in threads: 1964 t.start() 1965 t.has_mutated_graph.wait() 1966 t.has_mutated_graph.clear() 1967 for t in threads: 1968 t.should_continue.set() 1969 t.join() 1970 1971 gd = g.as_graph_def() 1972 self.assertProtoEqualsVersion(""" 1973 node { name: "ColocateWithMe_0" op: "FloatOutput" 1974 device: "/job:worker/replica:0" } 1975 node { name: "ColocateWithMe_1" op: "FloatOutput" 1976 device: "/job:worker/replica:1" } 1977 node { name: "ColocateWithMe_2" op: "FloatOutput" 1978 device: "/job:worker/replica:2" } 1979 node { name: "FloatOutput_0" op: "FloatOutput" 1980 device: "/job:worker/replica:0" 1981 attr { key: "_class" 1982 value { list { 1983 s: "loc:@ColocateWithMe_0"}}}} 1984 node { name: "FloatOutput_1" op: "FloatOutput" 1985 device: "/job:worker/replica:1" 1986 attr { key: "_class" 1987 value { list { 1988 s: "loc:@ColocateWithMe_1"}}}} 1989 node { name: "FloatOutput_2" op: "FloatOutput" 1990 device: "/job:worker/replica:2" 1991 attr { key: "_class" 1992 value { list { 1993 s: "loc:@ColocateWithMe_2"}}}} 1994 """, gd) 1995 1996 def testControlDependencies(self): 1997 1998 class DependingThread(self.TestThread): 1999 2000 def __init__(self, graph, replica_id, dependency_op): 2001 super(DependingThread, self).__init__(graph, replica_id) 2002 self._dependency_op = dependency_op 2003 2004 def run(self): 2005 with g.control_dependencies([self._dependency_op]): 2006 self.has_mutated_graph.set() 2007 self.should_continue.wait() 2008 self.should_continue.clear() 2009 g.create_op( 2010 "FloatOutput", [], [dtypes.float32], 2011 name="FloatOutput_{}".format(self._replica_id)) 2012 2013 g = ops.Graph() 2014 dependency_ops = [] 2015 for i in range(3): 2016 dependency_ops.append( 2017 g.create_op( 2018 "FloatOutput", [], [dtypes.float32], 2019 name="ColocateWithMe_{}".format(i))) 2020 2021 # If `switch_to_thread` isn't called, then `input` values for the ops below 2022 # are not deterministic. 2023 g.switch_to_thread_local() 2024 threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)] 2025 for t in threads: 2026 t.start() 2027 t.has_mutated_graph.wait() 2028 t.has_mutated_graph.clear() 2029 for t in threads: 2030 t.should_continue.set() 2031 t.join() 2032 2033 gd = g.as_graph_def() 2034 self.assertProtoEqualsVersion(""" 2035 node { name: "ColocateWithMe_0" op: "FloatOutput" } 2036 node { name: "ColocateWithMe_1" op: "FloatOutput" } 2037 node { name: "ColocateWithMe_2" op: "FloatOutput" } 2038 node { name: "FloatOutput_0" op: "FloatOutput" 2039 input: "^ColocateWithMe_0" } 2040 node { name: "FloatOutput_1" op: "FloatOutput" 2041 input: "^ColocateWithMe_1" } 2042 node { name: "FloatOutput_2" op: "FloatOutput" 2043 input: "^ColocateWithMe_2" } 2044 """, gd) 2045 2046 def testNameStack(self): 2047 2048 class NameSettingThread(self.TestThread): 2049 2050 def run(self): 2051 with g.name_scope("foo"): 2052 op1 = g.create_op("FloatOutput", [], [dtypes.float32]) 2053 self.has_mutated_graph.set() 2054 self.should_continue.wait() 2055 self.should_continue.clear() 2056 op2 = g.create_op("FloatOutput", [], [dtypes.float32]) 2057 self.result = (op1, op2) 2058 2059 g = ops.Graph() 2060 threads = [NameSettingThread(g, i) for i in range(3)] 2061 for t in threads: 2062 t.start() 2063 t.has_mutated_graph.wait() 2064 t.has_mutated_graph.clear() 2065 2066 for t in threads: 2067 t.should_continue.set() 2068 t.join() 2069 2070 suffixes = ["", "_1", "_2"] 2071 for t, s in zip(threads, suffixes): 2072 self.assertEqual("foo" + s + "/FloatOutput", t.result[0].name) 2073 self.assertEqual("foo" + s + "/FloatOutput_1", t.result[1].name) 2074 2075 2076class ObjectWithName(object): 2077 2078 def __init__(self, name): 2079 self._name = name 2080 2081 @property 2082 def name(self): 2083 return self._name 2084 2085 2086class CollectionTest(test_util.TensorFlowTestCase): 2087 2088 def test_get_collections(self): 2089 g = ops.Graph() 2090 self.assertSequenceEqual(g.collections, []) 2091 g.add_to_collection("key", 12) 2092 g.add_to_collection("key", 15) 2093 self.assertSequenceEqual(g.collections, ["key"]) 2094 g.add_to_collection("other", "foo") 2095 self.assertSequenceEqual(sorted(g.collections), ["key", "other"]) 2096 self.assertSequenceEqual( 2097 sorted(g.get_all_collection_keys()), ["key", "other"]) 2098 2099 def test_add_to_collection(self): 2100 g = ops.Graph() 2101 g.add_to_collection("key", 12) 2102 g.add_to_collection("other", "foo") 2103 g.add_to_collection("key", 34) 2104 2105 # Note that only blank1 is returned. 2106 g.add_to_collection("blah", 27) 2107 blank1 = ObjectWithName("prefix/foo") 2108 g.add_to_collection("blah", blank1) 2109 blank2 = ObjectWithName("junk/foo") 2110 g.add_to_collection("blah", blank2) 2111 2112 self.assertEqual([12, 34], g.get_collection("key")) 2113 self.assertEqual([], g.get_collection("nothing")) 2114 self.assertEqual([27, blank1, blank2], g.get_collection("blah")) 2115 self.assertEqual([blank1], g.get_collection("blah", "prefix")) 2116 self.assertEqual([blank1], g.get_collection("blah", ".*x")) 2117 2118 # Make sure that get_collection() returns a first-level 2119 # copy of the collection, while get_collection_ref() returns 2120 # the original list. 2121 other_collection_snapshot = g.get_collection("other") 2122 other_collection_ref = g.get_collection_ref("other") 2123 self.assertEqual(["foo"], other_collection_snapshot) 2124 self.assertEqual(["foo"], other_collection_ref) 2125 g.add_to_collection("other", "bar") 2126 self.assertEqual(["foo"], other_collection_snapshot) 2127 self.assertEqual(["foo", "bar"], other_collection_ref) 2128 self.assertEqual(["foo", "bar"], g.get_collection("other")) 2129 self.assertTrue(other_collection_ref is g.get_collection_ref("other")) 2130 2131 # Verify that getting an empty collection ref returns a modifiable list. 2132 empty_coll_ref = g.get_collection_ref("empty") 2133 self.assertEqual([], empty_coll_ref) 2134 empty_coll = g.get_collection("empty") 2135 self.assertEqual([], empty_coll) 2136 self.assertFalse(empty_coll is empty_coll_ref) 2137 empty_coll_ref2 = g.get_collection_ref("empty") 2138 self.assertTrue(empty_coll_ref2 is empty_coll_ref) 2139 # Add to the collection. 2140 empty_coll_ref.append("something") 2141 self.assertEqual(["something"], empty_coll_ref) 2142 self.assertEqual(["something"], empty_coll_ref2) 2143 self.assertEqual([], empty_coll) 2144 self.assertEqual(["something"], g.get_collection("empty")) 2145 empty_coll_ref3 = g.get_collection_ref("empty") 2146 self.assertTrue(empty_coll_ref3 is empty_coll_ref) 2147 2148 def test_add_to_collections_uniquify(self): 2149 g = ops.Graph() 2150 g.add_to_collections([1, 2, 1], "key") 2151 # Make sure "key" is not added twice 2152 self.assertEqual(["key"], g.get_collection(1)) 2153 2154 def test_add_to_collections_from_list(self): 2155 g = ops.Graph() 2156 g.add_to_collections(["abc", "123"], "key") 2157 self.assertEqual(["key"], g.get_collection("abc")) 2158 self.assertEqual(["key"], g.get_collection("123")) 2159 2160 def test_add_to_collections_from_tuple(self): 2161 g = ops.Graph() 2162 g.add_to_collections(("abc", "123"), "key") 2163 self.assertEqual(["key"], g.get_collection("abc")) 2164 self.assertEqual(["key"], g.get_collection("123")) 2165 2166 def test_add_to_collections_from_generator(self): 2167 g = ops.Graph() 2168 2169 def generator(): 2170 yield "abc" 2171 yield "123" 2172 2173 g.add_to_collections(generator(), "key") 2174 self.assertEqual(["key"], g.get_collection("abc")) 2175 self.assertEqual(["key"], g.get_collection("123")) 2176 2177 def test_add_to_collections_from_set(self): 2178 g = ops.Graph() 2179 g.add_to_collections(set(["abc", "123"]), "key") 2180 self.assertEqual(["key"], g.get_collection("abc")) 2181 self.assertEqual(["key"], g.get_collection("123")) 2182 2183 def test_add_to_collections_from_string(self): 2184 g = ops.Graph() 2185 g.add_to_collections("abc", "key") 2186 self.assertEqual(["key"], g.get_collection("abc")) 2187 2188 def test_default_graph(self): 2189 with ops.Graph().as_default(): 2190 ops.add_to_collection("key", 90) 2191 ops.add_to_collection("key", 100) 2192 # Collections are ordered. 2193 self.assertEqual([90, 100], ops.get_collection("key")) 2194 2195 def test_defun(self): 2196 with context.eager_mode(): 2197 2198 @eager_function.defun 2199 def defun(): 2200 ops.add_to_collection("int", 1) 2201 ops.add_to_collection("tensor", constant_op.constant(2)) 2202 2203 @eager_function.defun 2204 def inner_defun(): 2205 self.assertEqual(ops.get_collection("int"), [1]) 2206 three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0] 2207 ops.add_to_collection("int", 2) 2208 self.assertEqual(ops.get_collection("int"), [1, 2]) 2209 ops.add_to_collection("foo", "bar") 2210 self.assertEqual(ops.get_collection("foo"), ["bar"]) 2211 return three 2212 2213 self.assertEqual(ops.get_collection("int"), [1]) 2214 three = inner_defun() 2215 self.assertEqual(ops.get_collection("int"), [1]) 2216 self.assertEqual(ops.get_collection("foo"), []) 2217 return three 2218 2219 three = defun() 2220 self.assertEqual(three.numpy(), 3) 2221 2222 2223ops.NotDifferentiable("FloatOutput") 2224 2225 2226@ops.RegisterGradient("CopyOp") 2227def _CopyGrad(op, x_grad): # pylint: disable=invalid-name 2228 _ = op 2229 return x_grad 2230 2231 2232@ops.RegisterGradient("copy_override") 2233def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name 2234 _ = op 2235 return x_grad 2236 2237 2238class RegistrationTest(test_util.TensorFlowTestCase): 2239 2240 @test_util.run_deprecated_v1 2241 def testRegisterGradients(self): 2242 x = test_ops.float_output() 2243 y = test_ops.copy_op(x) 2244 fn = ops.get_gradient_function(y.op) 2245 self.assertEqual(_CopyGrad, fn) 2246 2247 def testOverrideGradients(self): 2248 g = ops.Graph() 2249 with g.as_default(): 2250 x = test_ops.float_output() 2251 with g.gradient_override_map({"CopyOp": "copy_override"}): 2252 y = test_ops.copy_op(x) 2253 fn = ops.get_gradient_function(y.op) 2254 self.assertEqual(_CopyOverrideGrad, fn) 2255 2256 def testNonExistentOverride(self): 2257 g = ops.Graph() 2258 with g.as_default(): 2259 x = test_ops.float_output() 2260 with g.gradient_override_map({"CopyOp": "unknown_override"}): 2261 y = test_ops.copy_op(x) 2262 with self.assertRaisesRegex(LookupError, "unknown_override"): 2263 ops.get_gradient_function(y.op) 2264 2265 2266class ComparisonTest(test_util.TensorFlowTestCase): 2267 2268 def testMembershipAllowed(self): 2269 g = ops.Graph() 2270 t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") 2271 t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") 2272 self.assertTrue(isinstance(t1, ops.Tensor)) 2273 self.assertTrue(isinstance(t2, ops.Tensor)) 2274 self.assertTrue(t1 in [t1]) 2275 self.assertTrue(t1 not in [t2]) 2276 2277 2278class ControlDependenciesTest(test_util.TensorFlowTestCase): 2279 2280 @test_util.run_deprecated_v1 2281 def testBasic(self): 2282 g = ops.Graph() 2283 with g.as_default(): 2284 # Creating unregistered ops with _apply_op() doesn't work with the C API 2285 # TODO(skyewm): address this more consistently. Possible solutions are 2286 # to use registered ops in all tests, create a way to register ops in 2287 # Python tests, or conditionally disable the op registration check in 2288 # the C API. 2289 a = constant_op.constant(1.0) 2290 b = constant_op.constant(1.0) 2291 with g.control_dependencies([a]): 2292 c = constant_op.constant(1.0) 2293 d = array_ops.identity(b) 2294 e = array_ops.identity(c) 2295 2296 self.assertEqual(c.op.control_inputs, [a.op]) 2297 self.assertEqual(d.op.control_inputs, [a.op]) 2298 # e should be dominated by c. 2299 self.assertEqual(e.op.control_inputs, []) 2300 2301 @test_util.run_in_graph_and_eager_modes 2302 def testEager(self): 2303 def future(): 2304 future.calls += 1 2305 return constant_op.constant(2.0) 2306 future.calls = 0 2307 2308 if context.executing_eagerly(): 2309 a = constant_op.constant(1.0) 2310 b = future 2311 with ops.control_dependencies([a, b]): 2312 c = constant_op.constant(3.0) 2313 self.assertEqual(future.calls, 1) 2314 else: 2315 g = ops.Graph() 2316 with g.as_default(): 2317 a = constant_op.constant(1.0) 2318 b = future() 2319 with g.control_dependencies([a, b]): 2320 c = constant_op.constant(3.0) 2321 self.assertEqual(c.op.control_inputs, [a.op, b.op]) 2322 self.assertEqual(future.calls, 1) 2323 2324 def testBasicWithConversion(self): 2325 g = ops.Graph() 2326 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2327 2328 class ConvertibleObj(object): 2329 2330 def _as_graph_element(self): 2331 return a 2332 2333 with g.control_dependencies([ConvertibleObj()]): 2334 c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2335 2336 self.assertEqual(c.op.control_inputs, [a.op]) 2337 2338 def testNested(self): 2339 g = ops.Graph() 2340 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2341 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2342 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2343 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2344 2345 with g.control_dependencies([a_1, a_2, a_3, a_4]): 2346 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2347 2348 with g.control_dependencies([a_1]): 2349 with g.control_dependencies([a_2]): 2350 with g.control_dependencies([a_3]): 2351 with g.control_dependencies([a_4]): 2352 b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2353 2354 self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op], 2355 b_1.op.control_inputs) 2356 self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) 2357 2358 def testClear(self): 2359 g = ops.Graph() 2360 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2361 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2362 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2363 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2364 2365 with g.control_dependencies([a_1]): 2366 with g.control_dependencies([a_2]): 2367 with g.control_dependencies(None): 2368 with g.control_dependencies([a_3]): 2369 with g.control_dependencies([a_4]): 2370 # deps [a_3, a_4] 2371 b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2372 # deps = [a_3] 2373 b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2374 # deps back to None 2375 b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2376 # deps back to [a_1, a_2] 2377 b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2378 # deps back to [a_1] 2379 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2380 with g.control_dependencies(None): 2381 # deps are None again 2382 b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2383 2384 self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) 2385 self.assertItemsEqual([a_3.op], b_3.op.control_inputs) 2386 self.assertItemsEqual([], b_none.op.control_inputs) 2387 self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) 2388 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 2389 self.assertItemsEqual([], b_none2.op.control_inputs) 2390 2391 def testComplex(self): 2392 g = ops.Graph() 2393 2394 # Usage pattern: 2395 # * Nodes a_i are constants defined at the outermost scope, and are used 2396 # as control inputs for the ith nested scope. 2397 # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. 2398 # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. 2399 # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. 2400 # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. 2401 2402 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2403 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2404 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2405 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2406 2407 with g.control_dependencies([a_1]): 2408 b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 2409 [dtypes.float32]) 2410 c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 2411 [dtypes.float32]) 2412 d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1], 2413 [dtypes.float32]) 2414 e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2415 with g.control_dependencies([a_2]): 2416 b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 2417 [dtypes.float32]) 2418 c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 2419 [dtypes.float32]) 2420 d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2], 2421 [dtypes.float32]) 2422 e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1], 2423 [dtypes.float32]) 2424 with g.control_dependencies([a_3]): 2425 b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 2426 [dtypes.float32]) 2427 c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 2428 [dtypes.float32]) 2429 d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3], 2430 [dtypes.float32]) 2431 e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2], 2432 [dtypes.float32]) 2433 with g.control_dependencies([a_4]): 2434 b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 2435 [dtypes.float32]) 2436 c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 2437 [dtypes.float32]) 2438 d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4], 2439 [dtypes.float32]) 2440 e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3], 2441 [dtypes.float32]) 2442 2443 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 2444 self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) 2445 self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) 2446 self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) 2447 2448 self.assertItemsEqual([], c_1.op.control_inputs) 2449 self.assertItemsEqual([a_2.op], c_2.op.control_inputs) 2450 self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) 2451 self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) 2452 2453 self.assertItemsEqual([], d_1.op.control_inputs) 2454 self.assertItemsEqual([], d_2.op.control_inputs) 2455 self.assertItemsEqual([], d_3.op.control_inputs) 2456 self.assertItemsEqual([], d_4.op.control_inputs) 2457 2458 self.assertItemsEqual([a_1.op], e_1.op.control_inputs) 2459 self.assertItemsEqual([a_2.op], e_2.op.control_inputs) 2460 self.assertItemsEqual([a_3.op], e_3.op.control_inputs) 2461 self.assertItemsEqual([a_4.op], e_4.op.control_inputs) 2462 2463 def testRepeatedDependency(self): 2464 g = ops.Graph() 2465 a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) 2466 a_0, a_1 = a.outputs 2467 with g.control_dependencies([a_0]): 2468 b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2469 with g.control_dependencies([a_1]): 2470 c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2471 2472 self.assertEqual(b.op.control_inputs, [a]) 2473 self.assertEqual(c.op.control_inputs, [a]) 2474 2475 def testNoControlDependencyWithDataDependency(self): 2476 g = ops.Graph() 2477 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2478 with g.control_dependencies([a]): 2479 b = _apply_op(g, "Identity", [a], [dtypes.float32]) 2480 2481 self.assertEqual(b.op.control_inputs, []) 2482 2483 2484class OpScopeTest(test_util.TensorFlowTestCase): 2485 2486 @test_util.run_in_graph_and_eager_modes 2487 def testNames(self): 2488 with ops.name_scope("foo", skip_on_eager=False) as foo: 2489 self.assertEqual("foo/", foo) 2490 with ops.name_scope("foo2", skip_on_eager=False) as foo2: 2491 self.assertEqual("foo/foo2/", foo2) 2492 with ops.name_scope(None, skip_on_eager=False) as empty1: 2493 self.assertEqual("", empty1) 2494 with ops.name_scope("foo3", skip_on_eager=False) as foo3: 2495 self.assertEqual("foo3/", foo3) 2496 with ops.name_scope("", skip_on_eager=False) as empty2: 2497 self.assertEqual("", empty2) 2498 with ops.name_scope("foo/", skip_on_eager=False) as outer_foo: 2499 self.assertEqual("foo/", outer_foo) 2500 with ops.name_scope("", skip_on_eager=False) as empty3: 2501 self.assertEqual("", empty3) 2502 with ops.name_scope("foo4", skip_on_eager=False) as foo4: 2503 self.assertEqual("foo/foo4/", foo4) 2504 with ops.name_scope("foo5//", skip_on_eager=False) as foo5: 2505 self.assertEqual("foo5//", foo5) 2506 with ops.name_scope("foo6", skip_on_eager=False) as foo6: 2507 self.assertEqual("foo5//foo6/", foo6) 2508 with ops.name_scope("/", skip_on_eager=False) as foo7: 2509 self.assertEqual("/", foo7) 2510 with ops.name_scope("//", skip_on_eager=False) as foo8: 2511 self.assertEqual("//", foo8) 2512 with ops.name_scope("a//b/c", skip_on_eager=False) as foo9: 2513 self.assertEqual("foo/a//b/c/", foo9) 2514 with ops.name_scope("a//b/c", skip_on_eager=False) as foo10: 2515 self.assertEqual("a//b/c/", foo10) 2516 2517 @test_util.run_in_graph_and_eager_modes 2518 def testEagerDefaultScopeName(self): 2519 with ops.name_scope(None, "default", skip_on_eager=False) as scope: 2520 self.assertEqual(scope, "default/") 2521 with ops.name_scope(None, "default2", skip_on_eager=False) as scope2: 2522 self.assertEqual(scope2, "default/default2/") 2523 2524 @test_util.run_in_graph_and_eager_modes 2525 def testNameScopeV2IsReEntrant(self): 2526 foo = ops.name_scope_v2("foo") 2527 bar = ops.name_scope_v2("bar") 2528 with foo as scope_name: 2529 self.assertEqual("foo/", scope_name) 2530 with foo as scope_name: 2531 self.assertEqual("foo/foo/", scope_name) 2532 with bar as scope_name: 2533 self.assertEqual("foo/bar/", scope_name) 2534 with foo as scope_name: 2535 self.assertEqual("foo/bar/foo/", scope_name) 2536 with bar as scope_name: 2537 self.assertEqual("bar/", scope_name) 2538 2539 @test_util.run_deprecated_v1 2540 def testNoScopeName(self): 2541 g0 = ops.Graph() 2542 values = [ 2543 g0.create_op("A", [], [dtypes.float32]), 2544 g0.create_op("B", [], [dtypes.float32]) 2545 ] 2546 with self.assertRaises(ValueError): 2547 with ops.name_scope(None, values=values): 2548 pass 2549 with self.assertRaises(ValueError): 2550 with ops.name_scope(None, None, values): 2551 pass 2552 2553 @test_util.run_deprecated_v1 2554 def testEmptyScopeName(self): 2555 g0 = ops.Graph() 2556 a = g0.create_op("A", [], [dtypes.float32]) 2557 b = g0.create_op("B", [], [dtypes.float32]) 2558 with ops.name_scope("", values=[a, b]) as scope: 2559 self.assertEqual("", scope) 2560 self.assertEqual(g0, ops.get_default_graph()) 2561 with ops.name_scope("", "my_default_scope", [a, b]) as scope: 2562 self.assertEqual("", scope) 2563 self.assertEqual(g0, ops.get_default_graph()) 2564 2565 @test_util.run_deprecated_v1 2566 def testDefaultScopeName(self): 2567 g0 = ops.Graph() 2568 a = g0.create_op("A", [], [dtypes.float32]) 2569 b = g0.create_op("B", [], [dtypes.float32]) 2570 scope_name = "my_scope" 2571 default_scope_name = "my_default_scope" 2572 with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope: 2573 self.assertEqual("%s/" % scope_name, scope) 2574 self.assertEqual(g0, ops.get_default_graph()) 2575 with ops.name_scope(None, default_scope_name, [a, b]) as scope: 2576 self.assertEqual("%s/" % default_scope_name, scope) 2577 self.assertEqual(g0, ops.get_default_graph()) 2578 with self.assertRaises(TypeError): 2579 with ops.name_scope(scope_name, [a, b]): 2580 pass 2581 2582 def _testGraphElements(self, graph_elements): 2583 scope_name = "my_scope" 2584 with ops.name_scope(scope_name, values=graph_elements) as scope: 2585 self.assertEqual("%s/" % scope_name, scope) 2586 self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) 2587 g1 = ops.Graph() 2588 a = g1.create_op("A", [], [dtypes.float32]) 2589 with self.assertRaises(ValueError): 2590 with ops.name_scope(scope_name, values=graph_elements + [a]): 2591 pass 2592 2593 @test_util.run_deprecated_v1 2594 def testTensor(self): 2595 g0 = ops.Graph() 2596 a = g0.create_op("A", [], [dtypes.float32]) 2597 b = g0.create_op("B", [], [dtypes.float32]) 2598 self._testGraphElements([a, b]) 2599 2600 @test_util.run_deprecated_v1 2601 def testSparseTensor(self): 2602 g0 = ops.Graph() 2603 a = g0.create_op("A", [], [dtypes.float32]) 2604 b = g0.create_op("B", [], [dtypes.float32]) 2605 sparse = sparse_tensor.SparseTensor( 2606 _apply_op(g0, "Int64Output", [], [dtypes.int64]), 2607 _apply_op(g0, "FloatOutput", [], [dtypes.float32]), 2608 _apply_op(g0, "Int64Output", [], [dtypes.int64])) 2609 self._testGraphElements([a, sparse, b]) 2610 2611 @test_util.run_deprecated_v1 2612 def testVariable(self): 2613 g0 = ops.Graph() 2614 with g0.as_default(): 2615 variable = variables.Variable([1.0]) 2616 a = g0.create_op("A", [], [dtypes.float32]) 2617 b = g0.create_op("B", [], [dtypes.float32]) 2618 self._testGraphElements([a, variable, b]) 2619 2620 2621class InitScopeTest(test_util.TensorFlowTestCase): 2622 2623 def testClearsControlDependencies(self): 2624 g = ops.Graph() 2625 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2626 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2627 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2628 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2629 2630 with g.as_default(): 2631 with g.control_dependencies([a_1]): 2632 with g.control_dependencies([a_2]): 2633 with ops.init_scope(): 2634 with g.control_dependencies([a_3]): 2635 with g.control_dependencies([a_4]): 2636 # deps [a_3, a_4] 2637 b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2638 # deps = [a_3] 2639 b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2640 # deps back to None 2641 b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2642 # deps back to [a_1, a_2] 2643 b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2644 # deps back to [a_1] 2645 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2646 with ops.init_scope(): 2647 # deps are None again 2648 b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2649 2650 self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) 2651 self.assertItemsEqual([a_3.op], b_3.op.control_inputs) 2652 self.assertItemsEqual([], b_none.op.control_inputs) 2653 self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) 2654 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 2655 self.assertItemsEqual([], b_none2.op.control_inputs) 2656 2657 def testLiftsOpsFromFunctions(self): 2658 g0 = ops.Graph() 2659 g1 = ops.Graph() 2660 g1._building_function = True # pylint: disable=protected-access 2661 g2 = ops.Graph() 2662 g2._building_function = True # pylint: disable=protected-access 2663 2664 with g0.as_default(): 2665 with g1.as_default(): 2666 with g2.as_default(): 2667 with ops.init_scope(): 2668 _ = constant_op.constant(1.0) 2669 2670 self.assertEqual(len(g2.get_operations()), 0) 2671 self.assertEqual(len(g1.get_operations()), 0) 2672 self.assertEqual(len(g0.get_operations()), 1) 2673 2674 def testPreservesDevices(self): 2675 g0 = ops.Graph() 2676 with g0.as_default(), ops.device("CPU:0"): 2677 g1 = ops.Graph() 2678 g1._building_function = True # pylint: disable=protected-access 2679 with g1.as_default(): 2680 with ops.device("GPU:0"): 2681 with ops.init_scope(): 2682 # init_scope should preserve device set under `g1`. 2683 on_gpu = constant_op.constant(1.0) 2684 self.assertEqual(on_gpu.device, "/device:GPU:0") 2685 still_on_gpu = constant_op.constant(1.0) 2686 self.assertEqual(still_on_gpu.device, "/device:GPU:0") 2687 blank = constant_op.constant(1.0) 2688 self.assertEqual(blank.device, "") 2689 with ops.init_scope(): 2690 now_on_cpu = constant_op.constant(1.0) 2691 self.assertEqual(now_on_cpu.device, "/device:CPU:0") 2692 on_cpu = constant_op.constant(1.0) 2693 self.assertEqual(on_cpu.device, "/device:CPU:0") 2694 2695 def testComposes(self): 2696 g0 = ops.Graph() 2697 g1 = ops.Graph() 2698 g1._building_function = True # pylint: disable=protected-access 2699 g2 = ops.Graph() 2700 g2._building_function = True # pylint: disable=protected-access 2701 g3 = ops.Graph() 2702 g3._building_function = False # pylint: disable=protected-access 2703 2704 with g0.as_default(): 2705 with g1.as_default(): 2706 with ops.init_scope(): 2707 # This op should be lifted into g0. 2708 _ = constant_op.constant(1.0) 2709 self.assertIs(g0, ops.get_default_graph()) 2710 self.assertEqual(len(g2.get_operations()), 0) 2711 self.assertEqual(len(g1.get_operations()), 0) 2712 self.assertEqual(len(g0.get_operations()), 1) 2713 with g2.as_default(): 2714 with ops.init_scope(): 2715 # This op should be lifted into g0. 2716 _ = constant_op.constant(1.0) 2717 self.assertIs(g0, ops.get_default_graph()) 2718 with g3.as_default(): 2719 with ops.init_scope(): 2720 # This op should be lifted into g3, because g3 is not building a 2721 # function. 2722 _ = constant_op.constant(1.0) 2723 self.assertIs(g3, ops.get_default_graph()) 2724 2725 self.assertEqual(len(g3.get_operations()), 1) 2726 self.assertEqual(len(g2.get_operations()), 0) 2727 self.assertEqual(len(g1.get_operations()), 0) 2728 self.assertEqual(len(g0.get_operations()), 2) 2729 2730 def testEscapesToEagerContext(self): 2731 g = ops.Graph() 2732 g._building_function = True # pylint: disable=protected-access 2733 with context.eager_mode(): 2734 with context.graph_mode(): 2735 with g.as_default(): 2736 with ops.init_scope(): 2737 # Because g is building a function, init_scope should 2738 # escape out to the eager context. 2739 self.assertTrue(context.executing_eagerly()) 2740 # g should be reinstated as the default graph, and the 2741 # graph context should be re-entered. 2742 self.assertIs(g, ops.get_default_graph()) 2743 self.assertFalse(context.executing_eagerly()) 2744 2745 def testStaysInEagerWhenOnlyEagerContextActive(self): 2746 with context.eager_mode(): 2747 with ops.init_scope(): 2748 self.assertTrue(context.eager_mode()) 2749 self.assertTrue(context.eager_mode()) 2750 2751 def testEscapesDefunWhenInEagerMode(self): 2752 2753 def function_with_variables(): 2754 with ops.init_scope(): 2755 self.v = resource_variable_ops.ResourceVariable(3) 2756 return self.v.assign_add(1) 2757 2758 with context.eager_mode(): 2759 # Each invocation of function_with_variables recreates a variable. 2760 self.assertEqual(4, int(function_with_variables())) 2761 self.assertEqual(4, int(function_with_variables())) 2762 2763 compiled = eager_function.defun(function_with_variables) 2764 # The init_scope in function_with_variables lifts the variable out 2765 # of the graph function constructed by defun; hence, 2766 # compiled now appears to be stateful. 2767 self.assertEqual(4, int(compiled())) 2768 self.assertEqual(5, int(compiled())) 2769 2770 def testEscapesDefunWhenInGraphMode(self): 2771 def function_with_variables(name): 2772 with ops.init_scope(): 2773 _ = variable_scope.get_variable(name, shape=(1,)) 2774 2775 g = ops.Graph() 2776 with g.as_default(): 2777 with self.cached_session(): 2778 # First ensure that graphs that are not building functions are 2779 # not escaped. 2780 function_with_variables("foo") 2781 with self.assertRaisesRegex(ValueError, 2782 r"Variable foo already exists.*"): 2783 # This will fail because reuse is not set to True. 2784 function_with_variables("foo") 2785 2786 compiled = eager_function.defun(function_with_variables) 2787 compiled("bar") 2788 self.assertEqual( 2789 len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) 2790 2791 # The second call to `compiled` should not create variables: the 2792 # init_scope has lifted the variable creation code out of the defun. 2793 compiled("bar") 2794 self.assertEqual( 2795 len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) 2796 2797 def testEscapesNestedDefun(self): 2798 2799 def inner_function(): 2800 with ops.init_scope(): 2801 self.v = resource_variable_ops.ResourceVariable(1) 2802 return self.v.assign_add(2) 2803 2804 def outer_function(inner=None): 2805 with ops.init_scope(): 2806 self.v0 = resource_variable_ops.ResourceVariable(0) 2807 return self.v0.assign_add(1) + inner() 2808 2809 with context.eager_mode(): 2810 # Each invocation of outer_function recreates variables. 2811 self.assertEqual(4, int(outer_function(inner=inner_function))) 2812 self.assertEqual(4, int(outer_function(inner=inner_function))) 2813 2814 compiled_inner = eager_function.defun(inner_function) 2815 compiled_outer = eager_function.defun(outer_function) 2816 # The init_scope lifts variables out of the graph functions 2817 # constructed by defun; hence, compiled_outer should now appear to be 2818 # stateful. 2819 self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) 2820 self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) 2821 2822 @test_util.run_v1_only("b/120545219") 2823 def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self): 2824 with context.graph_mode(): 2825 ops.reset_default_graph() 2826 # This doesn't push anything onto the graph stack, but it does 2827 # set the stack's global graph. 2828 global_graph = ops.get_default_graph() 2829 fn_graph = ops.Graph() 2830 2831 # pylint: disable=protected-access 2832 fn_graph._building_function = True 2833 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2834 with fn_graph.as_default(): 2835 self.assertEqual(len(ops._default_graph_stack.stack), 1) 2836 with ops.init_scope(): 2837 self.assertGreater(len(ops._default_graph_stack.stack), 1) 2838 dummy = constant_op.constant(1.0) 2839 self.assertEqual(len(ops._default_graph_stack.stack), 1) 2840 # Note that the global graph is _not_ on the graph stack. 2841 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2842 # Ensure that `dummy` was added to the global graph. 2843 self.assertEqual(global_graph, dummy.graph) 2844 # pylint: enable=protected-access 2845 2846 def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self): 2847 with context.graph_mode(): 2848 # pylint: disable=protected-access 2849 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2850 with ops.init_scope(): 2851 self.assertGreater(len(ops._default_graph_stack.stack), 0) 2852 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2853 # pylint: enable=protected-access 2854 2855 def testPreservesNameScopeInGraphConstruction(self): 2856 with ops.Graph().as_default(): 2857 function_graph = ops.Graph() 2858 with function_graph.as_default(): 2859 with ops.name_scope("inner", skip_on_eager=False), ops.init_scope(): 2860 self.assertEqual(ops.get_name_scope(), "inner") 2861 self.assertEqual(ops.get_name_scope(), "") 2862 2863 def testEnteringGraphFromEagerIsSticky(self): 2864 with context.eager_mode(): 2865 g = ops.Graph() 2866 with g.as_default(): 2867 with ops.init_scope(): 2868 self.assertFalse(context.executing_eagerly()) 2869 self.assertEqual(g, ops.get_default_graph()) 2870 2871 def testMixGraphEager(self): 2872 with context.eager_mode(): 2873 c = constant_op.constant(1.0) 2874 with ops.Graph().as_default(): 2875 with self.assertRaisesRegex(RuntimeError, 2876 "Attempting to capture an EagerTensor"): 2877 math_ops.add(c, c) 2878 c2 = constant_op.constant(2.0) 2879 with self.assertRaisesRegex(TypeError, "Graph tensors"): 2880 math_ops.add(c2, c2) 2881 2882 def testPreservesNameScopeInEagerExecution(self): 2883 with context.eager_mode(): 2884 def foo(): 2885 with ops.name_scope("inner", skip_on_eager=False), ops.init_scope(): 2886 if context.executing_eagerly(): 2887 # A trailing slash is always appended when eager execution is 2888 # enabled. 2889 self.assertEqual(context.context().scope_name, "inner/") 2890 else: 2891 self.assertEqual(ops.get_name_scope(), "inner") 2892 2893 foo() 2894 self.assertEqual(ops.get_name_scope(), "") 2895 foo_compiled = eager_function.defun(foo) 2896 foo_compiled() 2897 self.assertEqual(ops.get_name_scope(), "") 2898 2899 def testExecutingEagerlyOutsideFunctions(self): 2900 2901 @def_function.function 2902 def f(): 2903 return ops.executing_eagerly_outside_functions() 2904 2905 with context.graph_mode(): 2906 self.assertFalse(ops.executing_eagerly_outside_functions()) 2907 with session.Session(): 2908 # Need self.evaluate for these as the return type of functions is 2909 # tensors. 2910 self.assertFalse(self.evaluate(f())) 2911 2912 with context.eager_mode(): 2913 self.assertTrue(ops.executing_eagerly_outside_functions()) 2914 self.assertTrue(f()) 2915 2916 with ops.Graph().as_default(): 2917 self.assertFalse(ops.executing_eagerly_outside_functions()) 2918 with session.Session(): 2919 self.assertFalse(self.evaluate(f())) 2920 2921 2922class GraphTest(test_util.TensorFlowTestCase): 2923 2924 def setUp(self): 2925 ops.reset_default_graph() 2926 2927 def _AssertDefault(self, expected): 2928 self.assertIs(expected, ops.get_default_graph()) 2929 2930 def testResetDefaultGraphNesting(self): 2931 g0 = ops.Graph() 2932 with self.assertRaises(AssertionError): 2933 with g0.as_default(): 2934 ops.reset_default_graph() 2935 2936 def testGraphContextManagerCancelsEager(self): 2937 with context.eager_mode(): 2938 with ops.Graph().as_default(): 2939 self.assertFalse(context.executing_eagerly()) 2940 2941 def testGraphContextManager(self): 2942 g0 = ops.Graph() 2943 with g0.as_default() as g1: 2944 self.assertIs(g0, g1) 2945 2946 def testDefaultGraph(self): 2947 orig = ops.get_default_graph() 2948 self.assertFalse(ops.has_default_graph()) 2949 self._AssertDefault(orig) 2950 g0 = ops.Graph() 2951 self.assertFalse(ops.has_default_graph()) 2952 self._AssertDefault(orig) 2953 context_manager_0 = g0.as_default() 2954 self.assertFalse(ops.has_default_graph()) 2955 self._AssertDefault(orig) 2956 with context_manager_0 as g0: 2957 self._AssertDefault(g0) 2958 with ops.Graph().as_default() as g1: 2959 self.assertTrue(ops.has_default_graph()) 2960 self._AssertDefault(g1) 2961 self._AssertDefault(g0) 2962 self._AssertDefault(orig) 2963 self.assertFalse(ops.has_default_graph()) 2964 2965 def testPreventFeeding(self): 2966 g = ops.Graph() 2967 a = constant_op.constant(2.0) 2968 self.assertTrue(g.is_feedable(a)) 2969 g.prevent_feeding(a) 2970 self.assertFalse(g.is_feedable(a)) 2971 2972 @test_util.run_deprecated_v1 2973 def testPreventFetching(self): 2974 g = ops.Graph() 2975 a = constant_op.constant(2.0) 2976 self.assertTrue(g.is_fetchable(a)) 2977 g.prevent_fetching(a.op) 2978 self.assertFalse(g.is_fetchable(a)) 2979 2980 def testAsGraphElementConversions(self): 2981 2982 class ConvertibleObj(object): 2983 2984 def _as_graph_element(self): 2985 return "FloatOutput:0" 2986 2987 class NonConvertibleObj(object): 2988 2989 pass 2990 2991 g = ops.Graph() 2992 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2993 self.assertEqual(a, g.as_graph_element(ConvertibleObj())) 2994 with self.assertRaises(TypeError): 2995 g.as_graph_element(NonConvertibleObj()) 2996 2997 # Regression test against creating custom __del__ functions in classes 2998 # involved in cyclic references, e.g. Graph and Operation. (Python won't gc 2999 # cycles that require calling a __del__ method, because the __del__ method can 3000 # theoretically increase the object's refcount to "save" it from gc, and any 3001 # already-deleted objects in the cycle would have be to restored.) 3002 def testGarbageCollected(self): 3003 # Create a graph we can delete and a weak reference to monitor if it's gc'd 3004 g = ops.Graph() 3005 g_ref = weakref.ref(g) 3006 # Create some ops 3007 with g.as_default(): 3008 a = constant_op.constant(2.0) 3009 b = constant_op.constant(3.0) 3010 c = math_ops.add(a, b) 3011 # Create a session we can delete 3012 with session.Session(graph=g) as sess: 3013 self.evaluate(c) 3014 # Delete all references and trigger gc 3015 del g 3016 del a 3017 del b 3018 del c 3019 del sess 3020 gc.collect() 3021 self.assertIsNone(g_ref()) 3022 3023 def testRunnableAfterInvalidShape(self): 3024 with ops.Graph().as_default(): 3025 with self.assertRaises(ValueError): 3026 math_ops.add([1, 2], [1, 2, 3]) 3027 a = constant_op.constant(1) 3028 with session.Session() as sess: 3029 self.evaluate(a) 3030 3031 def testRunnableAfterInvalidShapeWithKernelLabelMap(self): 3032 g = ops.Graph() 3033 with g.as_default(): 3034 with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): 3035 with self.assertRaises(ValueError): 3036 test_ops.kernel_label_required(1) 3037 a = constant_op.constant(1) 3038 with session.Session() as sess: 3039 self.evaluate(a) 3040 3041 3042class AttrScopeTest(test_util.TensorFlowTestCase): 3043 3044 def _get_test_attrs(self): 3045 x = control_flow_ops.no_op() 3046 try: 3047 a = compat.as_text(x.get_attr("_A")) 3048 except ValueError: 3049 a = None 3050 try: 3051 b = compat.as_text(x.get_attr("_B")) 3052 except ValueError: 3053 b = None 3054 return (a, b) 3055 3056 @test_util.run_deprecated_v1 3057 def testNoLabel(self): 3058 with self.cached_session(): 3059 self.assertAllEqual((None, None), self._get_test_attrs()) 3060 3061 @test_util.run_deprecated_v1 3062 def testLabelMap(self): 3063 with self.cached_session() as sess: 3064 a1 = self._get_test_attrs() 3065 with sess.graph._attr_scope({ 3066 "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) 3067 }): 3068 a2 = self._get_test_attrs() 3069 with sess.graph._attr_scope({ 3070 "_A": None, 3071 "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar")) 3072 }): 3073 a3 = self._get_test_attrs() 3074 with sess.graph._attr_scope({ 3075 "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz")) 3076 }): 3077 a4 = self._get_test_attrs() 3078 a5 = self._get_test_attrs() 3079 a6 = self._get_test_attrs() 3080 a7 = self._get_test_attrs() 3081 3082 self.assertAllEqual((None, None), a1) 3083 self.assertAllEqual(("foo", None), a2) 3084 self.assertAllEqual((None, "bar"), a3) 3085 self.assertAllEqual(("baz", "bar"), a4) 3086 self.assertAllEqual((None, "bar"), a5) 3087 self.assertAllEqual(("foo", None), a6) 3088 self.assertAllEqual((None, None), a7) 3089 3090 3091class KernelLabelTest(test_util.TensorFlowTestCase): 3092 3093 @test_util.run_deprecated_v1 3094 def testNoLabel(self): 3095 with self.cached_session(): 3096 self.assertAllEqual(b"My label is: default", 3097 test_ops.kernel_label().eval()) 3098 3099 @test_util.run_deprecated_v1 3100 def testLabelMap(self): 3101 with self.cached_session() as sess: 3102 default_1 = test_ops.kernel_label() 3103 # pylint: disable=protected-access 3104 with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): 3105 overload_1_1 = test_ops.kernel_label() 3106 with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): 3107 overload_2 = test_ops.kernel_label() 3108 with sess.graph._kernel_label_map({"KernelLabel": ""}): 3109 default_2 = test_ops.kernel_label() 3110 overload_1_2 = test_ops.kernel_label() 3111 # pylint: enable=protected-access 3112 default_3 = test_ops.kernel_label() 3113 3114 self.assertAllEqual(b"My label is: default", self.evaluate(default_1)) 3115 self.assertAllEqual(b"My label is: default", self.evaluate(default_2)) 3116 self.assertAllEqual(b"My label is: default", self.evaluate(default_3)) 3117 self.assertAllEqual(b"My label is: overload_1", 3118 self.evaluate(overload_1_1)) 3119 self.assertAllEqual(b"My label is: overload_1", 3120 self.evaluate(overload_1_2)) 3121 self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2)) 3122 3123 3124class AsGraphDefTest(test_util.TensorFlowTestCase): 3125 3126 def testGraphDefVersion(self): 3127 """Test that the graphdef version is plumbed through to kernels.""" 3128 with ops.Graph().as_default() as g: 3129 version = g.graph_def_versions.producer 3130 with self.session(graph=g): 3131 v = test_ops.graph_def_version().eval() 3132 self.assertEqual(version, v) 3133 3134 def testAddShapes(self): 3135 with ops.Graph().as_default() as g: 3136 t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [], 3137 [dtypes.float32] * 5) 3138 t1.set_shape(None) 3139 t2.set_shape([]) 3140 t3.set_shape([None]) 3141 t4.set_shape([43, 37]) 3142 t5.set_shape([43, None]) 3143 3144 b = constant_op.constant(1.0) # pylint: disable=unused-variable 3145 3146 gd = g.as_graph_def(add_shapes=True) 3147 self.assertProtoEqualsVersion(""" 3148 node { name: "FiveFloatOutputs" op: "FiveFloatOutputs" 3149 attr { 3150 key: "_output_shapes" 3151 value { 3152 list { 3153 shape { unknown_rank: true } 3154 shape { } 3155 shape { dim { size: -1 } } 3156 shape { dim { size: 43 } dim { size: 37 } } 3157 shape { dim { size: 43 } dim { size: -1 } } 3158 } 3159 } 3160 } 3161 } 3162 node { name: "Const" op: "Const" 3163 attr { 3164 key: "_output_shapes" 3165 value { 3166 list { 3167 shape { } 3168 } 3169 } 3170 } 3171 attr { 3172 key: "dtype" 3173 value { type: DT_FLOAT } 3174 } 3175 attr { 3176 key: "value" 3177 value { 3178 tensor { 3179 dtype: DT_FLOAT 3180 tensor_shape { } 3181 float_val: 1.0 } } } } 3182 """, gd) 3183 3184 3185@ops.RegisterStatistics("a", "flops") 3186def _calc_a_forward_flops(unused_graph, unused_node): 3187 return ops.OpStats("flops", 20) 3188 3189 3190class StatisticsTest(test_util.TensorFlowTestCase): 3191 3192 def testRegisteredNode(self): 3193 graph = ops.Graph() 3194 node = ops._NodeDef("a", "an_a") 3195 flops = ops.get_stats_for_node_def(graph, node, "flops") 3196 self.assertEqual(20, flops.value) 3197 missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat") 3198 self.assertEqual(None, missing_stat.value) 3199 3200 def testUnregisteredNode(self): 3201 graph = ops.Graph() 3202 node = ops._NodeDef("b", "a_b") 3203 weight_params = ops.get_stats_for_node_def(graph, node, "weight_params") 3204 self.assertEqual(None, weight_params.value) 3205 3206 def testAccumulateStatistics(self): 3207 flops_total = ops.OpStats("flops") 3208 self.assertEqual(None, flops_total.value) 3209 second_flops = ops.OpStats("flops", 3) 3210 flops_total += second_flops 3211 self.assertEqual(3, flops_total.value) 3212 3213 3214class DeviceStackTest(test_util.TensorFlowTestCase): 3215 3216 @test_util.run_deprecated_v1 3217 def testBasicDeviceAssignmentMetadata(self): 3218 3219 def device_func(unused_op): 3220 return "/cpu:*" 3221 3222 const_zero = constant_op.constant([0.0], name="zero") 3223 with ops.device("/cpu"): 3224 const_one = constant_op.constant([1.0], name="one") 3225 with ops.device("/cpu:0"): 3226 const_two = constant_op.constant([2.0], name="two") 3227 with ops.device(device_func): 3228 const_three = constant_op.constant(3.0, name="three") 3229 3230 self.assertEqual(0, len(const_zero.op._device_assignments)) 3231 3232 one_list = const_one.op._device_assignments 3233 self.assertEqual(1, len(one_list)) 3234 self.assertEqual("/cpu", one_list[0].obj) 3235 self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) 3236 3237 two_list = const_two.op._device_assignments 3238 self.assertEqual(2, len(two_list)) 3239 devices = [t.obj for t in two_list] 3240 self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) 3241 3242 three_list = const_three.op._device_assignments 3243 self.assertEqual(1, len(three_list)) 3244 func_description = three_list[0].obj 3245 expected_regex = r"device_func<.*ops_test.py, [0-9]+" 3246 self.assertRegex(func_description, expected_regex) 3247 3248 @test_util.run_deprecated_v1 3249 def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): 3250 3251 with ops.device("/cpu"): 3252 const_one = constant_op.constant([1.0], name="one") 3253 with ops.get_default_graph().device("/cpu"): 3254 const_two = constant_op.constant([2.0], name="two") 3255 3256 one_metadata = const_one.op._device_assignments[0] 3257 two_metadata = const_two.op._device_assignments[0] 3258 3259 # Verify both types of device assignment return the right stack info. 3260 self.assertRegex("ops_test.py", os.path.basename(one_metadata.filename)) 3261 self.assertEqual(one_metadata.filename, two_metadata.filename) 3262 self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) 3263 3264 3265class ColocationGroupTest(test_util.TensorFlowTestCase): 3266 3267 @test_util.run_deprecated_v1 3268 def testBasic(self): 3269 a = constant_op.constant([2.0], name="a") 3270 with ops.colocate_with(a.op): 3271 b = constant_op.constant(3.0) 3272 c = constant_op.constant(4.0) 3273 self.assertEqual([b"loc:@a"], a.op.colocation_groups()) 3274 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3275 with self.assertRaises(ValueError): 3276 c.op.get_attr("_class") 3277 3278 @test_util.run_deprecated_v1 3279 def testBasicColocationMetadata(self): 3280 const_two = constant_op.constant([2.0], name="two") 3281 with ops.colocate_with(const_two.op): 3282 const_three = constant_op.constant(3.0, name="three") 3283 locations_dict = const_three.op._colocation_dict 3284 self.assertIn("two", locations_dict) 3285 metadata = locations_dict["two"] 3286 self.assertIsNone(metadata.obj) 3287 # Check that this test's filename is recorded as the file containing the 3288 # colocation statement. 3289 self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) 3290 3291 @test_util.run_deprecated_v1 3292 def testColocationDeviceInteraction(self): 3293 with ops.device("/cpu:0"): 3294 with ops.device("/device:GPU:0"): 3295 a = constant_op.constant([2.0], name="a") 3296 with ops.colocate_with(a.op): 3297 # 'b' is created in the scope of /cpu:0, but it is 3298 # colocated with 'a', which is on '/device:GPU:0'. colocate_with 3299 # overrides devices because it is a stronger constraint. 3300 b = constant_op.constant(3.0) 3301 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3302 self.assertEqual(a.op.device, b.op.device) 3303 3304 @test_util.run_deprecated_v1 3305 def testColocationCanonicalization(self): 3306 with ops.device("/device:GPU:0"): 3307 _ = constant_op.constant(2.0) 3308 with ops.device(lambda op: "/device:GPU:0"): 3309 b = constant_op.constant(3.0) 3310 with ops.get_default_graph().colocate_with(b): 3311 with ops.device("/device:GPU:0"): 3312 c = constant_op.constant(4.0) 3313 3314 # A's device will be /device:GPU:0 3315 # B's device will be /device:GPU:0 3316 # C's device will be /device:GPU:0 because it 3317 # inherits B's device name, after canonicalizing the names. 3318 self.assertEqual(b.op.device, c.op.device) 3319 3320 @test_util.run_deprecated_v1 3321 def testLocationOverrides(self): 3322 with ops.device("/cpu:0"): 3323 with ops.device("/device:GPU:0"): 3324 a = constant_op.constant([2.0], name="a") 3325 # Note that this colocation is "redundant", since we are 3326 # within the scope of "/device:GPU:0". However, we would like to 3327 # preserve in the GraphDef that these two ops should be 3328 # colocated in a portable way. 3329 with ops.colocate_with(a.op): 3330 b = constant_op.constant(3.0) 3331 c = constant_op.constant(4.0) 3332 d = constant_op.constant(5.0) 3333 3334 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3335 self.assertEqual("/device:GPU:0", a.op.device) 3336 self.assertEqual(a.op.device, b.op.device) 3337 3338 # Test that device function stack is restored. 3339 self.assertEqual("/device:GPU:0", c.op.device) 3340 self.assertEqual("/device:CPU:0", d.op.device) 3341 3342 @test_util.run_deprecated_v1 3343 def testNestedColocateWith(self): 3344 a = constant_op.constant([2.0], name="a") 3345 with ops.colocate_with(a.op): 3346 b = constant_op.constant(3.0) 3347 with ops.colocate_with(b.op): 3348 c = constant_op.constant(4.0) 3349 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3350 self.assertEqual([b"loc:@a"], c.op.colocation_groups()) 3351 3352 @test_util.run_deprecated_v1 3353 def testMultiColocationGroups(self): 3354 a = constant_op.constant([2.0], name="a") 3355 b = constant_op.constant(3.0, name="b") 3356 with ops.colocate_with(a.op): 3357 with ops.colocate_with(b.op): 3358 c = constant_op.constant(4.0) 3359 self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups())) 3360 3361 @test_util.run_deprecated_v1 3362 def testColocationIgnoreStack(self): 3363 a = constant_op.constant([2.0], name="a") 3364 b = constant_op.constant(3.0, name="b") 3365 with ops.colocate_with(a.op): 3366 with ops.colocate_with(b.op, ignore_existing=True): 3367 c = constant_op.constant(4.0) 3368 self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups())) 3369 3370 @test_util.run_deprecated_v1 3371 def testColocateWithReset(self): 3372 a = constant_op.constant([2.0], name="a") 3373 with ops.colocate_with(a.op): 3374 b = constant_op.constant(3.0, name="b") 3375 with ops.colocate_with(None, ignore_existing=True): 3376 c = constant_op.constant(4.0, name="c") 3377 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3378 self.assertEqual([b"loc:@c"], c.op.colocation_groups()) 3379 3380 @test_util.run_deprecated_v1 3381 def testColocateWithInitialNoneThenNested(self): 3382 a = constant_op.constant([2.0], name="a") 3383 with ops.colocate_with(a.op): 3384 with ops.colocate_with(None, ignore_existing=True): 3385 b = constant_op.constant(3.0, name="b") 3386 with ops.colocate_with(b.op): 3387 c = constant_op.constant(4.0, name="c") 3388 self.assertEqual([b"loc:@b"], b.op.colocation_groups()) 3389 self.assertEqual([b"loc:@b"], c.op.colocation_groups()) 3390 3391 @test_util.run_deprecated_v1 3392 def testColocateVariables(self): 3393 a = variables.Variable([2.0], name="a") 3394 with ops.colocate_with(a.op): 3395 b = variables.Variable([3.0], name="b") 3396 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 3397 3398 @test_util.run_deprecated_v1 3399 def testColocateResourceVariablesInFunction(self): 3400 with ops.device("/device:CPU:0"): 3401 a = resource_variable_ops.ResourceVariable(1.0) 3402 3403 @def_function.function 3404 def f(): 3405 with ops.colocate_with(a): 3406 b = array_ops.ones([], name="output") 3407 self.assertEqual("/device:CPU:0", b.op.device) 3408 f() 3409 3410 def testColocateWithVariableInFunction(self): 3411 v = variables.Variable(1.) 3412 3413 @def_function.function 3414 def f(): 3415 with ops.colocate_with(v): 3416 return array_ops.ones([], name="output") 3417 3418 f() 3419 graph_def = f.get_concrete_function().graph.as_graph_def() 3420 wrap_function.function_from_graph_def(graph_def, [], ["output"]) 3421 3422 3423class DeprecatedTest(test_util.TensorFlowTestCase): 3424 3425 def testSuccess(self): 3426 with ops.Graph().as_default() as g: 3427 test_util.set_producer_version(g, 7) 3428 old = test_ops.old() 3429 with self.session(graph=g): 3430 old.run() 3431 3432 def _error(self): 3433 return ((r"Op Old is not available in GraphDef version %d\. " 3434 r"It has been removed in version 8\. For reasons\.") % 3435 versions.GRAPH_DEF_VERSION) 3436 3437 def testGraphConstructionFail(self): 3438 with ops.Graph().as_default(): 3439 with self.assertRaisesRegex(NotImplementedError, self._error()): 3440 test_ops.old() 3441 3442 3443class NameScopeTest(test_util.TensorFlowTestCase): 3444 3445 def testStripAndPrependScope(self): 3446 strs = [ 3447 "hidden1/hidden1/weights", # Same prefix. Should strip. 3448 "hidden1///hidden1/weights", # Extra "/". Should strip. 3449 "^hidden1/hidden1/weights", # Same prefix. Should strip. 3450 "loc:@hidden1/hidden1/weights", # Same prefix. Should strip. 3451 "hhidden1/hidden1/weights", # Different prefix. Should keep. 3452 "hidden1" 3453 ] # Not a prefix. Should keep. 3454 expected_striped = [ 3455 "hidden1/weights", "hidden1/weights", "^hidden1/weights", 3456 "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1" 3457 ] 3458 expected_prepended = [ 3459 "hidden2/hidden1/weights", "hidden2/hidden1/weights", 3460 "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights", 3461 "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1" 3462 ] 3463 name_scope_to_strip = "hidden1" 3464 name_scope_to_add = "hidden2" 3465 for es, ep, s in zip(expected_striped, expected_prepended, strs): 3466 striped = ops.strip_name_scope(s, name_scope_to_strip) 3467 self.assertEqual(es, striped) 3468 self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add)) 3469 3470 def testGetNameScope(self): 3471 with ops.Graph().as_default() as g: 3472 with ops.name_scope("scope1"): 3473 with ops.name_scope("scope2"): 3474 with ops.name_scope("scope3"): 3475 self.assertEqual("scope1/scope2/scope3", g.get_name_scope()) 3476 self.assertEqual("scope1/scope2", g.get_name_scope()) 3477 self.assertEqual("scope1", g.get_name_scope()) 3478 self.assertEqual("", g.get_name_scope()) 3479 3480 def testTwoGraphs(self): 3481 3482 def f(): 3483 g1 = ops.Graph() 3484 g2 = ops.Graph() 3485 with g1.as_default(): 3486 with g2.as_default(): 3487 with ops.name_scope("_"): 3488 pass 3489 3490 self.assertRaisesRegex(ValueError, "'_' is not a valid scope name", f) 3491 3492 3493class EnableEagerExecutionTest(test_util.TensorFlowTestCase): 3494 3495 @test_util.run_v1_only("b/120545219") 3496 def testBadArgumentsToEnableEagerExecution(self): 3497 with self.assertRaisesRegex(TypeError, "config must be a tf.ConfigProto"): 3498 ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT) 3499 with self.assertRaisesRegex(ValueError, "device_policy must be one of"): 3500 c = config_pb2.ConfigProto() 3501 ops.enable_eager_execution(c, c) 3502 with self.assertRaisesRegex(ValueError, "execution_mode must be one of"): 3503 c = config_pb2.ConfigProto() 3504 ops.enable_eager_execution(c, execution_mode=c) 3505 3506 3507class _TupleTensor(composite_tensor.CompositeTensor): 3508 """`Tensor`-like `tuple`-like for custom `Tensor` conversion masquerading.""" 3509 3510 def __init__(self, components): 3511 super(_TupleTensor, self).__init__() 3512 self._components = tuple(ops.convert_to_tensor(c) for c in components) 3513 3514 @property 3515 def _type_spec(self): 3516 return _TupleTensorSpec(type_spec.from_value(c) for c in self._components) 3517 3518 def __getitem__(self, key): 3519 return self._components[key] 3520 3521 def __len__(self): 3522 return len(self._components) 3523 3524 def __iter__(self): 3525 return iter(self._components) 3526 3527 3528class _TupleTensorSpec(type_spec.TypeSpec): 3529 3530 def __init__(self, specs): 3531 self._specs = specs 3532 3533 value_type = property(lambda self: _TupleTensor) 3534 _component_specs = property(lambda self: self._specs) 3535 3536 def _to_components(self, value): 3537 return value._components 3538 3539 def _from_components(self, components): 3540 return _TupleTensor(*components) 3541 3542 def _serialize(self): 3543 return (self._specs,) 3544 3545 3546class _MyTuple(object): 3547 """Pretend user-side class for `ConvertToCompositeTensorTest .""" 3548 3549 def __init__(self, components): 3550 super(_MyTuple, self).__init__() 3551 self._components = tuple(components) 3552 3553 def __getitem__(self, key): 3554 return self._components[key] 3555 3556 def __len__(self): 3557 return len(self._components) 3558 3559 def __iter__(self): 3560 return iter(self._components) 3561 3562 3563ops.register_tensor_conversion_function( 3564 _MyTuple, conversion_func=lambda x, *_, **__: _TupleTensor(x)) 3565 3566 3567class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase): 3568 3569 @test_util.disable_tfrt("TODO(kkb): This makes Kokoro tests fail.") 3570 def testCompositeTensorConversion(self): 3571 """Tests that a user can register a CompositeTensor converter.""" 3572 x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]])) 3573 y = ops.convert_to_tensor_or_composite(x) 3574 self.assertFalse(tensor_util.is_tf_type(y)) 3575 self.assertIsInstance(y, _TupleTensor) 3576 self.assertLen(y, len(x)) 3577 for x_, y_ in zip(x, y): 3578 self.assertIsInstance(y_, ops.Tensor) 3579 self.assertTrue(tensor_util.is_tf_type(y_)) 3580 self.assertAllEqual(x_, tensor_util.constant_value(y_)) 3581 3582 3583@test_util.disable_tfrt("Packing EagerTensors is not supported yet.") 3584class PackEagerTensorTest(test_util.TensorFlowTestCase): 3585 3586 def setUp(self): 3587 super(PackEagerTensorTest, self).setUp() 3588 context._reset_context() 3589 cpus = config.list_physical_devices("CPU") 3590 # Set 2 virtual CPUs 3591 config.set_logical_device_configuration(cpus[0], [ 3592 context.LogicalDeviceConfiguration(), 3593 context.LogicalDeviceConfiguration(), 3594 ]) 3595 3596 def testPack(self): 3597 with context.eager_mode(): 3598 with ops.device("CPU:0"): 3599 var0 = resource_variable_ops.ResourceVariable(1.0) 3600 c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 3601 with ops.device("CPU:1"): 3602 var1 = resource_variable_ops.ResourceVariable(2.0) 3603 var2 = resource_variable_ops.ResourceVariable([3.0]) 3604 c1 = constant_op.constant([9.0]) 3605 3606 packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle]) 3607 self.assertTrue(packed_var0.is_packed) 3608 self.assertEqual(packed_var0.dtype, var0.handle.dtype) 3609 self.assertEqual(packed_var0.shape, var0.handle.shape) 3610 self.assertEqual(packed_var0._handle_data, var0.handle._handle_data) 3611 self.assertIn("COMPOSITE:0", packed_var0.device) 3612 self.assertIn("COMPOSITE:0", packed_var0.backing_device) 3613 with self.assertRaises(errors.InvalidArgumentError): 3614 packed_var0.numpy() 3615 3616 # Different dtypes 3617 with self.assertRaises(ValueError): 3618 ops.pack_eager_tensors([var0.handle, c1]) 3619 3620 # Different shapes 3621 with self.assertRaises(ValueError): 3622 ops.pack_eager_tensors([c0, c1]) 3623 3624 # Different handle data 3625 with self.assertRaises(ValueError): 3626 ops.pack_eager_tensors([var0.handle, var2.handle]) 3627 3628 3629if __name__ == "__main__": 3630 googletest.main() 3631