1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.framework.ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gc 22import os 23import threading 24import weakref 25 26from tensorflow.core.framework import attr_value_pb2 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.python.client import session 29from tensorflow.python.eager import context 30from tensorflow.python.eager import function as eager_function 31from tensorflow.python.framework import common_shapes 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import device as pydev 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import function 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import sparse_tensor 39from tensorflow.python.framework import tensor_shape 40from tensorflow.python.framework import tensor_util 41from tensorflow.python.framework import test_ops 42from tensorflow.python.framework import test_util 43from tensorflow.python.framework import versions 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import control_flow_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import resource_variable_ops 48from tensorflow.python.ops import resources 49from tensorflow.python.ops import variable_scope 50from tensorflow.python.ops import variables 51import tensorflow.python.ops.gradients # pylint: disable=unused-import 52from tensorflow.python.platform import googletest 53from tensorflow.python.util import compat 54 55ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn) 56 57 58class ResourceTest(test_util.TensorFlowTestCase): 59 60 @test_util.run_deprecated_v1 61 def testBuildGraph(self): 62 with self.cached_session(): 63 pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") 64 test_ops.resource_create_op(pt).run() 65 66 @test_util.run_deprecated_v1 67 def testInitialize(self): 68 with self.cached_session(): 69 handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") 70 resources.register_resource( 71 handle=handle, 72 create_op=test_ops.resource_create_op(handle), 73 is_initialized_op=test_ops.resource_initialized_op(handle)) 74 self.assertEquals( 75 len( 76 resources.report_uninitialized_resources( 77 resources.shared_resources()).eval()), 1) 78 resources.initialize_resources(resources.shared_resources()).run() 79 self.assertEquals( 80 len( 81 resources.report_uninitialized_resources( 82 resources.shared_resources()).eval()), 0) 83 84 85class TensorAndShapeTest(test_util.TensorFlowTestCase): 86 87 def testShape(self): 88 op = ops.Operation( 89 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 90 t = op.outputs[0] 91 self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) 92 t.set_shape([1, 2, 3]) 93 self.assertEqual([1, 2, 3], t.get_shape()) 94 95 def testIterable(self): 96 op = ops.Operation( 97 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 98 t = op.outputs[0] 99 self.assertTrue(isinstance(t, ops.Tensor)) 100 with self.assertRaisesRegexp(TypeError, "iter"): 101 for _ in t: 102 pass 103 104 def testAddShape(self): 105 with self.cached_session(): 106 a = array_ops.zeros([2, 3]) 107 b = array_ops.ones([1, 3]) 108 c = a + b 109 self.assertEqual([2, 3], c.shape) 110 111 @test_util.run_deprecated_v1 112 def testUnknownDim(self): 113 with self.cached_session(): 114 a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) 115 b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) 116 c = a + b 117 self.assertEqual([2, None, 3], c.shape.as_list()) 118 119 @test_util.run_deprecated_v1 120 def testUnknownShape(self): 121 with self.cached_session(): 122 a = array_ops.placeholder(dtype=dtypes.float32, shape=None) 123 b = array_ops.ones([1, 3]) 124 c = a + b 125 self.assertEqual(tensor_shape.unknown_shape(), c.shape) 126 127 @test_util.run_deprecated_v1 128 def testScalarShape(self): 129 with self.cached_session(): 130 a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 131 b = array_ops.ones([]) 132 c = a + b 133 self.assertEqual(tensor_shape.scalar(), c.shape) 134 135 @test_util.run_deprecated_v1 136 def testShapeFunctionError(self): 137 with self.cached_session(): 138 a = array_ops.ones([1, 2, 3]) 139 b = array_ops.ones([4, 5, 6]) 140 with self.assertRaisesRegexp( 141 ValueError, 142 r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) " 143 r"with input shapes: \[1,2,3\], \[4,5,6\]."): 144 _ = a + b 145 146 147class IndexedSlicesTest(test_util.TensorFlowTestCase): 148 149 @test_util.run_in_graph_and_eager_modes 150 def testToTensor(self): 151 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 152 indices = constant_op.constant([0, 2]) 153 dense_shape = constant_op.constant([3, 2]) 154 x = ops.IndexedSlices(values, indices, dense_shape) 155 tensor = ops.convert_to_tensor(x, name="tensor") 156 self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]]) 157 158 @test_util.run_deprecated_v1 159 def testNegation(self): 160 with self.cached_session(): 161 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 162 indices = constant_op.constant([0, 2]) 163 x = -ops.IndexedSlices(values, indices) 164 self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]]) 165 self.assertAllEqual(x.indices.eval(), [0, 2]) 166 167 @test_util.run_deprecated_v1 168 def testScalarMul(self): 169 with self.cached_session(): 170 values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) 171 indices = constant_op.constant([0, 2]) 172 x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices)) 173 self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]]) 174 self.assertAllEqual(x.indices.eval(), [0, 2]) 175 176 177class NodeDefConstructorTest(test_util.TensorFlowTestCase): 178 179 def testNoArgs(self): 180 nodedef = ops._NodeDef("None", "bar") 181 self.assertProtoEquals("op: 'None' name: 'bar'", nodedef) 182 183 def testArgs(self): 184 nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") 185 self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", 186 nodedef) 187 nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j")) 188 self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef) 189 190 191def _apply_op(g, *args, **kwargs): 192 op = g.create_op(*args, **kwargs) 193 if len(op.outputs) == 1: 194 return op.outputs[0] 195 else: 196 return op.outputs 197 198 199class OperationTest(test_util.TensorFlowTestCase): 200 201 @test_util.run_deprecated_v1 202 def testNoInputs(self): 203 op = test_ops.float_output_string_output(name="myop").a.op 204 self.assertEqual(2, len(op.values())) 205 self.assertEqual(0, len(op.inputs)) 206 self.assertEqual("myop", op.name) 207 208 float_t, label_str_t = op.values() 209 self.assertEqual(dtypes.float32, float_t.dtype) 210 self.assertEqual(op, float_t.op) 211 self.assertEqual(0, float_t._value_index) 212 self.assertEqual(0, len(float_t.consumers())) 213 self.assertEqual("myop", float_t._as_node_def_input()) 214 215 self.assertEqual(dtypes.string, label_str_t.dtype) 216 self.assertEqual(op, label_str_t.op) 217 self.assertEqual(1, label_str_t._value_index) 218 self.assertEqual(0, len(label_str_t.consumers())) 219 self.assertEqual("myop:1", label_str_t._as_node_def_input()) 220 221 self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'", 222 op.node_def) 223 224 @test_util.run_deprecated_v1 225 def testNoOutputs(self): 226 op1 = test_ops.float_output(name="myop1").op 227 float_t, = op1.values() 228 op2 = test_ops.float_input(float_t, name="myop2") 229 self.assertEqual(0, len(op2.values())) 230 self.assertEqual(1, len(op2.inputs)) 231 self.assertIs(float_t, op2.inputs[0]) 232 233 self.assertEqual(1, len(float_t.consumers())) 234 self.assertEqual(op2, float_t.consumers()[0]) 235 236 self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def) 237 self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'", 238 op2.node_def) 239 240 @test_util.run_deprecated_v1 241 def testInputsAndOutputs(self): 242 op1 = test_ops.float_output(name="myop1").op 243 self.assertEqual(1, len(op1.values())) 244 float1_t, = op1.values() 245 246 op2 = test_ops.float_output_string_output(name="myop2").a.op 247 self.assertEqual(2, len(op2.values())) 248 float2_t, label2_str_t = op2.values() 249 250 # Note that we consume label2_str_t twice here. 251 op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op 252 self.assertEqual(2, len(op3.values())) 253 254 self.assertEqual(1, len(float1_t.consumers())) 255 self.assertEqual(op3, float1_t.consumers()[0]) 256 257 self.assertEqual(0, len(float2_t.consumers())) 258 259 self.assertEqual(2, len(label2_str_t.consumers())) 260 self.assertEqual(op3, label2_str_t.consumers()[0]) 261 self.assertEqual(op3, label2_str_t.consumers()[1]) 262 263 self.assertProtoEquals(""" 264 op:'Foo2' name:'myop3' 265 input:'myop1' input:'myop2:1' input:'myop2:1' 266 """, op3.node_def) 267 268 def testDeviceFromNodeDef(self): 269 op = ops.Operation( 270 ops._NodeDef("None", "myop", device="/job:goo/device:GPU:0"), 271 ops.Graph(), [], []) 272 self.assertEqual("/job:goo/device:GPU:0", op.device) 273 274 def testDeviceObject(self): 275 op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], []) 276 op._set_device("/job:goo/device:GPU:0") 277 self.assertProtoEquals( 278 "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) 279 op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], []) 280 op._set_device( 281 pydev.DeviceSpec( 282 job="muu", device_type="CPU", device_index=0)) 283 self.assertProtoEquals( 284 "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def) 285 286 def testReferenceInput(self): 287 g = ops.Graph() 288 op1 = ops.Operation( 289 ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], 290 [dtypes.float32_ref, dtypes.float32]) 291 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) 292 self.assertEquals([], list(op1.inputs)) 293 ref_t, nonref_t = op1.values() 294 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 295 op2 = ops.Operation( 296 ops._NodeDef("RefInputFloatInput", "op2"), 297 g, [ref_t, nonref_t], [], 298 input_types=[dtypes.float32_ref, dtypes.float32]) 299 self.assertProtoEquals( 300 "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", 301 op2.node_def) 302 self.assertEquals([ref_t, nonref_t], list(op2.inputs)) 303 op3 = ops.Operation( 304 ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) 305 self.assertProtoEquals( 306 "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", 307 op3.node_def) 308 309 def testInvalidNames(self): 310 g = ops.Graph() 311 with self.assertRaises(ValueError): 312 ops.Operation(ops._NodeDef("op", ""), g) 313 with self.assertRaises(ValueError): 314 ops.Operation(ops._NodeDef("op", "_invalid"), g) 315 with self.assertRaises(ValueError): 316 ops.Operation(ops._NodeDef("op", "-invalid"), g) 317 with self.assertRaises(ValueError): 318 ops.Operation(ops._NodeDef("op", "/invalid"), g) 319 with self.assertRaises(ValueError): 320 ops.Operation(ops._NodeDef("op", "invalid:0"), g) 321 322 @test_util.run_deprecated_v1 323 def testNoShapeFunction(self): 324 op = test_ops.a() 325 self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) 326 327 @test_util.run_in_graph_and_eager_modes 328 def testConvertToTensorNestedArray(self): 329 values = [[2], [3], [5], [7]] 330 tensor = ops.convert_to_tensor(values) 331 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 332 self.assertAllEqual(values, self.evaluate(tensor)) 333 334 def testShapeTuple(self): 335 with self.cached_session(): 336 c = constant_op.constant(1) 337 self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access 338 339 def testConvertToTensorEager(self): 340 with context.eager_mode(): 341 t = constant_op.constant(1) 342 self.assertTrue(isinstance(t, ops.EagerTensor)) 343 converted = ops.convert_to_tensor(t) 344 self.assertTrue(isinstance(converted, ops.EagerTensor)) 345 converted = ops.convert_to_tensor(1) 346 self.assertTrue(isinstance(converted, ops.EagerTensor)) 347 348 @test_util.run_in_graph_and_eager_modes 349 def testConvertToTensorNestedTuple(self): 350 values = ((2,), (3,), (5,), (7,)) 351 tensor = ops.convert_to_tensor(values) 352 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 353 self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values))) 354 355 @test_util.run_in_graph_and_eager_modes 356 def testConvertToTensorNestedTensors(self): 357 values = ((2,), (3,), (5,), (7,)) 358 tensor = ops.convert_to_tensor( 359 [constant_op.constant(row) for row in values]) 360 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 361 self.assertAllEqual(values, self.evaluate(tensor)) 362 tensor = ops.convert_to_tensor( 363 [[constant_op.constant(v) for v in row] for row in values]) 364 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 365 self.assertAllEqual(values, self.evaluate(tensor)) 366 367 @test_util.run_in_graph_and_eager_modes 368 def testConvertToTensorNestedMix(self): 369 values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) 370 tensor = ops.convert_to_tensor(values) 371 self.assertAllEqual((4, 1), tensor.get_shape().as_list()) 372 self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor)) 373 374 @test_util.run_in_graph_and_eager_modes 375 def testConvertToTensorPreferred(self): 376 values = [2, 3, 5, 7] 377 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) 378 self.assertEqual(dtypes.float32, tensor.dtype) 379 380 # Convert empty tensor to anything. 381 values = [] 382 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) 383 self.assertEqual(dtypes.int64, tensor.dtype) 384 385 # The preferred dtype is a type error and will convert to 386 # float32 instead. 387 values = [1.23] 388 tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) 389 self.assertEqual(dtypes.float32, tensor.dtype) 390 391 @test_util.run_in_graph_and_eager_modes 392 def testConvertToInvalidTensorType(self): 393 with self.assertRaises(TypeError): 394 # Forcing an invalid dtype should fail with a type error. 395 values = [1.23] 396 ops.convert_to_tensor(values, dtype=dtypes.int64) 397 398 @test_util.run_in_graph_and_eager_modes 399 def testConvertToTensorFromInvalidTensor(self): 400 tensor = constant_op.constant(42.0, dtype=dtypes.float32) 401 with self.assertRaises(ValueError): 402 ops.convert_to_tensor(tensor, dtype=dtypes.int32) 403 404 @test_util.run_deprecated_v1 405 def testNoConvert(self): 406 # Operation cannot be converted to Tensor. 407 op = control_flow_ops.no_op() 408 with self.assertRaisesRegexp(TypeError, 409 r"Can't convert Operation '.*' to Tensor"): 410 ops.convert_to_tensor(op) 411 412 def testStr(self): 413 node_def = ops._NodeDef("None", "op1") 414 op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32]) 415 self.assertEqual(str(node_def), str(op)) 416 417 def testRepr(self): 418 op = ops.Operation( 419 ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32]) 420 self.assertEqual("<tf.Operation 'op1' type=None>", repr(op)) 421 422 @test_util.run_deprecated_v1 423 def testGetAttr(self): 424 op = test_ops.default_attrs() 425 self.assertEqual(op.get_attr("string_val"), b"abc") 426 self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) 427 self.assertEqual(op.get_attr("int_val"), 123) 428 self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) 429 self.assertEqual(op.get_attr("float_val"), 10.0) 430 self.assertEqual(op.get_attr("float_list_val"), [10.0]) 431 self.assertEqual(op.get_attr("bool_val"), True) 432 self.assertEqual(op.get_attr("bool_list_val"), [True, False]) 433 self.assertEqual(op.get_attr("shape_val"), 434 tensor_shape.as_shape([2, 1]).as_proto()) 435 self.assertEqual(op.get_attr("shape_list_val"), 436 [tensor_shape.as_shape([]).as_proto(), 437 tensor_shape.as_shape([1]).as_proto()]) 438 self.assertEqual(op.get_attr("tensor_val"), 439 tensor_util.make_tensor_proto(1, dtypes.int32)) 440 self.assertEqual(op.get_attr("tensor_list_val"), 441 [tensor_util.make_tensor_proto(1, dtypes.int32)]) 442 443 type_val = op.get_attr("type_val") 444 # First check that type_val is a DType, because the assertEquals will work 445 # no matter what since DType overrides __eq__ 446 self.assertIsInstance(type_val, dtypes.DType) 447 self.assertEqual(type_val, dtypes.int32) 448 449 type_list_val = op.get_attr("type_list_val") 450 self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) 451 self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) 452 453 @function.Defun(dtypes.float32, func_name="MyFunc") 454 def func(x): 455 return x 456 457 op = test_ops.func_attr(func) 458 self.assertEqual(op.get_attr("f"), 459 attr_value_pb2.NameAttrList(name="MyFunc")) 460 461 # Try fetching missing attr 462 with self.assertRaisesRegexp( 463 ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."): 464 op.get_attr("FakeAttr") 465 466 # TODO(b/65162920): remove this test when users who are directly mutating the 467 # node_def have been updated to proper usage. 468 @test_util.run_deprecated_v1 469 def testSetAttr(self): 470 op = test_ops.int_attr().op 471 op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) 472 # TODO(skyewm): add node_def check 473 self.assertEqual(op.get_attr("foo"), 2) 474 475 # TODO(nolivia): test all error cases 476 def testAddControlInput(self): 477 with ops.Graph().as_default(): 478 x = constant_op.constant(1).op 479 y = constant_op.constant(2).op 480 z = constant_op.constant(3).op 481 z._add_control_input(x) # pylint: disable=protected-access 482 self.assertEqual(z.control_inputs, [x]) 483 z._add_control_input(x) # pylint: disable=protected-access 484 self.assertEqual(z.control_inputs, [x]) 485 z._add_control_inputs([x, y, y]) # pylint: disable=protected-access 486 self.assertEqual(z.control_inputs, [x, y]) 487 self.assertEqual(x._control_outputs, [z]) 488 489 @test_util.run_deprecated_v1 490 def testRemoveAllControlInputs(self): 491 a = constant_op.constant(1) 492 with ops.control_dependencies([a]): 493 b = constant_op.constant(2) 494 c = constant_op.constant(3) 495 d = constant_op.constant(4) 496 e = constant_op.constant(5) 497 with ops.control_dependencies([a, c]): 498 f = d + e 499 500 self.assertEqual(a.op.control_inputs, []) 501 self.assertEqual(b.op.control_inputs, [a.op]) 502 self.assertEqual(f.op.control_inputs, [a.op, c.op]) 503 504 a.op._remove_all_control_inputs() # pylint: disable=protected-access 505 self.assertEqual(a.op.control_inputs, []) 506 507 b.op._remove_all_control_inputs() # pylint: disable=protected-access 508 self.assertEqual(b.op.control_inputs, []) 509 510 f.op._remove_all_control_inputs() # pylint: disable=protected-access 511 self.assertEqual(f.op.control_inputs, []) 512 self.assertEqual(list(f.op.inputs), [d, e]) 513 514 @test_util.run_deprecated_v1 515 def testControlInputCycle(self): 516 graph = ops.Graph() 517 with graph.as_default(): 518 z = constant_op.constant(0) 519 x = constant_op.constant(1) 520 y = constant_op.constant(2) 521 y.op._add_control_input(z.op) # pylint: disable=protected-access 522 y.op._add_control_input(x.op) # pylint: disable=protected-access 523 x.op._add_control_input(y.op) # pylint: disable=protected-access 524 with self.session(graph=graph) as sess: 525 with self.assertRaisesRegexp( 526 errors.InvalidArgumentError, 527 "Graph is invalid, contains a cycle with 2 nodes"): 528 self.evaluate(x) 529 530 def testUpdateInput(self): 531 g = ops.Graph() 532 with g.as_default(): 533 x = constant_op.constant(1) 534 y = constant_op.constant(2) 535 z = x + y 536 537 z.op._update_input(0, y) # pylint: disable=protected-access 538 self.assertEquals(list(z.op.inputs), [y, y]) 539 self.assertEquals(x.consumers(), []) 540 self.assertEquals(y.consumers(), [z.op, z.op]) 541 with session.Session(graph=g) as sess: 542 self.assertEquals(self.evaluate(z), 4) 543 544 z.op._update_input(0, x) # pylint: disable=protected-access 545 self.assertEquals(list(z.op.inputs), [x, y]) 546 self.assertEquals(x.consumers(), [z.op]) 547 self.assertEquals(y.consumers(), [z.op]) 548 with session.Session(graph=g) as sess: 549 self.assertEquals(self.evaluate(z), 3) 550 551 z.op._update_input(1, y) # pylint: disable=protected-access 552 self.assertEquals(list(z.op.inputs), [x, y]) 553 self.assertEquals(x.consumers(), [z.op]) 554 self.assertEquals(y.consumers(), [z.op]) 555 with session.Session(graph=g) as sess: 556 self.assertEquals(self.evaluate(z), 3) 557 558 def testUpdateInputGraphError(self): 559 g_0 = ops.Graph() 560 g_1 = ops.Graph() 561 with g_0.as_default(): 562 x = constant_op.constant(1) 563 with g_1.as_default(): 564 y = constant_op.constant(2) 565 z = y * 2 566 with self.assertRaisesRegexp(ValueError, "must be from the same graph"): 567 z.op._update_input(0, x) # pylint: disable=protected-access 568 569 def testUpdateInputTypeError(self): 570 g = ops.Graph() 571 with g.as_default(): 572 w = constant_op.constant(0) 573 x = constant_op.constant("") 574 y = constant_op.constant(1) 575 z = y + w 576 z.op._update_input(0, x) # pylint: disable=protected-access 577 with session.Session(graph=g) as sess: 578 with self.assertRaisesRegexp( 579 errors.InvalidArgumentError, 580 "Input 0 of node add was passed string from Const_1:0 incompatible " 581 "with expected int32"): 582 self.evaluate(z) 583 584 def testUpdateInputShapeError(self): 585 g = ops.Graph() 586 with g.as_default(): 587 w = constant_op.constant(2, shape=[3, 1]) 588 x = constant_op.constant(0, shape=[3, 1]) 589 y = constant_op.constant(1, shape=[2, 2]) 590 z = w + x 591 with self.assertRaisesRegexp( 592 errors.InvalidArgumentError, 593 r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): 594 z.op._update_input(0, y) # pylint: disable=protected-access 595 596 def testUpdateInputOutOfRange(self): 597 g = ops.Graph() 598 with g.as_default(): 599 x = constant_op.constant(1) 600 with self.assertRaisesRegexp( 601 errors.OutOfRangeError, 602 r"Cannot update edge. Input index \[1\] is greater than the number of " 603 r"total inputs \[0\]." 604 ): 605 x.op._update_input(1, x) # pylint: disable=protected-access 606 607 @test_util.enable_control_flow_v2 608 @test_util.run_v1_only("b/120545219") 609 def testAddWhileInput(self): 610 @eager_function.defun 611 def test(): 612 output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1, 613 [1]) 614 while_op = output.op.inputs[0].op 615 self.assertEqual(while_op.type, "While") 616 orig_num_inputs = len(while_op.inputs) 617 618 # Make sure we can handle the while op having a control input. 619 while_op._add_control_input(constant_op.constant(0).op) 620 621 new_input1 = constant_op.constant(1.0) 622 new_input2 = constant_op.constant(True) 623 624 while_op._set_type_list_attr("T", 625 [t.dtype for t in while_op.inputs] + 626 [new_input1.dtype, new_input2.dtype]) 627 628 while_op._add_while_inputs([new_input1, new_input2]) 629 # Can't add an edge beyond what's specified by "T" 630 with self.assertRaises(errors.OutOfRangeError): 631 while_op._add_while_inputs([new_input2]) 632 self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert 633 634 test() 635 636 @test_util.run_deprecated_v1 637 def testOpDef(self): 638 x = constant_op.constant(0) 639 y = constant_op.constant(1) 640 z = x + y 641 642 self.assertEqual(x.op.op_def.name, "Const") 643 self.assertEqual(len(x.op.op_def.input_arg), 0) 644 self.assertEqual(len(x.op.op_def.output_arg), 1) 645 646 self.assertEqual(z.op.op_def.name, "Add") 647 self.assertEqual(len(z.op.op_def.input_arg), 2) 648 self.assertEqual(len(z.op.op_def.output_arg), 1) 649 650 def testInputFromDifferentGraphError(self): 651 g_0 = ops.Graph() 652 g_1 = ops.Graph() 653 with g_0.as_default(): 654 x = constant_op.constant(1) 655 with g_1.as_default(): 656 y = constant_op.constant(2) 657 with self.assertRaisesRegexp(ValueError, "must be from the same graph"): 658 y * x # pylint: disable=pointless-statement 659 660 def testInputsAreImmutable(self): 661 g = ops.Graph() 662 with g.as_default(): 663 x = test_ops.int_output() 664 op = test_ops.int_input_int_output(x, name="myop").op 665 with self.assertRaisesRegexp( 666 AttributeError, "'_InputList' object has no attribute 'append'"): 667 op.inputs.append(None) 668 669 670class CreateOpTest(test_util.TensorFlowTestCase): 671 672 def testNodeDefArgs(self): 673 g = ops.Graph() 674 op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 675 with g.device("/device:GPU:0"): 676 op2 = g.create_op( 677 "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None, 678 name="myop2") 679 op3 = g.create_op( 680 "Foo3", 681 [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]], 682 [dtypes.float32, dtypes.int32], 683 None, 684 name="myop3") 685 self.assertDeviceEqual(None, op1.device) 686 self.assertDeviceEqual("/device:GPU:0", op2.device) 687 self.assertDeviceEqual(None, op3.device) 688 self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def) 689 self.assertProtoEquals( 690 "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'", 691 op2.node_def) 692 self.assertProtoEquals( 693 "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'", 694 op3.node_def) 695 696 def testReferenceInput(self): 697 g = ops.Graph() 698 op1 = g.create_op( 699 "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], 700 name="op1") 701 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) 702 ref_t, nonref_t = op1.values() 703 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 704 op2 = g.create_op( 705 "RefInputFloatInput", [ref_t, nonref_t], [], 706 input_types=[dtypes.float32_ref, dtypes.float32], 707 name="op2") 708 self.assertProtoEquals( 709 "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", 710 op2.node_def) 711 op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3") 712 self.assertProtoEquals( 713 "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", 714 op3.node_def) 715 716 def testFinalized(self): 717 g = ops.Graph() 718 g.finalize() 719 with self.assertRaises(RuntimeError): 720 g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 721 722 # Test unfinalize. 723 g._unsafe_unfinalize() 724 g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") 725 726 727# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation 728# method. Arguably we should only test the public APIs that depend on this 729# method. However, this logic is complex and tricky, and it can be difficult to 730# ascertain if we have adequate coverage (e.g. a graph may run successfully if 731# the control flow context isn't set properly, but a more complicated use case 732# that might not be obvious to test will fail). Thus we instead explicitly test 733# the low-level behavior. 734class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): 735 736 @test_util.run_deprecated_v1 737 def testBasic(self): 738 g = ops.Graph() 739 with g.as_default(): 740 x = test_ops.int_output() 741 c_op = ops._create_c_op( 742 g, ops._NodeDef("IntInputIntOutput", "myop"), [x], []) 743 op = g._create_op_from_tf_operation(c_op) 744 745 self.assertEqual(op.name, "myop") 746 self.assertEqual(op.type, "IntInputIntOutput") 747 self.assertEqual(len(op.outputs), 1) 748 self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape()) 749 self.assertEqual(list(op.inputs), [x]) 750 self.assertEqual(op.control_inputs, []) 751 self.assertEqual(op.graph, g) 752 self.assertEqual(x.consumers(), [op]) 753 self.assertIsNotNone(op.traceback) 754 self.assertEqual(g.get_operation_by_name("myop"), op) 755 self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0]) 756 757 def testShape(self): 758 g = ops.Graph() 759 with g.as_default(): 760 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 761 c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], []) 762 op = g._create_op_from_tf_operation(c_op) 763 764 self.assertEqual(op.name, "myop") 765 self.assertEqual(op.type, "Identity") 766 self.assertEqual(len(op.outputs), 1) 767 self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3)) 768 769 def testUniqueName(self): 770 g = ops.Graph() 771 with g.as_default(): 772 c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], []) 773 c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], []) 774 op = g._create_op_from_tf_operation(c_op) 775 op2 = g._create_op_from_tf_operation(c_op2) 776 777 # Create ops with same names as op1 and op2. We expect the new names to be 778 # uniquified. 779 op3 = test_ops.int_output(name="myop").op 780 op4 = test_ops.int_output(name="myop_1").op 781 782 self.assertEqual(op.name, "myop") 783 self.assertEqual(op2.name, "myop_1") 784 self.assertEqual(op3.name, "myop_2") 785 self.assertEqual(op4.name, "myop_1_1") 786 787 @test_util.run_v1_only("b/120545219") 788 def testCond(self): 789 g = ops.Graph() 790 with g.as_default(): 791 x = test_ops.int_output() 792 793 def true_fn(): 794 ops._create_c_op(ops.get_default_graph(), 795 ops._NodeDef("IntInput", "cond/myop"), [x], []) 796 new_ops = g._add_new_tf_operations() 797 self.assertEqual(len(new_ops), 1) 798 return x 799 800 control_flow_ops.cond(x < 10, true_fn, lambda: x) 801 802 op = g.get_operation_by_name("cond/myop") 803 self.assertIsNotNone(op) 804 self.assertEqual(op.name, "cond/myop") 805 self.assertEqual(op.type, "IntInput") 806 self.assertEqual(op.outputs, []) 807 op_input = op.inputs[0].op 808 self.assertEqual(op_input.type, "Switch") 809 self.assertEqual(op_input.inputs[0], x) 810 self.assertEqual(op.graph, g) 811 # pylint: disable=protected-access 812 self.assertIsNotNone(op._get_control_flow_context()) 813 self.assertEqual(op._get_control_flow_context().name, 814 "cond/cond_text") 815 # pylint: enable=protected-access 816 817 @test_util.run_v1_only("b/120545219") 818 def testWhileLoop(self): 819 g = ops.Graph() 820 with g.as_default(): 821 x = test_ops.int_output() 822 823 def body(i): 824 ops._create_c_op(ops.get_default_graph(), 825 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 826 new_ops = g._add_new_tf_operations() 827 self.assertEqual(len(new_ops), 1) 828 return i 829 830 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 831 832 op = g.get_operation_by_name("myloop/myop") 833 self.assertIsNotNone(op) 834 self.assertEqual(op.name, "myloop/myop") 835 self.assertEqual(op.type, "IntInput") 836 self.assertEqual(op.outputs, []) 837 op_input = op.inputs[0].op 838 self.assertEqual(op_input.type, "Enter") 839 self.assertEqual(list(op_input.inputs), [x]) 840 self.assertEqual(op.graph, g) 841 # pylint: disable=protected-access 842 self.assertIsNotNone(op._get_control_flow_context()) 843 self.assertEqual(op._get_control_flow_context().name, 844 "myloop/while_context") 845 # pylint: enable=protected-access 846 847 @test_util.run_v1_only("b/120545219") 848 def testWhileLoopWithInternalControlDep(self): 849 g = ops.Graph() 850 with g.as_default(): 851 x = test_ops.int_output() 852 853 def body(i): 854 c = constant_op.constant(1.0, name="c") 855 ops._create_c_op(ops.get_default_graph(), 856 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 857 with ops.control_dependencies([c]): 858 new_ops = g._add_new_tf_operations() 859 self.assertEqual(len(new_ops), 1) 860 return i 861 862 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 863 864 op = g.get_operation_by_name("myloop/myop") 865 self.assertIsNotNone(op) 866 c = g.get_operation_by_name("myloop/c") 867 self.assertIsNotNone(c) 868 # Internal control dep is preserved 869 self.assertEqual(op.control_inputs, [c]) 870 871 @test_util.run_v1_only("b/120545219") 872 def testWhileLoopWithExternalControlDep(self): 873 g = ops.Graph() 874 with g.as_default(): 875 x = test_ops.int_output() 876 c = constant_op.constant(1.0) 877 878 def body(i): 879 ops._create_c_op(ops.get_default_graph(), 880 ops._NodeDef("IntInput", "myloop/myop"), [x], []) 881 with ops.control_dependencies([c]): 882 new_ops = g._add_new_tf_operations() 883 self.assertEqual(len(new_ops), 1) 884 return i 885 886 control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") 887 888 op = g.get_operation_by_name("myloop/myop") 889 self.assertIsNotNone(op) 890 # External control dep is removed and replaced with internal control dep 891 self.assertNotEqual(op.control_inputs[0], c.op) 892 self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) 893 894 895class ApplyOpTest(test_util.TensorFlowTestCase): 896 897 def testNodeDefArgs(self): 898 g = ops.Graph() 899 t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") 900 with g.device("/device:GPU:0"): 901 t2 = _apply_op( 902 g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2") 903 t3 = _apply_op( 904 g, 905 "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], 906 name="myop3") 907 self.assertTrue(isinstance(t1, ops.Tensor)) 908 self.assertTrue(isinstance(t2, list)) 909 self.assertTrue(isinstance(t3, list)) 910 self.assertTrue(isinstance(t3[0], ops.Tensor)) 911 self.assertEqual("myop1", t1._as_node_def_input()) 912 self.assertEqual("myop2", t2[0]._as_node_def_input()) 913 self.assertEqual("myop2:1", t2[1]._as_node_def_input()) 914 self.assertEqual("myop3", t3[0]._as_node_def_input()) 915 # Validate that we got the right ops as well 916 self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def) 917 self.assertProtoEquals( 918 "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'", 919 t2[0].op.node_def) 920 self.assertProtoEquals( 921 "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'", 922 t3[0].op.node_def) 923 924 def testReferenceInput(self): 925 g = ops.Graph() 926 ref_t, nonref_t = _apply_op( 927 g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], 928 name="op1") 929 self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", 930 ref_t.op.node_def) 931 # NOTE(mrry): Must specify input_types to preserve ref-typed input. 932 out_2 = _apply_op( 933 g, 934 "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32], 935 input_types=[dtypes.float32_ref, dtypes.float32], 936 name="op2") 937 self.assertProtoEquals( 938 "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'", 939 out_2.op.node_def) 940 out_3 = _apply_op( 941 g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32], 942 name="op3") 943 self.assertProtoEquals( 944 "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'", 945 out_3.op.node_def) 946 947 948class NameStackTest(test_util.TensorFlowTestCase): 949 950 def testBasics(self): 951 g = ops.Graph() 952 self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) 953 self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) 954 self.assertEqual("foo", g.unique_name("foo")) 955 self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False)) 956 self.assertEqual("foo_1", g.unique_name("foo")) 957 self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False)) 958 self.assertEqual("foo_2", g.unique_name("foo")) 959 self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False)) 960 self.assertEqual("foo_1_1", g.unique_name("foo_1")) 961 self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False)) 962 self.assertEqual("foo_1_2", g.unique_name("foo_1")) 963 self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False)) 964 self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2")) 965 with g.name_scope("bar"): 966 self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False)) 967 self.assertEqual("bar/foo", g.unique_name("foo")) 968 self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False)) 969 self.assertEqual("bar/foo_1", g.unique_name("foo")) 970 with g.name_scope(None): 971 self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False)) 972 self.assertEqual("foo_3", g.unique_name("foo")) 973 with g.name_scope("baz"): 974 self.assertEqual( 975 "bar/baz/foo", g.unique_name( 976 "foo", mark_as_used=False)) 977 self.assertEqual("bar/baz/foo", g.unique_name("foo")) 978 self.assertEqual( 979 "bar/baz/foo_1", g.unique_name( 980 "foo", mark_as_used=False)) 981 self.assertEqual("bar/baz/foo_1", g.unique_name("foo")) 982 with g.name_scope("baz"): 983 self.assertEqual( 984 "bar/baz_1/foo", g.unique_name( 985 "foo", mark_as_used=False)) 986 self.assertEqual("bar/baz_1/foo", g.unique_name("foo")) 987 self.assertEqual( 988 "bar/baz_1/foo_1", g.unique_name( 989 "foo", mark_as_used=False)) 990 self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo")) 991 with g.name_scope("quux"): 992 self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False)) 993 self.assertEqual("quux/foo", g.unique_name("foo")) 994 with g.name_scope("bar"): 995 with g.name_scope("baz"): 996 self.assertEqual( 997 "bar_1/baz/foo", g.unique_name( 998 "foo", mark_as_used=False)) 999 self.assertEqual("bar_1/baz/foo", g.unique_name("foo")) 1000 self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False)) 1001 self.assertEqual("foo_4", g.unique_name("foo")) 1002 self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False)) 1003 self.assertEqual("bar_2", g.unique_name("bar")) 1004 1005 @test_util.run_deprecated_v1 1006 def testNameAndVariableScope(self): 1007 with self.cached_session() as sess: 1008 with sess.graph.name_scope("l0"): 1009 with variable_scope.variable_scope("l1"): 1010 with sess.graph.name_scope("l1") as scope: 1011 self.assertEqual("l0/l1/l1/", scope) 1012 self.assertEqual( 1013 "l0/l1/l1/foo", 1014 sess.graph.unique_name( 1015 "foo", mark_as_used=False)) 1016 self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo")) 1017 with sess.graph.name_scope("l2") as scope: 1018 self.assertEqual("l0/l1/l2/", scope) 1019 self.assertEqual( 1020 "l0/l1/l2/foo", 1021 sess.graph.unique_name( 1022 "foo", mark_as_used=False)) 1023 self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo")) 1024 1025 def testOutOfOrderUniqueName(self): 1026 g = ops.Graph() 1027 self.assertEqual("foo_2", g.unique_name("foo_2")) 1028 self.assertEqual("foo", g.unique_name("foo")) 1029 self.assertEqual("foo_1", g.unique_name("foo")) 1030 self.assertEqual("foo_3", g.unique_name("foo")) 1031 1032 def testUniqueNameCaseInsensitivity(self): 1033 g = ops.Graph() 1034 self.assertEqual("foo", g.unique_name("foo")) 1035 self.assertEqual("Foo_1", g.unique_name("Foo")) 1036 with g.name_scope("bar"): 1037 self.assertEqual("bar/foo", g.unique_name("foo")) 1038 with g.name_scope("Bar"): 1039 self.assertEqual("Bar_1/foo", g.unique_name("foo")) 1040 1041 def testInvalidNameRaisesError(self): 1042 g = ops.Graph() 1043 with g.name_scope(""): # Should not raise 1044 pass 1045 with g.name_scope("foo/"): # Should not raise 1046 with g.name_scope("_bar"): # Should not raise 1047 pass 1048 with self.assertRaises(ValueError): 1049 with g.name_scope("foo:0"): 1050 pass 1051 with self.assertRaises(ValueError): 1052 with g.name_scope("_bar"): 1053 pass 1054 1055 1056class NameTest(test_util.TensorFlowTestCase): 1057 1058 def testGenerateName(self): 1059 g = ops.Graph() 1060 op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) 1061 self.assertEqual("TwoFloatOutputs", op0.name) 1062 self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name) 1063 self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name) 1064 1065 op1 = g.create_op("FloatOutput", [], [dtypes.float32]) 1066 self.assertEqual("FloatOutput", op1.name) 1067 self.assertEqual("FloatOutput:0", op1.outputs[0].name) 1068 1069 op2 = g.create_op("FloatOutput", [], [dtypes.float32]) 1070 self.assertEqual("FloatOutput_1", op2.name) 1071 self.assertEqual("FloatOutput_1:0", op2.outputs[0].name) 1072 1073 op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op") 1074 self.assertEqual("my_op", op3.name) 1075 self.assertEqual("my_op:0", op3.outputs[0].name) 1076 1077 def testNameScope(self): 1078 g = ops.Graph() 1079 1080 with g.name_scope("foo") as foo: 1081 self.assertEqual("foo/", foo) 1082 with g.name_scope("foo2") as foo2: 1083 self.assertEqual("foo/foo2/", foo2) 1084 with g.name_scope(None) as empty1: 1085 self.assertEqual("", empty1) 1086 with g.name_scope("foo3") as foo3: 1087 self.assertEqual("foo3/", foo3) 1088 with g.name_scope("") as empty2: 1089 self.assertEqual("", empty2) 1090 1091 self.assertEqual("FloatOutput", 1092 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1093 with g.name_scope("bar") as scope: 1094 self.assertEqual("bar/FloatOutput", 1095 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1096 self.assertEqual("bar/FloatOutput_1", 1097 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1098 # If you use the value from "with .. as", that values is used as-is. 1099 self.assertEqual( 1100 "bar", g.create_op( 1101 "FloatOutput", [], [dtypes.float32], name=scope).name) 1102 with g.name_scope("baz") as scope: 1103 with g.name_scope("quux"): 1104 self.assertEqual("baz/quux/FloatOutput", 1105 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1106 # If you use the value from the enclosing "with .. as", nothing is pushed. 1107 with g.name_scope(scope): 1108 self.assertEqual("baz/FloatOutput", 1109 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1110 self.assertEqual( 1111 "baz", g.create_op( 1112 "FloatOutput", [], [dtypes.float32], name=scope).name) 1113 self.assertEqual( 1114 "trailing", 1115 g.create_op( 1116 "FloatOutput", [], [dtypes.float32], name="trailing/").name) 1117 with g.name_scope("bar"): 1118 self.assertEqual("bar_1/FloatOutput", 1119 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1120 with g.name_scope("bar/"): 1121 self.assertEqual("bar/FloatOutput_2", 1122 g.create_op("FloatOutput", [], [dtypes.float32]).name) 1123 1124 1125class DeviceTest(test_util.TensorFlowTestCase): 1126 1127 def testNoDevice(self): 1128 g = ops.Graph() 1129 op = g.create_op("FloatOutput", [], [dtypes.float32]) 1130 self.assertDeviceEqual(None, op.device) 1131 gd = g.as_graph_def() 1132 self.assertProtoEqualsVersion(""" 1133 node { name: "FloatOutput" op: "FloatOutput" } 1134 """, gd) 1135 1136 def testEagerBackingDevice(self): 1137 with context.eager_mode(): 1138 with ops.device("/device:CPU:0"): 1139 t = constant_op.constant(1.0) 1140 self.assertRegexpMatches(t.device, "/device:CPU:0") 1141 self.assertRegexpMatches(t.backing_device, "/device:CPU:0") 1142 1143 def testDevicePartialString(self): 1144 g = ops.Graph() 1145 with g.device("/job:worker/replica:2"): 1146 g.create_op("FloatOutput", [], [dtypes.float32]) 1147 gd = g.as_graph_def() 1148 self.assertProtoEqualsVersion(""" 1149 node { name: "FloatOutput" op: "FloatOutput" 1150 device: "/job:worker/replica:2" } 1151 """, gd) 1152 1153 def testDeviceFull(self): 1154 g = ops.Graph() 1155 with g.device( 1156 pydev.DeviceSpec( 1157 job="worker", replica=2, task=0, device_type="CPU", 1158 device_index=3)): 1159 g.create_op("FloatOutput", [], [dtypes.float32]) 1160 gd = g.as_graph_def() 1161 self.assertProtoEqualsVersion(""" 1162 node { name: "FloatOutput" op: "FloatOutput" 1163 device: "/job:worker/replica:2/task:0/device:CPU:3" } 1164 """, gd) 1165 1166 def testNesting(self): 1167 g = ops.Graph() 1168 with g.device("/job:worker/replica:2"): 1169 g.create_op("FloatOutput", [], [dtypes.float32]) 1170 with g.device("/job:worker/replica:3/task:0"): 1171 g.create_op("FloatOutput", [], [dtypes.float32]) 1172 g.create_op("FloatOutput", [], [dtypes.float32]) 1173 gd = g.as_graph_def() 1174 self.assertProtoEqualsVersion(""" 1175 node { name: "FloatOutput" op: "FloatOutput" 1176 device: "/job:worker/replica:2" } 1177 node { name: "FloatOutput_1" op: "FloatOutput" 1178 device: "/job:worker/replica:3/task:0" } 1179 node { name: "FloatOutput_2" op: "FloatOutput" 1180 device: "/job:worker/replica:2" } 1181 """, gd) 1182 1183 def testNestingString(self): 1184 g = ops.Graph() 1185 with g.device("/job:worker/replica:2"): 1186 g.create_op("FloatOutput", [], [dtypes.float32]) 1187 with g.device("/job:worker/replica:3/task:0"): 1188 g.create_op("FloatOutput", [], [dtypes.float32]) 1189 g.create_op("FloatOutput", [], [dtypes.float32]) 1190 gd = g.as_graph_def() 1191 self.assertProtoEqualsVersion(""" 1192 node { name: "FloatOutput" op: "FloatOutput" 1193 device: "/job:worker/replica:2" } 1194 node { name: "FloatOutput_1" op: "FloatOutput" 1195 device: "/job:worker/replica:3/task:0" } 1196 node { name: "FloatOutput_2" op: "FloatOutput" 1197 device: "/job:worker/replica:2" } 1198 """, gd) 1199 1200 def testNestingOverrideGpuCpu(self): 1201 g = ops.Graph() 1202 with g.device("/job:worker/replica:2/device:CPU:1"): 1203 g.create_op("FloatOutput", [], [dtypes.float32]) 1204 with g.device("/job:worker/replica:2/device:GPU:2"): 1205 g.create_op("FloatOutput", [], [dtypes.float32]) 1206 g.create_op("FloatOutput", [], [dtypes.float32]) 1207 gd = g.as_graph_def() 1208 self.assertProtoEqualsVersion(""" 1209 node { name: "FloatOutput" op: "FloatOutput" 1210 device: "/job:worker/replica:2/device:CPU:1" } 1211 node { name: "FloatOutput_1" op: "FloatOutput" 1212 device: "/job:worker/replica:2/device:GPU:2" } 1213 node { name: "FloatOutput_2" op: "FloatOutput" 1214 device: "/job:worker/replica:2/device:CPU:1" } 1215 """, gd) 1216 1217 def testNestingWithMergeDeviceFunction(self): 1218 g = ops.Graph() 1219 1220 with g.device(pydev.merge_device("/device:GPU:0")): 1221 g.create_op("FloatOutput", [], [dtypes.float32]) 1222 with g.device(pydev.merge_device("/job:worker")): 1223 g.create_op("FloatOutput", [], [dtypes.float32]) 1224 with g.device(pydev.merge_device("/device:CPU:0")): 1225 g.create_op("FloatOutput", [], [dtypes.float32]) 1226 with g.device(pydev.merge_device("/job:ps")): 1227 g.create_op("FloatOutput", [], [dtypes.float32]) 1228 with g.device(pydev.merge_device(None)): 1229 g.create_op("FloatOutput", [], [dtypes.float32]) 1230 1231 gd = g.as_graph_def() 1232 self.assertProtoEqualsVersion(""" 1233 node { name: "FloatOutput" op: "FloatOutput" 1234 device: "/device:GPU:0" } 1235 node { name: "FloatOutput_1" op: "FloatOutput" 1236 device: "/job:worker/device:GPU:0" } 1237 node { name: "FloatOutput_2" op: "FloatOutput" 1238 device: "/job:worker/device:CPU:0" } 1239 node { name: "FloatOutput_3" op: "FloatOutput" 1240 device: "/job:ps/device:CPU:0" } 1241 node { name: "FloatOutput_4" op: "FloatOutput" 1242 device: "/job:ps/device:CPU:0" } 1243 """, gd) 1244 1245 def testNestingWithDeviceStrings(self): 1246 g = ops.Graph() 1247 1248 with g.device("/device:GPU:0"): 1249 g.create_op("FloatOutput", [], [dtypes.float32]) 1250 with g.device("/job:worker"): 1251 g.create_op("FloatOutput", [], [dtypes.float32]) 1252 with g.device("/device:CPU:0"): 1253 g.create_op("FloatOutput", [], [dtypes.float32]) 1254 with g.device("/job:ps"): 1255 g.create_op("FloatOutput", [], [dtypes.float32]) 1256 with g.device(""): 1257 g.create_op("FloatOutput", [], [dtypes.float32]) 1258 1259 gd = g.as_graph_def() 1260 self.assertProtoEqualsVersion(""" 1261 node { name: "FloatOutput" op: "FloatOutput" 1262 device: "/device:GPU:0" } 1263 node { name: "FloatOutput_1" op: "FloatOutput" 1264 device: "/job:worker/device:GPU:0" } 1265 node { name: "FloatOutput_2" op: "FloatOutput" 1266 device: "/job:worker/device:CPU:0" } 1267 node { name: "FloatOutput_3" op: "FloatOutput" 1268 device: "/job:ps/device:CPU:0" } 1269 node { name: "FloatOutput_4" op: "FloatOutput" 1270 device: "/job:ps/device:CPU:0" } 1271 """, gd) 1272 1273 def testNestingWithDeviceStringWildcard(self): 1274 g = ops.Graph() 1275 1276 with g.device("/device:GPU:7"): 1277 g.create_op("FloatOutput", [], [dtypes.float32]) 1278 with g.device("/device:GPU:*"): 1279 g.create_op("FloatOutput", [], [dtypes.float32]) 1280 1281 with g.device("/device:CPU:*"): 1282 g.create_op("FloatOutput", [], [dtypes.float32]) 1283 with g.device("/device:CPU:5"): 1284 g.create_op("FloatOutput", [], [dtypes.float32]) 1285 1286 gd = g.as_graph_def() 1287 self.assertProtoEqualsVersion(""" 1288 node { name: "FloatOutput" op: "FloatOutput" 1289 device: "/device:GPU:7" } 1290 node { name: "FloatOutput_1" op: "FloatOutput" 1291 device: "/device:GPU:7" } 1292 node { name: "FloatOutput_2" op: "FloatOutput" 1293 device: "/device:CPU:*" } 1294 node { name: "FloatOutput_3" op: "FloatOutput" 1295 device: "/device:CPU:5" } 1296 """, gd) 1297 1298 def testNoneClearsDefault(self): 1299 g = ops.Graph() 1300 with g.device("/job:worker/replica:2/device:CPU:1"): 1301 g.create_op("FloatOutput", [], [dtypes.float32]) 1302 with g.device(None): 1303 g.create_op("FloatOutput", [], [dtypes.float32]) 1304 g.create_op("FloatOutput", [], [dtypes.float32]) 1305 gd = g.as_graph_def() 1306 self.assertProtoEqualsVersion(""" 1307 node { name: "FloatOutput" op: "FloatOutput" 1308 device: "/job:worker/replica:2/device:CPU:1" } 1309 node { name: "FloatOutput_1" op: "FloatOutput" } 1310 node { name: "FloatOutput_2" op: "FloatOutput" 1311 device: "/job:worker/replica:2/device:CPU:1" } 1312 """, gd) 1313 1314 def testNoneIgnoresOuterDeviceFunction(self): 1315 g = ops.Graph() 1316 with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"): 1317 g.create_op("FloatOutput", [], [dtypes.float32]) 1318 with g.device(None): 1319 g.create_op("FloatOutput", [], [dtypes.float32]) 1320 g.create_op("FloatOutput", [], [dtypes.float32]) 1321 gd = g.as_graph_def() 1322 self.assertProtoEqualsVersion(""" 1323 node { name: "FloatOutput" op: "FloatOutput" 1324 device: "/job:worker/replica:2/device:CPU:1" } 1325 node { name: "FloatOutput_1" op: "FloatOutput" } 1326 node { name: "FloatOutput_2" op: "FloatOutput" 1327 device: "/job:worker/replica:2/device:CPU:1" } 1328 """, gd) 1329 1330 def _overwritingDeviceFunction(self, unused_op): 1331 # This device function unconditionally overwrites the device of ops. 1332 # 1333 # NOTE(mrry): Writing device functions like this is not 1334 # recommended. Instead, in most cases you should use 1335 # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the 1336 # argument to `tf.device()` and the device component will be merged in. 1337 return "/job:overwrite" 1338 1339 def testOverwritingBehavior(self): 1340 g = ops.Graph() 1341 with g.device(self._overwritingDeviceFunction): 1342 g.create_op("FloatOutput", [], [dtypes.float32]) 1343 with g.device("/job:ps"): # Will be overwritten. 1344 g.create_op("FloatOutput", [], [dtypes.float32]) 1345 with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. 1346 g.create_op("FloatOutput", [], [dtypes.float32]) 1347 with g.device(None): # Disables overwriting device function 1348 with g.device("/job:ps"): 1349 g.create_op("FloatOutput", [], [dtypes.float32]) 1350 with g.device(None): # Disables overwriting device function 1351 with g.device(pydev.merge_device("/job:ps")): 1352 g.create_op("FloatOutput", [], [dtypes.float32]) 1353 gd = g.as_graph_def() 1354 self.assertProtoEqualsVersion(""" 1355 node { name: "FloatOutput" op: "FloatOutput" 1356 device: "/job:overwrite" } 1357 node { name: "FloatOutput_1" op: "FloatOutput" 1358 device: "/job:overwrite" } 1359 node { name: "FloatOutput_2" op: "FloatOutput" 1360 device: "/job:overwrite" } 1361 node { name: "FloatOutput_3" op: "FloatOutput" 1362 device: "/job:ps" } 1363 node { name: "FloatOutput_4" op: "FloatOutput" 1364 device: "/job:ps" } 1365 """, gd) 1366 1367 1368class MultithreadedGraphStateTest(test_util.TensorFlowTestCase): 1369 1370 class TestThread(threading.Thread): 1371 1372 def __init__(self, graph, replica_id): 1373 super(MultithreadedGraphStateTest.TestThread, self).__init__() 1374 self._graph = graph 1375 self._replica_id = replica_id 1376 # This thread sets this event when it mutated the graph. The caller can 1377 # wait for that. 1378 self.has_mutated_graph = threading.Event() 1379 # This thread waits for when it should continue. The caller can set this 1380 # event. 1381 self.should_continue = threading.Event() 1382 1383 def run(self): 1384 # Mutate a graph's stack, then set `has_mutated_graph`, then wait for 1385 # `should_continue`, then add an op to the graph affected by the graph's 1386 # stack. 1387 raise NotImplementedError("must be implemented in descendants") 1388 1389 def testDeviceFunctionStack(self): 1390 1391 class DeviceSettingThread(self.TestThread): 1392 1393 def run(self): 1394 with g.device("/job:worker/replica:{}".format(self._replica_id)): 1395 self.has_mutated_graph.set() 1396 self.should_continue.wait() 1397 self.should_continue.clear() 1398 g.create_op( 1399 "FloatOutput", [], [dtypes.float32], 1400 name="FloatOutput_{}".format(self._replica_id)) 1401 1402 g = ops.Graph() 1403 # If `switch_to_thread` isn't called, then device placement of the ops 1404 # below is not deterministic. 1405 g.switch_to_thread_local() 1406 threads = [DeviceSettingThread(g, i) for i in range(3)] 1407 for t in threads: 1408 t.start() 1409 t.has_mutated_graph.wait() 1410 t.has_mutated_graph.clear() 1411 for t in threads: 1412 t.should_continue.set() 1413 t.join() 1414 1415 gd = g.as_graph_def() 1416 self.assertProtoEqualsVersion(""" 1417 node { name: "FloatOutput_0" op: "FloatOutput" 1418 device: "/job:worker/replica:0" } 1419 node { name: "FloatOutput_1" op: "FloatOutput" 1420 device: "/job:worker/replica:1" } 1421 node { name: "FloatOutput_2" op: "FloatOutput" 1422 device: "/job:worker/replica:2" } 1423 """, gd) 1424 1425 def testColocateWith(self): 1426 1427 class ColocatingThread(self.TestThread): 1428 1429 def __init__(self, graph, replica_id, op_to_colocate_with): 1430 super(ColocatingThread, self).__init__(graph, replica_id) 1431 self._op_to_colocate_with = op_to_colocate_with 1432 1433 def run(self): 1434 with g.colocate_with(self._op_to_colocate_with): 1435 self.has_mutated_graph.set() 1436 self.should_continue.wait() 1437 self.should_continue.clear() 1438 g.create_op( 1439 "FloatOutput", [], [dtypes.float32], 1440 name="FloatOutput_{}".format(self._replica_id)) 1441 1442 g = ops.Graph() 1443 ops_to_colocate_with = [] 1444 for i in range(3): 1445 with g.device("/job:worker/replica:{}".format(i)): 1446 ops_to_colocate_with.append( 1447 g.create_op( 1448 "FloatOutput", [], [dtypes.float32], 1449 name="ColocateWithMe_{}".format(i))) 1450 1451 # If `switch_to_thread` isn't called, then `device` and `attr` values for 1452 # the ops below are not deterministic. 1453 g.switch_to_thread_local() 1454 threads = [ 1455 ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3) 1456 ] 1457 for t in threads: 1458 t.start() 1459 t.has_mutated_graph.wait() 1460 t.has_mutated_graph.clear() 1461 for t in threads: 1462 t.should_continue.set() 1463 t.join() 1464 1465 gd = g.as_graph_def() 1466 self.assertProtoEqualsVersion(""" 1467 node { name: "ColocateWithMe_0" op: "FloatOutput" 1468 device: "/job:worker/replica:0" } 1469 node { name: "ColocateWithMe_1" op: "FloatOutput" 1470 device: "/job:worker/replica:1" } 1471 node { name: "ColocateWithMe_2" op: "FloatOutput" 1472 device: "/job:worker/replica:2" } 1473 node { name: "FloatOutput_0" op: "FloatOutput" 1474 device: "/job:worker/replica:0" 1475 attr { key: "_class" 1476 value { list { 1477 s: "loc:@ColocateWithMe_0"}}}} 1478 node { name: "FloatOutput_1" op: "FloatOutput" 1479 device: "/job:worker/replica:1" 1480 attr { key: "_class" 1481 value { list { 1482 s: "loc:@ColocateWithMe_1"}}}} 1483 node { name: "FloatOutput_2" op: "FloatOutput" 1484 device: "/job:worker/replica:2" 1485 attr { key: "_class" 1486 value { list { 1487 s: "loc:@ColocateWithMe_2"}}}} 1488 """, gd) 1489 1490 def testControlDependencies(self): 1491 1492 class DependingThread(self.TestThread): 1493 1494 def __init__(self, graph, replica_id, dependency_op): 1495 super(DependingThread, self).__init__(graph, replica_id) 1496 self._dependency_op = dependency_op 1497 1498 def run(self): 1499 with g.control_dependencies([self._dependency_op]): 1500 self.has_mutated_graph.set() 1501 self.should_continue.wait() 1502 self.should_continue.clear() 1503 g.create_op( 1504 "FloatOutput", [], [dtypes.float32], 1505 name="FloatOutput_{}".format(self._replica_id)) 1506 1507 g = ops.Graph() 1508 dependency_ops = [] 1509 for i in range(3): 1510 dependency_ops.append( 1511 g.create_op( 1512 "FloatOutput", [], [dtypes.float32], 1513 name="ColocateWithMe_{}".format(i))) 1514 1515 # If `switch_to_thread` isn't called, then `input` values for the ops below 1516 # are not deterministic. 1517 g.switch_to_thread_local() 1518 threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)] 1519 for t in threads: 1520 t.start() 1521 t.has_mutated_graph.wait() 1522 t.has_mutated_graph.clear() 1523 for t in threads: 1524 t.should_continue.set() 1525 t.join() 1526 1527 gd = g.as_graph_def() 1528 self.assertProtoEqualsVersion(""" 1529 node { name: "ColocateWithMe_0" op: "FloatOutput" } 1530 node { name: "ColocateWithMe_1" op: "FloatOutput" } 1531 node { name: "ColocateWithMe_2" op: "FloatOutput" } 1532 node { name: "FloatOutput_0" op: "FloatOutput" 1533 input: "^ColocateWithMe_0" } 1534 node { name: "FloatOutput_1" op: "FloatOutput" 1535 input: "^ColocateWithMe_1" } 1536 node { name: "FloatOutput_2" op: "FloatOutput" 1537 input: "^ColocateWithMe_2" } 1538 """, gd) 1539 1540 def testNameStack(self): 1541 1542 class NameSettingThread(self.TestThread): 1543 1544 def run(self): 1545 with g.name_scope("foo"): 1546 op1 = g.create_op("FloatOutput", [], [dtypes.float32]) 1547 self.has_mutated_graph.set() 1548 self.should_continue.wait() 1549 self.should_continue.clear() 1550 op2 = g.create_op("FloatOutput", [], [dtypes.float32]) 1551 self.result = (op1, op2) 1552 1553 g = ops.Graph() 1554 threads = [NameSettingThread(g, i) for i in range(3)] 1555 for t in threads: 1556 t.start() 1557 t.has_mutated_graph.wait() 1558 t.has_mutated_graph.clear() 1559 1560 for t in threads: 1561 t.should_continue.set() 1562 t.join() 1563 1564 suffixes = ["", "_1", "_2"] 1565 for t, s in zip(threads, suffixes): 1566 self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name) 1567 self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name) 1568 1569 1570class ObjectWithName(object): 1571 1572 def __init__(self, name): 1573 self._name = name 1574 1575 @property 1576 def name(self): 1577 return self._name 1578 1579 1580class CollectionTest(test_util.TensorFlowTestCase): 1581 1582 def test_get_collections(self): 1583 g = ops.Graph() 1584 self.assertSequenceEqual(g.collections, []) 1585 g.add_to_collection("key", 12) 1586 g.add_to_collection("key", 15) 1587 self.assertSequenceEqual(g.collections, ["key"]) 1588 g.add_to_collection("other", "foo") 1589 self.assertSequenceEqual(sorted(g.collections), ["key", "other"]) 1590 self.assertSequenceEqual( 1591 sorted(g.get_all_collection_keys()), ["key", "other"]) 1592 1593 def test_add_to_collection(self): 1594 g = ops.Graph() 1595 g.add_to_collection("key", 12) 1596 g.add_to_collection("other", "foo") 1597 g.add_to_collection("key", 34) 1598 1599 # Note that only blank1 is returned. 1600 g.add_to_collection("blah", 27) 1601 blank1 = ObjectWithName("prefix/foo") 1602 g.add_to_collection("blah", blank1) 1603 blank2 = ObjectWithName("junk/foo") 1604 g.add_to_collection("blah", blank2) 1605 1606 self.assertEqual([12, 34], g.get_collection("key")) 1607 self.assertEqual([], g.get_collection("nothing")) 1608 self.assertEqual([27, blank1, blank2], g.get_collection("blah")) 1609 self.assertEqual([blank1], g.get_collection("blah", "prefix")) 1610 self.assertEqual([blank1], g.get_collection("blah", ".*x")) 1611 1612 # Make sure that get_collection() returns a first-level 1613 # copy of the collection, while get_collection_ref() returns 1614 # the original list. 1615 other_collection_snapshot = g.get_collection("other") 1616 other_collection_ref = g.get_collection_ref("other") 1617 self.assertEqual(["foo"], other_collection_snapshot) 1618 self.assertEqual(["foo"], other_collection_ref) 1619 g.add_to_collection("other", "bar") 1620 self.assertEqual(["foo"], other_collection_snapshot) 1621 self.assertEqual(["foo", "bar"], other_collection_ref) 1622 self.assertEqual(["foo", "bar"], g.get_collection("other")) 1623 self.assertTrue(other_collection_ref is g.get_collection_ref("other")) 1624 1625 # Verify that getting an empty collection ref returns a modifiable list. 1626 empty_coll_ref = g.get_collection_ref("empty") 1627 self.assertEqual([], empty_coll_ref) 1628 empty_coll = g.get_collection("empty") 1629 self.assertEqual([], empty_coll) 1630 self.assertFalse(empty_coll is empty_coll_ref) 1631 empty_coll_ref2 = g.get_collection_ref("empty") 1632 self.assertTrue(empty_coll_ref2 is empty_coll_ref) 1633 # Add to the collection. 1634 empty_coll_ref.append("something") 1635 self.assertEqual(["something"], empty_coll_ref) 1636 self.assertEqual(["something"], empty_coll_ref2) 1637 self.assertEqual([], empty_coll) 1638 self.assertEqual(["something"], g.get_collection("empty")) 1639 empty_coll_ref3 = g.get_collection_ref("empty") 1640 self.assertTrue(empty_coll_ref3 is empty_coll_ref) 1641 1642 def test_add_to_collections_uniquify(self): 1643 g = ops.Graph() 1644 g.add_to_collections([1, 2, 1], "key") 1645 # Make sure "key" is not added twice 1646 self.assertEqual(["key"], g.get_collection(1)) 1647 1648 def test_add_to_collections_from_list(self): 1649 g = ops.Graph() 1650 g.add_to_collections(["abc", "123"], "key") 1651 self.assertEqual(["key"], g.get_collection("abc")) 1652 self.assertEqual(["key"], g.get_collection("123")) 1653 1654 def test_add_to_collections_from_tuple(self): 1655 g = ops.Graph() 1656 g.add_to_collections(("abc", "123"), "key") 1657 self.assertEqual(["key"], g.get_collection("abc")) 1658 self.assertEqual(["key"], g.get_collection("123")) 1659 1660 def test_add_to_collections_from_generator(self): 1661 g = ops.Graph() 1662 1663 def generator(): 1664 yield "abc" 1665 yield "123" 1666 1667 g.add_to_collections(generator(), "key") 1668 self.assertEqual(["key"], g.get_collection("abc")) 1669 self.assertEqual(["key"], g.get_collection("123")) 1670 1671 def test_add_to_collections_from_set(self): 1672 g = ops.Graph() 1673 g.add_to_collections(set(["abc", "123"]), "key") 1674 self.assertEqual(["key"], g.get_collection("abc")) 1675 self.assertEqual(["key"], g.get_collection("123")) 1676 1677 def test_add_to_collections_from_string(self): 1678 g = ops.Graph() 1679 g.add_to_collections("abc", "key") 1680 self.assertEqual(["key"], g.get_collection("abc")) 1681 1682 def test_default_graph(self): 1683 with ops.Graph().as_default(): 1684 ops.add_to_collection("key", 90) 1685 ops.add_to_collection("key", 100) 1686 # Collections are ordered. 1687 self.assertEqual([90, 100], ops.get_collection("key")) 1688 1689 def test_defun(self): 1690 with context.eager_mode(): 1691 1692 @eager_function.defun 1693 def defun(): 1694 ops.add_to_collection("int", 1) 1695 ops.add_to_collection("tensor", constant_op.constant(2)) 1696 1697 @eager_function.defun 1698 def inner_defun(): 1699 self.assertEqual(ops.get_collection("int"), [1]) 1700 three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0] 1701 ops.add_to_collection("int", 2) 1702 self.assertEqual(ops.get_collection("int"), [1, 2]) 1703 ops.add_to_collection("foo", "bar") 1704 self.assertEqual(ops.get_collection("foo"), ["bar"]) 1705 return three 1706 1707 self.assertEqual(ops.get_collection("int"), [1]) 1708 three = inner_defun() 1709 self.assertEqual(ops.get_collection("int"), [1]) 1710 self.assertEqual(ops.get_collection("foo"), []) 1711 return three 1712 1713 three = defun() 1714 self.assertEqual(three.numpy(), 3) 1715 1716 1717ops.NotDifferentiable("FloatOutput") 1718 1719 1720@ops.RegisterGradient("CopyOp") 1721def _CopyGrad(op, x_grad): # pylint: disable=invalid-name 1722 _ = op 1723 return x_grad 1724 1725 1726@ops.RegisterGradient("copy_override") 1727def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name 1728 _ = op 1729 return x_grad 1730 1731 1732class RegistrationTest(test_util.TensorFlowTestCase): 1733 1734 @test_util.run_deprecated_v1 1735 def testRegisterGradients(self): 1736 x = test_ops.float_output() 1737 y = test_ops.copy_op(x) 1738 fn = ops.get_gradient_function(y.op) 1739 self.assertEqual(_CopyGrad, fn) 1740 1741 def testOverrideGradients(self): 1742 g = ops.Graph() 1743 with g.as_default(): 1744 x = test_ops.float_output() 1745 with g.gradient_override_map({"CopyOp": "copy_override"}): 1746 y = test_ops.copy_op(x) 1747 fn = ops.get_gradient_function(y.op) 1748 self.assertEqual(_CopyOverrideGrad, fn) 1749 1750 def testNonExistentOverride(self): 1751 g = ops.Graph() 1752 with g.as_default(): 1753 x = test_ops.float_output() 1754 with g.gradient_override_map({"CopyOp": "unknown_override"}): 1755 y = test_ops.copy_op(x) 1756 with self.assertRaisesRegexp(LookupError, "unknown_override"): 1757 ops.get_gradient_function(y.op) 1758 1759 1760class ComparisonTest(test_util.TensorFlowTestCase): 1761 1762 def testMembershipAllowed(self): 1763 g = ops.Graph() 1764 t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") 1765 t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") 1766 self.assertTrue(isinstance(t1, ops.Tensor)) 1767 self.assertTrue(isinstance(t2, ops.Tensor)) 1768 self.assertTrue(t1 in [t1]) 1769 self.assertTrue(t1 not in [t2]) 1770 1771 1772class ControlDependenciesTest(test_util.TensorFlowTestCase): 1773 1774 @test_util.run_deprecated_v1 1775 def testBasic(self): 1776 g = ops.Graph() 1777 with g.as_default(): 1778 # Creating unregistered ops with _apply_op() doesn't work with the C API 1779 # TODO(skyewm): address this more consistently. Possible solutions are 1780 # to use registered ops in all tests, create a way to register ops in 1781 # Python tests, or conditionally disable the op registration check in 1782 # the C API. 1783 a = constant_op.constant(1.0) 1784 b = constant_op.constant(1.0) 1785 with g.control_dependencies([a]): 1786 c = constant_op.constant(1.0) 1787 d = array_ops.identity(b) 1788 e = array_ops.identity(c) 1789 1790 self.assertEqual(c.op.control_inputs, [a.op]) 1791 self.assertEqual(d.op.control_inputs, [a.op]) 1792 # e should be dominated by c. 1793 self.assertEqual(e.op.control_inputs, []) 1794 1795 @test_util.run_in_graph_and_eager_modes 1796 def testEager(self): 1797 def future(): 1798 future.calls += 1 1799 return constant_op.constant(2.0) 1800 future.calls = 0 1801 1802 if context.executing_eagerly(): 1803 a = constant_op.constant(1.0) 1804 b = future 1805 with ops.control_dependencies([a, b]): 1806 c = constant_op.constant(3.0) 1807 self.assertEqual(future.calls, 1) 1808 else: 1809 g = ops.Graph() 1810 with g.as_default(): 1811 a = constant_op.constant(1.0) 1812 b = future() 1813 with g.control_dependencies([a, b]): 1814 c = constant_op.constant(3.0) 1815 self.assertEqual(c.op.control_inputs, [a.op, b.op]) 1816 self.assertEqual(future.calls, 1) 1817 1818 def testBasicWithConversion(self): 1819 g = ops.Graph() 1820 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1821 1822 class ConvertibleObj(object): 1823 1824 def _as_graph_element(self): 1825 return a 1826 1827 with g.control_dependencies([ConvertibleObj()]): 1828 c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1829 1830 self.assertEqual(c.op.control_inputs, [a.op]) 1831 1832 def testNested(self): 1833 g = ops.Graph() 1834 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1835 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1836 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1837 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1838 1839 with g.control_dependencies([a_1, a_2, a_3, a_4]): 1840 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1841 1842 with g.control_dependencies([a_1]): 1843 with g.control_dependencies([a_2]): 1844 with g.control_dependencies([a_3]): 1845 with g.control_dependencies([a_4]): 1846 b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1847 1848 self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op], 1849 b_1.op.control_inputs) 1850 self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) 1851 1852 def testClear(self): 1853 g = ops.Graph() 1854 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1855 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1856 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1857 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1858 1859 with g.control_dependencies([a_1]): 1860 with g.control_dependencies([a_2]): 1861 with g.control_dependencies(None): 1862 with g.control_dependencies([a_3]): 1863 with g.control_dependencies([a_4]): 1864 # deps [a_3, a_4] 1865 b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1866 # deps = [a_3] 1867 b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1868 # deps back to None 1869 b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1870 # deps back to [a_1, a_2] 1871 b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1872 # deps back to [a_1] 1873 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1874 with g.control_dependencies(None): 1875 # deps are None again 1876 b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1877 1878 self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) 1879 self.assertItemsEqual([a_3.op], b_3.op.control_inputs) 1880 self.assertItemsEqual([], b_none.op.control_inputs) 1881 self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) 1882 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 1883 self.assertItemsEqual([], b_none2.op.control_inputs) 1884 1885 def testComplex(self): 1886 g = ops.Graph() 1887 1888 # Usage pattern: 1889 # * Nodes a_i are constants defined at the outermost scope, and are used 1890 # as control inputs for the ith nested scope. 1891 # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. 1892 # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. 1893 # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. 1894 # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. 1895 1896 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1897 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1898 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1899 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1900 1901 with g.control_dependencies([a_1]): 1902 b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 1903 [dtypes.float32]) 1904 c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 1905 [dtypes.float32]) 1906 d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1], 1907 [dtypes.float32]) 1908 e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1909 with g.control_dependencies([a_2]): 1910 b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 1911 [dtypes.float32]) 1912 c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 1913 [dtypes.float32]) 1914 d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2], 1915 [dtypes.float32]) 1916 e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1], 1917 [dtypes.float32]) 1918 with g.control_dependencies([a_3]): 1919 b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 1920 [dtypes.float32]) 1921 c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 1922 [dtypes.float32]) 1923 d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3], 1924 [dtypes.float32]) 1925 e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2], 1926 [dtypes.float32]) 1927 with g.control_dependencies([a_4]): 1928 b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], 1929 [dtypes.float32]) 1930 c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], 1931 [dtypes.float32]) 1932 d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4], 1933 [dtypes.float32]) 1934 e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3], 1935 [dtypes.float32]) 1936 1937 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 1938 self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) 1939 self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) 1940 self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) 1941 1942 self.assertItemsEqual([], c_1.op.control_inputs) 1943 self.assertItemsEqual([a_2.op], c_2.op.control_inputs) 1944 self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) 1945 self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) 1946 1947 self.assertItemsEqual([], d_1.op.control_inputs) 1948 self.assertItemsEqual([], d_2.op.control_inputs) 1949 self.assertItemsEqual([], d_3.op.control_inputs) 1950 self.assertItemsEqual([], d_4.op.control_inputs) 1951 1952 self.assertItemsEqual([a_1.op], e_1.op.control_inputs) 1953 self.assertItemsEqual([a_2.op], e_2.op.control_inputs) 1954 self.assertItemsEqual([a_3.op], e_3.op.control_inputs) 1955 self.assertItemsEqual([a_4.op], e_4.op.control_inputs) 1956 1957 def testRepeatedDependency(self): 1958 g = ops.Graph() 1959 a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) 1960 a_0, a_1 = a.outputs 1961 with g.control_dependencies([a_0]): 1962 b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1963 with g.control_dependencies([a_1]): 1964 c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1965 1966 self.assertEqual(b.op.control_inputs, [a]) 1967 self.assertEqual(c.op.control_inputs, [a]) 1968 1969 def testNoControlDependencyWithDataDependency(self): 1970 g = ops.Graph() 1971 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 1972 with g.control_dependencies([a]): 1973 b = _apply_op(g, "Identity", [a], [dtypes.float32]) 1974 1975 self.assertEqual(b.op.control_inputs, []) 1976 1977 1978class OpScopeTest(test_util.TensorFlowTestCase): 1979 1980 @test_util.run_in_graph_and_eager_modes 1981 def testNames(self): 1982 with ops.name_scope("foo") as foo: 1983 self.assertEqual("foo/", foo) 1984 with ops.name_scope("foo2") as foo2: 1985 self.assertEqual("foo/foo2/", foo2) 1986 with ops.name_scope(None) as empty1: 1987 self.assertEqual("", empty1) 1988 with ops.name_scope("foo3") as foo3: 1989 self.assertEqual("foo3/", foo3) 1990 with ops.name_scope("") as empty2: 1991 self.assertEqual("", empty2) 1992 with ops.name_scope("foo/") as outer_foo: 1993 self.assertEqual("foo/", outer_foo) 1994 with ops.name_scope("") as empty3: 1995 self.assertEqual("", empty3) 1996 with ops.name_scope("foo4") as foo4: 1997 self.assertEqual("foo/foo4/", foo4) 1998 with ops.name_scope("foo5//") as foo5: 1999 self.assertEqual("foo5//", foo5) 2000 with ops.name_scope("foo6") as foo6: 2001 self.assertEqual("foo5//foo6/", foo6) 2002 with ops.name_scope("/") as foo7: 2003 self.assertEqual("/", foo7) 2004 with ops.name_scope("//") as foo8: 2005 self.assertEqual("//", foo8) 2006 with ops.name_scope("a//b/c") as foo9: 2007 self.assertEqual("foo/a//b/c/", foo9) 2008 with ops.name_scope("a//b/c") as foo10: 2009 self.assertEqual("a//b/c/", foo10) 2010 2011 @test_util.run_in_graph_and_eager_modes 2012 def testEagerDefaultScopeName(self): 2013 with ops.name_scope(None, "default") as scope: 2014 self.assertEqual(scope, "default/") 2015 with ops.name_scope(None, "default2") as scope2: 2016 self.assertEqual(scope2, "default/default2/") 2017 2018 @test_util.run_deprecated_v1 2019 def testNoScopeName(self): 2020 g0 = ops.Graph() 2021 values = [ 2022 g0.create_op("A", [], [dtypes.float32]), 2023 g0.create_op("B", [], [dtypes.float32]) 2024 ] 2025 with self.assertRaises(ValueError): 2026 with ops.name_scope(None, values=values): 2027 pass 2028 with self.assertRaises(ValueError): 2029 with ops.name_scope(None, None, values): 2030 pass 2031 2032 @test_util.run_deprecated_v1 2033 def testEmptyScopeName(self): 2034 g0 = ops.Graph() 2035 a = g0.create_op("A", [], [dtypes.float32]) 2036 b = g0.create_op("B", [], [dtypes.float32]) 2037 with ops.name_scope("", values=[a, b]) as scope: 2038 self.assertEqual("", scope) 2039 self.assertEqual(g0, ops.get_default_graph()) 2040 with ops.name_scope("", "my_default_scope", [a, b]) as scope: 2041 self.assertEqual("", scope) 2042 self.assertEqual(g0, ops.get_default_graph()) 2043 2044 @test_util.run_deprecated_v1 2045 def testDefaultScopeName(self): 2046 g0 = ops.Graph() 2047 a = g0.create_op("A", [], [dtypes.float32]) 2048 b = g0.create_op("B", [], [dtypes.float32]) 2049 scope_name = "my_scope" 2050 default_scope_name = "my_default_scope" 2051 with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope: 2052 self.assertEqual("%s/" % scope_name, scope) 2053 self.assertEqual(g0, ops.get_default_graph()) 2054 with ops.name_scope(None, default_scope_name, [a, b]) as scope: 2055 self.assertEqual("%s/" % default_scope_name, scope) 2056 self.assertEqual(g0, ops.get_default_graph()) 2057 with self.assertRaises(TypeError): 2058 with ops.name_scope(scope_name, [a, b]): 2059 pass 2060 2061 def _testGraphElements(self, graph_elements): 2062 scope_name = "my_scope" 2063 with ops.name_scope(scope_name, values=graph_elements) as scope: 2064 self.assertEqual("%s/" % scope_name, scope) 2065 self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) 2066 g1 = ops.Graph() 2067 a = g1.create_op("A", [], [dtypes.float32]) 2068 with self.assertRaises(ValueError): 2069 with ops.name_scope(scope_name, values=graph_elements + [a]): 2070 pass 2071 2072 @test_util.run_deprecated_v1 2073 def testTensor(self): 2074 g0 = ops.Graph() 2075 a = g0.create_op("A", [], [dtypes.float32]) 2076 b = g0.create_op("B", [], [dtypes.float32]) 2077 self._testGraphElements([a, b]) 2078 2079 @test_util.run_deprecated_v1 2080 def testSparseTensor(self): 2081 g0 = ops.Graph() 2082 a = g0.create_op("A", [], [dtypes.float32]) 2083 b = g0.create_op("B", [], [dtypes.float32]) 2084 sparse = sparse_tensor.SparseTensor( 2085 _apply_op(g0, "Int64Output", [], [dtypes.int64]), 2086 _apply_op(g0, "FloatOutput", [], [dtypes.float32]), 2087 _apply_op(g0, "Int64Output", [], [dtypes.int64])) 2088 self._testGraphElements([a, sparse, b]) 2089 2090 @test_util.run_deprecated_v1 2091 def testVariable(self): 2092 g0 = ops.Graph() 2093 with g0.as_default(): 2094 variable = variables.Variable([1.0]) 2095 a = g0.create_op("A", [], [dtypes.float32]) 2096 b = g0.create_op("B", [], [dtypes.float32]) 2097 self._testGraphElements([a, variable, b]) 2098 2099 2100class InitScopeTest(test_util.TensorFlowTestCase): 2101 2102 def testClearsControlDependencies(self): 2103 g = ops.Graph() 2104 a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2105 a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2106 a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2107 a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2108 2109 with g.as_default(): 2110 with g.control_dependencies([a_1]): 2111 with g.control_dependencies([a_2]): 2112 with ops.init_scope(): 2113 with g.control_dependencies([a_3]): 2114 with g.control_dependencies([a_4]): 2115 # deps [a_3, a_4] 2116 b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2117 # deps = [a_3] 2118 b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2119 # deps back to None 2120 b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2121 # deps back to [a_1, a_2] 2122 b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2123 # deps back to [a_1] 2124 b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2125 with ops.init_scope(): 2126 # deps are None again 2127 b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2128 2129 self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) 2130 self.assertItemsEqual([a_3.op], b_3.op.control_inputs) 2131 self.assertItemsEqual([], b_none.op.control_inputs) 2132 self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) 2133 self.assertItemsEqual([a_1.op], b_1.op.control_inputs) 2134 self.assertItemsEqual([], b_none2.op.control_inputs) 2135 2136 def testLiftsOpsFromFunctions(self): 2137 g0 = ops.Graph() 2138 g1 = ops.Graph() 2139 g1._building_function = True # pylint: disable=protected-access 2140 g2 = ops.Graph() 2141 g2._building_function = True # pylint: disable=protected-access 2142 2143 with g0.as_default(): 2144 with g1.as_default(): 2145 with g2.as_default(): 2146 with ops.init_scope(): 2147 _ = constant_op.constant(1.0) 2148 2149 self.assertEqual(len(g2.get_operations()), 0) 2150 self.assertEqual(len(g1.get_operations()), 0) 2151 self.assertEqual(len(g0.get_operations()), 1) 2152 2153 def testPreservesDevices(self): 2154 g0 = ops.Graph() 2155 with g0.as_default(), ops.device("CPU:0"): 2156 g1 = ops.Graph() 2157 g1._building_function = True # pylint: disable=protected-access 2158 with g1.as_default(): 2159 with ops.device("GPU:0"): 2160 with ops.init_scope(): 2161 # init_scope should preserve device set under `g1`. 2162 on_gpu = constant_op.constant(1.0) 2163 self.assertEqual(on_gpu.device, "/device:GPU:0") 2164 still_on_gpu = constant_op.constant(1.0) 2165 self.assertEqual(still_on_gpu.device, "/device:GPU:0") 2166 blank = constant_op.constant(1.0) 2167 self.assertEqual(blank.device, "") 2168 with ops.init_scope(): 2169 now_on_cpu = constant_op.constant(1.0) 2170 self.assertEqual(now_on_cpu.device, "/device:CPU:0") 2171 on_cpu = constant_op.constant(1.0) 2172 self.assertEqual(on_cpu.device, "/device:CPU:0") 2173 2174 def testComposes(self): 2175 g0 = ops.Graph() 2176 g1 = ops.Graph() 2177 g1._building_function = True # pylint: disable=protected-access 2178 g2 = ops.Graph() 2179 g2._building_function = True # pylint: disable=protected-access 2180 g3 = ops.Graph() 2181 g3._building_function = False # pylint: disable=protected-access 2182 2183 with g0.as_default(): 2184 with g1.as_default(): 2185 with ops.init_scope(): 2186 # This op should be lifted into g0. 2187 _ = constant_op.constant(1.0) 2188 self.assertIs(g0, ops.get_default_graph()) 2189 self.assertEqual(len(g2.get_operations()), 0) 2190 self.assertEqual(len(g1.get_operations()), 0) 2191 self.assertEqual(len(g0.get_operations()), 1) 2192 with g2.as_default(): 2193 with ops.init_scope(): 2194 # This op should be lifted into g0. 2195 _ = constant_op.constant(1.0) 2196 self.assertIs(g0, ops.get_default_graph()) 2197 with g3.as_default(): 2198 with ops.init_scope(): 2199 # This op should be lifted into g3, because g3 is not building a 2200 # function. 2201 _ = constant_op.constant(1.0) 2202 self.assertIs(g3, ops.get_default_graph()) 2203 2204 self.assertEqual(len(g3.get_operations()), 1) 2205 self.assertEqual(len(g2.get_operations()), 0) 2206 self.assertEqual(len(g1.get_operations()), 0) 2207 self.assertEqual(len(g0.get_operations()), 2) 2208 2209 def testEscapesToEagerContext(self): 2210 g = ops.Graph() 2211 g._building_function = True # pylint: disable=protected-access 2212 with context.eager_mode(): 2213 with context.graph_mode(): 2214 with g.as_default(): 2215 with ops.init_scope(): 2216 # Because g is building a function, init_scope should 2217 # escape out to the eager context. 2218 self.assertTrue(context.executing_eagerly()) 2219 # g should be reinstated as the default graph, and the 2220 # graph context should be re-entered. 2221 self.assertIs(g, ops.get_default_graph()) 2222 self.assertFalse(context.executing_eagerly()) 2223 2224 def testStaysInEagerWhenOnlyEagerContextActive(self): 2225 with context.eager_mode(): 2226 with ops.init_scope(): 2227 self.assertTrue(context.eager_mode()) 2228 self.assertTrue(context.eager_mode()) 2229 2230 def testEscapesDefunWhenInEagerMode(self): 2231 2232 def function_with_variables(): 2233 with ops.init_scope(): 2234 self.v = resource_variable_ops.ResourceVariable(3) 2235 return self.v.assign_add(1) 2236 2237 with context.eager_mode(): 2238 # Each invocation of function_with_variables recreates a variable. 2239 self.assertEqual(4, int(function_with_variables())) 2240 self.assertEqual(4, int(function_with_variables())) 2241 2242 compiled = eager_function.defun(function_with_variables) 2243 # The init_scope in function_with_variables lifts the variable out 2244 # of the graph function constructed by defun; hence, 2245 # compiled now appears to be stateful. 2246 self.assertEqual(4, int(compiled())) 2247 self.assertEqual(5, int(compiled())) 2248 2249 def testEscapesDefunWhenInGraphMode(self): 2250 def function_with_variables(name): 2251 with ops.init_scope(): 2252 _ = variable_scope.get_variable(name, shape=(1,)) 2253 2254 g = ops.Graph() 2255 with g.as_default(): 2256 with self.cached_session(): 2257 # First ensure that graphs that are not building functions are 2258 # not escaped. 2259 function_with_variables("foo") 2260 with self.assertRaisesRegexp(ValueError, 2261 r"Variable foo already exists.*"): 2262 # This will fail because reuse is not set to True. 2263 function_with_variables("foo") 2264 2265 compiled = eager_function.defun(function_with_variables) 2266 compiled("bar") 2267 self.assertEqual( 2268 len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) 2269 2270 # The second call to `compiled` should not create variables: the 2271 # init_scope has lifted the variable creation code out of the defun. 2272 compiled("bar") 2273 self.assertEqual( 2274 len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) 2275 2276 def testEscapesNestedDefun(self): 2277 2278 def inner_function(): 2279 with ops.init_scope(): 2280 self.v = resource_variable_ops.ResourceVariable(1) 2281 return self.v.assign_add(2) 2282 2283 def outer_function(inner=None): 2284 with ops.init_scope(): 2285 self.v0 = resource_variable_ops.ResourceVariable(0) 2286 return self.v0.assign_add(1) + inner() 2287 2288 with context.eager_mode(): 2289 # Each invocation of outer_function recreates variables. 2290 self.assertEqual(4, int(outer_function(inner=inner_function))) 2291 self.assertEqual(4, int(outer_function(inner=inner_function))) 2292 2293 compiled_inner = eager_function.defun(inner_function) 2294 compiled_outer = eager_function.defun(outer_function) 2295 # The init_scope lifts variables out of the graph functions 2296 # constructed by defun; hence, compiled_outer should now appear to be 2297 # stateful. 2298 self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) 2299 self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) 2300 2301 @test_util.run_v1_only("b/120545219") 2302 def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self): 2303 with context.graph_mode(): 2304 ops.reset_default_graph() 2305 # This doesn't push anything onto the graph stack, but it does 2306 # set the stack's global graph. 2307 global_graph = ops.get_default_graph() 2308 fn_graph = ops.Graph() 2309 2310 # pylint: disable=protected-access 2311 fn_graph._building_function = True 2312 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2313 with fn_graph.as_default(): 2314 self.assertEqual(len(ops._default_graph_stack.stack), 1) 2315 with ops.init_scope(): 2316 self.assertGreater(len(ops._default_graph_stack.stack), 1) 2317 dummy = constant_op.constant(1.0) 2318 self.assertEqual(len(ops._default_graph_stack.stack), 1) 2319 # Note that the global graph is _not_ on the graph stack. 2320 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2321 # Ensure that `dummy` was added to the global graph. 2322 self.assertEqual(global_graph, dummy.graph) 2323 # pylint: enable=protected-access 2324 2325 def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self): 2326 with context.graph_mode(): 2327 # pylint: disable=protected-access 2328 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2329 with ops.init_scope(): 2330 self.assertGreater(len(ops._default_graph_stack.stack), 0) 2331 self.assertEqual(len(ops._default_graph_stack.stack), 0) 2332 # pylint: enable=protected-access 2333 2334 def testPreservesNameScopeInGraphConstruction(self): 2335 with ops.Graph().as_default(): 2336 function_graph = ops.Graph() 2337 with function_graph.as_default(): 2338 with ops.name_scope("inner"), ops.init_scope(): 2339 self.assertEqual(ops.get_name_scope(), "inner") 2340 self.assertEqual(ops.get_name_scope(), "") 2341 2342 def testEnteringGraphFromEagerIsSticky(self): 2343 with context.eager_mode(): 2344 g = ops.Graph() 2345 with g.as_default(): 2346 with ops.init_scope(): 2347 self.assertFalse(context.executing_eagerly()) 2348 self.assertEqual(g, ops.get_default_graph()) 2349 2350 def testMixGraphEager(self): 2351 with context.eager_mode(): 2352 c = constant_op.constant(1.0) 2353 with ops.Graph().as_default(): 2354 with self.assertRaisesRegexp( 2355 RuntimeError, "Attempting to capture an EagerTensor"): 2356 math_ops.add(c, c) 2357 c2 = constant_op.constant(2.0) 2358 with self.assertRaisesRegexp( 2359 TypeError, "Graph tensors"): 2360 math_ops.add(c2, c2) 2361 2362 def testPreservesNameScopeInEagerExecution(self): 2363 with context.eager_mode(): 2364 def foo(): 2365 with ops.name_scope("inner"), ops.init_scope(): 2366 if context.executing_eagerly(): 2367 # A trailing slash is always appended when eager execution is 2368 # enabled. 2369 self.assertEqual(context.context().scope_name, "inner/") 2370 else: 2371 self.assertEqual(ops.get_name_scope(), "inner") 2372 2373 foo() 2374 self.assertEqual(ops.get_name_scope(), "") 2375 foo_compiled = eager_function.defun(foo) 2376 foo_compiled() 2377 self.assertEqual(ops.get_name_scope(), "") 2378 2379 def testExecutingEagerlyOutsideFunctions(self): 2380 2381 @eager_function.defun 2382 def f(): 2383 return ops.executing_eagerly_outside_functions() 2384 2385 with context.eager_mode(): 2386 self.assertTrue(ops.executing_eagerly_outside_functions()) 2387 self.assertTrue(f()) 2388 g = ops.Graph() 2389 with g.as_default(): 2390 self.assertFalse(ops.executing_eagerly_outside_functions()) 2391 2392 2393class GraphTest(test_util.TensorFlowTestCase): 2394 2395 def setUp(self): 2396 ops.reset_default_graph() 2397 2398 def _AssertDefault(self, expected): 2399 self.assertIs(expected, ops.get_default_graph()) 2400 2401 def testResetDefaultGraphNesting(self): 2402 g0 = ops.Graph() 2403 with self.assertRaises(AssertionError): 2404 with g0.as_default(): 2405 ops.reset_default_graph() 2406 2407 def testGraphContextManagerCancelsEager(self): 2408 with context.eager_mode(): 2409 with ops.Graph().as_default(): 2410 self.assertFalse(context.executing_eagerly()) 2411 2412 def testGraphContextManager(self): 2413 g0 = ops.Graph() 2414 with g0.as_default() as g1: 2415 self.assertIs(g0, g1) 2416 2417 def testDefaultGraph(self): 2418 orig = ops.get_default_graph() 2419 self.assertFalse(ops.has_default_graph()) 2420 self._AssertDefault(orig) 2421 g0 = ops.Graph() 2422 self.assertFalse(ops.has_default_graph()) 2423 self._AssertDefault(orig) 2424 context_manager_0 = g0.as_default() 2425 self.assertFalse(ops.has_default_graph()) 2426 self._AssertDefault(orig) 2427 with context_manager_0 as g0: 2428 self._AssertDefault(g0) 2429 with ops.Graph().as_default() as g1: 2430 self.assertTrue(ops.has_default_graph()) 2431 self._AssertDefault(g1) 2432 self._AssertDefault(g0) 2433 self._AssertDefault(orig) 2434 self.assertFalse(ops.has_default_graph()) 2435 2436 def testPreventFeeding(self): 2437 g = ops.Graph() 2438 a = constant_op.constant(2.0) 2439 self.assertTrue(g.is_feedable(a)) 2440 g.prevent_feeding(a) 2441 self.assertFalse(g.is_feedable(a)) 2442 2443 @test_util.run_deprecated_v1 2444 def testPreventFetching(self): 2445 g = ops.Graph() 2446 a = constant_op.constant(2.0) 2447 self.assertTrue(g.is_fetchable(a)) 2448 g.prevent_fetching(a.op) 2449 self.assertFalse(g.is_fetchable(a)) 2450 2451 def testAsGraphElementConversions(self): 2452 2453 class ConvertibleObj(object): 2454 2455 def _as_graph_element(self): 2456 return "FloatOutput:0" 2457 2458 class NonConvertibleObj(object): 2459 2460 pass 2461 2462 g = ops.Graph() 2463 a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) 2464 self.assertEqual(a, g.as_graph_element(ConvertibleObj())) 2465 with self.assertRaises(TypeError): 2466 g.as_graph_element(NonConvertibleObj()) 2467 2468 # Regression test against creating custom __del__ functions in classes 2469 # involved in cyclic references, e.g. Graph and Operation. (Python won't gc 2470 # cycles that require calling a __del__ method, because the __del__ method can 2471 # theoretically increase the object's refcount to "save" it from gc, and any 2472 # already-deleted objects in the cycle would have be to restored.) 2473 def testGarbageCollected(self): 2474 # Create a graph we can delete and a weak reference to monitor if it's gc'd 2475 g = ops.Graph() 2476 g_ref = weakref.ref(g) 2477 # Create some ops 2478 with g.as_default(): 2479 a = constant_op.constant(2.0) 2480 b = constant_op.constant(3.0) 2481 c = math_ops.add(a, b) 2482 # Create a session we can delete 2483 with session.Session(graph=g) as sess: 2484 self.evaluate(c) 2485 # Delete all references and trigger gc 2486 del g 2487 del a 2488 del b 2489 del c 2490 del sess 2491 gc.collect() 2492 self.assertIsNone(g_ref()) 2493 2494 def testRunnableAfterInvalidShape(self): 2495 with ops.Graph().as_default(): 2496 with self.assertRaises(ValueError): 2497 math_ops.add([1, 2], [1, 2, 3]) 2498 a = constant_op.constant(1) 2499 with session.Session() as sess: 2500 self.evaluate(a) 2501 2502 def testRunnableAfterInvalidShapeWithKernelLabelMap(self): 2503 g = ops.Graph() 2504 with g.as_default(): 2505 with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): 2506 with self.assertRaises(ValueError): 2507 test_ops.kernel_label_required(1) 2508 a = constant_op.constant(1) 2509 with session.Session() as sess: 2510 self.evaluate(a) 2511 2512 2513class AttrScopeTest(test_util.TensorFlowTestCase): 2514 2515 def _get_test_attrs(self): 2516 x = control_flow_ops.no_op() 2517 try: 2518 a = compat.as_text(x.get_attr("_A")) 2519 except ValueError: 2520 a = None 2521 try: 2522 b = compat.as_text(x.get_attr("_B")) 2523 except ValueError: 2524 b = None 2525 return (a, b) 2526 2527 @test_util.run_deprecated_v1 2528 def testNoLabel(self): 2529 with self.cached_session(): 2530 self.assertAllEqual((None, None), self._get_test_attrs()) 2531 2532 @test_util.run_deprecated_v1 2533 def testLabelMap(self): 2534 with self.cached_session() as sess: 2535 a1 = self._get_test_attrs() 2536 with sess.graph._attr_scope({ 2537 "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) 2538 }): 2539 a2 = self._get_test_attrs() 2540 with sess.graph._attr_scope({ 2541 "_A": None, 2542 "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar")) 2543 }): 2544 a3 = self._get_test_attrs() 2545 with sess.graph._attr_scope({ 2546 "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz")) 2547 }): 2548 a4 = self._get_test_attrs() 2549 a5 = self._get_test_attrs() 2550 a6 = self._get_test_attrs() 2551 a7 = self._get_test_attrs() 2552 2553 self.assertAllEqual((None, None), a1) 2554 self.assertAllEqual(("foo", None), a2) 2555 self.assertAllEqual((None, "bar"), a3) 2556 self.assertAllEqual(("baz", "bar"), a4) 2557 self.assertAllEqual((None, "bar"), a5) 2558 self.assertAllEqual(("foo", None), a6) 2559 self.assertAllEqual((None, None), a7) 2560 2561 2562ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) 2563 2564 2565class KernelLabelTest(test_util.TensorFlowTestCase): 2566 2567 @test_util.run_deprecated_v1 2568 def testNoLabel(self): 2569 with self.cached_session(): 2570 self.assertAllEqual(b"My label is: default", 2571 test_ops.kernel_label().eval()) 2572 2573 @test_util.run_deprecated_v1 2574 def testLabelMap(self): 2575 with self.cached_session() as sess: 2576 default_1 = test_ops.kernel_label() 2577 # pylint: disable=protected-access 2578 with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): 2579 overload_1_1 = test_ops.kernel_label() 2580 with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): 2581 overload_2 = test_ops.kernel_label() 2582 with sess.graph._kernel_label_map({"KernelLabel": ""}): 2583 default_2 = test_ops.kernel_label() 2584 overload_1_2 = test_ops.kernel_label() 2585 # pylint: enable=protected-access 2586 default_3 = test_ops.kernel_label() 2587 2588 self.assertAllEqual(b"My label is: default", self.evaluate(default_1)) 2589 self.assertAllEqual(b"My label is: default", self.evaluate(default_2)) 2590 self.assertAllEqual(b"My label is: default", self.evaluate(default_3)) 2591 self.assertAllEqual(b"My label is: overload_1", 2592 self.evaluate(overload_1_1)) 2593 self.assertAllEqual(b"My label is: overload_1", 2594 self.evaluate(overload_1_2)) 2595 self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2)) 2596 2597 2598class AsGraphDefTest(test_util.TensorFlowTestCase): 2599 2600 def testGraphDefVersion(self): 2601 """Test that the graphdef version is plumbed through to kernels.""" 2602 with ops.Graph().as_default() as g: 2603 version = g.graph_def_versions.producer 2604 with self.session(graph=g): 2605 v = test_ops.graph_def_version().eval() 2606 self.assertEqual(version, v) 2607 2608 def testAddShapes(self): 2609 with ops.Graph().as_default() as g: 2610 t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [], 2611 [dtypes.float32] * 5) 2612 t1.set_shape(None) 2613 t2.set_shape([]) 2614 t3.set_shape([None]) 2615 t4.set_shape([43, 37]) 2616 t5.set_shape([43, None]) 2617 2618 b = constant_op.constant(1.0) # pylint: disable=unused-variable 2619 2620 gd = g.as_graph_def(add_shapes=True) 2621 self.assertProtoEqualsVersion(""" 2622 node { name: "FiveFloatOutputs" op: "FiveFloatOutputs" 2623 attr { 2624 key: "_output_shapes" 2625 value { 2626 list { 2627 shape { unknown_rank: true } 2628 shape { } 2629 shape { dim { size: -1 } } 2630 shape { dim { size: 43 } dim { size: 37 } } 2631 shape { dim { size: 43 } dim { size: -1 } } 2632 } 2633 } 2634 } 2635 } 2636 node { name: "Const" op: "Const" 2637 attr { 2638 key: "_output_shapes" 2639 value { 2640 list { 2641 shape { } 2642 } 2643 } 2644 } 2645 attr { 2646 key: "dtype" 2647 value { type: DT_FLOAT } 2648 } 2649 attr { 2650 key: "value" 2651 value { 2652 tensor { 2653 dtype: DT_FLOAT 2654 tensor_shape { } 2655 float_val: 1.0 } } } } 2656 """, gd) 2657 2658 2659@ops.RegisterStatistics("a", "flops") 2660def _calc_a_forward_flops(unused_graph, unused_node): 2661 return ops.OpStats("flops", 20) 2662 2663 2664class StatisticsTest(test_util.TensorFlowTestCase): 2665 2666 def testRegisteredNode(self): 2667 graph = ops.Graph() 2668 node = ops._NodeDef("a", "an_a") 2669 flops = ops.get_stats_for_node_def(graph, node, "flops") 2670 self.assertEqual(20, flops.value) 2671 missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat") 2672 self.assertEqual(None, missing_stat.value) 2673 2674 def testUnregisteredNode(self): 2675 graph = ops.Graph() 2676 node = ops._NodeDef("b", "a_b") 2677 weight_params = ops.get_stats_for_node_def(graph, node, "weight_params") 2678 self.assertEqual(None, weight_params.value) 2679 2680 def testAccumulateStatistics(self): 2681 flops_total = ops.OpStats("flops") 2682 self.assertEqual(None, flops_total.value) 2683 second_flops = ops.OpStats("flops", 3) 2684 flops_total += second_flops 2685 self.assertEqual(3, flops_total.value) 2686 2687 2688class DeviceStackTest(test_util.TensorFlowTestCase): 2689 2690 @test_util.run_deprecated_v1 2691 def testBasicDeviceAssignmentMetadata(self): 2692 2693 def device_func(unused_op): 2694 return "/cpu:*" 2695 2696 const_zero = constant_op.constant([0.0], name="zero") 2697 with ops.device("/cpu"): 2698 const_one = constant_op.constant([1.0], name="one") 2699 with ops.device("/cpu:0"): 2700 const_two = constant_op.constant([2.0], name="two") 2701 with ops.device(device_func): 2702 const_three = constant_op.constant(3.0, name="three") 2703 2704 self.assertEqual(0, len(const_zero.op._device_assignments)) 2705 2706 one_list = const_one.op._device_assignments 2707 self.assertEqual(1, len(one_list)) 2708 self.assertEqual("/cpu", one_list[0].obj) 2709 self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) 2710 2711 two_list = const_two.op._device_assignments 2712 self.assertEqual(2, len(two_list)) 2713 devices = [t.obj for t in two_list] 2714 self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) 2715 2716 three_list = const_three.op._device_assignments 2717 self.assertEqual(1, len(three_list)) 2718 func_description = three_list[0].obj 2719 expected_regex = r"device_func<.*ops_test.py, [0-9]+" 2720 self.assertRegexpMatches(func_description, expected_regex) 2721 2722 @test_util.run_deprecated_v1 2723 def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): 2724 2725 with ops.device("/cpu"): 2726 const_one = constant_op.constant([1.0], name="one") 2727 with ops.get_default_graph().device("/cpu"): 2728 const_two = constant_op.constant([2.0], name="two") 2729 2730 one_metadata = const_one.op._device_assignments[0] 2731 two_metadata = const_two.op._device_assignments[0] 2732 2733 # Verify both types of device assignment return the right stack info. 2734 self.assertRegexpMatches("ops_test.py", 2735 os.path.basename(one_metadata.filename)) 2736 self.assertEqual(one_metadata.filename, two_metadata.filename) 2737 self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) 2738 2739 2740class ColocationGroupTest(test_util.TensorFlowTestCase): 2741 2742 @test_util.run_deprecated_v1 2743 def testBasic(self): 2744 a = constant_op.constant([2.0], name="a") 2745 with ops.colocate_with(a.op): 2746 b = constant_op.constant(3.0) 2747 c = constant_op.constant(4.0) 2748 self.assertEqual([b"loc:@a"], a.op.colocation_groups()) 2749 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2750 with self.assertRaises(ValueError): 2751 c.op.get_attr("_class") 2752 2753 @test_util.run_deprecated_v1 2754 def testBasicColocationMetadata(self): 2755 const_two = constant_op.constant([2.0], name="two") 2756 with ops.colocate_with(const_two.op): 2757 const_three = constant_op.constant(3.0, name="three") 2758 locations_dict = const_three.op._colocation_dict 2759 self.assertIn("two", locations_dict) 2760 metadata = locations_dict["two"] 2761 self.assertIsNone(metadata.obj) 2762 # Check that this test's filename is recorded as the file containing the 2763 # colocation statement. 2764 self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) 2765 2766 @test_util.run_deprecated_v1 2767 def testColocationDeviceInteraction(self): 2768 with ops.device("/cpu:0"): 2769 with ops.device("/device:GPU:0"): 2770 a = constant_op.constant([2.0], name="a") 2771 with ops.colocate_with(a.op): 2772 # 'b' is created in the scope of /cpu:0, but it is 2773 # colocated with 'a', which is on '/device:GPU:0'. colocate_with 2774 # overrides devices because it is a stronger constraint. 2775 b = constant_op.constant(3.0) 2776 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2777 self.assertEqual(a.op.device, b.op.device) 2778 2779 @test_util.run_deprecated_v1 2780 def testColocationCanonicalization(self): 2781 with ops.device("/device:GPU:0"): 2782 _ = constant_op.constant(2.0) 2783 with ops.device(lambda op: "/device:GPU:0"): 2784 b = constant_op.constant(3.0) 2785 with ops.get_default_graph().colocate_with(b): 2786 with ops.device("/device:GPU:0"): 2787 c = constant_op.constant(4.0) 2788 2789 # A's device will be /device:GPU:0 2790 # B's device will be /device:GPU:0 2791 # C's device will be /device:GPU:0 because it 2792 # inherits B's device name, after canonicalizing the names. 2793 self.assertEqual(b.op.device, c.op.device) 2794 2795 @test_util.run_deprecated_v1 2796 def testLocationOverrides(self): 2797 with ops.device("/cpu:0"): 2798 with ops.device("/device:GPU:0"): 2799 a = constant_op.constant([2.0], name="a") 2800 # Note that this colocation is "redundant", since we are 2801 # within the scope of "/device:GPU:0". However, we would like to 2802 # preserve in the GraphDef that these two ops should be 2803 # colocated in a portable way. 2804 with ops.colocate_with(a.op): 2805 b = constant_op.constant(3.0) 2806 c = constant_op.constant(4.0) 2807 d = constant_op.constant(5.0) 2808 2809 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2810 self.assertEqual("/device:GPU:0", a.op.device) 2811 self.assertEqual(a.op.device, b.op.device) 2812 2813 # Test that device function stack is restored. 2814 self.assertEqual("/device:GPU:0", c.op.device) 2815 self.assertEqual("/device:CPU:0", d.op.device) 2816 2817 @test_util.run_deprecated_v1 2818 def testNestedColocateWith(self): 2819 a = constant_op.constant([2.0], name="a") 2820 with ops.colocate_with(a.op): 2821 b = constant_op.constant(3.0) 2822 with ops.colocate_with(b.op): 2823 c = constant_op.constant(4.0) 2824 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2825 self.assertEqual([b"loc:@a"], c.op.colocation_groups()) 2826 2827 @test_util.run_deprecated_v1 2828 def testMultiColocationGroups(self): 2829 a = constant_op.constant([2.0], name="a") 2830 b = constant_op.constant(3.0, name="b") 2831 with ops.colocate_with(a.op): 2832 with ops.colocate_with(b.op): 2833 c = constant_op.constant(4.0) 2834 self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups())) 2835 2836 @test_util.run_deprecated_v1 2837 def testColocationIgnoreStack(self): 2838 a = constant_op.constant([2.0], name="a") 2839 b = constant_op.constant(3.0, name="b") 2840 with ops.colocate_with(a.op): 2841 with ops.colocate_with(b.op, ignore_existing=True): 2842 c = constant_op.constant(4.0) 2843 self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups())) 2844 2845 @test_util.run_deprecated_v1 2846 def testColocateWithReset(self): 2847 a = constant_op.constant([2.0], name="a") 2848 with ops.colocate_with(a.op): 2849 b = constant_op.constant(3.0, name="b") 2850 with ops.colocate_with(None, ignore_existing=True): 2851 c = constant_op.constant(4.0, name="c") 2852 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2853 self.assertEqual([b"loc:@c"], c.op.colocation_groups()) 2854 2855 @test_util.run_deprecated_v1 2856 def testColocateWithInitialNoneThenNested(self): 2857 a = constant_op.constant([2.0], name="a") 2858 with ops.colocate_with(a.op): 2859 with ops.colocate_with(None, ignore_existing=True): 2860 b = constant_op.constant(3.0, name="b") 2861 with ops.colocate_with(b.op): 2862 c = constant_op.constant(4.0, name="c") 2863 self.assertEqual([b"loc:@b"], b.op.colocation_groups()) 2864 self.assertEqual([b"loc:@b"], c.op.colocation_groups()) 2865 2866 @test_util.run_deprecated_v1 2867 def testColocateVariables(self): 2868 a = variables.Variable([2.0], name="a") 2869 with ops.colocate_with(a.op): 2870 b = variables.Variable([3.0], name="b") 2871 self.assertEqual([b"loc:@a"], b.op.colocation_groups()) 2872 2873 2874class DeprecatedTest(test_util.TensorFlowTestCase): 2875 2876 def testSuccess(self): 2877 with ops.Graph().as_default() as g: 2878 test_util.set_producer_version(g, 7) 2879 old = test_ops.old() 2880 with self.session(graph=g): 2881 old.run() 2882 2883 def _error(self): 2884 return ((r"Op Old is not available in GraphDef version %d\. " 2885 r"It has been removed in version 8\. For reasons\.") % 2886 versions.GRAPH_DEF_VERSION) 2887 2888 def testGraphConstructionFail(self): 2889 with ops.Graph().as_default(): 2890 with self.assertRaisesRegexp(NotImplementedError, self._error()): 2891 test_ops.old() 2892 2893 2894class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase): 2895 2896 def testSuccess(self): 2897 op = ops.Operation( 2898 ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) 2899 t = op.outputs[0] 2900 self.assertTrue(ops.is_dense_tensor_like(t)) 2901 2902 v = variables.Variable([17]) 2903 self.assertTrue(ops.is_dense_tensor_like(v)) 2904 2905 class BadClassNoName(object): 2906 pass 2907 2908 class BadClassBadName(object): 2909 2910 def name(self): 2911 pass 2912 2913 class BadClassNoDtype(object): 2914 2915 @property 2916 def name(self): 2917 pass 2918 2919 class BadClassBadDtype(object): 2920 2921 @property 2922 def name(self): 2923 pass 2924 2925 def dtype(self): 2926 pass 2927 2928 def testBadClass(self): 2929 with self.assertRaisesRegexp(TypeError, "`name`"): 2930 ops.register_dense_tensor_like_type( 2931 DenseTensorLikeTypeTest.BadClassNoName) 2932 with self.assertRaisesRegexp(TypeError, "`name`"): 2933 ops.register_dense_tensor_like_type( 2934 DenseTensorLikeTypeTest.BadClassBadName) 2935 with self.assertRaisesRegexp(TypeError, "`dtype`"): 2936 ops.register_dense_tensor_like_type( 2937 DenseTensorLikeTypeTest.BadClassNoDtype) 2938 with self.assertRaisesRegexp(TypeError, "`dtype`"): 2939 ops.register_dense_tensor_like_type( 2940 DenseTensorLikeTypeTest.BadClassBadDtype) 2941 2942 2943class NameScopeTest(test_util.TensorFlowTestCase): 2944 2945 def testStripAndPrependScope(self): 2946 strs = [ 2947 "hidden1/hidden1/weights", # Same prefix. Should strip. 2948 "hidden1///hidden1/weights", # Extra "/". Should strip. 2949 "^hidden1/hidden1/weights", # Same prefix. Should strip. 2950 "loc:@hidden1/hidden1/weights", # Same prefix. Should strip. 2951 "hhidden1/hidden1/weights", # Different prefix. Should keep. 2952 "hidden1" 2953 ] # Not a prefix. Should keep. 2954 expected_striped = [ 2955 "hidden1/weights", "hidden1/weights", "^hidden1/weights", 2956 "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1" 2957 ] 2958 expected_prepended = [ 2959 "hidden2/hidden1/weights", "hidden2/hidden1/weights", 2960 "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights", 2961 "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1" 2962 ] 2963 name_scope_to_strip = "hidden1" 2964 name_scope_to_add = "hidden2" 2965 for es, ep, s in zip(expected_striped, expected_prepended, strs): 2966 striped = ops.strip_name_scope(s, name_scope_to_strip) 2967 self.assertEqual(es, striped) 2968 self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add)) 2969 2970 def testGetNameScope(self): 2971 with ops.Graph().as_default() as g: 2972 with ops.name_scope("scope1"): 2973 with ops.name_scope("scope2"): 2974 with ops.name_scope("scope3"): 2975 self.assertEqual("scope1/scope2/scope3", g.get_name_scope()) 2976 self.assertEqual("scope1/scope2", g.get_name_scope()) 2977 self.assertEqual("scope1", g.get_name_scope()) 2978 self.assertEqual("", g.get_name_scope()) 2979 2980 def testTwoGraphs(self): 2981 2982 def f(): 2983 g1 = ops.Graph() 2984 g2 = ops.Graph() 2985 with g1.as_default(): 2986 with g2.as_default(): 2987 with ops.name_scope("_"): 2988 pass 2989 2990 self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f) 2991 2992 2993class TracebackTest(test_util.TensorFlowTestCase): 2994 2995 @test_util.run_deprecated_v1 2996 def testTracebackWithStartLines(self): 2997 with self.cached_session() as sess: 2998 a = constant_op.constant(2.0) 2999 sess.run( 3000 a, 3001 options=config_pb2.RunOptions( 3002 trace_level=config_pb2.RunOptions.FULL_TRACE)) 3003 self.assertTrue(sess.graph.get_operations()) 3004 3005 # Tests that traceback_with_start_lines is the same as traceback 3006 # but includes one more element at the end. 3007 for op in sess.graph.get_operations(): 3008 self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines)) 3009 for frame, frame_with_start_line in zip( 3010 op.traceback, op.traceback_with_start_lines): 3011 self.assertEquals(5, len(frame_with_start_line)) 3012 self.assertEquals(frame, frame_with_start_line[:-1]) 3013 3014 3015class EnableEagerExecutionTest(test_util.TensorFlowTestCase): 3016 3017 @test_util.run_v1_only("b/120545219") 3018 def testBadArgumentsToEnableEagerExecution(self): 3019 with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"): 3020 ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT) 3021 with self.assertRaisesRegexp(ValueError, "device_policy must be one of"): 3022 c = config_pb2.ConfigProto() 3023 ops.enable_eager_execution(c, c) 3024 with self.assertRaisesRegexp(ValueError, "execution_mode must be one of"): 3025 c = config_pb2.ConfigProto() 3026 ops.enable_eager_execution(c, execution_mode=c) 3027 3028 3029if __name__ == "__main__": 3030 googletest.main() 3031