• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import gc
23import numpy as np
24import os
25import threading
26import weakref
27
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.python.autograph.core import ag_ctx
31from tensorflow.python.client import session
32from tensorflow.python.eager import backprop
33from tensorflow.python.eager import context
34from tensorflow.python.eager import def_function
35from tensorflow.python.eager import function as eager_function
36from tensorflow.python.eager import wrap_function
37from tensorflow.python.framework import config
38from tensorflow.python.framework import composite_tensor
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import device as pydev
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import errors
43from tensorflow.python.framework import function
44from tensorflow.python.framework import indexed_slices
45from tensorflow.python.framework import ops
46from tensorflow.python.framework import sparse_tensor
47from tensorflow.python.framework import tensor_shape
48from tensorflow.python.framework import tensor_spec
49from tensorflow.python.framework import tensor_util
50from tensorflow.python.framework import test_ops
51from tensorflow.python.framework import test_util
52from tensorflow.python.framework import type_spec
53from tensorflow.python.framework import versions
54from tensorflow.python.ops import array_ops
55from tensorflow.python.ops import control_flow_ops
56from tensorflow.python.ops import math_ops
57from tensorflow.python.ops import resource_variable_ops
58from tensorflow.python.ops import resources
59from tensorflow.python.ops import special_math_ops
60from tensorflow.python.ops import variable_scope
61from tensorflow.python.ops import variables
62import tensorflow.python.ops.gradients  # pylint: disable=unused-import
63from tensorflow.python.platform import googletest
64from tensorflow.python.util import compat
65
66
67class ResourceTest(test_util.TensorFlowTestCase):
68
69  @test_util.run_deprecated_v1
70  def testBuildGraph(self):
71    with self.cached_session():
72      pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
73      test_ops.resource_create_op(pt).run()
74
75  @test_util.run_deprecated_v1
76  def testInitialize(self):
77    with self.cached_session():
78      handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
79      resources.register_resource(
80          handle=handle,
81          create_op=test_ops.resource_create_op(handle),
82          is_initialized_op=test_ops.resource_initialized_op(handle))
83      self.assertEqual(
84          len(
85              resources.report_uninitialized_resources(
86                  resources.shared_resources()).eval()), 1)
87      resources.initialize_resources(resources.shared_resources()).run()
88      self.assertEqual(
89          len(
90              resources.report_uninitialized_resources(
91                  resources.shared_resources()).eval()), 0)
92
93
94class TensorAndShapeTest(test_util.TensorFlowTestCase):
95
96  def testShape(self):
97    op = ops.Operation(
98        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
99    t = op.outputs[0]
100    self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
101    t.set_shape([1, 2, 3])
102    self.assertEqual([1, 2, 3], t.get_shape())
103
104  def testIterable(self):
105    if not context.executing_eagerly():
106      self.skipTest("Eager-mode test")
107    op = ops.Operation(
108        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
109    t = op.outputs[0]
110    with self.assertRaisesRegex(TypeError, "Cannot iterate"):
111      iter(t)
112
113  def testIterableGraph(self):
114    if context.executing_eagerly():
115      self.skipTest("Graph-mode test")
116
117    op = ops.Operation(
118        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
119    t = op.outputs[0]
120    with self.assertRaisesRegex(TypeError, "iterating.*not allowed in Graph"):
121      next(iter(t))
122    with self.assertRaisesRegex(TypeError, "iterating.*AutoGraph did convert"):
123      with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
124        next(iter(t))
125    with self.assertRaisesRegex(TypeError, "iterating.*AutoGraph is disabled"):
126      with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
127        next(iter(t))
128
129  def testImplicitBool(self):
130    op = ops.Operation(
131        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.bool])
132    t = op.outputs[0]
133    with self.assertRaisesRegex(TypeError,
134                                "using.*as a.*bool.*not allowed in Graph"):
135      bool(t)
136    with self.assertRaisesRegex(TypeError,
137                                "using.*as a.*bool.*AutoGraph did convert"):
138      with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
139        bool(t)
140    with self.assertRaisesRegex(TypeError,
141                                "using.*as a.*bool.*AutoGraph is disabled"):
142      with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
143        bool(t)
144
145  def testAddShape(self):
146    with self.cached_session():
147      a = array_ops.zeros([2, 3])
148      b = array_ops.ones([1, 3])
149      c = a + b
150      self.assertEqual([2, 3], c.shape)
151
152  @test_util.run_deprecated_v1
153  def testUnknownDim(self):
154    with self.cached_session():
155      a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
156      b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
157      c = a + b
158      self.assertEqual([2, None, 3], c.shape.as_list())
159
160  @test_util.run_deprecated_v1
161  def testUnknownShape(self):
162    with self.cached_session():
163      a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
164      b = array_ops.ones([1, 3])
165      c = a + b
166      self.assertEqual(tensor_shape.unknown_shape(), c.shape)
167
168  @test_util.run_deprecated_v1
169  def testScalarShape(self):
170    with self.cached_session():
171      a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
172      b = array_ops.ones([])
173      c = a + b
174      self.assertEqual(tensor_shape.TensorShape([]), c.shape)
175
176  @test_util.run_deprecated_v1
177  def testShapeFunctionError(self):
178    with self.cached_session():
179      a = array_ops.ones([1, 2, 3])
180      b = array_ops.ones([4, 5, 6])
181      with self.assertRaisesRegex(
182          ValueError, r"Dimensions must be equal, but are 2 and 5 for .*add"
183          r".*Add(V2)?.* with input shapes: \[1,2,3\], \[4,5,6\]."):
184        _ = a + b
185
186  def testNumpyArray(self):
187    with ops.Graph().as_default():
188      x = array_ops.ones((3, 4), name="test_ones")
189
190    with self.assertRaisesRegex(NotImplementedError,
191                                r"Cannot convert a symbolic.+test_ones"):
192      np.array(x)
193
194    with self.assertRaisesRegex(TypeError, "not well defined.+test_ones"):
195      len(x)
196
197    # EagerTensors should still behave as numpy arrays.
198    with context.eager_mode():
199      x = array_ops.ones((3, 4))
200
201    self.assertAllEqual(x, np.ones((3, 4)))
202    self.assertAllEqual(np.array(x), np.ones((3, 4)))
203    self.assertEqual(len(x), 3)
204
205  def testConstructor(self):
206    a = array_ops.ones([])
207    for name in ["T", "astype", "ravel", "transpose", "reshape", "clip", "size",
208                 "tolist", "data"]:
209      with self.assertRaisesRegex(
210          AttributeError, r"If you are looking for numpy-related methods"):
211        getattr(a, name)
212    with self.assertRaisesRegex(
213        AttributeError, r"object has no attribute"):
214      a.foo_bar()
215
216  def testRef(self):
217    x1 = constant_op.constant(3)
218    x2 = x1
219    y = constant_op.constant(3)
220    z = constant_op.constant([6, 10])
221    w = variables.Variable(5)
222
223    self.assertEqual(x1.ref(), x1.ref())
224    self.assertEqual(x2.ref(), x2.ref())
225    self.assertEqual(x1.ref(), x2.ref())
226    self.assertEqual(y.ref(), y.ref())
227    self.assertEqual(z.ref(), z.ref())
228    self.assertEqual(w.ref(), w.ref())
229
230    self.assertNotEqual(x1.ref(), y.ref())
231    self.assertNotEqual(x1.ref(), z.ref())
232    self.assertNotEqual(x1.ref(), w.ref())
233    self.assertNotEqual(y.ref(), z.ref())
234    self.assertNotEqual(y.ref(), w.ref())
235    self.assertNotEqual(z.ref(), w.ref())
236
237  def testRefDeref(self):
238    x1 = constant_op.constant(3)
239    x2 = x1
240    y = constant_op.constant(3)
241    z = constant_op.constant([6, 10])
242    w = variables.Variable(5)
243
244    self.assertIs(x1, x1.ref().deref())
245    self.assertIs(x2, x2.ref().deref())
246    self.assertIs(x1, x2.ref().deref())
247    self.assertIs(x2, x1.ref().deref())
248    self.assertIs(y, y.ref().deref())
249    self.assertIs(z, z.ref().deref())
250
251    self.assertIsNot(x1, y.ref().deref())
252    self.assertIsNot(x1, z.ref().deref())
253    self.assertIsNot(x1, w.ref().deref())
254    self.assertIsNot(y, z.ref().deref())
255    self.assertIsNot(y, w.ref().deref())
256    self.assertIsNot(z, w.ref().deref())
257
258  def testRefInSet(self):
259    x1 = constant_op.constant(3)
260    x2 = x1
261    y = constant_op.constant(3)
262    z = constant_op.constant([6, 10])
263    w = variables.Variable(5)
264
265    self.assertEqual(x1.ref(), x2.ref())
266
267    tensor_set = {
268        x1.ref(),
269        x2.ref(),
270        y.ref(),
271        z.ref(),
272        w.ref(),
273    }
274
275    self.assertEqual(len(tensor_set), 4)
276    self.assertIn(x1.ref(), tensor_set)
277    self.assertIn(x2.ref(), tensor_set)
278    self.assertIn(y.ref(), tensor_set)
279    self.assertIn(z.ref(), tensor_set)
280    self.assertIn(w.ref(), tensor_set)
281
282  def testRefInDict(self):
283    x1 = constant_op.constant(3)
284    x2 = x1
285    y = constant_op.constant(3)
286    z = constant_op.constant([6, 10])
287    w = variables.Variable(5)
288
289    self.assertEqual(x1.ref(), x2.ref())
290
291    tensor_dict = {
292        x1.ref(): "x1",
293        y.ref(): "y",
294        z.ref(): "z",
295        w.ref(): "w",
296    }
297
298    self.assertEqual(len(tensor_dict), 4)
299
300    # Overwriting x1
301    tensor_dict[x2.ref()] = "x2"
302    self.assertEqual(len(tensor_dict), 4)
303
304    self.assertEqual(tensor_dict[x1.ref()], "x2")
305    self.assertEqual(tensor_dict[x2.ref()], "x2")
306    self.assertEqual(tensor_dict[y.ref()], "y")
307    self.assertEqual(tensor_dict[z.ref()], "z")
308    self.assertEqual(tensor_dict[w.ref()], "w")
309
310  def testTensorRefStrong(self):
311    x = constant_op.constant(1.)
312    x_ref = x.ref()
313    del x
314    self.assertIsNotNone(x_ref.deref())
315
316  def testVariableRefStrong(self):
317    x = variables.Variable(1.)
318    x_ref = x.ref()
319    del x
320    self.assertIsNotNone(x_ref.deref())
321
322  @test_util.run_in_graph_and_eager_modes
323  def testBitwiseAndNumeric(self):
324    x = constant_op.constant([0, 1, 3])
325    y = constant_op.constant([1, 1, 1])
326
327    z = x & y
328
329    self.assertAllEqual(z, [0, 1, 1])
330
331  @test_util.run_in_graph_and_eager_modes
332  def testBitwiseAndBool(self):
333    x = constant_op.constant([False, False, True, True])
334    y = constant_op.constant([False, True, False, True])
335
336    z = x & y
337
338    self.assertAllEqual(z, [False, False, False, True])
339
340  @test_util.run_in_graph_and_eager_modes
341  def testBitwiseAndErrors(self):
342    x_int = constant_op.constant(0)
343    x_bool = constant_op.constant(True)
344
345    if context.executing_eagerly():  # :(
346      expected_errtype = errors.InvalidArgumentError
347    else:
348      expected_errtype = TypeError
349
350    with self.assertRaises(expected_errtype):
351      _ = x_int & x_bool
352    with self.assertRaises(expected_errtype):
353      _ = x_int & constant_op.constant("a")
354
355    with self.assertRaises(expected_errtype):
356      _ = x_bool & x_int
357    with self.assertRaises(expected_errtype):
358      _ = x_bool & constant_op.constant("a")
359
360    with self.assertRaises(expected_errtype):
361      _ = constant_op.constant("a") & constant_op.constant("b")
362
363  @test_util.run_in_graph_and_eager_modes
364  def testBitwiseOrNumeric(self):
365    x = constant_op.constant([0, 1, 2])
366    y = constant_op.constant([1, 1, 1])
367
368    z = x | y
369
370    self.assertAllEqual(z, [1, 1, 3])
371
372  @test_util.run_in_graph_and_eager_modes
373  def testBitwiseOrBool(self):
374    x = constant_op.constant([False, False, True, True])
375    y = constant_op.constant([False, True, False, True])
376
377    z = x | y
378
379    self.assertAllEqual(z, [False, True, True, True])
380
381  @test_util.run_in_graph_and_eager_modes
382  def testBitwiseOrErrors(self):
383    x_int = constant_op.constant(0)
384    x_bool = constant_op.constant(True)
385
386    if context.executing_eagerly():  # :(
387      expected_errtype = errors.InvalidArgumentError
388    else:
389      expected_errtype = TypeError
390
391    with self.assertRaises(expected_errtype):
392      _ = x_int | x_bool
393    with self.assertRaises(expected_errtype):
394      _ = x_int | constant_op.constant("a")
395
396    with self.assertRaises(expected_errtype):
397      _ = x_bool | x_int
398    with self.assertRaises(expected_errtype):
399      _ = x_bool | constant_op.constant("a")
400
401    with self.assertRaises(expected_errtype):
402      _ = constant_op.constant("a") | constant_op.constant("b")
403
404  @test_util.run_in_graph_and_eager_modes
405  def testBitwiseXorNumeric(self):
406    x = constant_op.constant([0, 1, 3])
407    y = constant_op.constant([1, 1, 1])
408
409    z = x ^ y
410
411    self.assertAllEqual(z, [1, 0, 2])
412
413  @test_util.run_in_graph_and_eager_modes
414  def testBitwiseXorBool(self):
415    x = constant_op.constant([False, False, True, True])
416    y = constant_op.constant([False, True, False, True])
417
418    z = x ^ y
419
420    self.assertAllEqual(z, [False, True, True, False])
421
422  @test_util.run_in_graph_and_eager_modes
423  def testBitwiseXorErrors(self):
424    x_int = constant_op.constant(0)
425    x_bool = constant_op.constant(True)
426
427    if context.executing_eagerly():  # :(
428      expected_errtype = errors.InvalidArgumentError
429    else:
430      expected_errtype = TypeError
431
432    with self.assertRaises(expected_errtype):
433      _ = x_int ^ x_bool
434    with self.assertRaises(expected_errtype):
435      _ = x_int ^ constant_op.constant("a")
436
437    with self.assertRaises(expected_errtype):
438      _ = x_bool ^ x_int
439    with self.assertRaises(expected_errtype):
440      _ = x_bool ^ constant_op.constant("a")
441
442    with self.assertRaises(expected_errtype):
443      _ = constant_op.constant("a") ^ constant_op.constant("b")
444
445  @test_util.run_in_graph_and_eager_modes
446  def testBitwiseNotNumeric(self):
447    x = constant_op.constant([0, dtypes.int32.min, 1])
448
449    y = ~x
450
451    self.assertAllEqual(y, [-1, dtypes.int32.max, -2])
452
453  @test_util.run_in_graph_and_eager_modes
454  def testBitwiseNotBool(self):
455    x = constant_op.constant([False, True])
456
457    y = ~x
458
459    self.assertAllEqual(y, [True, False])
460
461  @test_util.run_in_graph_and_eager_modes
462  def testBitwiseNotErrors(self):
463    if context.executing_eagerly():  # :(
464      expected_errtype = errors.InvalidArgumentError
465    else:
466      expected_errtype = TypeError
467
468    with self.assertRaises(expected_errtype):
469      _ = ~constant_op.constant("a")
470
471
472@test_util.run_all_in_graph_and_eager_modes
473class IndexedSlicesTest(test_util.TensorFlowTestCase):
474
475  def testToTensor(self):
476    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
477    indices = constant_op.constant([0, 2])
478    x = ops.IndexedSlices(values, indices)
479    with self.assertRaises(ValueError):
480      tensor = ops.convert_to_tensor(x, name="tensor")
481    self.assertEqual(tensor_shape.TensorShape(None), x.shape)
482
483    dense_shape = constant_op.constant([3, 2])
484    y = ops.IndexedSlices(values, indices, dense_shape)
485    tensor = ops.convert_to_tensor(y, name="tensor")
486    self.assertAllEqual(tensor.shape, y.shape)
487    self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]])
488
489  @test_util.run_gpu_only
490  def testEagerCopy(self):
491    with context.eager_mode():
492      var = variables.Variable([[0.0], [0.0], [0.0], [0.0]], name="tensor")
493      with backprop.GradientTape() as tape:
494        a = array_ops.gather(array_ops.gather(var, [0, 1]), [0, 1])
495        b = array_ops.gather(array_ops.gather(var, [2, 3]), [0, 1])
496        r = special_math_ops.einsum("ij,ij->i", a, b)
497      g = tape.gradient(r, [var])[0]
498      values = g.values if isinstance(g, ops.IndexedSlices) else g
499      self.assertAllEqual(values.get_shape(), [4, 1])
500
501  def testNegation(self):
502    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
503    indices = constant_op.constant([0, 2])
504    x = -ops.IndexedSlices(values, indices)
505    self.assertAllEqual(x.values, [[-2, -3], [-5, -7]])
506    self.assertAllEqual(x.indices, [0, 2])
507
508  def testScalarMul(self):
509    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
510    indices = constant_op.constant([0, 2])
511    x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
512    self.assertAllEqual(x.values, [[-4, -6], [-10, -14]])
513    self.assertAllEqual(x.indices, [0, 2])
514
515
516@test_util.run_all_in_graph_and_eager_modes
517class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
518                            parameterized.TestCase):
519
520  def assertAllTensorsEqual(self, list1, list2):
521    self.assertLen(list1, len(list2))
522    for (t1, t2) in zip(list1, list2):
523      self.assertAllEqual(t1, t2)
524
525  def testConstruction(self):
526    spec1 = indexed_slices.IndexedSlicesSpec()
527    self.assertEqual(spec1._shape.rank, None)
528    self.assertEqual(spec1._values_dtype, dtypes.float32)
529    self.assertEqual(spec1._indices_dtype, dtypes.int64)
530    self.assertEqual(spec1._dense_shape_dtype, None)
531    self.assertEqual(spec1._indices_shape.as_list(), [None])
532
533    spec2 = indexed_slices.IndexedSlicesSpec([None, None], dtypes.string,
534                                             dtypes.int32, dtypes.int64, [10])
535    self.assertEqual(spec2._shape.as_list(), [None, None])
536    self.assertEqual(spec2._values_dtype, dtypes.string)
537    self.assertEqual(spec2._indices_dtype, dtypes.int32)
538    self.assertEqual(spec2._dense_shape_dtype, dtypes.int64)
539    self.assertEqual(spec2._indices_shape.as_list(), [10])
540
541  def testValueType(self):
542    spec1 = indexed_slices.IndexedSlicesSpec()
543    self.assertEqual(spec1.value_type, ops.IndexedSlices)
544
545  @parameterized.parameters([
546      (indexed_slices.IndexedSlicesSpec(),
547       (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
548        tensor_shape.TensorShape([None]))),
549      (indexed_slices.IndexedSlicesSpec(shape=[5, None, None]),
550       (tensor_shape.TensorShape([5, None, None]), dtypes.float32,
551        dtypes.int64, None, tensor_shape.TensorShape([None]))),
552      (indexed_slices.IndexedSlicesSpec(
553          dtype=dtypes.int32, dense_shape_dtype=dtypes.int64),
554       (tensor_shape.TensorShape(None), dtypes.int32, dtypes.int64,
555        dtypes.int64, tensor_shape.TensorShape([None]))),
556      (indexed_slices.IndexedSlicesSpec(indices_shape=[100]),
557       (tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
558        tensor_shape.TensorShape([100]))),
559  ])  # pyformat: disable
560  def testSerialize(self, spec, expected):
561    serialization = spec._serialize()
562    # TensorShape has an unconventional definition of equality, so we can't use
563    # assertEqual directly here.  But repr() is deterministic and lossless for
564    # the expected values, so we can use that instead.
565    self.assertEqual(repr(serialization), repr(expected))
566
567  @parameterized.parameters([
568      (indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), (
569          tensor_spec.TensorSpec(None, dtypes.string),
570          tensor_spec.TensorSpec([None], dtypes.int64),
571      )),
572      (indexed_slices.IndexedSlicesSpec(
573          dtype=dtypes.string, dense_shape_dtype=dtypes.int32), (
574              tensor_spec.TensorSpec(None, dtypes.string),
575              tensor_spec.TensorSpec([None], dtypes.int64),
576              tensor_spec.TensorSpec([None], dtypes.int32),
577          )),
578      (indexed_slices.IndexedSlicesSpec(
579          shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), (
580              tensor_spec.TensorSpec([None, 10, 15], dtypes.float32),
581              tensor_spec.TensorSpec([None], dtypes.int64),
582              tensor_spec.TensorSpec([3], dtypes.int32),
583          )),
584      (indexed_slices.IndexedSlicesSpec(
585          shape=[5, 10, 15], dense_shape_dtype=dtypes.int32,
586          indices_shape=[20]), (
587              tensor_spec.TensorSpec([20, 10, 15], dtypes.float32),
588              tensor_spec.TensorSpec([20], dtypes.int64),
589              tensor_spec.TensorSpec([3], dtypes.int32),
590          )),
591  ])
592  def testComponentSpecs(self, spec, expected):
593    self.assertEqual(spec._component_specs, expected)
594
595  @parameterized.parameters([
596      {
597          "spec": indexed_slices.IndexedSlicesSpec(),
598          "values": [3.0, 5.0],
599          "indices": [5, 10]
600      },
601      {
602          "spec":
603              indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32),
604          "values": [3.0, 5.0],
605          "indices": [5, 10],
606          "dense_shape": [100]
607      },
608  ])
609  def testToFromComponents(self, spec, indices, values, dense_shape=None):
610    x = ops.IndexedSlices(indices, values, dense_shape)
611    actual_components = spec._to_components(x)
612    if dense_shape is None:
613      self.assertAllTensorsEqual(actual_components, [indices, values])
614    else:
615      self.assertAllTensorsEqual(actual_components,
616                                 [indices, values, dense_shape])
617    st_reconstructed = spec._from_components(actual_components)
618    self.assertAllEqual(x.indices, st_reconstructed.indices)
619    self.assertAllEqual(x.values, st_reconstructed.values)
620    if dense_shape is None:
621      self.assertIs(st_reconstructed.dense_shape, None)
622    else:
623      self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape)
624
625  @test_util.run_v1_only("IndexedSlicesValue is deprecated in v2")
626  def testFromNumpyComponents(self):
627    indices = np.array([3, 8])
628    values = np.array([1.0, 9.0])
629    dense_shape = np.array([100])
630
631    spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32)
632    st1 = spec1._from_components((values, indices, dense_shape))
633    self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue)
634    self.assertAllEqual(st1.indices, indices)
635    self.assertAllEqual(st1.values, values)
636    self.assertAllEqual(st1.dense_shape, dense_shape)
637
638    spec2 = indexed_slices.IndexedSlicesSpec()
639    st2 = spec2._from_components((values, indices))
640    self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue)
641    self.assertAllEqual(st2.indices, indices)
642    self.assertAllEqual(st2.values, values)
643    self.assertIs(st2.dense_shape, None)
644
645
646class NodeDefConstructorTest(test_util.TensorFlowTestCase):
647
648  def testNoArgs(self):
649    nodedef = ops._NodeDef("None", "bar")
650    self.assertProtoEquals("op: 'None' name: 'bar'", nodedef)
651
652
653def _apply_op(g, *args, **kwargs):
654  op = g.create_op(*args, **kwargs)
655  if len(op.outputs) == 1:
656    return op.outputs[0]
657  else:
658    return op.outputs
659
660
661class OperationTest(test_util.TensorFlowTestCase):
662
663  @test_util.run_deprecated_v1
664  def testNoInputs(self):
665    op = test_ops.float_output_string_output(name="myop").a.op
666    self.assertEqual(2, len(op.values()))
667    self.assertEqual(0, len(op.inputs))
668    self.assertEqual("myop", op.name)
669
670    float_t, label_str_t = op.values()
671    self.assertEqual(dtypes.float32, float_t.dtype)
672    self.assertEqual(op, float_t.op)
673    self.assertEqual(0, float_t._value_index)
674    self.assertEqual(0, len(float_t.consumers()))
675    self.assertEqual("myop", float_t._as_node_def_input())
676
677    self.assertEqual(dtypes.string, label_str_t.dtype)
678    self.assertEqual(op, label_str_t.op)
679    self.assertEqual(1, label_str_t._value_index)
680    self.assertEqual(0, len(label_str_t.consumers()))
681    self.assertEqual("myop:1", label_str_t._as_node_def_input())
682
683    self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
684                           op.node_def)
685
686  @test_util.run_deprecated_v1
687  def testNoOutputs(self):
688    op1 = test_ops.float_output(name="myop1").op
689    float_t, = op1.values()
690    op2 = test_ops.float_input(float_t, name="myop2")
691    self.assertEqual(0, len(op2.values()))
692    self.assertEqual(1, len(op2.inputs))
693    self.assertIs(float_t, op2.inputs[0])
694
695    self.assertEqual(1, len(float_t.consumers()))
696    self.assertEqual(op2, float_t.consumers()[0])
697
698    self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def)
699    self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
700                           op2.node_def)
701
702  @test_util.run_deprecated_v1
703  def testInputsAndOutputs(self):
704    op1 = test_ops.float_output(name="myop1").op
705    self.assertEqual(1, len(op1.values()))
706    float1_t, = op1.values()
707
708    op2 = test_ops.float_output_string_output(name="myop2").a.op
709    self.assertEqual(2, len(op2.values()))
710    float2_t, label2_str_t = op2.values()
711
712    # Note that we consume label2_str_t twice here.
713    op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op
714    self.assertEqual(2, len(op3.values()))
715
716    self.assertEqual(1, len(float1_t.consumers()))
717    self.assertEqual(op3, float1_t.consumers()[0])
718
719    self.assertEqual(0, len(float2_t.consumers()))
720
721    self.assertEqual(2, len(label2_str_t.consumers()))
722    self.assertEqual(op3, label2_str_t.consumers()[0])
723    self.assertEqual(op3, label2_str_t.consumers()[1])
724
725    self.assertProtoEquals("""
726    op:'Foo2' name:'myop3'
727    input:'myop1' input:'myop2:1' input:'myop2:1'
728    """, op3.node_def)
729
730  def testDeviceObject(self):
731    op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], [])
732    op._set_device("/job:goo/device:GPU:0")
733    self.assertProtoEquals(
734        "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
735    op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], [])
736    op._set_device(
737        pydev.DeviceSpec(
738            job="muu", device_type="CPU", device_index=0))
739    self.assertProtoEquals(
740        "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
741
742  def testReferenceInput(self):
743    g = ops.Graph()
744    op1 = ops.Operation(
745        ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
746        [dtypes.float32_ref, dtypes.float32])
747    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
748    self.assertEqual([], list(op1.inputs))
749    ref_t, nonref_t = op1.values()
750    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
751    op2 = ops.Operation(
752        ops._NodeDef("RefInputFloatInput", "op2"),
753        g, [ref_t, nonref_t], [],
754        input_types=[dtypes.float32_ref, dtypes.float32])
755    self.assertProtoEquals(
756        "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
757        op2.node_def)
758    self.assertEqual([ref_t, nonref_t], list(op2.inputs))
759    op3 = ops.Operation(
760        ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
761    self.assertProtoEquals(
762        "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
763        op3.node_def)
764
765  def testInvalidNames(self):
766    g = ops.Graph()
767    with self.assertRaises(ValueError):
768      ops.Operation(ops._NodeDef("op", ""), g)
769    with self.assertRaises(ValueError):
770      ops.Operation(ops._NodeDef("op", "_invalid"), g)
771    with self.assertRaises(ValueError):
772      ops.Operation(ops._NodeDef("op", "-invalid"), g)
773    with self.assertRaises(ValueError):
774      ops.Operation(ops._NodeDef("op", "/invalid"), g)
775    with self.assertRaises(ValueError):
776      ops.Operation(ops._NodeDef("op", "invalid:0"), g)
777
778  @test_util.run_deprecated_v1
779  def testNoShapeFunction(self):
780    op = test_ops.a()
781    self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
782
783  @test_util.run_in_graph_and_eager_modes
784  def testConvertToTensorNestedArray(self):
785    values = [[2], [3], [5], [7]]
786    tensor = ops.convert_to_tensor(values)
787    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
788    self.assertAllEqual(values, self.evaluate(tensor))
789
790  def testShapeTuple(self):
791    with self.cached_session():
792      c = constant_op.constant(1)
793      self.assertEqual(c._shape_tuple(), ())  # pylint: disable=protected-access
794
795  def testConvertToTensorEager(self):
796    with context.eager_mode():
797      t = constant_op.constant(1)
798      self.assertTrue(isinstance(t, ops.EagerTensor))
799      converted = ops.convert_to_tensor(t)
800      self.assertTrue(isinstance(converted, ops.EagerTensor))
801      converted = ops.convert_to_tensor(1)
802      self.assertTrue(isinstance(converted, ops.EagerTensor))
803
804  @test_util.run_in_graph_and_eager_modes
805  def testConvertToTensorNestedTuple(self):
806    values = ((2,), (3,), (5,), (7,))
807    tensor = ops.convert_to_tensor(values)
808    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
809    self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values)))
810
811  @test_util.run_in_graph_and_eager_modes
812  def testConvertToTensorNestedTensors(self):
813    values = ((2,), (3,), (5,), (7,))
814    tensor = ops.convert_to_tensor(
815        [constant_op.constant(row) for row in values])
816    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
817    self.assertAllEqual(values, self.evaluate(tensor))
818    tensor = ops.convert_to_tensor(
819        [[constant_op.constant(v) for v in row] for row in values])
820    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
821    self.assertAllEqual(values, self.evaluate(tensor))
822
823  @test_util.run_in_graph_and_eager_modes
824  def testConvertToTensorNestedMix(self):
825    values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
826    tensor = ops.convert_to_tensor(values)
827    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
828    self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor))
829
830  @test_util.run_in_graph_and_eager_modes
831  def testConvertToTensorPreferred(self):
832    values = [2, 3, 5, 7]
833    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
834    self.assertEqual(dtypes.float32, tensor.dtype)
835
836    # Convert empty tensor to anything.
837    values = []
838    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
839    self.assertEqual(dtypes.int64, tensor.dtype)
840
841    # The preferred dtype is a type error and will convert to
842    # float32 instead.
843    values = [1.23]
844    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
845    self.assertEqual(dtypes.float32, tensor.dtype)
846
847  @test_util.run_in_graph_and_eager_modes
848  def testConvertToInvalidTensorType(self):
849    with self.assertRaises(TypeError):
850      # Forcing an invalid dtype should fail with a type error.
851      values = [1.23]
852      ops.convert_to_tensor(values, dtype=dtypes.int64)
853
854  @test_util.run_in_graph_and_eager_modes
855  def testConvertToLongLongTensorType(self):
856    tensor = ops.convert_to_tensor(
857        # Get a numpy array of dtype NPY_LONGLONG
858        np.prod(constant_op.constant([1])._shape_tuple()),
859        dtype=dtypes.int64)
860    self.assertEqual(dtypes.int64, tensor.dtype)
861
862  @test_util.run_in_graph_and_eager_modes
863  def testConvertToTensorFromInvalidTensor(self):
864    tensor = constant_op.constant(42.0, dtype=dtypes.float32)
865    with self.assertRaises(ValueError):
866      ops.convert_to_tensor(tensor, dtype=dtypes.int32)
867
868  @test_util.run_in_graph_and_eager_modes
869  def testConvertToTensorProtocol(self):
870    class TensorCompatible:
871
872      def __tf_tensor__(self, dtype=None, name=None):
873        return constant_op.constant((1, 2, 3), dtype=dtype, name=name)
874
875    tc = TensorCompatible()
876
877    tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32)
878    self.assertEqual(tensor.dtype, dtypes.int32)
879    self.assertAllEqual((1, 2, 3), self.evaluate(tensor))
880
881  @test_util.run_deprecated_v1
882  def testNoConvert(self):
883    # Operation cannot be converted to Tensor.
884    op = control_flow_ops.no_op()
885    with self.assertRaisesRegex(TypeError,
886                                "can't convert Operation '.+' to Tensor"):
887      ops.convert_to_tensor(op)
888
889  def testStr(self):
890    node_def = ops._NodeDef("None", "op1")
891    op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32])
892    self.assertEqual(str(node_def), str(op))
893
894  def testRepr(self):
895    op = ops.Operation(
896        ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
897    self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
898
899  @test_util.run_deprecated_v1
900  def testGetAttr(self):
901    op = test_ops.default_attrs()
902    self.assertEqual(op.get_attr("string_val"), b"abc")
903    self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
904    self.assertEqual(op.get_attr("int_val"), 123)
905    self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
906    self.assertEqual(op.get_attr("float_val"), 10.0)
907    self.assertEqual(op.get_attr("float_list_val"), [10.0])
908    self.assertEqual(op.get_attr("bool_val"), True)
909    self.assertEqual(op.get_attr("bool_list_val"), [True, False])
910    self.assertEqual(op.get_attr("shape_val"),
911                     tensor_shape.as_shape([2, 1]).as_proto())
912    self.assertEqual(op.get_attr("shape_list_val"),
913                     [tensor_shape.as_shape([]).as_proto(),
914                      tensor_shape.as_shape([1]).as_proto()])
915    self.assertEqual(op.get_attr("tensor_val"),
916                     tensor_util.make_tensor_proto(1, dtypes.int32))
917    self.assertEqual(op.get_attr("tensor_list_val"),
918                     [tensor_util.make_tensor_proto(1, dtypes.int32)])
919
920    type_val = op.get_attr("type_val")
921    # First check that type_val is a DType, because the assertEqual will work
922    # no matter what since DType overrides __eq__
923    self.assertIsInstance(type_val, dtypes.DType)
924    self.assertEqual(type_val, dtypes.int32)
925
926    type_list_val = op.get_attr("type_list_val")
927    self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
928    self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
929
930    @function.Defun(dtypes.float32, func_name="MyFunc")
931    def func(x):
932      return x
933
934    op = test_ops.func_attr(func)
935    self.assertEqual(op.get_attr("f"),
936                     attr_value_pb2.NameAttrList(name="MyFunc"))
937
938    # Try fetching missing attr
939    with self.assertRaisesRegex(
940        ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."):
941      op.get_attr("FakeAttr")
942
943  # TODO(b/65162920): remove this test when users who are directly mutating the
944  # node_def have been updated to proper usage.
945  @test_util.run_deprecated_v1
946  def testSetAttr(self):
947    op = test_ops.int_attr().op
948    op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
949    # TODO(skyewm): add node_def check
950    self.assertEqual(op.get_attr("foo"), 2)
951
952  # TODO(nolivia): test all error cases
953  def testAddControlInput(self):
954    with ops.Graph().as_default():
955      x = constant_op.constant(1).op
956      y = constant_op.constant(2).op
957      z = constant_op.constant(3).op
958    z._add_control_input(x)  # pylint: disable=protected-access
959    self.assertEqual(z.control_inputs, [x])
960    z._add_control_input(x)  # pylint: disable=protected-access
961    self.assertEqual(z.control_inputs, [x])
962    z._add_control_inputs([x, y, y])  # pylint: disable=protected-access
963    self.assertEqual(z.control_inputs, [x, y])
964    self.assertEqual(x._control_outputs, [z])
965
966  @test_util.run_deprecated_v1
967  def testRemoveAllControlInputs(self):
968    a = constant_op.constant(1)
969    with ops.control_dependencies([a]):
970      b = constant_op.constant(2)
971    c = constant_op.constant(3)
972    d = constant_op.constant(4)
973    e = constant_op.constant(5)
974    with ops.control_dependencies([a, c]):
975      f = d + e
976
977    self.assertEqual(a.op.control_inputs, [])
978    self.assertEqual(b.op.control_inputs, [a.op])
979    self.assertEqual(f.op.control_inputs, [a.op, c.op])
980
981    a.op._remove_all_control_inputs()  # pylint: disable=protected-access
982    self.assertEqual(a.op.control_inputs, [])
983
984    b.op._remove_all_control_inputs()  # pylint: disable=protected-access
985    self.assertEqual(b.op.control_inputs, [])
986
987    f.op._remove_all_control_inputs()  # pylint: disable=protected-access
988    self.assertEqual(f.op.control_inputs, [])
989    self.assertEqual(list(f.op.inputs), [d, e])
990
991  @test_util.run_deprecated_v1
992  def testControlInputCycle(self):
993    graph = ops.Graph()
994    with graph.as_default():
995      z = constant_op.constant(0)
996      x = constant_op.constant(1)
997      y = constant_op.constant(2)
998      y.op._add_control_input(z.op)  # pylint: disable=protected-access
999      y.op._add_control_input(x.op)  # pylint: disable=protected-access
1000      x.op._add_control_input(y.op)  # pylint: disable=protected-access
1001    with self.session(graph=graph) as sess:
1002      with self.assertRaisesRegex(
1003          errors.InvalidArgumentError,
1004          "Graph is invalid, contains a cycle with 2 nodes"):
1005        self.evaluate(x)
1006
1007  def testUpdateInput(self):
1008    g = ops.Graph()
1009    with g.as_default():
1010      x = constant_op.constant(1)
1011      y = constant_op.constant(2)
1012      z = x + y
1013
1014    z.op._update_input(0, y)  # pylint: disable=protected-access
1015    self.assertEqual(list(z.op.inputs), [y, y])
1016    self.assertEqual(x.consumers(), [])
1017    self.assertEqual(y.consumers(), [z.op, z.op])
1018    with session.Session(graph=g) as sess:
1019      self.assertEqual(self.evaluate(z), 4)
1020
1021    z.op._update_input(0, x)  # pylint: disable=protected-access
1022    self.assertEqual(list(z.op.inputs), [x, y])
1023    self.assertEqual(x.consumers(), [z.op])
1024    self.assertEqual(y.consumers(), [z.op])
1025    with session.Session(graph=g) as sess:
1026      self.assertEqual(self.evaluate(z), 3)
1027
1028    z.op._update_input(1, y)  # pylint: disable=protected-access
1029    self.assertEqual(list(z.op.inputs), [x, y])
1030    self.assertEqual(x.consumers(), [z.op])
1031    self.assertEqual(y.consumers(), [z.op])
1032    with session.Session(graph=g) as sess:
1033      self.assertEqual(self.evaluate(z), 3)
1034
1035  def testUpdateInputGraphError(self):
1036    g_0 = ops.Graph()
1037    g_1 = ops.Graph()
1038    with g_0.as_default():
1039      x = constant_op.constant(1)
1040    with g_1.as_default():
1041      y = constant_op.constant(2)
1042      z = y * 2
1043      with self.assertRaisesRegex(ValueError, "must be from the same graph"):
1044        z.op._update_input(0, x)  # pylint: disable=protected-access
1045
1046  def testUpdateInputTypeError(self):
1047    g = ops.Graph()
1048    with g.as_default():
1049      w = constant_op.constant(0)
1050      x = constant_op.constant("")
1051      y = constant_op.constant(1)
1052      z = y + w
1053      z.op._update_input(0, x)  # pylint: disable=protected-access
1054    with session.Session(graph=g) as sess:
1055      with self.assertRaisesRegex(
1056          errors.InvalidArgumentError,
1057          "Input 0 of node add was passed string from Const_1:0 incompatible "
1058          "with expected int32"):
1059        self.evaluate(z)
1060
1061  def testUpdateInputShapeError(self):
1062    g = ops.Graph()
1063    with g.as_default():
1064      w = constant_op.constant(2, shape=[3, 1])
1065      x = constant_op.constant(0, shape=[3, 1])
1066      y = constant_op.constant(1, shape=[2, 2])
1067      z = w + x
1068    with self.assertRaisesRegex(
1069        errors.InvalidArgumentError,
1070        r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
1071      z.op._update_input(0, y)  # pylint: disable=protected-access
1072
1073  def testUpdateInputOutOfRange(self):
1074    g = ops.Graph()
1075    with g.as_default():
1076      x = constant_op.constant(1)
1077    with self.assertRaisesRegex(
1078        errors.OutOfRangeError,
1079        r"Cannot update edge. Input index \[1\] is greater than the number of "
1080        r"total inputs \[0\]."):
1081      x.op._update_input(1, x)  # pylint: disable=protected-access
1082
1083  @test_util.enable_control_flow_v2
1084  @test_util.run_v1_only("b/120545219")
1085  def testAddWhileInput(self):
1086
1087    @eager_function.defun
1088    def test():
1089      output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
1090                                           [1])
1091      while_op = output.op
1092      self.assertEqual(while_op.type, "StatelessWhile")
1093      orig_num_inputs = len(while_op.inputs)
1094
1095      # Make sure we can handle the while op having a control input.
1096      while_op._add_control_input(constant_op.constant(0).op)
1097
1098      new_input1 = constant_op.constant(1.0)
1099      new_input2 = constant_op.constant(True)
1100
1101      # Clear output shapes to bypass shape checking.
1102      while_op._set_shape_list_attr("output_shapes", [])
1103      while_op._set_type_list_attr("T", [t.dtype for t in while_op.inputs] +
1104                                   [new_input1.dtype, new_input2.dtype])
1105
1106      while_op._add_while_inputs([new_input1, new_input2])
1107      # Can't add an edge beyond what's specified by "T"
1108      with self.assertRaises(errors.OutOfRangeError):
1109        while_op._add_while_inputs([new_input2])
1110      self.assertEqual(len(while_op.inputs), orig_num_inputs + 2)  # pylint: disable=g-deprecated-assert
1111
1112      test()
1113
1114  @test_util.run_deprecated_v1
1115  def testOpDef(self):
1116    x = constant_op.constant(0)
1117    y = constant_op.constant(1)
1118    z = x + y
1119
1120    self.assertEqual(x.op.op_def.name, "Const")
1121    self.assertEqual(len(x.op.op_def.input_arg), 0)
1122    self.assertEqual(len(x.op.op_def.output_arg), 1)
1123
1124    self.assertRegex(z.op.op_def.name, "Add(V2)?")
1125    self.assertEqual(len(z.op.op_def.input_arg), 2)
1126    self.assertEqual(len(z.op.op_def.output_arg), 1)
1127
1128  def testInputFromDifferentGraphError(self):
1129    g_0 = ops.Graph()
1130    g_1 = ops.Graph()
1131    with g_0.as_default():
1132      x = constant_op.constant(1)
1133    with g_1.as_default():
1134      y = constant_op.constant(2)
1135      with self.assertRaisesRegex(ValueError, "must be from the same graph"):
1136        y * x  # pylint: disable=pointless-statement
1137
1138  def testInputsAreImmutable(self):
1139    g = ops.Graph()
1140    with g.as_default():
1141      x = test_ops.int_output()
1142      op = test_ops.int_input_int_output(x, name="myop").op
1143    with self.assertRaisesRegex(AttributeError,
1144                                "'tuple' object has no attribute 'append'"):
1145      op.inputs.append(None)
1146
1147
1148class CreateOpTest(test_util.TensorFlowTestCase):
1149
1150  def testNodeDefArgs(self):
1151    g = ops.Graph()
1152    op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
1153    with g.device("/device:GPU:0"):
1154      op2 = g.create_op(
1155          "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None,
1156          name="myop2")
1157    op3 = g.create_op(
1158        "Foo3",
1159        [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]],
1160        [dtypes.float32, dtypes.int32],
1161        None,
1162        name="myop3")
1163    self.assertDeviceEqual(None, op1.device)
1164    self.assertDeviceEqual("/device:GPU:0", op2.device)
1165    self.assertDeviceEqual(None, op3.device)
1166    self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def)
1167    self.assertProtoEquals(
1168        "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'",
1169        op2.node_def)
1170    self.assertProtoEquals(
1171        "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'",
1172        op3.node_def)
1173
1174  def testReferenceInput(self):
1175    g = ops.Graph()
1176    op1 = g.create_op(
1177        "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
1178        name="op1")
1179    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
1180    ref_t, nonref_t = op1.values()
1181    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
1182    op2 = g.create_op(
1183        "RefInputFloatInput", [ref_t, nonref_t], [],
1184        input_types=[dtypes.float32_ref, dtypes.float32],
1185        name="op2")
1186    self.assertProtoEquals(
1187        "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
1188        op2.node_def)
1189    op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3")
1190    self.assertProtoEquals(
1191        "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
1192        op3.node_def)
1193
1194  def testFinalized(self):
1195    g = ops.Graph()
1196    g.finalize()
1197    with self.assertRaises(RuntimeError):
1198      g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
1199
1200    # Test unfinalize.
1201    g._unsafe_unfinalize()
1202    g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
1203
1204
1205# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation
1206# method. Arguably we should only test the public APIs that depend on this
1207# method. However, this logic is complex and tricky, and it can be difficult to
1208# ascertain if we have adequate coverage (e.g. a graph may run successfully if
1209# the control flow context isn't set properly, but a more complicated use case
1210# that might not be obvious to test will fail). Thus we instead explicitly test
1211# the low-level behavior.
1212class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
1213
1214  @test_util.run_deprecated_v1
1215  def testBasic(self):
1216    g = ops.Graph()
1217    with g.as_default():
1218      x = test_ops.int_output()
1219      c_op = ops._create_c_op(
1220          g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
1221      op = g._create_op_from_tf_operation(c_op)
1222
1223    self.assertEqual(op.name, "myop")
1224    self.assertEqual(op.type, "IntInputIntOutput")
1225    self.assertEqual(len(op.outputs), 1)
1226    self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape())
1227    self.assertEqual(list(op.inputs), [x])
1228    self.assertEqual(op.control_inputs, [])
1229    self.assertEqual(op.graph, g)
1230    self.assertEqual(x.consumers(), [op])
1231    self.assertIsNotNone(op.traceback)
1232    self.assertEqual(g.get_operation_by_name("myop"), op)
1233    self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0])
1234
1235  def testShape(self):
1236    g = ops.Graph()
1237    with g.as_default():
1238      x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
1239      c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
1240      op = g._create_op_from_tf_operation(c_op)
1241
1242    self.assertEqual(op.name, "myop")
1243    self.assertEqual(op.type, "Identity")
1244    self.assertEqual(len(op.outputs), 1)
1245    self.assertEqual(op.outputs[0].shape, tensor_shape.TensorShape([2, 3]))
1246
1247  def testUniqueName(self):
1248    g = ops.Graph()
1249    with g.as_default():
1250      c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
1251      c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
1252      op = g._create_op_from_tf_operation(c_op)
1253      op2 = g._create_op_from_tf_operation(c_op2)
1254
1255      # Create ops with same names as op1 and op2. We expect the new names to be
1256      # uniquified.
1257      op3 = test_ops.int_output(name="myop").op
1258      op4 = test_ops.int_output(name="myop_1").op
1259
1260    self.assertEqual(op.name, "myop")
1261    self.assertEqual(op2.name, "myop_1")
1262    self.assertEqual(op3.name, "myop_2")
1263    self.assertEqual(op4.name, "myop_1_1")
1264
1265  @test_util.run_v1_only("b/120545219")
1266  def testCond(self):
1267    g = ops.Graph()
1268    with g.as_default():
1269      x = test_ops.int_output()
1270
1271      def true_fn():
1272        ops._create_c_op(ops.get_default_graph(),
1273                         ops._NodeDef("IntInput", "cond/myop"), [x], [])
1274        new_ops = g._add_new_tf_operations()
1275        self.assertEqual(len(new_ops), 1)
1276        return x
1277
1278      control_flow_ops.cond(x < 10, true_fn, lambda: x)
1279
1280    op = g.get_operation_by_name("cond/myop")
1281    self.assertIsNotNone(op)
1282    self.assertEqual(op.name, "cond/myop")
1283    self.assertEqual(op.type, "IntInput")
1284    self.assertEqual(op.outputs, [])
1285    op_input = op.inputs[0].op
1286    self.assertEqual(op_input.type, "Switch")
1287    self.assertEqual(op_input.inputs[0], x)
1288    self.assertEqual(op.graph, g)
1289    # pylint: disable=protected-access
1290    self.assertIsNotNone(op._get_control_flow_context())
1291    self.assertEqual(op._get_control_flow_context().name,
1292                     "cond/cond_text")
1293    # pylint: enable=protected-access
1294
1295  @test_util.run_v1_only("b/120545219")
1296  def testWhileLoop(self):
1297    g = ops.Graph()
1298    with g.as_default():
1299      x = test_ops.int_output()
1300
1301      def body(i):
1302        ops._create_c_op(ops.get_default_graph(),
1303                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
1304        new_ops = g._add_new_tf_operations()
1305        self.assertEqual(len(new_ops), 1)
1306        return i
1307
1308      control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
1309
1310    op = g.get_operation_by_name("myloop/myop")
1311    self.assertIsNotNone(op)
1312    self.assertEqual(op.name, "myloop/myop")
1313    self.assertEqual(op.type, "IntInput")
1314    self.assertEqual(op.outputs, [])
1315    op_input = op.inputs[0].op
1316    self.assertEqual(op_input.type, "Enter")
1317    self.assertEqual(list(op_input.inputs), [x])
1318    self.assertEqual(op.graph, g)
1319    # pylint: disable=protected-access
1320    self.assertIsNotNone(op._get_control_flow_context())
1321    self.assertEqual(op._get_control_flow_context().name,
1322                     "myloop/while_context")
1323    # pylint: enable=protected-access
1324
1325  @test_util.run_v1_only("b/120545219")
1326  def testWhileLoopWithInternalControlDep(self):
1327    g = ops.Graph()
1328    with g.as_default():
1329      x = test_ops.int_output()
1330
1331      def body(i):
1332        c = constant_op.constant(1.0, name="c")
1333        ops._create_c_op(ops.get_default_graph(),
1334                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
1335        with ops.control_dependencies([c]):
1336          new_ops = g._add_new_tf_operations()
1337          self.assertEqual(len(new_ops), 1)
1338        return i
1339
1340      control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
1341
1342    op = g.get_operation_by_name("myloop/myop")
1343    self.assertIsNotNone(op)
1344    c = g.get_operation_by_name("myloop/c")
1345    self.assertIsNotNone(c)
1346    # Internal control dep is preserved
1347    self.assertEqual(op.control_inputs, [c])
1348
1349  @test_util.run_v1_only("b/120545219")
1350  def testWhileLoopWithExternalControlDep(self):
1351    g = ops.Graph()
1352    with g.as_default():
1353      x = test_ops.int_output()
1354      c = constant_op.constant(1.0)
1355
1356      def body(i):
1357        ops._create_c_op(ops.get_default_graph(),
1358                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
1359        with ops.control_dependencies([c]):
1360          new_ops = g._add_new_tf_operations()
1361          self.assertEqual(len(new_ops), 1)
1362        return i
1363
1364      control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
1365
1366    op = g.get_operation_by_name("myloop/myop")
1367    self.assertIsNotNone(op)
1368    # External control dep is removed and replaced with internal control dep
1369    self.assertNotEqual(op.control_inputs[0], c.op)
1370    self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
1371
1372
1373class ApplyOpTest(test_util.TensorFlowTestCase):
1374
1375  def testNodeDefArgs(self):
1376    g = ops.Graph()
1377    t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
1378    with g.device("/device:GPU:0"):
1379      t2 = _apply_op(
1380          g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2")
1381    t3 = _apply_op(
1382        g,
1383        "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32],
1384        name="myop3")
1385    self.assertTrue(isinstance(t1, ops.Tensor))
1386    self.assertTrue(isinstance(t2, list))
1387    self.assertTrue(isinstance(t3, list))
1388    self.assertTrue(isinstance(t3[0], ops.Tensor))
1389    self.assertEqual("myop1", t1._as_node_def_input())
1390    self.assertEqual("myop2", t2[0]._as_node_def_input())
1391    self.assertEqual("myop2:1", t2[1]._as_node_def_input())
1392    self.assertEqual("myop3", t3[0]._as_node_def_input())
1393    # Validate that we got the right ops as well
1394    self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def)
1395    self.assertProtoEquals(
1396        "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'",
1397        t2[0].op.node_def)
1398    self.assertProtoEquals(
1399        "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'",
1400        t3[0].op.node_def)
1401
1402  def testReferenceInput(self):
1403    g = ops.Graph()
1404    ref_t, nonref_t = _apply_op(
1405        g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
1406        name="op1")
1407    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'",
1408                           ref_t.op.node_def)
1409    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
1410    out_2 = _apply_op(
1411        g,
1412        "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32],
1413        input_types=[dtypes.float32_ref, dtypes.float32],
1414        name="op2")
1415    self.assertProtoEquals(
1416        "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'",
1417        out_2.op.node_def)
1418    out_3 = _apply_op(
1419        g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32],
1420        name="op3")
1421    self.assertProtoEquals(
1422        "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'",
1423        out_3.op.node_def)
1424
1425
1426class NameStackTest(test_util.TensorFlowTestCase):
1427
1428  def testBasics(self):
1429    g = ops.Graph()
1430    self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
1431    self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
1432    self.assertEqual("foo", g.unique_name("foo"))
1433    self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False))
1434    self.assertEqual("foo_1", g.unique_name("foo"))
1435    self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False))
1436    self.assertEqual("foo_2", g.unique_name("foo"))
1437    self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False))
1438    self.assertEqual("foo_1_1", g.unique_name("foo_1"))
1439    self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False))
1440    self.assertEqual("foo_1_2", g.unique_name("foo_1"))
1441    self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False))
1442    self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2"))
1443    with g.name_scope("bar"):
1444      self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False))
1445      self.assertEqual("bar/foo", g.unique_name("foo"))
1446      self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False))
1447      self.assertEqual("bar/foo_1", g.unique_name("foo"))
1448      with g.name_scope(None):
1449        self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False))
1450        self.assertEqual("foo_3", g.unique_name("foo"))
1451      with g.name_scope("baz"):
1452        self.assertEqual(
1453            "bar/baz/foo", g.unique_name(
1454                "foo", mark_as_used=False))
1455        self.assertEqual("bar/baz/foo", g.unique_name("foo"))
1456        self.assertEqual(
1457            "bar/baz/foo_1", g.unique_name(
1458                "foo", mark_as_used=False))
1459        self.assertEqual("bar/baz/foo_1", g.unique_name("foo"))
1460      with g.name_scope("baz"):
1461        self.assertEqual(
1462            "bar/baz_1/foo", g.unique_name(
1463                "foo", mark_as_used=False))
1464        self.assertEqual("bar/baz_1/foo", g.unique_name("foo"))
1465        self.assertEqual(
1466            "bar/baz_1/foo_1", g.unique_name(
1467                "foo", mark_as_used=False))
1468        self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo"))
1469    with g.name_scope("quux"):
1470      self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False))
1471      self.assertEqual("quux/foo", g.unique_name("foo"))
1472    with g.name_scope("bar"):
1473      with g.name_scope("baz"):
1474        self.assertEqual(
1475            "bar_1/baz/foo", g.unique_name(
1476                "foo", mark_as_used=False))
1477        self.assertEqual("bar_1/baz/foo", g.unique_name("foo"))
1478    self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False))
1479    self.assertEqual("foo_4", g.unique_name("foo"))
1480    self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
1481    self.assertEqual("bar_2", g.unique_name("bar"))
1482
1483  def testBackslashAndDashRegex(self):
1484    # GitHub issue 39019, all should pass
1485    g = ops.Graph()
1486    with g.name_scope("n_CatCntc-campaign\\c_campaign"):
1487      pass
1488    with g.name_scope("foo"):
1489      with g.name_scope("n_CatCntc-campaign\\c_campaign"):
1490        pass
1491    with g.name_scope("n_CatCntc-campaign\\c_campaign"):
1492      with g.name_scope("foo"):
1493        pass
1494
1495  @test_util.run_deprecated_v1
1496  def testNameAndVariableScope(self):
1497    with self.cached_session() as sess:
1498      with sess.graph.name_scope("l0"):
1499        with variable_scope.variable_scope("l1"):
1500          with sess.graph.name_scope("l1") as scope:
1501            self.assertEqual("l0/l1/l1/", scope)
1502            self.assertEqual(
1503                "l0/l1/l1/foo",
1504                sess.graph.unique_name(
1505                    "foo", mark_as_used=False))
1506            self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo"))
1507          with sess.graph.name_scope("l2") as scope:
1508            self.assertEqual("l0/l1/l2/", scope)
1509            self.assertEqual(
1510                "l0/l1/l2/foo",
1511                sess.graph.unique_name(
1512                    "foo", mark_as_used=False))
1513            self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo"))
1514
1515  def testOutOfOrderUniqueName(self):
1516    g = ops.Graph()
1517    self.assertEqual("foo_2", g.unique_name("foo_2"))
1518    self.assertEqual("foo", g.unique_name("foo"))
1519    self.assertEqual("foo_1", g.unique_name("foo"))
1520    self.assertEqual("foo_3", g.unique_name("foo"))
1521
1522  def testUniqueNameCaseInsensitivity(self):
1523    g = ops.Graph()
1524    self.assertEqual("foo", g.unique_name("foo"))
1525    self.assertEqual("Foo_1", g.unique_name("Foo"))
1526    with g.name_scope("bar"):
1527      self.assertEqual("bar/foo", g.unique_name("foo"))
1528    with g.name_scope("Bar"):
1529      self.assertEqual("Bar_1/foo", g.unique_name("foo"))
1530
1531  def testInvalidNameRaisesError(self):
1532    g = ops.Graph()
1533    with g.name_scope(""):  # Should not raise
1534      pass
1535    with g.name_scope("foo/"):  # Should not raise
1536      with g.name_scope("_bar"):  # Should not raise
1537        pass
1538    with self.assertRaises(ValueError):
1539      with g.name_scope("foo:0"):
1540        pass
1541    with self.assertRaises(ValueError):
1542      with g.name_scope("_bar"):
1543        pass
1544
1545
1546class NameTest(test_util.TensorFlowTestCase):
1547
1548  def testGenerateName(self):
1549    g = ops.Graph()
1550    op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
1551    self.assertEqual("TwoFloatOutputs", op0.name)
1552    self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name)
1553    self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name)
1554
1555    op1 = g.create_op("FloatOutput", [], [dtypes.float32])
1556    self.assertEqual("FloatOutput", op1.name)
1557    self.assertEqual("FloatOutput:0", op1.outputs[0].name)
1558
1559    op2 = g.create_op("FloatOutput", [], [dtypes.float32])
1560    self.assertEqual("FloatOutput_1", op2.name)
1561    self.assertEqual("FloatOutput_1:0", op2.outputs[0].name)
1562
1563    op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op")
1564    self.assertEqual("my_op", op3.name)
1565    self.assertEqual("my_op:0", op3.outputs[0].name)
1566
1567  def testNameScope(self):
1568    g = ops.Graph()
1569
1570    with g.name_scope("foo") as foo:
1571      self.assertEqual("foo/", foo)
1572      with g.name_scope("foo2") as foo2:
1573        self.assertEqual("foo/foo2/", foo2)
1574      with g.name_scope(None) as empty1:
1575        self.assertEqual("", empty1)
1576        with g.name_scope("foo3") as foo3:
1577          self.assertEqual("foo3/", foo3)
1578      with g.name_scope("") as empty2:
1579        self.assertEqual("", empty2)
1580
1581    self.assertEqual("FloatOutput",
1582                     g.create_op("FloatOutput", [], [dtypes.float32]).name)
1583    with g.name_scope("bar") as scope:
1584      self.assertEqual("bar/FloatOutput",
1585                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1586      self.assertEqual("bar/FloatOutput_1",
1587                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1588      # If you use the value from "with .. as", that values is used as-is.
1589      self.assertEqual(
1590          "bar", g.create_op(
1591              "FloatOutput", [], [dtypes.float32], name=scope).name)
1592    with g.name_scope("baz") as scope:
1593      with g.name_scope("quux"):
1594        self.assertEqual("baz/quux/FloatOutput",
1595                         g.create_op("FloatOutput", [], [dtypes.float32]).name)
1596      # If you use the value from the enclosing "with .. as", nothing is pushed.
1597      with g.name_scope(scope):
1598        self.assertEqual("baz/FloatOutput",
1599                         g.create_op("FloatOutput", [], [dtypes.float32]).name)
1600        self.assertEqual(
1601            "baz", g.create_op(
1602                "FloatOutput", [], [dtypes.float32], name=scope).name)
1603        self.assertEqual(
1604            "trailing",
1605            g.create_op(
1606                "FloatOutput", [], [dtypes.float32], name="trailing/").name)
1607    with g.name_scope("bar"):
1608      self.assertEqual("bar_1/FloatOutput",
1609                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1610    with g.name_scope("bar/"):
1611      self.assertEqual("bar/FloatOutput_2",
1612                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1613
1614
1615class DeviceTest(test_util.TensorFlowTestCase):
1616
1617  def testNoDevice(self):
1618    g = ops.Graph()
1619    op = g.create_op("FloatOutput", [], [dtypes.float32])
1620    self.assertDeviceEqual(None, op.device)
1621    gd = g.as_graph_def()
1622    self.assertProtoEqualsVersion("""
1623      node { name: "FloatOutput" op: "FloatOutput" }
1624    """, gd)
1625
1626  def testEagerBackingDevice(self):
1627    with context.eager_mode():
1628      with ops.device("/device:CPU:0"):
1629        t = constant_op.constant(1.0)
1630        self.assertRegex(t.device, "/device:CPU:0")
1631        self.assertRegex(t.backing_device, "/device:CPU:0")
1632
1633  def testDevicePartialString(self):
1634    g = ops.Graph()
1635    with g.device("/job:worker/replica:2"):
1636      g.create_op("FloatOutput", [], [dtypes.float32])
1637    gd = g.as_graph_def()
1638    self.assertProtoEqualsVersion("""
1639      node { name: "FloatOutput" op: "FloatOutput"
1640             device: "/job:worker/replica:2" }
1641    """, gd)
1642
1643  def testDeviceFull(self):
1644    g = ops.Graph()
1645    with g.device(
1646        pydev.DeviceSpec(
1647            job="worker", replica=2, task=0, device_type="CPU",
1648            device_index=3)):
1649      g.create_op("FloatOutput", [], [dtypes.float32])
1650    gd = g.as_graph_def()
1651    self.assertProtoEqualsVersion("""
1652      node { name: "FloatOutput" op: "FloatOutput"
1653             device: "/job:worker/replica:2/task:0/device:CPU:3" }
1654    """, gd)
1655
1656  def testNesting(self):
1657    g = ops.Graph()
1658    with g.device("/job:worker/replica:2"):
1659      g.create_op("FloatOutput", [], [dtypes.float32])
1660      with g.device("/job:worker/replica:3/task:0"):
1661        g.create_op("FloatOutput", [], [dtypes.float32])
1662      g.create_op("FloatOutput", [], [dtypes.float32])
1663    gd = g.as_graph_def()
1664    self.assertProtoEqualsVersion("""
1665      node { name: "FloatOutput" op: "FloatOutput"
1666             device: "/job:worker/replica:2" }
1667      node { name: "FloatOutput_1" op: "FloatOutput"
1668             device: "/job:worker/replica:3/task:0" }
1669      node { name: "FloatOutput_2" op: "FloatOutput"
1670             device: "/job:worker/replica:2" }
1671    """, gd)
1672
1673  def testNestingString(self):
1674    g = ops.Graph()
1675    with g.device("/job:worker/replica:2"):
1676      g.create_op("FloatOutput", [], [dtypes.float32])
1677      with g.device("/job:worker/replica:3/task:0"):
1678        g.create_op("FloatOutput", [], [dtypes.float32])
1679      g.create_op("FloatOutput", [], [dtypes.float32])
1680    gd = g.as_graph_def()
1681    self.assertProtoEqualsVersion("""
1682      node { name: "FloatOutput" op: "FloatOutput"
1683             device: "/job:worker/replica:2" }
1684      node { name: "FloatOutput_1" op: "FloatOutput"
1685             device: "/job:worker/replica:3/task:0" }
1686      node { name: "FloatOutput_2" op: "FloatOutput"
1687             device: "/job:worker/replica:2" }
1688    """, gd)
1689
1690  def testNestingOverrideGpuCpu(self):
1691    g = ops.Graph()
1692    with g.device("/job:worker/replica:2/device:CPU:1"):
1693      g.create_op("FloatOutput", [], [dtypes.float32])
1694      with g.device("/job:worker/replica:2/device:GPU:2"):
1695        g.create_op("FloatOutput", [], [dtypes.float32])
1696      g.create_op("FloatOutput", [], [dtypes.float32])
1697    gd = g.as_graph_def()
1698    self.assertProtoEqualsVersion("""
1699      node { name: "FloatOutput" op: "FloatOutput"
1700             device: "/job:worker/replica:2/device:CPU:1"  }
1701      node { name: "FloatOutput_1" op: "FloatOutput"
1702             device: "/job:worker/replica:2/device:GPU:2" }
1703      node { name: "FloatOutput_2" op: "FloatOutput"
1704             device: "/job:worker/replica:2/device:CPU:1" }
1705    """, gd)
1706
1707  def testNestingWithMergeDeviceFunction(self):
1708    g = ops.Graph()
1709
1710    with g.device(pydev.merge_device("/device:GPU:0")):
1711      g.create_op("FloatOutput", [], [dtypes.float32])
1712      with g.device(pydev.merge_device("/job:worker")):
1713        g.create_op("FloatOutput", [], [dtypes.float32])
1714        with g.device(pydev.merge_device("/device:CPU:0")):
1715          g.create_op("FloatOutput", [], [dtypes.float32])
1716          with g.device(pydev.merge_device("/job:ps")):
1717            g.create_op("FloatOutput", [], [dtypes.float32])
1718            with g.device(pydev.merge_device(None)):
1719              g.create_op("FloatOutput", [], [dtypes.float32])
1720
1721    gd = g.as_graph_def()
1722    self.assertProtoEqualsVersion("""
1723      node { name: "FloatOutput" op: "FloatOutput"
1724             device: "/device:GPU:0" }
1725      node { name: "FloatOutput_1" op: "FloatOutput"
1726             device: "/job:worker/device:GPU:0" }
1727      node { name: "FloatOutput_2" op: "FloatOutput"
1728             device: "/job:worker/device:CPU:0" }
1729      node { name: "FloatOutput_3" op: "FloatOutput"
1730             device: "/job:ps/device:CPU:0" }
1731      node { name: "FloatOutput_4" op: "FloatOutput"
1732             device: "/job:ps/device:CPU:0" }
1733    """, gd)
1734
1735  def testNestingWithDeviceStrings(self):
1736    g = ops.Graph()
1737
1738    with g.device("/device:GPU:0"):
1739      g.create_op("FloatOutput", [], [dtypes.float32])
1740      with g.device("/job:worker"):
1741        g.create_op("FloatOutput", [], [dtypes.float32])
1742        with g.device("/device:CPU:0"):
1743          g.create_op("FloatOutput", [], [dtypes.float32])
1744          with g.device("/job:ps"):
1745            g.create_op("FloatOutput", [], [dtypes.float32])
1746            with g.device(""):
1747              g.create_op("FloatOutput", [], [dtypes.float32])
1748
1749    gd = g.as_graph_def()
1750    self.assertProtoEqualsVersion("""
1751      node { name: "FloatOutput" op: "FloatOutput"
1752             device: "/device:GPU:0" }
1753      node { name: "FloatOutput_1" op: "FloatOutput"
1754             device: "/job:worker/device:GPU:0" }
1755      node { name: "FloatOutput_2" op: "FloatOutput"
1756             device: "/job:worker/device:CPU:0" }
1757      node { name: "FloatOutput_3" op: "FloatOutput"
1758             device: "/job:ps/device:CPU:0" }
1759      node { name: "FloatOutput_4" op: "FloatOutput"
1760             device: "/job:ps/device:CPU:0" }
1761    """, gd)
1762
1763  def testNestingWithDeviceStringWildcard(self):
1764    g = ops.Graph()
1765
1766    with g.device("/device:GPU:7"):
1767      g.create_op("FloatOutput", [], [dtypes.float32])
1768      with g.device("/device:GPU:*"):
1769        g.create_op("FloatOutput", [], [dtypes.float32])
1770
1771    with g.device("/device:CPU:*"):
1772      g.create_op("FloatOutput", [], [dtypes.float32])
1773      with g.device("/device:CPU:5"):
1774        g.create_op("FloatOutput", [], [dtypes.float32])
1775
1776    gd = g.as_graph_def()
1777    self.assertProtoEqualsVersion("""
1778      node { name: "FloatOutput" op: "FloatOutput"
1779             device: "/device:GPU:7" }
1780      node { name: "FloatOutput_1" op: "FloatOutput"
1781             device: "/device:GPU:7" }
1782      node { name: "FloatOutput_2" op: "FloatOutput"
1783             device: "/device:CPU:*" }
1784      node { name: "FloatOutput_3" op: "FloatOutput"
1785             device: "/device:CPU:5" }
1786    """, gd)
1787
1788  def testNestingErrorGraph(self):
1789    g = ops.Graph()
1790    scope = g.device("/device:GPU:8")
1791    scope.__enter__()
1792    with g.device("/device:GPU:9"):
1793      with self.assertRaises(RuntimeError):
1794        scope.__exit__(None, None, None)
1795
1796  def testNestingErrorEager(self):
1797    with context.eager_mode():
1798      scope = ops.device("/device:CPU:0")
1799      scope.__enter__()
1800      with ops.device(None):
1801        with self.assertRaises(RuntimeError):
1802          scope.__exit__(None, None, None)
1803
1804  def testNoneClearsDefault(self):
1805    g = ops.Graph()
1806    with g.device("/job:worker/replica:2/device:CPU:1"):
1807      g.create_op("FloatOutput", [], [dtypes.float32])
1808      with g.device(None):
1809        g.create_op("FloatOutput", [], [dtypes.float32])
1810      g.create_op("FloatOutput", [], [dtypes.float32])
1811    gd = g.as_graph_def()
1812    self.assertProtoEqualsVersion("""
1813      node { name: "FloatOutput" op: "FloatOutput"
1814             device: "/job:worker/replica:2/device:CPU:1" }
1815      node { name: "FloatOutput_1" op: "FloatOutput" }
1816      node { name: "FloatOutput_2" op: "FloatOutput"
1817             device: "/job:worker/replica:2/device:CPU:1" }
1818    """, gd)
1819
1820  def testNoneIgnoresOuterDeviceFunction(self):
1821    g = ops.Graph()
1822    with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
1823      g.create_op("FloatOutput", [], [dtypes.float32])
1824      with g.device(None):
1825        g.create_op("FloatOutput", [], [dtypes.float32])
1826      g.create_op("FloatOutput", [], [dtypes.float32])
1827    gd = g.as_graph_def()
1828    self.assertProtoEqualsVersion("""
1829      node { name: "FloatOutput" op: "FloatOutput"
1830             device: "/job:worker/replica:2/device:CPU:1" }
1831      node { name: "FloatOutput_1" op: "FloatOutput" }
1832      node { name: "FloatOutput_2" op: "FloatOutput"
1833             device: "/job:worker/replica:2/device:CPU:1" }
1834    """, gd)
1835
1836  def _overwritingDeviceFunction(self, unused_op):
1837    # This device function unconditionally overwrites the device of ops.
1838    #
1839    # NOTE(mrry): Writing device functions like this is not
1840    # recommended. Instead, in most cases you should use
1841    # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
1842    # argument to `tf.device()` and the device component will be merged in.
1843    return "/job:overwrite"
1844
1845  def testOverwritingBehavior(self):
1846    g = ops.Graph()
1847    with g.device(self._overwritingDeviceFunction):
1848      g.create_op("FloatOutput", [], [dtypes.float32])
1849      with g.device("/job:ps"):  # Will be overwritten.
1850        g.create_op("FloatOutput", [], [dtypes.float32])
1851      with g.device(pydev.merge_device("/job:ps")):  # Will be overwritten.
1852        g.create_op("FloatOutput", [], [dtypes.float32])
1853      with g.device(None):  # Disables overwriting device function
1854        with g.device("/job:ps"):
1855          g.create_op("FloatOutput", [], [dtypes.float32])
1856      with g.device(None):  # Disables overwriting device function
1857        with g.device(pydev.merge_device("/job:ps")):
1858          g.create_op("FloatOutput", [], [dtypes.float32])
1859    gd = g.as_graph_def()
1860    self.assertProtoEqualsVersion("""
1861      node { name: "FloatOutput" op: "FloatOutput"
1862             device: "/job:overwrite" }
1863      node { name: "FloatOutput_1" op: "FloatOutput"
1864             device: "/job:overwrite" }
1865      node { name: "FloatOutput_2" op: "FloatOutput"
1866             device: "/job:overwrite" }
1867      node { name: "FloatOutput_3" op: "FloatOutput"
1868             device: "/job:ps" }
1869      node { name: "FloatOutput_4" op: "FloatOutput"
1870             device: "/job:ps" }
1871    """, gd)
1872
1873
1874class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):
1875
1876  class TestThread(threading.Thread):
1877
1878    def __init__(self, graph, replica_id):
1879      super(MultithreadedGraphStateTest.TestThread, self).__init__()
1880      self._graph = graph
1881      self._replica_id = replica_id
1882      # This thread sets this event when it mutated the graph.  The caller can
1883      # wait for that.
1884      self.has_mutated_graph = threading.Event()
1885      # This thread waits for when it should continue.  The caller can set this
1886      # event.
1887      self.should_continue = threading.Event()
1888
1889    def run(self):
1890      # Mutate a graph's stack, then set `has_mutated_graph`, then wait for
1891      # `should_continue`, then add an op to the graph affected by the graph's
1892      # stack.
1893      raise NotImplementedError("must be implemented in descendants")
1894
1895  def testDeviceFunctionStack(self):
1896
1897    class DeviceSettingThread(self.TestThread):
1898
1899      def run(self):
1900        with g.device("/job:worker/replica:{}".format(self._replica_id)):
1901          self.has_mutated_graph.set()
1902          self.should_continue.wait()
1903          self.should_continue.clear()
1904          g.create_op(
1905              "FloatOutput", [], [dtypes.float32],
1906              name="FloatOutput_{}".format(self._replica_id))
1907
1908    g = ops.Graph()
1909    # If `switch_to_thread` isn't called, then device placement of the ops
1910    # below is not deterministic.
1911    g.switch_to_thread_local()
1912    threads = [DeviceSettingThread(g, i) for i in range(3)]
1913    for t in threads:
1914      t.start()
1915      t.has_mutated_graph.wait()
1916      t.has_mutated_graph.clear()
1917    for t in threads:
1918      t.should_continue.set()
1919      t.join()
1920
1921    gd = g.as_graph_def()
1922    self.assertProtoEqualsVersion("""
1923      node { name: "FloatOutput_0" op: "FloatOutput"
1924             device: "/job:worker/replica:0" }
1925      node { name: "FloatOutput_1" op: "FloatOutput"
1926             device: "/job:worker/replica:1" }
1927      node { name: "FloatOutput_2" op: "FloatOutput"
1928             device: "/job:worker/replica:2" }
1929    """, gd)
1930
1931  def testColocateWith(self):
1932
1933    class ColocatingThread(self.TestThread):
1934
1935      def __init__(self, graph, replica_id, op_to_colocate_with):
1936        super(ColocatingThread, self).__init__(graph, replica_id)
1937        self._op_to_colocate_with = op_to_colocate_with
1938
1939      def run(self):
1940        with g.colocate_with(self._op_to_colocate_with):
1941          self.has_mutated_graph.set()
1942          self.should_continue.wait()
1943          self.should_continue.clear()
1944          g.create_op(
1945              "FloatOutput", [], [dtypes.float32],
1946              name="FloatOutput_{}".format(self._replica_id))
1947
1948    g = ops.Graph()
1949    ops_to_colocate_with = []
1950    for i in range(3):
1951      with g.device("/job:worker/replica:{}".format(i)):
1952        ops_to_colocate_with.append(
1953            g.create_op(
1954                "FloatOutput", [], [dtypes.float32],
1955                name="ColocateWithMe_{}".format(i)))
1956
1957    # If `switch_to_thread` isn't called, then `device` and `attr` values for
1958    # the ops below are not deterministic.
1959    g.switch_to_thread_local()
1960    threads = [
1961        ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3)
1962    ]
1963    for t in threads:
1964      t.start()
1965      t.has_mutated_graph.wait()
1966      t.has_mutated_graph.clear()
1967    for t in threads:
1968      t.should_continue.set()
1969      t.join()
1970
1971    gd = g.as_graph_def()
1972    self.assertProtoEqualsVersion("""
1973      node { name: "ColocateWithMe_0" op: "FloatOutput"
1974             device: "/job:worker/replica:0" }
1975      node { name: "ColocateWithMe_1" op: "FloatOutput"
1976             device: "/job:worker/replica:1" }
1977      node { name: "ColocateWithMe_2" op: "FloatOutput"
1978             device: "/job:worker/replica:2" }
1979      node { name: "FloatOutput_0" op: "FloatOutput"
1980             device: "/job:worker/replica:0"
1981             attr { key: "_class"
1982               value { list {
1983                 s: "loc:@ColocateWithMe_0"}}}}
1984      node { name: "FloatOutput_1" op: "FloatOutput"
1985             device: "/job:worker/replica:1"
1986             attr { key: "_class"
1987               value { list {
1988                 s: "loc:@ColocateWithMe_1"}}}}
1989      node { name: "FloatOutput_2" op: "FloatOutput"
1990             device: "/job:worker/replica:2"
1991             attr { key: "_class"
1992               value { list {
1993                 s: "loc:@ColocateWithMe_2"}}}}
1994    """, gd)
1995
1996  def testControlDependencies(self):
1997
1998    class DependingThread(self.TestThread):
1999
2000      def __init__(self, graph, replica_id, dependency_op):
2001        super(DependingThread, self).__init__(graph, replica_id)
2002        self._dependency_op = dependency_op
2003
2004      def run(self):
2005        with g.control_dependencies([self._dependency_op]):
2006          self.has_mutated_graph.set()
2007          self.should_continue.wait()
2008          self.should_continue.clear()
2009          g.create_op(
2010              "FloatOutput", [], [dtypes.float32],
2011              name="FloatOutput_{}".format(self._replica_id))
2012
2013    g = ops.Graph()
2014    dependency_ops = []
2015    for i in range(3):
2016      dependency_ops.append(
2017          g.create_op(
2018              "FloatOutput", [], [dtypes.float32],
2019              name="ColocateWithMe_{}".format(i)))
2020
2021    # If `switch_to_thread` isn't called, then `input` values for the ops below
2022    # are not deterministic.
2023    g.switch_to_thread_local()
2024    threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)]
2025    for t in threads:
2026      t.start()
2027      t.has_mutated_graph.wait()
2028      t.has_mutated_graph.clear()
2029    for t in threads:
2030      t.should_continue.set()
2031      t.join()
2032
2033    gd = g.as_graph_def()
2034    self.assertProtoEqualsVersion("""
2035      node { name: "ColocateWithMe_0" op: "FloatOutput" }
2036      node { name: "ColocateWithMe_1" op: "FloatOutput" }
2037      node { name: "ColocateWithMe_2" op: "FloatOutput" }
2038      node { name: "FloatOutput_0" op: "FloatOutput"
2039             input: "^ColocateWithMe_0" }
2040      node { name: "FloatOutput_1" op: "FloatOutput"
2041             input: "^ColocateWithMe_1" }
2042      node { name: "FloatOutput_2" op: "FloatOutput"
2043             input: "^ColocateWithMe_2" }
2044    """, gd)
2045
2046  def testNameStack(self):
2047
2048    class NameSettingThread(self.TestThread):
2049
2050      def run(self):
2051        with g.name_scope("foo"):
2052          op1 = g.create_op("FloatOutput", [], [dtypes.float32])
2053          self.has_mutated_graph.set()
2054          self.should_continue.wait()
2055          self.should_continue.clear()
2056          op2 = g.create_op("FloatOutput", [], [dtypes.float32])
2057          self.result = (op1, op2)
2058
2059    g = ops.Graph()
2060    threads = [NameSettingThread(g, i) for i in range(3)]
2061    for t in threads:
2062      t.start()
2063      t.has_mutated_graph.wait()
2064      t.has_mutated_graph.clear()
2065
2066    for t in threads:
2067      t.should_continue.set()
2068      t.join()
2069
2070    suffixes = ["", "_1", "_2"]
2071    for t, s in zip(threads, suffixes):
2072      self.assertEqual("foo" + s + "/FloatOutput", t.result[0].name)
2073      self.assertEqual("foo" + s + "/FloatOutput_1", t.result[1].name)
2074
2075
2076class ObjectWithName(object):
2077
2078  def __init__(self, name):
2079    self._name = name
2080
2081  @property
2082  def name(self):
2083    return self._name
2084
2085
2086class CollectionTest(test_util.TensorFlowTestCase):
2087
2088  def test_get_collections(self):
2089    g = ops.Graph()
2090    self.assertSequenceEqual(g.collections, [])
2091    g.add_to_collection("key", 12)
2092    g.add_to_collection("key", 15)
2093    self.assertSequenceEqual(g.collections, ["key"])
2094    g.add_to_collection("other", "foo")
2095    self.assertSequenceEqual(sorted(g.collections), ["key", "other"])
2096    self.assertSequenceEqual(
2097        sorted(g.get_all_collection_keys()), ["key", "other"])
2098
2099  def test_add_to_collection(self):
2100    g = ops.Graph()
2101    g.add_to_collection("key", 12)
2102    g.add_to_collection("other", "foo")
2103    g.add_to_collection("key", 34)
2104
2105    # Note that only blank1 is returned.
2106    g.add_to_collection("blah", 27)
2107    blank1 = ObjectWithName("prefix/foo")
2108    g.add_to_collection("blah", blank1)
2109    blank2 = ObjectWithName("junk/foo")
2110    g.add_to_collection("blah", blank2)
2111
2112    self.assertEqual([12, 34], g.get_collection("key"))
2113    self.assertEqual([], g.get_collection("nothing"))
2114    self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
2115    self.assertEqual([blank1], g.get_collection("blah", "prefix"))
2116    self.assertEqual([blank1], g.get_collection("blah", ".*x"))
2117
2118    # Make sure that get_collection() returns a first-level
2119    # copy of the collection, while get_collection_ref() returns
2120    # the original list.
2121    other_collection_snapshot = g.get_collection("other")
2122    other_collection_ref = g.get_collection_ref("other")
2123    self.assertEqual(["foo"], other_collection_snapshot)
2124    self.assertEqual(["foo"], other_collection_ref)
2125    g.add_to_collection("other", "bar")
2126    self.assertEqual(["foo"], other_collection_snapshot)
2127    self.assertEqual(["foo", "bar"], other_collection_ref)
2128    self.assertEqual(["foo", "bar"], g.get_collection("other"))
2129    self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
2130
2131    # Verify that getting an empty collection ref returns a modifiable list.
2132    empty_coll_ref = g.get_collection_ref("empty")
2133    self.assertEqual([], empty_coll_ref)
2134    empty_coll = g.get_collection("empty")
2135    self.assertEqual([], empty_coll)
2136    self.assertFalse(empty_coll is empty_coll_ref)
2137    empty_coll_ref2 = g.get_collection_ref("empty")
2138    self.assertTrue(empty_coll_ref2 is empty_coll_ref)
2139    # Add to the collection.
2140    empty_coll_ref.append("something")
2141    self.assertEqual(["something"], empty_coll_ref)
2142    self.assertEqual(["something"], empty_coll_ref2)
2143    self.assertEqual([], empty_coll)
2144    self.assertEqual(["something"], g.get_collection("empty"))
2145    empty_coll_ref3 = g.get_collection_ref("empty")
2146    self.assertTrue(empty_coll_ref3 is empty_coll_ref)
2147
2148  def test_add_to_collections_uniquify(self):
2149    g = ops.Graph()
2150    g.add_to_collections([1, 2, 1], "key")
2151    # Make sure "key" is not added twice
2152    self.assertEqual(["key"], g.get_collection(1))
2153
2154  def test_add_to_collections_from_list(self):
2155    g = ops.Graph()
2156    g.add_to_collections(["abc", "123"], "key")
2157    self.assertEqual(["key"], g.get_collection("abc"))
2158    self.assertEqual(["key"], g.get_collection("123"))
2159
2160  def test_add_to_collections_from_tuple(self):
2161    g = ops.Graph()
2162    g.add_to_collections(("abc", "123"), "key")
2163    self.assertEqual(["key"], g.get_collection("abc"))
2164    self.assertEqual(["key"], g.get_collection("123"))
2165
2166  def test_add_to_collections_from_generator(self):
2167    g = ops.Graph()
2168
2169    def generator():
2170      yield "abc"
2171      yield "123"
2172
2173    g.add_to_collections(generator(), "key")
2174    self.assertEqual(["key"], g.get_collection("abc"))
2175    self.assertEqual(["key"], g.get_collection("123"))
2176
2177  def test_add_to_collections_from_set(self):
2178    g = ops.Graph()
2179    g.add_to_collections(set(["abc", "123"]), "key")
2180    self.assertEqual(["key"], g.get_collection("abc"))
2181    self.assertEqual(["key"], g.get_collection("123"))
2182
2183  def test_add_to_collections_from_string(self):
2184    g = ops.Graph()
2185    g.add_to_collections("abc", "key")
2186    self.assertEqual(["key"], g.get_collection("abc"))
2187
2188  def test_default_graph(self):
2189    with ops.Graph().as_default():
2190      ops.add_to_collection("key", 90)
2191      ops.add_to_collection("key", 100)
2192      # Collections are ordered.
2193      self.assertEqual([90, 100], ops.get_collection("key"))
2194
2195  def test_defun(self):
2196    with context.eager_mode():
2197
2198      @eager_function.defun
2199      def defun():
2200        ops.add_to_collection("int", 1)
2201        ops.add_to_collection("tensor", constant_op.constant(2))
2202
2203        @eager_function.defun
2204        def inner_defun():
2205          self.assertEqual(ops.get_collection("int"), [1])
2206          three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
2207          ops.add_to_collection("int", 2)
2208          self.assertEqual(ops.get_collection("int"), [1, 2])
2209          ops.add_to_collection("foo", "bar")
2210          self.assertEqual(ops.get_collection("foo"), ["bar"])
2211          return three
2212
2213        self.assertEqual(ops.get_collection("int"), [1])
2214        three = inner_defun()
2215        self.assertEqual(ops.get_collection("int"), [1])
2216        self.assertEqual(ops.get_collection("foo"), [])
2217        return three
2218
2219      three = defun()
2220      self.assertEqual(three.numpy(), 3)
2221
2222
2223ops.NotDifferentiable("FloatOutput")
2224
2225
2226@ops.RegisterGradient("CopyOp")
2227def _CopyGrad(op, x_grad):  # pylint: disable=invalid-name
2228  _ = op
2229  return x_grad
2230
2231
2232@ops.RegisterGradient("copy_override")
2233def _CopyOverrideGrad(op, x_grad):  # pylint: disable=invalid-name
2234  _ = op
2235  return x_grad
2236
2237
2238class RegistrationTest(test_util.TensorFlowTestCase):
2239
2240  @test_util.run_deprecated_v1
2241  def testRegisterGradients(self):
2242    x = test_ops.float_output()
2243    y = test_ops.copy_op(x)
2244    fn = ops.get_gradient_function(y.op)
2245    self.assertEqual(_CopyGrad, fn)
2246
2247  def testOverrideGradients(self):
2248    g = ops.Graph()
2249    with g.as_default():
2250      x = test_ops.float_output()
2251      with g.gradient_override_map({"CopyOp": "copy_override"}):
2252        y = test_ops.copy_op(x)
2253      fn = ops.get_gradient_function(y.op)
2254      self.assertEqual(_CopyOverrideGrad, fn)
2255
2256  def testNonExistentOverride(self):
2257    g = ops.Graph()
2258    with g.as_default():
2259      x = test_ops.float_output()
2260      with g.gradient_override_map({"CopyOp": "unknown_override"}):
2261        y = test_ops.copy_op(x)
2262      with self.assertRaisesRegex(LookupError, "unknown_override"):
2263        ops.get_gradient_function(y.op)
2264
2265
2266class ComparisonTest(test_util.TensorFlowTestCase):
2267
2268  def testMembershipAllowed(self):
2269    g = ops.Graph()
2270    t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
2271    t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2")
2272    self.assertTrue(isinstance(t1, ops.Tensor))
2273    self.assertTrue(isinstance(t2, ops.Tensor))
2274    self.assertTrue(t1 in [t1])
2275    self.assertTrue(t1 not in [t2])
2276
2277
2278class ControlDependenciesTest(test_util.TensorFlowTestCase):
2279
2280  @test_util.run_deprecated_v1
2281  def testBasic(self):
2282    g = ops.Graph()
2283    with g.as_default():
2284      # Creating unregistered ops with _apply_op() doesn't work with the C API
2285      # TODO(skyewm): address this more consistently. Possible solutions are
2286      # to use registered ops in all tests, create a way to register ops in
2287      # Python tests, or conditionally disable the op registration check in
2288      # the C API.
2289      a = constant_op.constant(1.0)
2290      b = constant_op.constant(1.0)
2291      with g.control_dependencies([a]):
2292        c = constant_op.constant(1.0)
2293        d = array_ops.identity(b)
2294        e = array_ops.identity(c)
2295
2296    self.assertEqual(c.op.control_inputs, [a.op])
2297    self.assertEqual(d.op.control_inputs, [a.op])
2298    # e should be dominated by c.
2299    self.assertEqual(e.op.control_inputs, [])
2300
2301  @test_util.run_in_graph_and_eager_modes
2302  def testEager(self):
2303    def future():
2304      future.calls += 1
2305      return constant_op.constant(2.0)
2306    future.calls = 0
2307
2308    if context.executing_eagerly():
2309      a = constant_op.constant(1.0)
2310      b = future
2311      with ops.control_dependencies([a, b]):
2312        c = constant_op.constant(3.0)
2313      self.assertEqual(future.calls, 1)
2314    else:
2315      g = ops.Graph()
2316      with g.as_default():
2317        a = constant_op.constant(1.0)
2318        b = future()
2319        with g.control_dependencies([a, b]):
2320          c = constant_op.constant(3.0)
2321      self.assertEqual(c.op.control_inputs, [a.op, b.op])
2322      self.assertEqual(future.calls, 1)
2323
2324  def testBasicWithConversion(self):
2325    g = ops.Graph()
2326    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2327
2328    class ConvertibleObj(object):
2329
2330      def _as_graph_element(self):
2331        return a
2332
2333    with g.control_dependencies([ConvertibleObj()]):
2334      c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2335
2336    self.assertEqual(c.op.control_inputs, [a.op])
2337
2338  def testNested(self):
2339    g = ops.Graph()
2340    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2341    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2342    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2343    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2344
2345    with g.control_dependencies([a_1, a_2, a_3, a_4]):
2346      b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2347
2348    with g.control_dependencies([a_1]):
2349      with g.control_dependencies([a_2]):
2350        with g.control_dependencies([a_3]):
2351          with g.control_dependencies([a_4]):
2352            b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2353
2354    self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
2355                          b_1.op.control_inputs)
2356    self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
2357
2358  def testClear(self):
2359    g = ops.Graph()
2360    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2361    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2362    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2363    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2364
2365    with g.control_dependencies([a_1]):
2366      with g.control_dependencies([a_2]):
2367        with g.control_dependencies(None):
2368          with g.control_dependencies([a_3]):
2369            with g.control_dependencies([a_4]):
2370              # deps [a_3, a_4]
2371              b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2372            # deps = [a_3]
2373            b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2374          # deps back to None
2375          b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2376        # deps back to [a_1, a_2]
2377        b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2378      # deps back to [a_1]
2379      b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2380      with g.control_dependencies(None):
2381        # deps are None again
2382        b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2383
2384    self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
2385    self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
2386    self.assertItemsEqual([], b_none.op.control_inputs)
2387    self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
2388    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
2389    self.assertItemsEqual([], b_none2.op.control_inputs)
2390
2391  def testComplex(self):
2392    g = ops.Graph()
2393
2394    # Usage pattern:
2395    # * Nodes a_i are constants defined at the outermost scope, and are used
2396    #   as control inputs for the ith nested scope.
2397    # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
2398    # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
2399    # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
2400    # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
2401
2402    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2403    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2404    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2405    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2406
2407    with g.control_dependencies([a_1]):
2408      b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
2409                      [dtypes.float32])
2410      c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
2411                      [dtypes.float32])
2412      d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
2413                      [dtypes.float32])
2414      e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2415      with g.control_dependencies([a_2]):
2416        b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
2417                        [dtypes.float32])
2418        c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
2419                        [dtypes.float32])
2420        d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
2421                        [dtypes.float32])
2422        e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
2423                        [dtypes.float32])
2424        with g.control_dependencies([a_3]):
2425          b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
2426                          [dtypes.float32])
2427          c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
2428                          [dtypes.float32])
2429          d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
2430                          [dtypes.float32])
2431          e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
2432                          [dtypes.float32])
2433          with g.control_dependencies([a_4]):
2434            b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
2435                            [dtypes.float32])
2436            c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
2437                            [dtypes.float32])
2438            d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
2439                            [dtypes.float32])
2440            e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
2441                            [dtypes.float32])
2442
2443    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
2444    self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
2445    self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
2446    self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
2447
2448    self.assertItemsEqual([], c_1.op.control_inputs)
2449    self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
2450    self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
2451    self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
2452
2453    self.assertItemsEqual([], d_1.op.control_inputs)
2454    self.assertItemsEqual([], d_2.op.control_inputs)
2455    self.assertItemsEqual([], d_3.op.control_inputs)
2456    self.assertItemsEqual([], d_4.op.control_inputs)
2457
2458    self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
2459    self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
2460    self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
2461    self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
2462
2463  def testRepeatedDependency(self):
2464    g = ops.Graph()
2465    a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
2466    a_0, a_1 = a.outputs
2467    with g.control_dependencies([a_0]):
2468      b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2469      with g.control_dependencies([a_1]):
2470        c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2471
2472    self.assertEqual(b.op.control_inputs, [a])
2473    self.assertEqual(c.op.control_inputs, [a])
2474
2475  def testNoControlDependencyWithDataDependency(self):
2476    g = ops.Graph()
2477    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2478    with g.control_dependencies([a]):
2479      b = _apply_op(g, "Identity", [a], [dtypes.float32])
2480
2481    self.assertEqual(b.op.control_inputs, [])
2482
2483
2484class OpScopeTest(test_util.TensorFlowTestCase):
2485
2486  @test_util.run_in_graph_and_eager_modes
2487  def testNames(self):
2488    with ops.name_scope("foo", skip_on_eager=False) as foo:
2489      self.assertEqual("foo/", foo)
2490      with ops.name_scope("foo2", skip_on_eager=False) as foo2:
2491        self.assertEqual("foo/foo2/", foo2)
2492      with ops.name_scope(None, skip_on_eager=False) as empty1:
2493        self.assertEqual("", empty1)
2494        with ops.name_scope("foo3", skip_on_eager=False) as foo3:
2495          self.assertEqual("foo3/", foo3)
2496      with ops.name_scope("", skip_on_eager=False) as empty2:
2497        self.assertEqual("", empty2)
2498    with ops.name_scope("foo/", skip_on_eager=False) as outer_foo:
2499      self.assertEqual("foo/", outer_foo)
2500      with ops.name_scope("", skip_on_eager=False) as empty3:
2501        self.assertEqual("", empty3)
2502      with ops.name_scope("foo4", skip_on_eager=False) as foo4:
2503        self.assertEqual("foo/foo4/", foo4)
2504      with ops.name_scope("foo5//", skip_on_eager=False) as foo5:
2505        self.assertEqual("foo5//", foo5)
2506        with ops.name_scope("foo6", skip_on_eager=False) as foo6:
2507          self.assertEqual("foo5//foo6/", foo6)
2508      with ops.name_scope("/", skip_on_eager=False) as foo7:
2509        self.assertEqual("/", foo7)
2510      with ops.name_scope("//", skip_on_eager=False) as foo8:
2511        self.assertEqual("//", foo8)
2512      with ops.name_scope("a//b/c", skip_on_eager=False) as foo9:
2513        self.assertEqual("foo/a//b/c/", foo9)
2514    with ops.name_scope("a//b/c", skip_on_eager=False) as foo10:
2515      self.assertEqual("a//b/c/", foo10)
2516
2517  @test_util.run_in_graph_and_eager_modes
2518  def testEagerDefaultScopeName(self):
2519    with ops.name_scope(None, "default", skip_on_eager=False) as scope:
2520      self.assertEqual(scope, "default/")
2521      with ops.name_scope(None, "default2", skip_on_eager=False) as scope2:
2522        self.assertEqual(scope2, "default/default2/")
2523
2524  @test_util.run_in_graph_and_eager_modes
2525  def testNameScopeV2IsReEntrant(self):
2526    foo = ops.name_scope_v2("foo")
2527    bar = ops.name_scope_v2("bar")
2528    with foo as scope_name:
2529      self.assertEqual("foo/", scope_name)
2530      with foo as scope_name:
2531        self.assertEqual("foo/foo/", scope_name)
2532      with bar as scope_name:
2533        self.assertEqual("foo/bar/", scope_name)
2534        with foo as scope_name:
2535          self.assertEqual("foo/bar/foo/", scope_name)
2536    with bar as scope_name:
2537      self.assertEqual("bar/", scope_name)
2538
2539  @test_util.run_deprecated_v1
2540  def testNoScopeName(self):
2541    g0 = ops.Graph()
2542    values = [
2543        g0.create_op("A", [], [dtypes.float32]),
2544        g0.create_op("B", [], [dtypes.float32])
2545    ]
2546    with self.assertRaises(ValueError):
2547      with ops.name_scope(None, values=values):
2548        pass
2549    with self.assertRaises(ValueError):
2550      with ops.name_scope(None, None, values):
2551        pass
2552
2553  @test_util.run_deprecated_v1
2554  def testEmptyScopeName(self):
2555    g0 = ops.Graph()
2556    a = g0.create_op("A", [], [dtypes.float32])
2557    b = g0.create_op("B", [], [dtypes.float32])
2558    with ops.name_scope("", values=[a, b]) as scope:
2559      self.assertEqual("", scope)
2560      self.assertEqual(g0, ops.get_default_graph())
2561    with ops.name_scope("", "my_default_scope", [a, b]) as scope:
2562      self.assertEqual("", scope)
2563      self.assertEqual(g0, ops.get_default_graph())
2564
2565  @test_util.run_deprecated_v1
2566  def testDefaultScopeName(self):
2567    g0 = ops.Graph()
2568    a = g0.create_op("A", [], [dtypes.float32])
2569    b = g0.create_op("B", [], [dtypes.float32])
2570    scope_name = "my_scope"
2571    default_scope_name = "my_default_scope"
2572    with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope:
2573      self.assertEqual("%s/" % scope_name, scope)
2574      self.assertEqual(g0, ops.get_default_graph())
2575    with ops.name_scope(None, default_scope_name, [a, b]) as scope:
2576      self.assertEqual("%s/" % default_scope_name, scope)
2577      self.assertEqual(g0, ops.get_default_graph())
2578    with self.assertRaises(TypeError):
2579      with ops.name_scope(scope_name, [a, b]):
2580        pass
2581
2582  def _testGraphElements(self, graph_elements):
2583    scope_name = "my_scope"
2584    with ops.name_scope(scope_name, values=graph_elements) as scope:
2585      self.assertEqual("%s/" % scope_name, scope)
2586      self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
2587    g1 = ops.Graph()
2588    a = g1.create_op("A", [], [dtypes.float32])
2589    with self.assertRaises(ValueError):
2590      with ops.name_scope(scope_name, values=graph_elements + [a]):
2591        pass
2592
2593  @test_util.run_deprecated_v1
2594  def testTensor(self):
2595    g0 = ops.Graph()
2596    a = g0.create_op("A", [], [dtypes.float32])
2597    b = g0.create_op("B", [], [dtypes.float32])
2598    self._testGraphElements([a, b])
2599
2600  @test_util.run_deprecated_v1
2601  def testSparseTensor(self):
2602    g0 = ops.Graph()
2603    a = g0.create_op("A", [], [dtypes.float32])
2604    b = g0.create_op("B", [], [dtypes.float32])
2605    sparse = sparse_tensor.SparseTensor(
2606        _apply_op(g0, "Int64Output", [], [dtypes.int64]),
2607        _apply_op(g0, "FloatOutput", [], [dtypes.float32]),
2608        _apply_op(g0, "Int64Output", [], [dtypes.int64]))
2609    self._testGraphElements([a, sparse, b])
2610
2611  @test_util.run_deprecated_v1
2612  def testVariable(self):
2613    g0 = ops.Graph()
2614    with g0.as_default():
2615      variable = variables.Variable([1.0])
2616    a = g0.create_op("A", [], [dtypes.float32])
2617    b = g0.create_op("B", [], [dtypes.float32])
2618    self._testGraphElements([a, variable, b])
2619
2620
2621class InitScopeTest(test_util.TensorFlowTestCase):
2622
2623  def testClearsControlDependencies(self):
2624    g = ops.Graph()
2625    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2626    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2627    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2628    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2629
2630    with g.as_default():
2631      with g.control_dependencies([a_1]):
2632        with g.control_dependencies([a_2]):
2633          with ops.init_scope():
2634            with g.control_dependencies([a_3]):
2635              with g.control_dependencies([a_4]):
2636                # deps [a_3, a_4]
2637                b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2638              # deps = [a_3]
2639              b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2640            # deps back to None
2641            b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2642          # deps back to [a_1, a_2]
2643          b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2644        # deps back to [a_1]
2645        b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2646        with ops.init_scope():
2647          # deps are None again
2648          b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2649
2650    self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
2651    self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
2652    self.assertItemsEqual([], b_none.op.control_inputs)
2653    self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
2654    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
2655    self.assertItemsEqual([], b_none2.op.control_inputs)
2656
2657  def testLiftsOpsFromFunctions(self):
2658    g0 = ops.Graph()
2659    g1 = ops.Graph()
2660    g1._building_function = True  # pylint: disable=protected-access
2661    g2 = ops.Graph()
2662    g2._building_function = True  # pylint: disable=protected-access
2663
2664    with g0.as_default():
2665      with g1.as_default():
2666        with g2.as_default():
2667          with ops.init_scope():
2668            _ = constant_op.constant(1.0)
2669
2670    self.assertEqual(len(g2.get_operations()), 0)
2671    self.assertEqual(len(g1.get_operations()), 0)
2672    self.assertEqual(len(g0.get_operations()), 1)
2673
2674  def testPreservesDevices(self):
2675    g0 = ops.Graph()
2676    with g0.as_default(), ops.device("CPU:0"):
2677      g1 = ops.Graph()
2678      g1._building_function = True  # pylint: disable=protected-access
2679      with g1.as_default():
2680        with ops.device("GPU:0"):
2681          with ops.init_scope():
2682            # init_scope should preserve device set under `g1`.
2683            on_gpu = constant_op.constant(1.0)
2684            self.assertEqual(on_gpu.device, "/device:GPU:0")
2685          still_on_gpu = constant_op.constant(1.0)
2686          self.assertEqual(still_on_gpu.device, "/device:GPU:0")
2687        blank = constant_op.constant(1.0)
2688        self.assertEqual(blank.device, "")
2689        with ops.init_scope():
2690          now_on_cpu = constant_op.constant(1.0)
2691          self.assertEqual(now_on_cpu.device, "/device:CPU:0")
2692      on_cpu = constant_op.constant(1.0)
2693      self.assertEqual(on_cpu.device, "/device:CPU:0")
2694
2695  def testComposes(self):
2696    g0 = ops.Graph()
2697    g1 = ops.Graph()
2698    g1._building_function = True  # pylint: disable=protected-access
2699    g2 = ops.Graph()
2700    g2._building_function = True  # pylint: disable=protected-access
2701    g3 = ops.Graph()
2702    g3._building_function = False  # pylint: disable=protected-access
2703
2704    with g0.as_default():
2705      with g1.as_default():
2706        with ops.init_scope():
2707          # This op should be lifted into g0.
2708          _ = constant_op.constant(1.0)
2709          self.assertIs(g0, ops.get_default_graph())
2710          self.assertEqual(len(g2.get_operations()), 0)
2711          self.assertEqual(len(g1.get_operations()), 0)
2712          self.assertEqual(len(g0.get_operations()), 1)
2713        with g2.as_default():
2714          with ops.init_scope():
2715            # This op should be lifted into g0.
2716            _ = constant_op.constant(1.0)
2717            self.assertIs(g0, ops.get_default_graph())
2718            with g3.as_default():
2719              with ops.init_scope():
2720                # This op should be lifted into g3, because g3 is not building a
2721                # function.
2722                _ = constant_op.constant(1.0)
2723                self.assertIs(g3, ops.get_default_graph())
2724
2725    self.assertEqual(len(g3.get_operations()), 1)
2726    self.assertEqual(len(g2.get_operations()), 0)
2727    self.assertEqual(len(g1.get_operations()), 0)
2728    self.assertEqual(len(g0.get_operations()), 2)
2729
2730  def testEscapesToEagerContext(self):
2731    g = ops.Graph()
2732    g._building_function = True  # pylint: disable=protected-access
2733    with context.eager_mode():
2734      with context.graph_mode():
2735        with g.as_default():
2736          with ops.init_scope():
2737            # Because g is building a function, init_scope should
2738            # escape out to the eager context.
2739            self.assertTrue(context.executing_eagerly())
2740          # g should be reinstated as the default graph, and the
2741          # graph context should be re-entered.
2742          self.assertIs(g, ops.get_default_graph())
2743          self.assertFalse(context.executing_eagerly())
2744
2745  def testStaysInEagerWhenOnlyEagerContextActive(self):
2746    with context.eager_mode():
2747      with ops.init_scope():
2748        self.assertTrue(context.eager_mode())
2749      self.assertTrue(context.eager_mode())
2750
2751  def testEscapesDefunWhenInEagerMode(self):
2752
2753    def function_with_variables():
2754      with ops.init_scope():
2755        self.v = resource_variable_ops.ResourceVariable(3)
2756      return self.v.assign_add(1)
2757
2758    with context.eager_mode():
2759      # Each invocation of function_with_variables recreates a variable.
2760      self.assertEqual(4, int(function_with_variables()))
2761      self.assertEqual(4, int(function_with_variables()))
2762
2763      compiled = eager_function.defun(function_with_variables)
2764      # The init_scope in function_with_variables lifts the variable out
2765      # of the graph function constructed by defun; hence,
2766      # compiled now appears to be stateful.
2767      self.assertEqual(4, int(compiled()))
2768      self.assertEqual(5, int(compiled()))
2769
2770  def testEscapesDefunWhenInGraphMode(self):
2771    def function_with_variables(name):
2772      with ops.init_scope():
2773        _ = variable_scope.get_variable(name, shape=(1,))
2774
2775    g = ops.Graph()
2776    with g.as_default():
2777      with self.cached_session():
2778        # First ensure that graphs that are not building functions are
2779        # not escaped.
2780        function_with_variables("foo")
2781        with self.assertRaisesRegex(ValueError,
2782                                    r"Variable foo already exists.*"):
2783          # This will fail because reuse is not set to True.
2784          function_with_variables("foo")
2785
2786        compiled = eager_function.defun(function_with_variables)
2787        compiled("bar")
2788        self.assertEqual(
2789            len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
2790
2791        # The second call to `compiled` should not create variables: the
2792        # init_scope has lifted the variable creation code out of the defun.
2793        compiled("bar")
2794        self.assertEqual(
2795            len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
2796
2797  def testEscapesNestedDefun(self):
2798
2799    def inner_function():
2800      with ops.init_scope():
2801        self.v = resource_variable_ops.ResourceVariable(1)
2802      return self.v.assign_add(2)
2803
2804    def outer_function(inner=None):
2805      with ops.init_scope():
2806        self.v0 = resource_variable_ops.ResourceVariable(0)
2807      return self.v0.assign_add(1) + inner()
2808
2809    with context.eager_mode():
2810      # Each invocation of outer_function recreates variables.
2811      self.assertEqual(4, int(outer_function(inner=inner_function)))
2812      self.assertEqual(4, int(outer_function(inner=inner_function)))
2813
2814      compiled_inner = eager_function.defun(inner_function)
2815      compiled_outer = eager_function.defun(outer_function)
2816      # The init_scope lifts variables out of the graph functions
2817      # constructed by defun; hence, compiled_outer should now appear to be
2818      # stateful.
2819      self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
2820      self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
2821
2822  @test_util.run_v1_only("b/120545219")
2823  def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
2824    with context.graph_mode():
2825      ops.reset_default_graph()
2826      # This doesn't push anything onto the graph stack, but it does
2827      # set the stack's global graph.
2828      global_graph = ops.get_default_graph()
2829      fn_graph = ops.Graph()
2830
2831      # pylint: disable=protected-access
2832      fn_graph._building_function = True
2833      self.assertEqual(len(ops._default_graph_stack.stack), 0)
2834      with fn_graph.as_default():
2835        self.assertEqual(len(ops._default_graph_stack.stack), 1)
2836        with ops.init_scope():
2837          self.assertGreater(len(ops._default_graph_stack.stack), 1)
2838          dummy = constant_op.constant(1.0)
2839        self.assertEqual(len(ops._default_graph_stack.stack), 1)
2840      # Note that the global graph is _not_ on the graph stack.
2841      self.assertEqual(len(ops._default_graph_stack.stack), 0)
2842      # Ensure that `dummy` was added to the global graph.
2843      self.assertEqual(global_graph, dummy.graph)
2844      # pylint: enable=protected-access
2845
2846  def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self):
2847    with context.graph_mode():
2848      # pylint: disable=protected-access
2849      self.assertEqual(len(ops._default_graph_stack.stack), 0)
2850      with ops.init_scope():
2851        self.assertGreater(len(ops._default_graph_stack.stack), 0)
2852      self.assertEqual(len(ops._default_graph_stack.stack), 0)
2853      # pylint: enable=protected-access
2854
2855  def testPreservesNameScopeInGraphConstruction(self):
2856    with ops.Graph().as_default():
2857      function_graph = ops.Graph()
2858      with function_graph.as_default():
2859        with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
2860          self.assertEqual(ops.get_name_scope(), "inner")
2861      self.assertEqual(ops.get_name_scope(), "")
2862
2863  def testEnteringGraphFromEagerIsSticky(self):
2864    with context.eager_mode():
2865      g = ops.Graph()
2866      with g.as_default():
2867        with ops.init_scope():
2868          self.assertFalse(context.executing_eagerly())
2869          self.assertEqual(g, ops.get_default_graph())
2870
2871  def testMixGraphEager(self):
2872    with context.eager_mode():
2873      c = constant_op.constant(1.0)
2874      with ops.Graph().as_default():
2875        with self.assertRaisesRegex(RuntimeError,
2876                                    "Attempting to capture an EagerTensor"):
2877          math_ops.add(c, c)
2878        c2 = constant_op.constant(2.0)
2879      with self.assertRaisesRegex(TypeError, "Graph tensors"):
2880        math_ops.add(c2, c2)
2881
2882  def testPreservesNameScopeInEagerExecution(self):
2883    with context.eager_mode():
2884      def foo():
2885        with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
2886          if context.executing_eagerly():
2887            # A trailing slash is always appended when eager execution is
2888            # enabled.
2889            self.assertEqual(context.context().scope_name, "inner/")
2890          else:
2891            self.assertEqual(ops.get_name_scope(), "inner")
2892
2893      foo()
2894      self.assertEqual(ops.get_name_scope(), "")
2895      foo_compiled = eager_function.defun(foo)
2896      foo_compiled()
2897      self.assertEqual(ops.get_name_scope(), "")
2898
2899  def testExecutingEagerlyOutsideFunctions(self):
2900
2901    @def_function.function
2902    def f():
2903      return ops.executing_eagerly_outside_functions()
2904
2905    with context.graph_mode():
2906      self.assertFalse(ops.executing_eagerly_outside_functions())
2907      with session.Session():
2908        # Need self.evaluate for these as the return type of functions is
2909        # tensors.
2910        self.assertFalse(self.evaluate(f()))
2911
2912    with context.eager_mode():
2913      self.assertTrue(ops.executing_eagerly_outside_functions())
2914      self.assertTrue(f())
2915
2916      with ops.Graph().as_default():
2917        self.assertFalse(ops.executing_eagerly_outside_functions())
2918        with session.Session():
2919          self.assertFalse(self.evaluate(f()))
2920
2921
2922class GraphTest(test_util.TensorFlowTestCase):
2923
2924  def setUp(self):
2925    ops.reset_default_graph()
2926
2927  def _AssertDefault(self, expected):
2928    self.assertIs(expected, ops.get_default_graph())
2929
2930  def testResetDefaultGraphNesting(self):
2931    g0 = ops.Graph()
2932    with self.assertRaises(AssertionError):
2933      with g0.as_default():
2934        ops.reset_default_graph()
2935
2936  def testGraphContextManagerCancelsEager(self):
2937    with context.eager_mode():
2938      with ops.Graph().as_default():
2939        self.assertFalse(context.executing_eagerly())
2940
2941  def testGraphContextManager(self):
2942    g0 = ops.Graph()
2943    with g0.as_default() as g1:
2944      self.assertIs(g0, g1)
2945
2946  def testDefaultGraph(self):
2947    orig = ops.get_default_graph()
2948    self.assertFalse(ops.has_default_graph())
2949    self._AssertDefault(orig)
2950    g0 = ops.Graph()
2951    self.assertFalse(ops.has_default_graph())
2952    self._AssertDefault(orig)
2953    context_manager_0 = g0.as_default()
2954    self.assertFalse(ops.has_default_graph())
2955    self._AssertDefault(orig)
2956    with context_manager_0 as g0:
2957      self._AssertDefault(g0)
2958      with ops.Graph().as_default() as g1:
2959        self.assertTrue(ops.has_default_graph())
2960        self._AssertDefault(g1)
2961      self._AssertDefault(g0)
2962    self._AssertDefault(orig)
2963    self.assertFalse(ops.has_default_graph())
2964
2965  def testPreventFeeding(self):
2966    g = ops.Graph()
2967    a = constant_op.constant(2.0)
2968    self.assertTrue(g.is_feedable(a))
2969    g.prevent_feeding(a)
2970    self.assertFalse(g.is_feedable(a))
2971
2972  @test_util.run_deprecated_v1
2973  def testPreventFetching(self):
2974    g = ops.Graph()
2975    a = constant_op.constant(2.0)
2976    self.assertTrue(g.is_fetchable(a))
2977    g.prevent_fetching(a.op)
2978    self.assertFalse(g.is_fetchable(a))
2979
2980  def testAsGraphElementConversions(self):
2981
2982    class ConvertibleObj(object):
2983
2984      def _as_graph_element(self):
2985        return "FloatOutput:0"
2986
2987    class NonConvertibleObj(object):
2988
2989      pass
2990
2991    g = ops.Graph()
2992    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2993    self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
2994    with self.assertRaises(TypeError):
2995      g.as_graph_element(NonConvertibleObj())
2996
2997  # Regression test against creating custom __del__ functions in classes
2998  # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
2999  # cycles that require calling a __del__ method, because the __del__ method can
3000  # theoretically increase the object's refcount to "save" it from gc, and any
3001  # already-deleted objects in the cycle would have be to restored.)
3002  def testGarbageCollected(self):
3003    # Create a graph we can delete and a weak reference to monitor if it's gc'd
3004    g = ops.Graph()
3005    g_ref = weakref.ref(g)
3006    # Create some ops
3007    with g.as_default():
3008      a = constant_op.constant(2.0)
3009      b = constant_op.constant(3.0)
3010      c = math_ops.add(a, b)
3011    # Create a session we can delete
3012    with session.Session(graph=g) as sess:
3013      self.evaluate(c)
3014    # Delete all references and trigger gc
3015    del g
3016    del a
3017    del b
3018    del c
3019    del sess
3020    gc.collect()
3021    self.assertIsNone(g_ref())
3022
3023  def testRunnableAfterInvalidShape(self):
3024    with ops.Graph().as_default():
3025      with self.assertRaises(ValueError):
3026        math_ops.add([1, 2], [1, 2, 3])
3027      a = constant_op.constant(1)
3028      with session.Session() as sess:
3029        self.evaluate(a)
3030
3031  def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
3032    g = ops.Graph()
3033    with g.as_default():
3034      with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
3035        with self.assertRaises(ValueError):
3036          test_ops.kernel_label_required(1)
3037      a = constant_op.constant(1)
3038      with session.Session() as sess:
3039        self.evaluate(a)
3040
3041
3042class AttrScopeTest(test_util.TensorFlowTestCase):
3043
3044  def _get_test_attrs(self):
3045    x = control_flow_ops.no_op()
3046    try:
3047      a = compat.as_text(x.get_attr("_A"))
3048    except ValueError:
3049      a = None
3050    try:
3051      b = compat.as_text(x.get_attr("_B"))
3052    except ValueError:
3053      b = None
3054    return (a, b)
3055
3056  @test_util.run_deprecated_v1
3057  def testNoLabel(self):
3058    with self.cached_session():
3059      self.assertAllEqual((None, None), self._get_test_attrs())
3060
3061  @test_util.run_deprecated_v1
3062  def testLabelMap(self):
3063    with self.cached_session() as sess:
3064      a1 = self._get_test_attrs()
3065      with sess.graph._attr_scope({
3066          "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
3067      }):
3068        a2 = self._get_test_attrs()
3069        with sess.graph._attr_scope({
3070            "_A": None,
3071            "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar"))
3072        }):
3073          a3 = self._get_test_attrs()
3074          with sess.graph._attr_scope({
3075              "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz"))
3076          }):
3077            a4 = self._get_test_attrs()
3078          a5 = self._get_test_attrs()
3079        a6 = self._get_test_attrs()
3080      a7 = self._get_test_attrs()
3081
3082      self.assertAllEqual((None, None), a1)
3083      self.assertAllEqual(("foo", None), a2)
3084      self.assertAllEqual((None, "bar"), a3)
3085      self.assertAllEqual(("baz", "bar"), a4)
3086      self.assertAllEqual((None, "bar"), a5)
3087      self.assertAllEqual(("foo", None), a6)
3088      self.assertAllEqual((None, None), a7)
3089
3090
3091class KernelLabelTest(test_util.TensorFlowTestCase):
3092
3093  @test_util.run_deprecated_v1
3094  def testNoLabel(self):
3095    with self.cached_session():
3096      self.assertAllEqual(b"My label is: default",
3097                          test_ops.kernel_label().eval())
3098
3099  @test_util.run_deprecated_v1
3100  def testLabelMap(self):
3101    with self.cached_session() as sess:
3102      default_1 = test_ops.kernel_label()
3103      # pylint: disable=protected-access
3104      with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
3105        overload_1_1 = test_ops.kernel_label()
3106        with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
3107          overload_2 = test_ops.kernel_label()
3108          with sess.graph._kernel_label_map({"KernelLabel": ""}):
3109            default_2 = test_ops.kernel_label()
3110        overload_1_2 = test_ops.kernel_label()
3111      # pylint: enable=protected-access
3112      default_3 = test_ops.kernel_label()
3113
3114      self.assertAllEqual(b"My label is: default", self.evaluate(default_1))
3115      self.assertAllEqual(b"My label is: default", self.evaluate(default_2))
3116      self.assertAllEqual(b"My label is: default", self.evaluate(default_3))
3117      self.assertAllEqual(b"My label is: overload_1",
3118                          self.evaluate(overload_1_1))
3119      self.assertAllEqual(b"My label is: overload_1",
3120                          self.evaluate(overload_1_2))
3121      self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2))
3122
3123
3124class AsGraphDefTest(test_util.TensorFlowTestCase):
3125
3126  def testGraphDefVersion(self):
3127    """Test that the graphdef version is plumbed through to kernels."""
3128    with ops.Graph().as_default() as g:
3129      version = g.graph_def_versions.producer
3130      with self.session(graph=g):
3131        v = test_ops.graph_def_version().eval()
3132        self.assertEqual(version, v)
3133
3134  def testAddShapes(self):
3135    with ops.Graph().as_default() as g:
3136      t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [],
3137                                     [dtypes.float32] * 5)
3138      t1.set_shape(None)
3139      t2.set_shape([])
3140      t3.set_shape([None])
3141      t4.set_shape([43, 37])
3142      t5.set_shape([43, None])
3143
3144      b = constant_op.constant(1.0)  # pylint: disable=unused-variable
3145
3146      gd = g.as_graph_def(add_shapes=True)
3147      self.assertProtoEqualsVersion("""
3148      node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
3149        attr {
3150          key: "_output_shapes"
3151          value {
3152            list {
3153              shape { unknown_rank: true }
3154              shape { }
3155              shape { dim { size: -1 } }
3156              shape { dim { size: 43 } dim { size: 37 } }
3157              shape { dim { size: 43 } dim { size: -1 } }
3158            }
3159          }
3160        }
3161      }
3162    node { name: "Const" op: "Const"
3163      attr {
3164        key: "_output_shapes"
3165        value {
3166          list {
3167            shape { }
3168          }
3169        }
3170      }
3171      attr {
3172        key: "dtype"
3173        value { type: DT_FLOAT }
3174      }
3175      attr {
3176        key: "value"
3177        value {
3178          tensor {
3179            dtype: DT_FLOAT
3180            tensor_shape { }
3181         float_val: 1.0  } } } }
3182      """, gd)
3183
3184
3185@ops.RegisterStatistics("a", "flops")
3186def _calc_a_forward_flops(unused_graph, unused_node):
3187  return ops.OpStats("flops", 20)
3188
3189
3190class StatisticsTest(test_util.TensorFlowTestCase):
3191
3192  def testRegisteredNode(self):
3193    graph = ops.Graph()
3194    node = ops._NodeDef("a", "an_a")
3195    flops = ops.get_stats_for_node_def(graph, node, "flops")
3196    self.assertEqual(20, flops.value)
3197    missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
3198    self.assertEqual(None, missing_stat.value)
3199
3200  def testUnregisteredNode(self):
3201    graph = ops.Graph()
3202    node = ops._NodeDef("b", "a_b")
3203    weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
3204    self.assertEqual(None, weight_params.value)
3205
3206  def testAccumulateStatistics(self):
3207    flops_total = ops.OpStats("flops")
3208    self.assertEqual(None, flops_total.value)
3209    second_flops = ops.OpStats("flops", 3)
3210    flops_total += second_flops
3211    self.assertEqual(3, flops_total.value)
3212
3213
3214class DeviceStackTest(test_util.TensorFlowTestCase):
3215
3216  @test_util.run_deprecated_v1
3217  def testBasicDeviceAssignmentMetadata(self):
3218
3219    def device_func(unused_op):
3220      return "/cpu:*"
3221
3222    const_zero = constant_op.constant([0.0], name="zero")
3223    with ops.device("/cpu"):
3224      const_one = constant_op.constant([1.0], name="one")
3225      with ops.device("/cpu:0"):
3226        const_two = constant_op.constant([2.0], name="two")
3227    with ops.device(device_func):
3228      const_three = constant_op.constant(3.0, name="three")
3229
3230    self.assertEqual(0, len(const_zero.op._device_assignments))
3231
3232    one_list = const_one.op._device_assignments
3233    self.assertEqual(1, len(one_list))
3234    self.assertEqual("/cpu", one_list[0].obj)
3235    self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
3236
3237    two_list = const_two.op._device_assignments
3238    self.assertEqual(2, len(two_list))
3239    devices = [t.obj for t in two_list]
3240    self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
3241
3242    three_list = const_three.op._device_assignments
3243    self.assertEqual(1, len(three_list))
3244    func_description = three_list[0].obj
3245    expected_regex = r"device_func<.*ops_test.py, [0-9]+"
3246    self.assertRegex(func_description, expected_regex)
3247
3248  @test_util.run_deprecated_v1
3249  def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
3250
3251    with ops.device("/cpu"):
3252      const_one = constant_op.constant([1.0], name="one")
3253    with ops.get_default_graph().device("/cpu"):
3254      const_two = constant_op.constant([2.0], name="two")
3255
3256    one_metadata = const_one.op._device_assignments[0]
3257    two_metadata = const_two.op._device_assignments[0]
3258
3259    # Verify both types of device assignment return the right stack info.
3260    self.assertRegex("ops_test.py", os.path.basename(one_metadata.filename))
3261    self.assertEqual(one_metadata.filename, two_metadata.filename)
3262    self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
3263
3264
3265class ColocationGroupTest(test_util.TensorFlowTestCase):
3266
3267  @test_util.run_deprecated_v1
3268  def testBasic(self):
3269    a = constant_op.constant([2.0], name="a")
3270    with ops.colocate_with(a.op):
3271      b = constant_op.constant(3.0)
3272    c = constant_op.constant(4.0)
3273    self.assertEqual([b"loc:@a"], a.op.colocation_groups())
3274    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3275    with self.assertRaises(ValueError):
3276      c.op.get_attr("_class")
3277
3278  @test_util.run_deprecated_v1
3279  def testBasicColocationMetadata(self):
3280    const_two = constant_op.constant([2.0], name="two")
3281    with ops.colocate_with(const_two.op):
3282      const_three = constant_op.constant(3.0, name="three")
3283    locations_dict = const_three.op._colocation_dict
3284    self.assertIn("two", locations_dict)
3285    metadata = locations_dict["two"]
3286    self.assertIsNone(metadata.obj)
3287    # Check that this test's filename is recorded as the file containing the
3288    # colocation statement.
3289    self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
3290
3291  @test_util.run_deprecated_v1
3292  def testColocationDeviceInteraction(self):
3293    with ops.device("/cpu:0"):
3294      with ops.device("/device:GPU:0"):
3295        a = constant_op.constant([2.0], name="a")
3296      with ops.colocate_with(a.op):
3297        # 'b' is created in the scope of /cpu:0, but it is
3298        # colocated with 'a', which is on '/device:GPU:0'.  colocate_with
3299        # overrides devices because it is a stronger constraint.
3300        b = constant_op.constant(3.0)
3301    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3302    self.assertEqual(a.op.device, b.op.device)
3303
3304  @test_util.run_deprecated_v1
3305  def testColocationCanonicalization(self):
3306    with ops.device("/device:GPU:0"):
3307      _ = constant_op.constant(2.0)
3308    with ops.device(lambda op: "/device:GPU:0"):
3309      b = constant_op.constant(3.0)
3310    with ops.get_default_graph().colocate_with(b):
3311      with ops.device("/device:GPU:0"):
3312        c = constant_op.constant(4.0)
3313
3314    # A's device will be /device:GPU:0
3315    # B's device will be /device:GPU:0
3316    # C's device will be /device:GPU:0 because it
3317    # inherits B's device name, after canonicalizing the names.
3318    self.assertEqual(b.op.device, c.op.device)
3319
3320  @test_util.run_deprecated_v1
3321  def testLocationOverrides(self):
3322    with ops.device("/cpu:0"):
3323      with ops.device("/device:GPU:0"):
3324        a = constant_op.constant([2.0], name="a")
3325        # Note that this colocation is "redundant", since we are
3326        # within the scope of "/device:GPU:0".  However, we would like to
3327        # preserve in the GraphDef that these two ops should be
3328        # colocated in a portable way.
3329        with ops.colocate_with(a.op):
3330          b = constant_op.constant(3.0)
3331        c = constant_op.constant(4.0)
3332      d = constant_op.constant(5.0)
3333
3334    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3335    self.assertEqual("/device:GPU:0", a.op.device)
3336    self.assertEqual(a.op.device, b.op.device)
3337
3338    # Test that device function stack is restored.
3339    self.assertEqual("/device:GPU:0", c.op.device)
3340    self.assertEqual("/device:CPU:0", d.op.device)
3341
3342  @test_util.run_deprecated_v1
3343  def testNestedColocateWith(self):
3344    a = constant_op.constant([2.0], name="a")
3345    with ops.colocate_with(a.op):
3346      b = constant_op.constant(3.0)
3347      with ops.colocate_with(b.op):
3348        c = constant_op.constant(4.0)
3349    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3350    self.assertEqual([b"loc:@a"], c.op.colocation_groups())
3351
3352  @test_util.run_deprecated_v1
3353  def testMultiColocationGroups(self):
3354    a = constant_op.constant([2.0], name="a")
3355    b = constant_op.constant(3.0, name="b")
3356    with ops.colocate_with(a.op):
3357      with ops.colocate_with(b.op):
3358        c = constant_op.constant(4.0)
3359    self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
3360
3361  @test_util.run_deprecated_v1
3362  def testColocationIgnoreStack(self):
3363    a = constant_op.constant([2.0], name="a")
3364    b = constant_op.constant(3.0, name="b")
3365    with ops.colocate_with(a.op):
3366      with ops.colocate_with(b.op, ignore_existing=True):
3367        c = constant_op.constant(4.0)
3368    self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
3369
3370  @test_util.run_deprecated_v1
3371  def testColocateWithReset(self):
3372    a = constant_op.constant([2.0], name="a")
3373    with ops.colocate_with(a.op):
3374      b = constant_op.constant(3.0, name="b")
3375      with ops.colocate_with(None, ignore_existing=True):
3376        c = constant_op.constant(4.0, name="c")
3377    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3378    self.assertEqual([b"loc:@c"], c.op.colocation_groups())
3379
3380  @test_util.run_deprecated_v1
3381  def testColocateWithInitialNoneThenNested(self):
3382    a = constant_op.constant([2.0], name="a")
3383    with ops.colocate_with(a.op):
3384      with ops.colocate_with(None, ignore_existing=True):
3385        b = constant_op.constant(3.0, name="b")
3386        with ops.colocate_with(b.op):
3387          c = constant_op.constant(4.0, name="c")
3388    self.assertEqual([b"loc:@b"], b.op.colocation_groups())
3389    self.assertEqual([b"loc:@b"], c.op.colocation_groups())
3390
3391  @test_util.run_deprecated_v1
3392  def testColocateVariables(self):
3393    a = variables.Variable([2.0], name="a")
3394    with ops.colocate_with(a.op):
3395      b = variables.Variable([3.0], name="b")
3396    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
3397
3398  @test_util.run_deprecated_v1
3399  def testColocateResourceVariablesInFunction(self):
3400    with ops.device("/device:CPU:0"):
3401      a = resource_variable_ops.ResourceVariable(1.0)
3402
3403    @def_function.function
3404    def f():
3405      with ops.colocate_with(a):
3406        b = array_ops.ones([], name="output")
3407        self.assertEqual("/device:CPU:0", b.op.device)
3408    f()
3409
3410  def testColocateWithVariableInFunction(self):
3411    v = variables.Variable(1.)
3412
3413    @def_function.function
3414    def f():
3415      with ops.colocate_with(v):
3416        return array_ops.ones([], name="output")
3417
3418    f()
3419    graph_def = f.get_concrete_function().graph.as_graph_def()
3420    wrap_function.function_from_graph_def(graph_def, [], ["output"])
3421
3422
3423class DeprecatedTest(test_util.TensorFlowTestCase):
3424
3425  def testSuccess(self):
3426    with ops.Graph().as_default() as g:
3427      test_util.set_producer_version(g, 7)
3428      old = test_ops.old()
3429      with self.session(graph=g):
3430        old.run()
3431
3432  def _error(self):
3433    return ((r"Op Old is not available in GraphDef version %d\. "
3434             r"It has been removed in version 8\. For reasons\.") %
3435            versions.GRAPH_DEF_VERSION)
3436
3437  def testGraphConstructionFail(self):
3438    with ops.Graph().as_default():
3439      with self.assertRaisesRegex(NotImplementedError, self._error()):
3440        test_ops.old()
3441
3442
3443class NameScopeTest(test_util.TensorFlowTestCase):
3444
3445  def testStripAndPrependScope(self):
3446    strs = [
3447        "hidden1/hidden1/weights",  # Same prefix. Should strip.
3448        "hidden1///hidden1/weights",  # Extra "/". Should strip.
3449        "^hidden1/hidden1/weights",  # Same prefix. Should strip.
3450        "loc:@hidden1/hidden1/weights",  # Same prefix. Should strip.
3451        "hhidden1/hidden1/weights",  # Different prefix. Should keep.
3452        "hidden1"
3453    ]  # Not a prefix. Should keep.
3454    expected_striped = [
3455        "hidden1/weights", "hidden1/weights", "^hidden1/weights",
3456        "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1"
3457    ]
3458    expected_prepended = [
3459        "hidden2/hidden1/weights", "hidden2/hidden1/weights",
3460        "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights",
3461        "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1"
3462    ]
3463    name_scope_to_strip = "hidden1"
3464    name_scope_to_add = "hidden2"
3465    for es, ep, s in zip(expected_striped, expected_prepended, strs):
3466      striped = ops.strip_name_scope(s, name_scope_to_strip)
3467      self.assertEqual(es, striped)
3468      self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
3469
3470  def testGetNameScope(self):
3471    with ops.Graph().as_default() as g:
3472      with ops.name_scope("scope1"):
3473        with ops.name_scope("scope2"):
3474          with ops.name_scope("scope3"):
3475            self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
3476          self.assertEqual("scope1/scope2", g.get_name_scope())
3477        self.assertEqual("scope1", g.get_name_scope())
3478      self.assertEqual("", g.get_name_scope())
3479
3480  def testTwoGraphs(self):
3481
3482    def f():
3483      g1 = ops.Graph()
3484      g2 = ops.Graph()
3485      with g1.as_default():
3486        with g2.as_default():
3487          with ops.name_scope("_"):
3488            pass
3489
3490    self.assertRaisesRegex(ValueError, "'_' is not a valid scope name", f)
3491
3492
3493class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
3494
3495  @test_util.run_v1_only("b/120545219")
3496  def testBadArgumentsToEnableEagerExecution(self):
3497    with self.assertRaisesRegex(TypeError, "config must be a tf.ConfigProto"):
3498      ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
3499    with self.assertRaisesRegex(ValueError, "device_policy must be one of"):
3500      c = config_pb2.ConfigProto()
3501      ops.enable_eager_execution(c, c)
3502    with self.assertRaisesRegex(ValueError, "execution_mode must be one of"):
3503      c = config_pb2.ConfigProto()
3504      ops.enable_eager_execution(c, execution_mode=c)
3505
3506
3507class _TupleTensor(composite_tensor.CompositeTensor):
3508  """`Tensor`-like `tuple`-like for custom `Tensor` conversion masquerading."""
3509
3510  def __init__(self, components):
3511    super(_TupleTensor, self).__init__()
3512    self._components = tuple(ops.convert_to_tensor(c) for c in components)
3513
3514  @property
3515  def _type_spec(self):
3516    return _TupleTensorSpec(type_spec.from_value(c) for c in self._components)
3517
3518  def __getitem__(self, key):
3519    return self._components[key]
3520
3521  def __len__(self):
3522    return len(self._components)
3523
3524  def __iter__(self):
3525    return iter(self._components)
3526
3527
3528class _TupleTensorSpec(type_spec.TypeSpec):
3529
3530  def __init__(self, specs):
3531    self._specs = specs
3532
3533  value_type = property(lambda self: _TupleTensor)
3534  _component_specs = property(lambda self: self._specs)
3535
3536  def _to_components(self, value):
3537    return value._components
3538
3539  def _from_components(self, components):
3540    return _TupleTensor(*components)
3541
3542  def _serialize(self):
3543    return (self._specs,)
3544
3545
3546class _MyTuple(object):
3547  """Pretend user-side class for `ConvertToCompositeTensorTest ."""
3548
3549  def __init__(self, components):
3550    super(_MyTuple, self).__init__()
3551    self._components = tuple(components)
3552
3553  def __getitem__(self, key):
3554    return self._components[key]
3555
3556  def __len__(self):
3557    return len(self._components)
3558
3559  def __iter__(self):
3560    return iter(self._components)
3561
3562
3563ops.register_tensor_conversion_function(
3564    _MyTuple, conversion_func=lambda x, *_, **__: _TupleTensor(x))
3565
3566
3567class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
3568
3569  @test_util.disable_tfrt("TODO(kkb): This makes Kokoro tests fail.")
3570  def testCompositeTensorConversion(self):
3571    """Tests that a user can register a CompositeTensor converter."""
3572    x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
3573    y = ops.convert_to_tensor_or_composite(x)
3574    self.assertFalse(tensor_util.is_tf_type(y))
3575    self.assertIsInstance(y, _TupleTensor)
3576    self.assertLen(y, len(x))
3577    for x_, y_ in zip(x, y):
3578      self.assertIsInstance(y_, ops.Tensor)
3579      self.assertTrue(tensor_util.is_tf_type(y_))
3580      self.assertAllEqual(x_, tensor_util.constant_value(y_))
3581
3582
3583@test_util.disable_tfrt("Packing EagerTensors is not supported yet.")
3584class PackEagerTensorTest(test_util.TensorFlowTestCase):
3585
3586  def setUp(self):
3587    super(PackEagerTensorTest, self).setUp()
3588    context._reset_context()
3589    cpus = config.list_physical_devices("CPU")
3590    # Set 2 virtual CPUs
3591    config.set_logical_device_configuration(cpus[0], [
3592        context.LogicalDeviceConfiguration(),
3593        context.LogicalDeviceConfiguration(),
3594    ])
3595
3596  def testPack(self):
3597    with context.eager_mode():
3598      with ops.device("CPU:0"):
3599        var0 = resource_variable_ops.ResourceVariable(1.0)
3600        c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
3601      with ops.device("CPU:1"):
3602        var1 = resource_variable_ops.ResourceVariable(2.0)
3603        var2 = resource_variable_ops.ResourceVariable([3.0])
3604        c1 = constant_op.constant([9.0])
3605
3606      packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle])
3607      self.assertTrue(packed_var0.is_packed)
3608      self.assertEqual(packed_var0.dtype, var0.handle.dtype)
3609      self.assertEqual(packed_var0.shape, var0.handle.shape)
3610      self.assertEqual(packed_var0._handle_data, var0.handle._handle_data)
3611      self.assertIn("COMPOSITE:0", packed_var0.device)
3612      self.assertIn("COMPOSITE:0", packed_var0.backing_device)
3613      with self.assertRaises(errors.InvalidArgumentError):
3614        packed_var0.numpy()
3615
3616      # Different dtypes
3617      with self.assertRaises(ValueError):
3618        ops.pack_eager_tensors([var0.handle, c1])
3619
3620      # Different shapes
3621      with self.assertRaises(ValueError):
3622        ops.pack_eager_tensors([c0, c1])
3623
3624      # Different handle data
3625      with self.assertRaises(ValueError):
3626        ops.pack_eager_tensors([var0.handle, var2.handle])
3627
3628
3629if __name__ == "__main__":
3630  googletest.main()
3631