• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.framework.ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gc
22import os
23import threading
24import weakref
25
26from tensorflow.core.framework import attr_value_pb2
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.python.client import session
29from tensorflow.python.eager import context
30from tensorflow.python.eager import function as eager_function
31from tensorflow.python.framework import common_shapes
32from tensorflow.python.framework import constant_op
33from tensorflow.python.framework import device as pydev
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import function
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import sparse_tensor
39from tensorflow.python.framework import tensor_shape
40from tensorflow.python.framework import tensor_util
41from tensorflow.python.framework import test_ops
42from tensorflow.python.framework import test_util
43from tensorflow.python.framework import versions
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import control_flow_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import resource_variable_ops
48from tensorflow.python.ops import resources
49from tensorflow.python.ops import variable_scope
50from tensorflow.python.ops import variables
51import tensorflow.python.ops.gradients  # pylint: disable=unused-import
52from tensorflow.python.platform import googletest
53from tensorflow.python.util import compat
54
55ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
56
57
58class ResourceTest(test_util.TensorFlowTestCase):
59
60  @test_util.run_deprecated_v1
61  def testBuildGraph(self):
62    with self.cached_session():
63      pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
64      test_ops.resource_create_op(pt).run()
65
66  @test_util.run_deprecated_v1
67  def testInitialize(self):
68    with self.cached_session():
69      handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
70      resources.register_resource(
71          handle=handle,
72          create_op=test_ops.resource_create_op(handle),
73          is_initialized_op=test_ops.resource_initialized_op(handle))
74      self.assertEquals(
75          len(
76              resources.report_uninitialized_resources(
77                  resources.shared_resources()).eval()), 1)
78      resources.initialize_resources(resources.shared_resources()).run()
79      self.assertEquals(
80          len(
81              resources.report_uninitialized_resources(
82                  resources.shared_resources()).eval()), 0)
83
84
85class TensorAndShapeTest(test_util.TensorFlowTestCase):
86
87  def testShape(self):
88    op = ops.Operation(
89        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
90    t = op.outputs[0]
91    self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
92    t.set_shape([1, 2, 3])
93    self.assertEqual([1, 2, 3], t.get_shape())
94
95  def testIterable(self):
96    op = ops.Operation(
97        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
98    t = op.outputs[0]
99    self.assertTrue(isinstance(t, ops.Tensor))
100    with self.assertRaisesRegexp(TypeError, "iter"):
101      for _ in t:
102        pass
103
104  def testAddShape(self):
105    with self.cached_session():
106      a = array_ops.zeros([2, 3])
107      b = array_ops.ones([1, 3])
108      c = a + b
109      self.assertEqual([2, 3], c.shape)
110
111  @test_util.run_deprecated_v1
112  def testUnknownDim(self):
113    with self.cached_session():
114      a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
115      b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
116      c = a + b
117      self.assertEqual([2, None, 3], c.shape.as_list())
118
119  @test_util.run_deprecated_v1
120  def testUnknownShape(self):
121    with self.cached_session():
122      a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
123      b = array_ops.ones([1, 3])
124      c = a + b
125      self.assertEqual(tensor_shape.unknown_shape(), c.shape)
126
127  @test_util.run_deprecated_v1
128  def testScalarShape(self):
129    with self.cached_session():
130      a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
131      b = array_ops.ones([])
132      c = a + b
133      self.assertEqual(tensor_shape.scalar(), c.shape)
134
135  @test_util.run_deprecated_v1
136  def testShapeFunctionError(self):
137    with self.cached_session():
138      a = array_ops.ones([1, 2, 3])
139      b = array_ops.ones([4, 5, 6])
140      with self.assertRaisesRegexp(
141          ValueError,
142          r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) "
143          r"with input shapes: \[1,2,3\], \[4,5,6\]."):
144        _ = a + b
145
146
147class IndexedSlicesTest(test_util.TensorFlowTestCase):
148
149  @test_util.run_in_graph_and_eager_modes
150  def testToTensor(self):
151    values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
152    indices = constant_op.constant([0, 2])
153    dense_shape = constant_op.constant([3, 2])
154    x = ops.IndexedSlices(values, indices, dense_shape)
155    tensor = ops.convert_to_tensor(x, name="tensor")
156    self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]])
157
158  @test_util.run_deprecated_v1
159  def testNegation(self):
160    with self.cached_session():
161      values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
162      indices = constant_op.constant([0, 2])
163      x = -ops.IndexedSlices(values, indices)
164      self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]])
165      self.assertAllEqual(x.indices.eval(), [0, 2])
166
167  @test_util.run_deprecated_v1
168  def testScalarMul(self):
169    with self.cached_session():
170      values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
171      indices = constant_op.constant([0, 2])
172      x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
173      self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]])
174      self.assertAllEqual(x.indices.eval(), [0, 2])
175
176
177class NodeDefConstructorTest(test_util.TensorFlowTestCase):
178
179  def testNoArgs(self):
180    nodedef = ops._NodeDef("None", "bar")
181    self.assertProtoEquals("op: 'None' name: 'bar'", nodedef)
182
183  def testArgs(self):
184    nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
185    self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
186                           nodedef)
187    nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j"))
188    self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
189
190
191def _apply_op(g, *args, **kwargs):
192  op = g.create_op(*args, **kwargs)
193  if len(op.outputs) == 1:
194    return op.outputs[0]
195  else:
196    return op.outputs
197
198
199class OperationTest(test_util.TensorFlowTestCase):
200
201  @test_util.run_deprecated_v1
202  def testNoInputs(self):
203    op = test_ops.float_output_string_output(name="myop").a.op
204    self.assertEqual(2, len(op.values()))
205    self.assertEqual(0, len(op.inputs))
206    self.assertEqual("myop", op.name)
207
208    float_t, label_str_t = op.values()
209    self.assertEqual(dtypes.float32, float_t.dtype)
210    self.assertEqual(op, float_t.op)
211    self.assertEqual(0, float_t._value_index)
212    self.assertEqual(0, len(float_t.consumers()))
213    self.assertEqual("myop", float_t._as_node_def_input())
214
215    self.assertEqual(dtypes.string, label_str_t.dtype)
216    self.assertEqual(op, label_str_t.op)
217    self.assertEqual(1, label_str_t._value_index)
218    self.assertEqual(0, len(label_str_t.consumers()))
219    self.assertEqual("myop:1", label_str_t._as_node_def_input())
220
221    self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
222                           op.node_def)
223
224  @test_util.run_deprecated_v1
225  def testNoOutputs(self):
226    op1 = test_ops.float_output(name="myop1").op
227    float_t, = op1.values()
228    op2 = test_ops.float_input(float_t, name="myop2")
229    self.assertEqual(0, len(op2.values()))
230    self.assertEqual(1, len(op2.inputs))
231    self.assertIs(float_t, op2.inputs[0])
232
233    self.assertEqual(1, len(float_t.consumers()))
234    self.assertEqual(op2, float_t.consumers()[0])
235
236    self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def)
237    self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
238                           op2.node_def)
239
240  @test_util.run_deprecated_v1
241  def testInputsAndOutputs(self):
242    op1 = test_ops.float_output(name="myop1").op
243    self.assertEqual(1, len(op1.values()))
244    float1_t, = op1.values()
245
246    op2 = test_ops.float_output_string_output(name="myop2").a.op
247    self.assertEqual(2, len(op2.values()))
248    float2_t, label2_str_t = op2.values()
249
250    # Note that we consume label2_str_t twice here.
251    op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op
252    self.assertEqual(2, len(op3.values()))
253
254    self.assertEqual(1, len(float1_t.consumers()))
255    self.assertEqual(op3, float1_t.consumers()[0])
256
257    self.assertEqual(0, len(float2_t.consumers()))
258
259    self.assertEqual(2, len(label2_str_t.consumers()))
260    self.assertEqual(op3, label2_str_t.consumers()[0])
261    self.assertEqual(op3, label2_str_t.consumers()[1])
262
263    self.assertProtoEquals("""
264    op:'Foo2' name:'myop3'
265    input:'myop1' input:'myop2:1' input:'myop2:1'
266    """, op3.node_def)
267
268  def testDeviceFromNodeDef(self):
269    op = ops.Operation(
270        ops._NodeDef("None", "myop", device="/job:goo/device:GPU:0"),
271        ops.Graph(), [], [])
272    self.assertEqual("/job:goo/device:GPU:0", op.device)
273
274  def testDeviceObject(self):
275    op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], [])
276    op._set_device("/job:goo/device:GPU:0")
277    self.assertProtoEquals(
278        "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
279    op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], [])
280    op._set_device(
281        pydev.DeviceSpec(
282            job="muu", device_type="CPU", device_index=0))
283    self.assertProtoEquals(
284        "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
285
286  def testReferenceInput(self):
287    g = ops.Graph()
288    op1 = ops.Operation(
289        ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
290        [dtypes.float32_ref, dtypes.float32])
291    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
292    self.assertEquals([], list(op1.inputs))
293    ref_t, nonref_t = op1.values()
294    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
295    op2 = ops.Operation(
296        ops._NodeDef("RefInputFloatInput", "op2"),
297        g, [ref_t, nonref_t], [],
298        input_types=[dtypes.float32_ref, dtypes.float32])
299    self.assertProtoEquals(
300        "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
301        op2.node_def)
302    self.assertEquals([ref_t, nonref_t], list(op2.inputs))
303    op3 = ops.Operation(
304        ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
305    self.assertProtoEquals(
306        "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
307        op3.node_def)
308
309  def testInvalidNames(self):
310    g = ops.Graph()
311    with self.assertRaises(ValueError):
312      ops.Operation(ops._NodeDef("op", ""), g)
313    with self.assertRaises(ValueError):
314      ops.Operation(ops._NodeDef("op", "_invalid"), g)
315    with self.assertRaises(ValueError):
316      ops.Operation(ops._NodeDef("op", "-invalid"), g)
317    with self.assertRaises(ValueError):
318      ops.Operation(ops._NodeDef("op", "/invalid"), g)
319    with self.assertRaises(ValueError):
320      ops.Operation(ops._NodeDef("op", "invalid:0"), g)
321
322  @test_util.run_deprecated_v1
323  def testNoShapeFunction(self):
324    op = test_ops.a()
325    self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
326
327  @test_util.run_in_graph_and_eager_modes
328  def testConvertToTensorNestedArray(self):
329    values = [[2], [3], [5], [7]]
330    tensor = ops.convert_to_tensor(values)
331    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
332    self.assertAllEqual(values, self.evaluate(tensor))
333
334  def testShapeTuple(self):
335    with self.cached_session():
336      c = constant_op.constant(1)
337      self.assertEqual(c._shape_tuple(), ())  # pylint: disable=protected-access
338
339  def testConvertToTensorEager(self):
340    with context.eager_mode():
341      t = constant_op.constant(1)
342      self.assertTrue(isinstance(t, ops.EagerTensor))
343      converted = ops.convert_to_tensor(t)
344      self.assertTrue(isinstance(converted, ops.EagerTensor))
345      converted = ops.convert_to_tensor(1)
346      self.assertTrue(isinstance(converted, ops.EagerTensor))
347
348  @test_util.run_in_graph_and_eager_modes
349  def testConvertToTensorNestedTuple(self):
350    values = ((2,), (3,), (5,), (7,))
351    tensor = ops.convert_to_tensor(values)
352    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
353    self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values)))
354
355  @test_util.run_in_graph_and_eager_modes
356  def testConvertToTensorNestedTensors(self):
357    values = ((2,), (3,), (5,), (7,))
358    tensor = ops.convert_to_tensor(
359        [constant_op.constant(row) for row in values])
360    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
361    self.assertAllEqual(values, self.evaluate(tensor))
362    tensor = ops.convert_to_tensor(
363        [[constant_op.constant(v) for v in row] for row in values])
364    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
365    self.assertAllEqual(values, self.evaluate(tensor))
366
367  @test_util.run_in_graph_and_eager_modes
368  def testConvertToTensorNestedMix(self):
369    values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
370    tensor = ops.convert_to_tensor(values)
371    self.assertAllEqual((4, 1), tensor.get_shape().as_list())
372    self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor))
373
374  @test_util.run_in_graph_and_eager_modes
375  def testConvertToTensorPreferred(self):
376    values = [2, 3, 5, 7]
377    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
378    self.assertEqual(dtypes.float32, tensor.dtype)
379
380    # Convert empty tensor to anything.
381    values = []
382    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
383    self.assertEqual(dtypes.int64, tensor.dtype)
384
385    # The preferred dtype is a type error and will convert to
386    # float32 instead.
387    values = [1.23]
388    tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
389    self.assertEqual(dtypes.float32, tensor.dtype)
390
391  @test_util.run_in_graph_and_eager_modes
392  def testConvertToInvalidTensorType(self):
393    with self.assertRaises(TypeError):
394      # Forcing an invalid dtype should fail with a type error.
395      values = [1.23]
396      ops.convert_to_tensor(values, dtype=dtypes.int64)
397
398  @test_util.run_in_graph_and_eager_modes
399  def testConvertToTensorFromInvalidTensor(self):
400    tensor = constant_op.constant(42.0, dtype=dtypes.float32)
401    with self.assertRaises(ValueError):
402      ops.convert_to_tensor(tensor, dtype=dtypes.int32)
403
404  @test_util.run_deprecated_v1
405  def testNoConvert(self):
406    # Operation cannot be converted to Tensor.
407    op = control_flow_ops.no_op()
408    with self.assertRaisesRegexp(TypeError,
409                                 r"Can't convert Operation '.*' to Tensor"):
410      ops.convert_to_tensor(op)
411
412  def testStr(self):
413    node_def = ops._NodeDef("None", "op1")
414    op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32])
415    self.assertEqual(str(node_def), str(op))
416
417  def testRepr(self):
418    op = ops.Operation(
419        ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
420    self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
421
422  @test_util.run_deprecated_v1
423  def testGetAttr(self):
424    op = test_ops.default_attrs()
425    self.assertEqual(op.get_attr("string_val"), b"abc")
426    self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
427    self.assertEqual(op.get_attr("int_val"), 123)
428    self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
429    self.assertEqual(op.get_attr("float_val"), 10.0)
430    self.assertEqual(op.get_attr("float_list_val"), [10.0])
431    self.assertEqual(op.get_attr("bool_val"), True)
432    self.assertEqual(op.get_attr("bool_list_val"), [True, False])
433    self.assertEqual(op.get_attr("shape_val"),
434                     tensor_shape.as_shape([2, 1]).as_proto())
435    self.assertEqual(op.get_attr("shape_list_val"),
436                     [tensor_shape.as_shape([]).as_proto(),
437                      tensor_shape.as_shape([1]).as_proto()])
438    self.assertEqual(op.get_attr("tensor_val"),
439                     tensor_util.make_tensor_proto(1, dtypes.int32))
440    self.assertEqual(op.get_attr("tensor_list_val"),
441                     [tensor_util.make_tensor_proto(1, dtypes.int32)])
442
443    type_val = op.get_attr("type_val")
444    # First check that type_val is a DType, because the assertEquals will work
445    # no matter what since DType overrides __eq__
446    self.assertIsInstance(type_val, dtypes.DType)
447    self.assertEqual(type_val, dtypes.int32)
448
449    type_list_val = op.get_attr("type_list_val")
450    self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
451    self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
452
453    @function.Defun(dtypes.float32, func_name="MyFunc")
454    def func(x):
455      return x
456
457    op = test_ops.func_attr(func)
458    self.assertEqual(op.get_attr("f"),
459                     attr_value_pb2.NameAttrList(name="MyFunc"))
460
461    # Try fetching missing attr
462    with self.assertRaisesRegexp(
463        ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."):
464      op.get_attr("FakeAttr")
465
466  # TODO(b/65162920): remove this test when users who are directly mutating the
467  # node_def have been updated to proper usage.
468  @test_util.run_deprecated_v1
469  def testSetAttr(self):
470    op = test_ops.int_attr().op
471    op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
472    # TODO(skyewm): add node_def check
473    self.assertEqual(op.get_attr("foo"), 2)
474
475  # TODO(nolivia): test all error cases
476  def testAddControlInput(self):
477    with ops.Graph().as_default():
478      x = constant_op.constant(1).op
479      y = constant_op.constant(2).op
480      z = constant_op.constant(3).op
481    z._add_control_input(x)  # pylint: disable=protected-access
482    self.assertEqual(z.control_inputs, [x])
483    z._add_control_input(x)  # pylint: disable=protected-access
484    self.assertEqual(z.control_inputs, [x])
485    z._add_control_inputs([x, y, y])  # pylint: disable=protected-access
486    self.assertEqual(z.control_inputs, [x, y])
487    self.assertEqual(x._control_outputs, [z])
488
489  @test_util.run_deprecated_v1
490  def testRemoveAllControlInputs(self):
491    a = constant_op.constant(1)
492    with ops.control_dependencies([a]):
493      b = constant_op.constant(2)
494    c = constant_op.constant(3)
495    d = constant_op.constant(4)
496    e = constant_op.constant(5)
497    with ops.control_dependencies([a, c]):
498      f = d + e
499
500    self.assertEqual(a.op.control_inputs, [])
501    self.assertEqual(b.op.control_inputs, [a.op])
502    self.assertEqual(f.op.control_inputs, [a.op, c.op])
503
504    a.op._remove_all_control_inputs()  # pylint: disable=protected-access
505    self.assertEqual(a.op.control_inputs, [])
506
507    b.op._remove_all_control_inputs()  # pylint: disable=protected-access
508    self.assertEqual(b.op.control_inputs, [])
509
510    f.op._remove_all_control_inputs()  # pylint: disable=protected-access
511    self.assertEqual(f.op.control_inputs, [])
512    self.assertEqual(list(f.op.inputs), [d, e])
513
514  @test_util.run_deprecated_v1
515  def testControlInputCycle(self):
516    graph = ops.Graph()
517    with graph.as_default():
518      z = constant_op.constant(0)
519      x = constant_op.constant(1)
520      y = constant_op.constant(2)
521      y.op._add_control_input(z.op)  # pylint: disable=protected-access
522      y.op._add_control_input(x.op)  # pylint: disable=protected-access
523      x.op._add_control_input(y.op)  # pylint: disable=protected-access
524    with self.session(graph=graph) as sess:
525      with self.assertRaisesRegexp(
526          errors.InvalidArgumentError,
527          "Graph is invalid, contains a cycle with 2 nodes"):
528        self.evaluate(x)
529
530  def testUpdateInput(self):
531    g = ops.Graph()
532    with g.as_default():
533      x = constant_op.constant(1)
534      y = constant_op.constant(2)
535      z = x + y
536
537    z.op._update_input(0, y)  # pylint: disable=protected-access
538    self.assertEquals(list(z.op.inputs), [y, y])
539    self.assertEquals(x.consumers(), [])
540    self.assertEquals(y.consumers(), [z.op, z.op])
541    with session.Session(graph=g) as sess:
542      self.assertEquals(self.evaluate(z), 4)
543
544    z.op._update_input(0, x)  # pylint: disable=protected-access
545    self.assertEquals(list(z.op.inputs), [x, y])
546    self.assertEquals(x.consumers(), [z.op])
547    self.assertEquals(y.consumers(), [z.op])
548    with session.Session(graph=g) as sess:
549      self.assertEquals(self.evaluate(z), 3)
550
551    z.op._update_input(1, y)  # pylint: disable=protected-access
552    self.assertEquals(list(z.op.inputs), [x, y])
553    self.assertEquals(x.consumers(), [z.op])
554    self.assertEquals(y.consumers(), [z.op])
555    with session.Session(graph=g) as sess:
556      self.assertEquals(self.evaluate(z), 3)
557
558  def testUpdateInputGraphError(self):
559    g_0 = ops.Graph()
560    g_1 = ops.Graph()
561    with g_0.as_default():
562      x = constant_op.constant(1)
563    with g_1.as_default():
564      y = constant_op.constant(2)
565      z = y * 2
566      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
567        z.op._update_input(0, x)  # pylint: disable=protected-access
568
569  def testUpdateInputTypeError(self):
570    g = ops.Graph()
571    with g.as_default():
572      w = constant_op.constant(0)
573      x = constant_op.constant("")
574      y = constant_op.constant(1)
575      z = y + w
576      z.op._update_input(0, x)  # pylint: disable=protected-access
577    with session.Session(graph=g) as sess:
578      with self.assertRaisesRegexp(
579          errors.InvalidArgumentError,
580          "Input 0 of node add was passed string from Const_1:0 incompatible "
581          "with expected int32"):
582        self.evaluate(z)
583
584  def testUpdateInputShapeError(self):
585    g = ops.Graph()
586    with g.as_default():
587      w = constant_op.constant(2, shape=[3, 1])
588      x = constant_op.constant(0, shape=[3, 1])
589      y = constant_op.constant(1, shape=[2, 2])
590      z = w + x
591    with self.assertRaisesRegexp(
592        errors.InvalidArgumentError,
593        r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
594      z.op._update_input(0, y)  # pylint: disable=protected-access
595
596  def testUpdateInputOutOfRange(self):
597    g = ops.Graph()
598    with g.as_default():
599      x = constant_op.constant(1)
600    with self.assertRaisesRegexp(
601        errors.OutOfRangeError,
602        r"Cannot update edge. Input index \[1\] is greater than the number of "
603        r"total inputs \[0\]."
604    ):
605      x.op._update_input(1, x)  # pylint: disable=protected-access
606
607  @test_util.enable_control_flow_v2
608  @test_util.run_v1_only("b/120545219")
609  def testAddWhileInput(self):
610    @eager_function.defun
611    def test():
612      output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
613                                           [1])
614      while_op = output.op.inputs[0].op
615      self.assertEqual(while_op.type, "While")
616      orig_num_inputs = len(while_op.inputs)
617
618      # Make sure we can handle the while op having a control input.
619      while_op._add_control_input(constant_op.constant(0).op)
620
621      new_input1 = constant_op.constant(1.0)
622      new_input2 = constant_op.constant(True)
623
624      while_op._set_type_list_attr("T",
625                                   [t.dtype for t in while_op.inputs] +
626                                   [new_input1.dtype, new_input2.dtype])
627
628      while_op._add_while_inputs([new_input1, new_input2])
629      # Can't add an edge beyond what's specified by "T"
630      with self.assertRaises(errors.OutOfRangeError):
631        while_op._add_while_inputs([new_input2])
632      self.assertEqual(len(while_op.inputs), orig_num_inputs + 2)  # pylint: disable=g-deprecated-assert
633
634    test()
635
636  @test_util.run_deprecated_v1
637  def testOpDef(self):
638    x = constant_op.constant(0)
639    y = constant_op.constant(1)
640    z = x + y
641
642    self.assertEqual(x.op.op_def.name, "Const")
643    self.assertEqual(len(x.op.op_def.input_arg), 0)
644    self.assertEqual(len(x.op.op_def.output_arg), 1)
645
646    self.assertEqual(z.op.op_def.name, "Add")
647    self.assertEqual(len(z.op.op_def.input_arg), 2)
648    self.assertEqual(len(z.op.op_def.output_arg), 1)
649
650  def testInputFromDifferentGraphError(self):
651    g_0 = ops.Graph()
652    g_1 = ops.Graph()
653    with g_0.as_default():
654      x = constant_op.constant(1)
655    with g_1.as_default():
656      y = constant_op.constant(2)
657      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
658        y * x  # pylint: disable=pointless-statement
659
660  def testInputsAreImmutable(self):
661    g = ops.Graph()
662    with g.as_default():
663      x = test_ops.int_output()
664      op = test_ops.int_input_int_output(x, name="myop").op
665    with self.assertRaisesRegexp(
666        AttributeError, "'_InputList' object has no attribute 'append'"):
667      op.inputs.append(None)
668
669
670class CreateOpTest(test_util.TensorFlowTestCase):
671
672  def testNodeDefArgs(self):
673    g = ops.Graph()
674    op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
675    with g.device("/device:GPU:0"):
676      op2 = g.create_op(
677          "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None,
678          name="myop2")
679    op3 = g.create_op(
680        "Foo3",
681        [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]],
682        [dtypes.float32, dtypes.int32],
683        None,
684        name="myop3")
685    self.assertDeviceEqual(None, op1.device)
686    self.assertDeviceEqual("/device:GPU:0", op2.device)
687    self.assertDeviceEqual(None, op3.device)
688    self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def)
689    self.assertProtoEquals(
690        "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'",
691        op2.node_def)
692    self.assertProtoEquals(
693        "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'",
694        op3.node_def)
695
696  def testReferenceInput(self):
697    g = ops.Graph()
698    op1 = g.create_op(
699        "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
700        name="op1")
701    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
702    ref_t, nonref_t = op1.values()
703    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
704    op2 = g.create_op(
705        "RefInputFloatInput", [ref_t, nonref_t], [],
706        input_types=[dtypes.float32_ref, dtypes.float32],
707        name="op2")
708    self.assertProtoEquals(
709        "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
710        op2.node_def)
711    op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3")
712    self.assertProtoEquals(
713        "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
714        op3.node_def)
715
716  def testFinalized(self):
717    g = ops.Graph()
718    g.finalize()
719    with self.assertRaises(RuntimeError):
720      g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
721
722    # Test unfinalize.
723    g._unsafe_unfinalize()
724    g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
725
726
727# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation
728# method. Arguably we should only test the public APIs that depend on this
729# method. However, this logic is complex and tricky, and it can be difficult to
730# ascertain if we have adequate coverage (e.g. a graph may run successfully if
731# the control flow context isn't set properly, but a more complicated use case
732# that might not be obvious to test will fail). Thus we instead explicitly test
733# the low-level behavior.
734class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
735
736  @test_util.run_deprecated_v1
737  def testBasic(self):
738    g = ops.Graph()
739    with g.as_default():
740      x = test_ops.int_output()
741      c_op = ops._create_c_op(
742          g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
743      op = g._create_op_from_tf_operation(c_op)
744
745    self.assertEqual(op.name, "myop")
746    self.assertEqual(op.type, "IntInputIntOutput")
747    self.assertEqual(len(op.outputs), 1)
748    self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape())
749    self.assertEqual(list(op.inputs), [x])
750    self.assertEqual(op.control_inputs, [])
751    self.assertEqual(op.graph, g)
752    self.assertEqual(x.consumers(), [op])
753    self.assertIsNotNone(op.traceback)
754    self.assertEqual(g.get_operation_by_name("myop"), op)
755    self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0])
756
757  def testShape(self):
758    g = ops.Graph()
759    with g.as_default():
760      x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
761      c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
762      op = g._create_op_from_tf_operation(c_op)
763
764    self.assertEqual(op.name, "myop")
765    self.assertEqual(op.type, "Identity")
766    self.assertEqual(len(op.outputs), 1)
767    self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3))
768
769  def testUniqueName(self):
770    g = ops.Graph()
771    with g.as_default():
772      c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
773      c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
774      op = g._create_op_from_tf_operation(c_op)
775      op2 = g._create_op_from_tf_operation(c_op2)
776
777      # Create ops with same names as op1 and op2. We expect the new names to be
778      # uniquified.
779      op3 = test_ops.int_output(name="myop").op
780      op4 = test_ops.int_output(name="myop_1").op
781
782    self.assertEqual(op.name, "myop")
783    self.assertEqual(op2.name, "myop_1")
784    self.assertEqual(op3.name, "myop_2")
785    self.assertEqual(op4.name, "myop_1_1")
786
787  @test_util.run_v1_only("b/120545219")
788  def testCond(self):
789    g = ops.Graph()
790    with g.as_default():
791      x = test_ops.int_output()
792
793      def true_fn():
794        ops._create_c_op(ops.get_default_graph(),
795                         ops._NodeDef("IntInput", "cond/myop"), [x], [])
796        new_ops = g._add_new_tf_operations()
797        self.assertEqual(len(new_ops), 1)
798        return x
799
800      control_flow_ops.cond(x < 10, true_fn, lambda: x)
801
802    op = g.get_operation_by_name("cond/myop")
803    self.assertIsNotNone(op)
804    self.assertEqual(op.name, "cond/myop")
805    self.assertEqual(op.type, "IntInput")
806    self.assertEqual(op.outputs, [])
807    op_input = op.inputs[0].op
808    self.assertEqual(op_input.type, "Switch")
809    self.assertEqual(op_input.inputs[0], x)
810    self.assertEqual(op.graph, g)
811    # pylint: disable=protected-access
812    self.assertIsNotNone(op._get_control_flow_context())
813    self.assertEqual(op._get_control_flow_context().name,
814                     "cond/cond_text")
815    # pylint: enable=protected-access
816
817  @test_util.run_v1_only("b/120545219")
818  def testWhileLoop(self):
819    g = ops.Graph()
820    with g.as_default():
821      x = test_ops.int_output()
822
823      def body(i):
824        ops._create_c_op(ops.get_default_graph(),
825                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
826        new_ops = g._add_new_tf_operations()
827        self.assertEqual(len(new_ops), 1)
828        return i
829
830      control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
831
832    op = g.get_operation_by_name("myloop/myop")
833    self.assertIsNotNone(op)
834    self.assertEqual(op.name, "myloop/myop")
835    self.assertEqual(op.type, "IntInput")
836    self.assertEqual(op.outputs, [])
837    op_input = op.inputs[0].op
838    self.assertEqual(op_input.type, "Enter")
839    self.assertEqual(list(op_input.inputs), [x])
840    self.assertEqual(op.graph, g)
841    # pylint: disable=protected-access
842    self.assertIsNotNone(op._get_control_flow_context())
843    self.assertEqual(op._get_control_flow_context().name,
844                     "myloop/while_context")
845    # pylint: enable=protected-access
846
847  @test_util.run_v1_only("b/120545219")
848  def testWhileLoopWithInternalControlDep(self):
849    g = ops.Graph()
850    with g.as_default():
851      x = test_ops.int_output()
852
853      def body(i):
854        c = constant_op.constant(1.0, name="c")
855        ops._create_c_op(ops.get_default_graph(),
856                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
857        with ops.control_dependencies([c]):
858          new_ops = g._add_new_tf_operations()
859          self.assertEqual(len(new_ops), 1)
860        return i
861
862      control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
863
864    op = g.get_operation_by_name("myloop/myop")
865    self.assertIsNotNone(op)
866    c = g.get_operation_by_name("myloop/c")
867    self.assertIsNotNone(c)
868    # Internal control dep is preserved
869    self.assertEqual(op.control_inputs, [c])
870
871  @test_util.run_v1_only("b/120545219")
872  def testWhileLoopWithExternalControlDep(self):
873    g = ops.Graph()
874    with g.as_default():
875      x = test_ops.int_output()
876      c = constant_op.constant(1.0)
877
878      def body(i):
879        ops._create_c_op(ops.get_default_graph(),
880                         ops._NodeDef("IntInput", "myloop/myop"), [x], [])
881        with ops.control_dependencies([c]):
882          new_ops = g._add_new_tf_operations()
883          self.assertEqual(len(new_ops), 1)
884        return i
885
886      control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
887
888    op = g.get_operation_by_name("myloop/myop")
889    self.assertIsNotNone(op)
890    # External control dep is removed and replaced with internal control dep
891    self.assertNotEqual(op.control_inputs[0], c.op)
892    self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
893
894
895class ApplyOpTest(test_util.TensorFlowTestCase):
896
897  def testNodeDefArgs(self):
898    g = ops.Graph()
899    t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
900    with g.device("/device:GPU:0"):
901      t2 = _apply_op(
902          g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2")
903    t3 = _apply_op(
904        g,
905        "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32],
906        name="myop3")
907    self.assertTrue(isinstance(t1, ops.Tensor))
908    self.assertTrue(isinstance(t2, list))
909    self.assertTrue(isinstance(t3, list))
910    self.assertTrue(isinstance(t3[0], ops.Tensor))
911    self.assertEqual("myop1", t1._as_node_def_input())
912    self.assertEqual("myop2", t2[0]._as_node_def_input())
913    self.assertEqual("myop2:1", t2[1]._as_node_def_input())
914    self.assertEqual("myop3", t3[0]._as_node_def_input())
915    # Validate that we got the right ops as well
916    self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def)
917    self.assertProtoEquals(
918        "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'",
919        t2[0].op.node_def)
920    self.assertProtoEquals(
921        "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'",
922        t3[0].op.node_def)
923
924  def testReferenceInput(self):
925    g = ops.Graph()
926    ref_t, nonref_t = _apply_op(
927        g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
928        name="op1")
929    self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'",
930                           ref_t.op.node_def)
931    # NOTE(mrry): Must specify input_types to preserve ref-typed input.
932    out_2 = _apply_op(
933        g,
934        "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32],
935        input_types=[dtypes.float32_ref, dtypes.float32],
936        name="op2")
937    self.assertProtoEquals(
938        "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'",
939        out_2.op.node_def)
940    out_3 = _apply_op(
941        g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32],
942        name="op3")
943    self.assertProtoEquals(
944        "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'",
945        out_3.op.node_def)
946
947
948class NameStackTest(test_util.TensorFlowTestCase):
949
950  def testBasics(self):
951    g = ops.Graph()
952    self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
953    self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
954    self.assertEqual("foo", g.unique_name("foo"))
955    self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False))
956    self.assertEqual("foo_1", g.unique_name("foo"))
957    self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False))
958    self.assertEqual("foo_2", g.unique_name("foo"))
959    self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False))
960    self.assertEqual("foo_1_1", g.unique_name("foo_1"))
961    self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False))
962    self.assertEqual("foo_1_2", g.unique_name("foo_1"))
963    self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False))
964    self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2"))
965    with g.name_scope("bar"):
966      self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False))
967      self.assertEqual("bar/foo", g.unique_name("foo"))
968      self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False))
969      self.assertEqual("bar/foo_1", g.unique_name("foo"))
970      with g.name_scope(None):
971        self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False))
972        self.assertEqual("foo_3", g.unique_name("foo"))
973      with g.name_scope("baz"):
974        self.assertEqual(
975            "bar/baz/foo", g.unique_name(
976                "foo", mark_as_used=False))
977        self.assertEqual("bar/baz/foo", g.unique_name("foo"))
978        self.assertEqual(
979            "bar/baz/foo_1", g.unique_name(
980                "foo", mark_as_used=False))
981        self.assertEqual("bar/baz/foo_1", g.unique_name("foo"))
982      with g.name_scope("baz"):
983        self.assertEqual(
984            "bar/baz_1/foo", g.unique_name(
985                "foo", mark_as_used=False))
986        self.assertEqual("bar/baz_1/foo", g.unique_name("foo"))
987        self.assertEqual(
988            "bar/baz_1/foo_1", g.unique_name(
989                "foo", mark_as_used=False))
990        self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo"))
991    with g.name_scope("quux"):
992      self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False))
993      self.assertEqual("quux/foo", g.unique_name("foo"))
994    with g.name_scope("bar"):
995      with g.name_scope("baz"):
996        self.assertEqual(
997            "bar_1/baz/foo", g.unique_name(
998                "foo", mark_as_used=False))
999        self.assertEqual("bar_1/baz/foo", g.unique_name("foo"))
1000    self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False))
1001    self.assertEqual("foo_4", g.unique_name("foo"))
1002    self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
1003    self.assertEqual("bar_2", g.unique_name("bar"))
1004
1005  @test_util.run_deprecated_v1
1006  def testNameAndVariableScope(self):
1007    with self.cached_session() as sess:
1008      with sess.graph.name_scope("l0"):
1009        with variable_scope.variable_scope("l1"):
1010          with sess.graph.name_scope("l1") as scope:
1011            self.assertEqual("l0/l1/l1/", scope)
1012            self.assertEqual(
1013                "l0/l1/l1/foo",
1014                sess.graph.unique_name(
1015                    "foo", mark_as_used=False))
1016            self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo"))
1017          with sess.graph.name_scope("l2") as scope:
1018            self.assertEqual("l0/l1/l2/", scope)
1019            self.assertEqual(
1020                "l0/l1/l2/foo",
1021                sess.graph.unique_name(
1022                    "foo", mark_as_used=False))
1023            self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo"))
1024
1025  def testOutOfOrderUniqueName(self):
1026    g = ops.Graph()
1027    self.assertEqual("foo_2", g.unique_name("foo_2"))
1028    self.assertEqual("foo", g.unique_name("foo"))
1029    self.assertEqual("foo_1", g.unique_name("foo"))
1030    self.assertEqual("foo_3", g.unique_name("foo"))
1031
1032  def testUniqueNameCaseInsensitivity(self):
1033    g = ops.Graph()
1034    self.assertEqual("foo", g.unique_name("foo"))
1035    self.assertEqual("Foo_1", g.unique_name("Foo"))
1036    with g.name_scope("bar"):
1037      self.assertEqual("bar/foo", g.unique_name("foo"))
1038    with g.name_scope("Bar"):
1039      self.assertEqual("Bar_1/foo", g.unique_name("foo"))
1040
1041  def testInvalidNameRaisesError(self):
1042    g = ops.Graph()
1043    with g.name_scope(""):  # Should not raise
1044      pass
1045    with g.name_scope("foo/"):  # Should not raise
1046      with g.name_scope("_bar"):  # Should not raise
1047        pass
1048    with self.assertRaises(ValueError):
1049      with g.name_scope("foo:0"):
1050        pass
1051    with self.assertRaises(ValueError):
1052      with g.name_scope("_bar"):
1053        pass
1054
1055
1056class NameTest(test_util.TensorFlowTestCase):
1057
1058  def testGenerateName(self):
1059    g = ops.Graph()
1060    op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
1061    self.assertEqual("TwoFloatOutputs", op0.name)
1062    self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name)
1063    self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name)
1064
1065    op1 = g.create_op("FloatOutput", [], [dtypes.float32])
1066    self.assertEqual("FloatOutput", op1.name)
1067    self.assertEqual("FloatOutput:0", op1.outputs[0].name)
1068
1069    op2 = g.create_op("FloatOutput", [], [dtypes.float32])
1070    self.assertEqual("FloatOutput_1", op2.name)
1071    self.assertEqual("FloatOutput_1:0", op2.outputs[0].name)
1072
1073    op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op")
1074    self.assertEqual("my_op", op3.name)
1075    self.assertEqual("my_op:0", op3.outputs[0].name)
1076
1077  def testNameScope(self):
1078    g = ops.Graph()
1079
1080    with g.name_scope("foo") as foo:
1081      self.assertEqual("foo/", foo)
1082      with g.name_scope("foo2") as foo2:
1083        self.assertEqual("foo/foo2/", foo2)
1084      with g.name_scope(None) as empty1:
1085        self.assertEqual("", empty1)
1086        with g.name_scope("foo3") as foo3:
1087          self.assertEqual("foo3/", foo3)
1088      with g.name_scope("") as empty2:
1089        self.assertEqual("", empty2)
1090
1091    self.assertEqual("FloatOutput",
1092                     g.create_op("FloatOutput", [], [dtypes.float32]).name)
1093    with g.name_scope("bar") as scope:
1094      self.assertEqual("bar/FloatOutput",
1095                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1096      self.assertEqual("bar/FloatOutput_1",
1097                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1098      # If you use the value from "with .. as", that values is used as-is.
1099      self.assertEqual(
1100          "bar", g.create_op(
1101              "FloatOutput", [], [dtypes.float32], name=scope).name)
1102    with g.name_scope("baz") as scope:
1103      with g.name_scope("quux"):
1104        self.assertEqual("baz/quux/FloatOutput",
1105                         g.create_op("FloatOutput", [], [dtypes.float32]).name)
1106      # If you use the value from the enclosing "with .. as", nothing is pushed.
1107      with g.name_scope(scope):
1108        self.assertEqual("baz/FloatOutput",
1109                         g.create_op("FloatOutput", [], [dtypes.float32]).name)
1110        self.assertEqual(
1111            "baz", g.create_op(
1112                "FloatOutput", [], [dtypes.float32], name=scope).name)
1113        self.assertEqual(
1114            "trailing",
1115            g.create_op(
1116                "FloatOutput", [], [dtypes.float32], name="trailing/").name)
1117    with g.name_scope("bar"):
1118      self.assertEqual("bar_1/FloatOutput",
1119                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1120    with g.name_scope("bar/"):
1121      self.assertEqual("bar/FloatOutput_2",
1122                       g.create_op("FloatOutput", [], [dtypes.float32]).name)
1123
1124
1125class DeviceTest(test_util.TensorFlowTestCase):
1126
1127  def testNoDevice(self):
1128    g = ops.Graph()
1129    op = g.create_op("FloatOutput", [], [dtypes.float32])
1130    self.assertDeviceEqual(None, op.device)
1131    gd = g.as_graph_def()
1132    self.assertProtoEqualsVersion("""
1133      node { name: "FloatOutput" op: "FloatOutput" }
1134    """, gd)
1135
1136  def testEagerBackingDevice(self):
1137    with context.eager_mode():
1138      with ops.device("/device:CPU:0"):
1139        t = constant_op.constant(1.0)
1140        self.assertRegexpMatches(t.device, "/device:CPU:0")
1141        self.assertRegexpMatches(t.backing_device, "/device:CPU:0")
1142
1143  def testDevicePartialString(self):
1144    g = ops.Graph()
1145    with g.device("/job:worker/replica:2"):
1146      g.create_op("FloatOutput", [], [dtypes.float32])
1147    gd = g.as_graph_def()
1148    self.assertProtoEqualsVersion("""
1149      node { name: "FloatOutput" op: "FloatOutput"
1150             device: "/job:worker/replica:2" }
1151    """, gd)
1152
1153  def testDeviceFull(self):
1154    g = ops.Graph()
1155    with g.device(
1156        pydev.DeviceSpec(
1157            job="worker", replica=2, task=0, device_type="CPU",
1158            device_index=3)):
1159      g.create_op("FloatOutput", [], [dtypes.float32])
1160    gd = g.as_graph_def()
1161    self.assertProtoEqualsVersion("""
1162      node { name: "FloatOutput" op: "FloatOutput"
1163             device: "/job:worker/replica:2/task:0/device:CPU:3" }
1164    """, gd)
1165
1166  def testNesting(self):
1167    g = ops.Graph()
1168    with g.device("/job:worker/replica:2"):
1169      g.create_op("FloatOutput", [], [dtypes.float32])
1170      with g.device("/job:worker/replica:3/task:0"):
1171        g.create_op("FloatOutput", [], [dtypes.float32])
1172      g.create_op("FloatOutput", [], [dtypes.float32])
1173    gd = g.as_graph_def()
1174    self.assertProtoEqualsVersion("""
1175      node { name: "FloatOutput" op: "FloatOutput"
1176             device: "/job:worker/replica:2" }
1177      node { name: "FloatOutput_1" op: "FloatOutput"
1178             device: "/job:worker/replica:3/task:0" }
1179      node { name: "FloatOutput_2" op: "FloatOutput"
1180             device: "/job:worker/replica:2" }
1181    """, gd)
1182
1183  def testNestingString(self):
1184    g = ops.Graph()
1185    with g.device("/job:worker/replica:2"):
1186      g.create_op("FloatOutput", [], [dtypes.float32])
1187      with g.device("/job:worker/replica:3/task:0"):
1188        g.create_op("FloatOutput", [], [dtypes.float32])
1189      g.create_op("FloatOutput", [], [dtypes.float32])
1190    gd = g.as_graph_def()
1191    self.assertProtoEqualsVersion("""
1192      node { name: "FloatOutput" op: "FloatOutput"
1193             device: "/job:worker/replica:2" }
1194      node { name: "FloatOutput_1" op: "FloatOutput"
1195             device: "/job:worker/replica:3/task:0" }
1196      node { name: "FloatOutput_2" op: "FloatOutput"
1197             device: "/job:worker/replica:2" }
1198    """, gd)
1199
1200  def testNestingOverrideGpuCpu(self):
1201    g = ops.Graph()
1202    with g.device("/job:worker/replica:2/device:CPU:1"):
1203      g.create_op("FloatOutput", [], [dtypes.float32])
1204      with g.device("/job:worker/replica:2/device:GPU:2"):
1205        g.create_op("FloatOutput", [], [dtypes.float32])
1206      g.create_op("FloatOutput", [], [dtypes.float32])
1207    gd = g.as_graph_def()
1208    self.assertProtoEqualsVersion("""
1209      node { name: "FloatOutput" op: "FloatOutput"
1210             device: "/job:worker/replica:2/device:CPU:1"  }
1211      node { name: "FloatOutput_1" op: "FloatOutput"
1212             device: "/job:worker/replica:2/device:GPU:2" }
1213      node { name: "FloatOutput_2" op: "FloatOutput"
1214             device: "/job:worker/replica:2/device:CPU:1" }
1215    """, gd)
1216
1217  def testNestingWithMergeDeviceFunction(self):
1218    g = ops.Graph()
1219
1220    with g.device(pydev.merge_device("/device:GPU:0")):
1221      g.create_op("FloatOutput", [], [dtypes.float32])
1222      with g.device(pydev.merge_device("/job:worker")):
1223        g.create_op("FloatOutput", [], [dtypes.float32])
1224        with g.device(pydev.merge_device("/device:CPU:0")):
1225          g.create_op("FloatOutput", [], [dtypes.float32])
1226          with g.device(pydev.merge_device("/job:ps")):
1227            g.create_op("FloatOutput", [], [dtypes.float32])
1228            with g.device(pydev.merge_device(None)):
1229              g.create_op("FloatOutput", [], [dtypes.float32])
1230
1231    gd = g.as_graph_def()
1232    self.assertProtoEqualsVersion("""
1233      node { name: "FloatOutput" op: "FloatOutput"
1234             device: "/device:GPU:0" }
1235      node { name: "FloatOutput_1" op: "FloatOutput"
1236             device: "/job:worker/device:GPU:0" }
1237      node { name: "FloatOutput_2" op: "FloatOutput"
1238             device: "/job:worker/device:CPU:0" }
1239      node { name: "FloatOutput_3" op: "FloatOutput"
1240             device: "/job:ps/device:CPU:0" }
1241      node { name: "FloatOutput_4" op: "FloatOutput"
1242             device: "/job:ps/device:CPU:0" }
1243    """, gd)
1244
1245  def testNestingWithDeviceStrings(self):
1246    g = ops.Graph()
1247
1248    with g.device("/device:GPU:0"):
1249      g.create_op("FloatOutput", [], [dtypes.float32])
1250      with g.device("/job:worker"):
1251        g.create_op("FloatOutput", [], [dtypes.float32])
1252        with g.device("/device:CPU:0"):
1253          g.create_op("FloatOutput", [], [dtypes.float32])
1254          with g.device("/job:ps"):
1255            g.create_op("FloatOutput", [], [dtypes.float32])
1256            with g.device(""):
1257              g.create_op("FloatOutput", [], [dtypes.float32])
1258
1259    gd = g.as_graph_def()
1260    self.assertProtoEqualsVersion("""
1261      node { name: "FloatOutput" op: "FloatOutput"
1262             device: "/device:GPU:0" }
1263      node { name: "FloatOutput_1" op: "FloatOutput"
1264             device: "/job:worker/device:GPU:0" }
1265      node { name: "FloatOutput_2" op: "FloatOutput"
1266             device: "/job:worker/device:CPU:0" }
1267      node { name: "FloatOutput_3" op: "FloatOutput"
1268             device: "/job:ps/device:CPU:0" }
1269      node { name: "FloatOutput_4" op: "FloatOutput"
1270             device: "/job:ps/device:CPU:0" }
1271    """, gd)
1272
1273  def testNestingWithDeviceStringWildcard(self):
1274    g = ops.Graph()
1275
1276    with g.device("/device:GPU:7"):
1277      g.create_op("FloatOutput", [], [dtypes.float32])
1278      with g.device("/device:GPU:*"):
1279        g.create_op("FloatOutput", [], [dtypes.float32])
1280
1281    with g.device("/device:CPU:*"):
1282      g.create_op("FloatOutput", [], [dtypes.float32])
1283      with g.device("/device:CPU:5"):
1284        g.create_op("FloatOutput", [], [dtypes.float32])
1285
1286    gd = g.as_graph_def()
1287    self.assertProtoEqualsVersion("""
1288      node { name: "FloatOutput" op: "FloatOutput"
1289             device: "/device:GPU:7" }
1290      node { name: "FloatOutput_1" op: "FloatOutput"
1291             device: "/device:GPU:7" }
1292      node { name: "FloatOutput_2" op: "FloatOutput"
1293             device: "/device:CPU:*" }
1294      node { name: "FloatOutput_3" op: "FloatOutput"
1295             device: "/device:CPU:5" }
1296    """, gd)
1297
1298  def testNoneClearsDefault(self):
1299    g = ops.Graph()
1300    with g.device("/job:worker/replica:2/device:CPU:1"):
1301      g.create_op("FloatOutput", [], [dtypes.float32])
1302      with g.device(None):
1303        g.create_op("FloatOutput", [], [dtypes.float32])
1304      g.create_op("FloatOutput", [], [dtypes.float32])
1305    gd = g.as_graph_def()
1306    self.assertProtoEqualsVersion("""
1307      node { name: "FloatOutput" op: "FloatOutput"
1308             device: "/job:worker/replica:2/device:CPU:1" }
1309      node { name: "FloatOutput_1" op: "FloatOutput" }
1310      node { name: "FloatOutput_2" op: "FloatOutput"
1311             device: "/job:worker/replica:2/device:CPU:1" }
1312    """, gd)
1313
1314  def testNoneIgnoresOuterDeviceFunction(self):
1315    g = ops.Graph()
1316    with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
1317      g.create_op("FloatOutput", [], [dtypes.float32])
1318      with g.device(None):
1319        g.create_op("FloatOutput", [], [dtypes.float32])
1320      g.create_op("FloatOutput", [], [dtypes.float32])
1321    gd = g.as_graph_def()
1322    self.assertProtoEqualsVersion("""
1323      node { name: "FloatOutput" op: "FloatOutput"
1324             device: "/job:worker/replica:2/device:CPU:1" }
1325      node { name: "FloatOutput_1" op: "FloatOutput" }
1326      node { name: "FloatOutput_2" op: "FloatOutput"
1327             device: "/job:worker/replica:2/device:CPU:1" }
1328    """, gd)
1329
1330  def _overwritingDeviceFunction(self, unused_op):
1331    # This device function unconditionally overwrites the device of ops.
1332    #
1333    # NOTE(mrry): Writing device functions like this is not
1334    # recommended. Instead, in most cases you should use
1335    # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
1336    # argument to `tf.device()` and the device component will be merged in.
1337    return "/job:overwrite"
1338
1339  def testOverwritingBehavior(self):
1340    g = ops.Graph()
1341    with g.device(self._overwritingDeviceFunction):
1342      g.create_op("FloatOutput", [], [dtypes.float32])
1343      with g.device("/job:ps"):  # Will be overwritten.
1344        g.create_op("FloatOutput", [], [dtypes.float32])
1345      with g.device(pydev.merge_device("/job:ps")):  # Will be overwritten.
1346        g.create_op("FloatOutput", [], [dtypes.float32])
1347      with g.device(None):  # Disables overwriting device function
1348        with g.device("/job:ps"):
1349          g.create_op("FloatOutput", [], [dtypes.float32])
1350      with g.device(None):  # Disables overwriting device function
1351        with g.device(pydev.merge_device("/job:ps")):
1352          g.create_op("FloatOutput", [], [dtypes.float32])
1353    gd = g.as_graph_def()
1354    self.assertProtoEqualsVersion("""
1355      node { name: "FloatOutput" op: "FloatOutput"
1356             device: "/job:overwrite" }
1357      node { name: "FloatOutput_1" op: "FloatOutput"
1358             device: "/job:overwrite" }
1359      node { name: "FloatOutput_2" op: "FloatOutput"
1360             device: "/job:overwrite" }
1361      node { name: "FloatOutput_3" op: "FloatOutput"
1362             device: "/job:ps" }
1363      node { name: "FloatOutput_4" op: "FloatOutput"
1364             device: "/job:ps" }
1365    """, gd)
1366
1367
1368class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):
1369
1370  class TestThread(threading.Thread):
1371
1372    def __init__(self, graph, replica_id):
1373      super(MultithreadedGraphStateTest.TestThread, self).__init__()
1374      self._graph = graph
1375      self._replica_id = replica_id
1376      # This thread sets this event when it mutated the graph.  The caller can
1377      # wait for that.
1378      self.has_mutated_graph = threading.Event()
1379      # This thread waits for when it should continue.  The caller can set this
1380      # event.
1381      self.should_continue = threading.Event()
1382
1383    def run(self):
1384      # Mutate a graph's stack, then set `has_mutated_graph`, then wait for
1385      # `should_continue`, then add an op to the graph affected by the graph's
1386      # stack.
1387      raise NotImplementedError("must be implemented in descendants")
1388
1389  def testDeviceFunctionStack(self):
1390
1391    class DeviceSettingThread(self.TestThread):
1392
1393      def run(self):
1394        with g.device("/job:worker/replica:{}".format(self._replica_id)):
1395          self.has_mutated_graph.set()
1396          self.should_continue.wait()
1397          self.should_continue.clear()
1398          g.create_op(
1399              "FloatOutput", [], [dtypes.float32],
1400              name="FloatOutput_{}".format(self._replica_id))
1401
1402    g = ops.Graph()
1403    # If `switch_to_thread` isn't called, then device placement of the ops
1404    # below is not deterministic.
1405    g.switch_to_thread_local()
1406    threads = [DeviceSettingThread(g, i) for i in range(3)]
1407    for t in threads:
1408      t.start()
1409      t.has_mutated_graph.wait()
1410      t.has_mutated_graph.clear()
1411    for t in threads:
1412      t.should_continue.set()
1413      t.join()
1414
1415    gd = g.as_graph_def()
1416    self.assertProtoEqualsVersion("""
1417      node { name: "FloatOutput_0" op: "FloatOutput"
1418             device: "/job:worker/replica:0" }
1419      node { name: "FloatOutput_1" op: "FloatOutput"
1420             device: "/job:worker/replica:1" }
1421      node { name: "FloatOutput_2" op: "FloatOutput"
1422             device: "/job:worker/replica:2" }
1423    """, gd)
1424
1425  def testColocateWith(self):
1426
1427    class ColocatingThread(self.TestThread):
1428
1429      def __init__(self, graph, replica_id, op_to_colocate_with):
1430        super(ColocatingThread, self).__init__(graph, replica_id)
1431        self._op_to_colocate_with = op_to_colocate_with
1432
1433      def run(self):
1434        with g.colocate_with(self._op_to_colocate_with):
1435          self.has_mutated_graph.set()
1436          self.should_continue.wait()
1437          self.should_continue.clear()
1438          g.create_op(
1439              "FloatOutput", [], [dtypes.float32],
1440              name="FloatOutput_{}".format(self._replica_id))
1441
1442    g = ops.Graph()
1443    ops_to_colocate_with = []
1444    for i in range(3):
1445      with g.device("/job:worker/replica:{}".format(i)):
1446        ops_to_colocate_with.append(
1447            g.create_op(
1448                "FloatOutput", [], [dtypes.float32],
1449                name="ColocateWithMe_{}".format(i)))
1450
1451    # If `switch_to_thread` isn't called, then `device` and `attr` values for
1452    # the ops below are not deterministic.
1453    g.switch_to_thread_local()
1454    threads = [
1455        ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3)
1456    ]
1457    for t in threads:
1458      t.start()
1459      t.has_mutated_graph.wait()
1460      t.has_mutated_graph.clear()
1461    for t in threads:
1462      t.should_continue.set()
1463      t.join()
1464
1465    gd = g.as_graph_def()
1466    self.assertProtoEqualsVersion("""
1467      node { name: "ColocateWithMe_0" op: "FloatOutput"
1468             device: "/job:worker/replica:0" }
1469      node { name: "ColocateWithMe_1" op: "FloatOutput"
1470             device: "/job:worker/replica:1" }
1471      node { name: "ColocateWithMe_2" op: "FloatOutput"
1472             device: "/job:worker/replica:2" }
1473      node { name: "FloatOutput_0" op: "FloatOutput"
1474             device: "/job:worker/replica:0"
1475             attr { key: "_class"
1476               value { list {
1477                 s: "loc:@ColocateWithMe_0"}}}}
1478      node { name: "FloatOutput_1" op: "FloatOutput"
1479             device: "/job:worker/replica:1"
1480             attr { key: "_class"
1481               value { list {
1482                 s: "loc:@ColocateWithMe_1"}}}}
1483      node { name: "FloatOutput_2" op: "FloatOutput"
1484             device: "/job:worker/replica:2"
1485             attr { key: "_class"
1486               value { list {
1487                 s: "loc:@ColocateWithMe_2"}}}}
1488    """, gd)
1489
1490  def testControlDependencies(self):
1491
1492    class DependingThread(self.TestThread):
1493
1494      def __init__(self, graph, replica_id, dependency_op):
1495        super(DependingThread, self).__init__(graph, replica_id)
1496        self._dependency_op = dependency_op
1497
1498      def run(self):
1499        with g.control_dependencies([self._dependency_op]):
1500          self.has_mutated_graph.set()
1501          self.should_continue.wait()
1502          self.should_continue.clear()
1503          g.create_op(
1504              "FloatOutput", [], [dtypes.float32],
1505              name="FloatOutput_{}".format(self._replica_id))
1506
1507    g = ops.Graph()
1508    dependency_ops = []
1509    for i in range(3):
1510      dependency_ops.append(
1511          g.create_op(
1512              "FloatOutput", [], [dtypes.float32],
1513              name="ColocateWithMe_{}".format(i)))
1514
1515    # If `switch_to_thread` isn't called, then `input` values for the ops below
1516    # are not deterministic.
1517    g.switch_to_thread_local()
1518    threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)]
1519    for t in threads:
1520      t.start()
1521      t.has_mutated_graph.wait()
1522      t.has_mutated_graph.clear()
1523    for t in threads:
1524      t.should_continue.set()
1525      t.join()
1526
1527    gd = g.as_graph_def()
1528    self.assertProtoEqualsVersion("""
1529      node { name: "ColocateWithMe_0" op: "FloatOutput" }
1530      node { name: "ColocateWithMe_1" op: "FloatOutput" }
1531      node { name: "ColocateWithMe_2" op: "FloatOutput" }
1532      node { name: "FloatOutput_0" op: "FloatOutput"
1533             input: "^ColocateWithMe_0" }
1534      node { name: "FloatOutput_1" op: "FloatOutput"
1535             input: "^ColocateWithMe_1" }
1536      node { name: "FloatOutput_2" op: "FloatOutput"
1537             input: "^ColocateWithMe_2" }
1538    """, gd)
1539
1540  def testNameStack(self):
1541
1542    class NameSettingThread(self.TestThread):
1543
1544      def run(self):
1545        with g.name_scope("foo"):
1546          op1 = g.create_op("FloatOutput", [], [dtypes.float32])
1547          self.has_mutated_graph.set()
1548          self.should_continue.wait()
1549          self.should_continue.clear()
1550          op2 = g.create_op("FloatOutput", [], [dtypes.float32])
1551          self.result = (op1, op2)
1552
1553    g = ops.Graph()
1554    threads = [NameSettingThread(g, i) for i in range(3)]
1555    for t in threads:
1556      t.start()
1557      t.has_mutated_graph.wait()
1558      t.has_mutated_graph.clear()
1559
1560    for t in threads:
1561      t.should_continue.set()
1562      t.join()
1563
1564    suffixes = ["", "_1", "_2"]
1565    for t, s in zip(threads, suffixes):
1566      self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name)
1567      self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name)
1568
1569
1570class ObjectWithName(object):
1571
1572  def __init__(self, name):
1573    self._name = name
1574
1575  @property
1576  def name(self):
1577    return self._name
1578
1579
1580class CollectionTest(test_util.TensorFlowTestCase):
1581
1582  def test_get_collections(self):
1583    g = ops.Graph()
1584    self.assertSequenceEqual(g.collections, [])
1585    g.add_to_collection("key", 12)
1586    g.add_to_collection("key", 15)
1587    self.assertSequenceEqual(g.collections, ["key"])
1588    g.add_to_collection("other", "foo")
1589    self.assertSequenceEqual(sorted(g.collections), ["key", "other"])
1590    self.assertSequenceEqual(
1591        sorted(g.get_all_collection_keys()), ["key", "other"])
1592
1593  def test_add_to_collection(self):
1594    g = ops.Graph()
1595    g.add_to_collection("key", 12)
1596    g.add_to_collection("other", "foo")
1597    g.add_to_collection("key", 34)
1598
1599    # Note that only blank1 is returned.
1600    g.add_to_collection("blah", 27)
1601    blank1 = ObjectWithName("prefix/foo")
1602    g.add_to_collection("blah", blank1)
1603    blank2 = ObjectWithName("junk/foo")
1604    g.add_to_collection("blah", blank2)
1605
1606    self.assertEqual([12, 34], g.get_collection("key"))
1607    self.assertEqual([], g.get_collection("nothing"))
1608    self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
1609    self.assertEqual([blank1], g.get_collection("blah", "prefix"))
1610    self.assertEqual([blank1], g.get_collection("blah", ".*x"))
1611
1612    # Make sure that get_collection() returns a first-level
1613    # copy of the collection, while get_collection_ref() returns
1614    # the original list.
1615    other_collection_snapshot = g.get_collection("other")
1616    other_collection_ref = g.get_collection_ref("other")
1617    self.assertEqual(["foo"], other_collection_snapshot)
1618    self.assertEqual(["foo"], other_collection_ref)
1619    g.add_to_collection("other", "bar")
1620    self.assertEqual(["foo"], other_collection_snapshot)
1621    self.assertEqual(["foo", "bar"], other_collection_ref)
1622    self.assertEqual(["foo", "bar"], g.get_collection("other"))
1623    self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
1624
1625    # Verify that getting an empty collection ref returns a modifiable list.
1626    empty_coll_ref = g.get_collection_ref("empty")
1627    self.assertEqual([], empty_coll_ref)
1628    empty_coll = g.get_collection("empty")
1629    self.assertEqual([], empty_coll)
1630    self.assertFalse(empty_coll is empty_coll_ref)
1631    empty_coll_ref2 = g.get_collection_ref("empty")
1632    self.assertTrue(empty_coll_ref2 is empty_coll_ref)
1633    # Add to the collection.
1634    empty_coll_ref.append("something")
1635    self.assertEqual(["something"], empty_coll_ref)
1636    self.assertEqual(["something"], empty_coll_ref2)
1637    self.assertEqual([], empty_coll)
1638    self.assertEqual(["something"], g.get_collection("empty"))
1639    empty_coll_ref3 = g.get_collection_ref("empty")
1640    self.assertTrue(empty_coll_ref3 is empty_coll_ref)
1641
1642  def test_add_to_collections_uniquify(self):
1643    g = ops.Graph()
1644    g.add_to_collections([1, 2, 1], "key")
1645    # Make sure "key" is not added twice
1646    self.assertEqual(["key"], g.get_collection(1))
1647
1648  def test_add_to_collections_from_list(self):
1649    g = ops.Graph()
1650    g.add_to_collections(["abc", "123"], "key")
1651    self.assertEqual(["key"], g.get_collection("abc"))
1652    self.assertEqual(["key"], g.get_collection("123"))
1653
1654  def test_add_to_collections_from_tuple(self):
1655    g = ops.Graph()
1656    g.add_to_collections(("abc", "123"), "key")
1657    self.assertEqual(["key"], g.get_collection("abc"))
1658    self.assertEqual(["key"], g.get_collection("123"))
1659
1660  def test_add_to_collections_from_generator(self):
1661    g = ops.Graph()
1662
1663    def generator():
1664      yield "abc"
1665      yield "123"
1666
1667    g.add_to_collections(generator(), "key")
1668    self.assertEqual(["key"], g.get_collection("abc"))
1669    self.assertEqual(["key"], g.get_collection("123"))
1670
1671  def test_add_to_collections_from_set(self):
1672    g = ops.Graph()
1673    g.add_to_collections(set(["abc", "123"]), "key")
1674    self.assertEqual(["key"], g.get_collection("abc"))
1675    self.assertEqual(["key"], g.get_collection("123"))
1676
1677  def test_add_to_collections_from_string(self):
1678    g = ops.Graph()
1679    g.add_to_collections("abc", "key")
1680    self.assertEqual(["key"], g.get_collection("abc"))
1681
1682  def test_default_graph(self):
1683    with ops.Graph().as_default():
1684      ops.add_to_collection("key", 90)
1685      ops.add_to_collection("key", 100)
1686      # Collections are ordered.
1687      self.assertEqual([90, 100], ops.get_collection("key"))
1688
1689  def test_defun(self):
1690    with context.eager_mode():
1691
1692      @eager_function.defun
1693      def defun():
1694        ops.add_to_collection("int", 1)
1695        ops.add_to_collection("tensor", constant_op.constant(2))
1696
1697        @eager_function.defun
1698        def inner_defun():
1699          self.assertEqual(ops.get_collection("int"), [1])
1700          three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
1701          ops.add_to_collection("int", 2)
1702          self.assertEqual(ops.get_collection("int"), [1, 2])
1703          ops.add_to_collection("foo", "bar")
1704          self.assertEqual(ops.get_collection("foo"), ["bar"])
1705          return three
1706
1707        self.assertEqual(ops.get_collection("int"), [1])
1708        three = inner_defun()
1709        self.assertEqual(ops.get_collection("int"), [1])
1710        self.assertEqual(ops.get_collection("foo"), [])
1711        return three
1712
1713      three = defun()
1714      self.assertEqual(three.numpy(), 3)
1715
1716
1717ops.NotDifferentiable("FloatOutput")
1718
1719
1720@ops.RegisterGradient("CopyOp")
1721def _CopyGrad(op, x_grad):  # pylint: disable=invalid-name
1722  _ = op
1723  return x_grad
1724
1725
1726@ops.RegisterGradient("copy_override")
1727def _CopyOverrideGrad(op, x_grad):  # pylint: disable=invalid-name
1728  _ = op
1729  return x_grad
1730
1731
1732class RegistrationTest(test_util.TensorFlowTestCase):
1733
1734  @test_util.run_deprecated_v1
1735  def testRegisterGradients(self):
1736    x = test_ops.float_output()
1737    y = test_ops.copy_op(x)
1738    fn = ops.get_gradient_function(y.op)
1739    self.assertEqual(_CopyGrad, fn)
1740
1741  def testOverrideGradients(self):
1742    g = ops.Graph()
1743    with g.as_default():
1744      x = test_ops.float_output()
1745      with g.gradient_override_map({"CopyOp": "copy_override"}):
1746        y = test_ops.copy_op(x)
1747      fn = ops.get_gradient_function(y.op)
1748      self.assertEqual(_CopyOverrideGrad, fn)
1749
1750  def testNonExistentOverride(self):
1751    g = ops.Graph()
1752    with g.as_default():
1753      x = test_ops.float_output()
1754      with g.gradient_override_map({"CopyOp": "unknown_override"}):
1755        y = test_ops.copy_op(x)
1756      with self.assertRaisesRegexp(LookupError, "unknown_override"):
1757        ops.get_gradient_function(y.op)
1758
1759
1760class ComparisonTest(test_util.TensorFlowTestCase):
1761
1762  def testMembershipAllowed(self):
1763    g = ops.Graph()
1764    t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
1765    t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2")
1766    self.assertTrue(isinstance(t1, ops.Tensor))
1767    self.assertTrue(isinstance(t2, ops.Tensor))
1768    self.assertTrue(t1 in [t1])
1769    self.assertTrue(t1 not in [t2])
1770
1771
1772class ControlDependenciesTest(test_util.TensorFlowTestCase):
1773
1774  @test_util.run_deprecated_v1
1775  def testBasic(self):
1776    g = ops.Graph()
1777    with g.as_default():
1778      # Creating unregistered ops with _apply_op() doesn't work with the C API
1779      # TODO(skyewm): address this more consistently. Possible solutions are
1780      # to use registered ops in all tests, create a way to register ops in
1781      # Python tests, or conditionally disable the op registration check in
1782      # the C API.
1783      a = constant_op.constant(1.0)
1784      b = constant_op.constant(1.0)
1785      with g.control_dependencies([a]):
1786        c = constant_op.constant(1.0)
1787        d = array_ops.identity(b)
1788        e = array_ops.identity(c)
1789
1790    self.assertEqual(c.op.control_inputs, [a.op])
1791    self.assertEqual(d.op.control_inputs, [a.op])
1792    # e should be dominated by c.
1793    self.assertEqual(e.op.control_inputs, [])
1794
1795  @test_util.run_in_graph_and_eager_modes
1796  def testEager(self):
1797    def future():
1798      future.calls += 1
1799      return constant_op.constant(2.0)
1800    future.calls = 0
1801
1802    if context.executing_eagerly():
1803      a = constant_op.constant(1.0)
1804      b = future
1805      with ops.control_dependencies([a, b]):
1806        c = constant_op.constant(3.0)
1807      self.assertEqual(future.calls, 1)
1808    else:
1809      g = ops.Graph()
1810      with g.as_default():
1811        a = constant_op.constant(1.0)
1812        b = future()
1813        with g.control_dependencies([a, b]):
1814          c = constant_op.constant(3.0)
1815      self.assertEqual(c.op.control_inputs, [a.op, b.op])
1816      self.assertEqual(future.calls, 1)
1817
1818  def testBasicWithConversion(self):
1819    g = ops.Graph()
1820    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1821
1822    class ConvertibleObj(object):
1823
1824      def _as_graph_element(self):
1825        return a
1826
1827    with g.control_dependencies([ConvertibleObj()]):
1828      c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1829
1830    self.assertEqual(c.op.control_inputs, [a.op])
1831
1832  def testNested(self):
1833    g = ops.Graph()
1834    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1835    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1836    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1837    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1838
1839    with g.control_dependencies([a_1, a_2, a_3, a_4]):
1840      b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1841
1842    with g.control_dependencies([a_1]):
1843      with g.control_dependencies([a_2]):
1844        with g.control_dependencies([a_3]):
1845          with g.control_dependencies([a_4]):
1846            b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1847
1848    self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
1849                          b_1.op.control_inputs)
1850    self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
1851
1852  def testClear(self):
1853    g = ops.Graph()
1854    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1855    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1856    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1857    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1858
1859    with g.control_dependencies([a_1]):
1860      with g.control_dependencies([a_2]):
1861        with g.control_dependencies(None):
1862          with g.control_dependencies([a_3]):
1863            with g.control_dependencies([a_4]):
1864              # deps [a_3, a_4]
1865              b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1866            # deps = [a_3]
1867            b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1868          # deps back to None
1869          b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1870        # deps back to [a_1, a_2]
1871        b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1872      # deps back to [a_1]
1873      b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1874      with g.control_dependencies(None):
1875        # deps are None again
1876        b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1877
1878    self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
1879    self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
1880    self.assertItemsEqual([], b_none.op.control_inputs)
1881    self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
1882    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
1883    self.assertItemsEqual([], b_none2.op.control_inputs)
1884
1885  def testComplex(self):
1886    g = ops.Graph()
1887
1888    # Usage pattern:
1889    # * Nodes a_i are constants defined at the outermost scope, and are used
1890    #   as control inputs for the ith nested scope.
1891    # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
1892    # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
1893    # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
1894    # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
1895
1896    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1897    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1898    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1899    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1900
1901    with g.control_dependencies([a_1]):
1902      b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
1903                      [dtypes.float32])
1904      c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
1905                      [dtypes.float32])
1906      d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
1907                      [dtypes.float32])
1908      e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1909      with g.control_dependencies([a_2]):
1910        b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
1911                        [dtypes.float32])
1912        c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
1913                        [dtypes.float32])
1914        d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
1915                        [dtypes.float32])
1916        e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
1917                        [dtypes.float32])
1918        with g.control_dependencies([a_3]):
1919          b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
1920                          [dtypes.float32])
1921          c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
1922                          [dtypes.float32])
1923          d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
1924                          [dtypes.float32])
1925          e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
1926                          [dtypes.float32])
1927          with g.control_dependencies([a_4]):
1928            b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
1929                            [dtypes.float32])
1930            c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
1931                            [dtypes.float32])
1932            d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
1933                            [dtypes.float32])
1934            e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
1935                            [dtypes.float32])
1936
1937    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
1938    self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
1939    self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
1940    self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
1941
1942    self.assertItemsEqual([], c_1.op.control_inputs)
1943    self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
1944    self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
1945    self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
1946
1947    self.assertItemsEqual([], d_1.op.control_inputs)
1948    self.assertItemsEqual([], d_2.op.control_inputs)
1949    self.assertItemsEqual([], d_3.op.control_inputs)
1950    self.assertItemsEqual([], d_4.op.control_inputs)
1951
1952    self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
1953    self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
1954    self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
1955    self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
1956
1957  def testRepeatedDependency(self):
1958    g = ops.Graph()
1959    a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
1960    a_0, a_1 = a.outputs
1961    with g.control_dependencies([a_0]):
1962      b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1963      with g.control_dependencies([a_1]):
1964        c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1965
1966    self.assertEqual(b.op.control_inputs, [a])
1967    self.assertEqual(c.op.control_inputs, [a])
1968
1969  def testNoControlDependencyWithDataDependency(self):
1970    g = ops.Graph()
1971    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
1972    with g.control_dependencies([a]):
1973      b = _apply_op(g, "Identity", [a], [dtypes.float32])
1974
1975    self.assertEqual(b.op.control_inputs, [])
1976
1977
1978class OpScopeTest(test_util.TensorFlowTestCase):
1979
1980  @test_util.run_in_graph_and_eager_modes
1981  def testNames(self):
1982    with ops.name_scope("foo") as foo:
1983      self.assertEqual("foo/", foo)
1984      with ops.name_scope("foo2") as foo2:
1985        self.assertEqual("foo/foo2/", foo2)
1986      with ops.name_scope(None) as empty1:
1987        self.assertEqual("", empty1)
1988        with ops.name_scope("foo3") as foo3:
1989          self.assertEqual("foo3/", foo3)
1990      with ops.name_scope("") as empty2:
1991        self.assertEqual("", empty2)
1992    with ops.name_scope("foo/") as outer_foo:
1993      self.assertEqual("foo/", outer_foo)
1994      with ops.name_scope("") as empty3:
1995        self.assertEqual("", empty3)
1996      with ops.name_scope("foo4") as foo4:
1997        self.assertEqual("foo/foo4/", foo4)
1998      with ops.name_scope("foo5//") as foo5:
1999        self.assertEqual("foo5//", foo5)
2000        with ops.name_scope("foo6") as foo6:
2001          self.assertEqual("foo5//foo6/", foo6)
2002      with ops.name_scope("/") as foo7:
2003        self.assertEqual("/", foo7)
2004      with ops.name_scope("//") as foo8:
2005        self.assertEqual("//", foo8)
2006      with ops.name_scope("a//b/c") as foo9:
2007        self.assertEqual("foo/a//b/c/", foo9)
2008    with ops.name_scope("a//b/c") as foo10:
2009      self.assertEqual("a//b/c/", foo10)
2010
2011  @test_util.run_in_graph_and_eager_modes
2012  def testEagerDefaultScopeName(self):
2013    with ops.name_scope(None, "default") as scope:
2014      self.assertEqual(scope, "default/")
2015      with ops.name_scope(None, "default2") as scope2:
2016        self.assertEqual(scope2, "default/default2/")
2017
2018  @test_util.run_deprecated_v1
2019  def testNoScopeName(self):
2020    g0 = ops.Graph()
2021    values = [
2022        g0.create_op("A", [], [dtypes.float32]),
2023        g0.create_op("B", [], [dtypes.float32])
2024    ]
2025    with self.assertRaises(ValueError):
2026      with ops.name_scope(None, values=values):
2027        pass
2028    with self.assertRaises(ValueError):
2029      with ops.name_scope(None, None, values):
2030        pass
2031
2032  @test_util.run_deprecated_v1
2033  def testEmptyScopeName(self):
2034    g0 = ops.Graph()
2035    a = g0.create_op("A", [], [dtypes.float32])
2036    b = g0.create_op("B", [], [dtypes.float32])
2037    with ops.name_scope("", values=[a, b]) as scope:
2038      self.assertEqual("", scope)
2039      self.assertEqual(g0, ops.get_default_graph())
2040    with ops.name_scope("", "my_default_scope", [a, b]) as scope:
2041      self.assertEqual("", scope)
2042      self.assertEqual(g0, ops.get_default_graph())
2043
2044  @test_util.run_deprecated_v1
2045  def testDefaultScopeName(self):
2046    g0 = ops.Graph()
2047    a = g0.create_op("A", [], [dtypes.float32])
2048    b = g0.create_op("B", [], [dtypes.float32])
2049    scope_name = "my_scope"
2050    default_scope_name = "my_default_scope"
2051    with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope:
2052      self.assertEqual("%s/" % scope_name, scope)
2053      self.assertEqual(g0, ops.get_default_graph())
2054    with ops.name_scope(None, default_scope_name, [a, b]) as scope:
2055      self.assertEqual("%s/" % default_scope_name, scope)
2056      self.assertEqual(g0, ops.get_default_graph())
2057    with self.assertRaises(TypeError):
2058      with ops.name_scope(scope_name, [a, b]):
2059        pass
2060
2061  def _testGraphElements(self, graph_elements):
2062    scope_name = "my_scope"
2063    with ops.name_scope(scope_name, values=graph_elements) as scope:
2064      self.assertEqual("%s/" % scope_name, scope)
2065      self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
2066    g1 = ops.Graph()
2067    a = g1.create_op("A", [], [dtypes.float32])
2068    with self.assertRaises(ValueError):
2069      with ops.name_scope(scope_name, values=graph_elements + [a]):
2070        pass
2071
2072  @test_util.run_deprecated_v1
2073  def testTensor(self):
2074    g0 = ops.Graph()
2075    a = g0.create_op("A", [], [dtypes.float32])
2076    b = g0.create_op("B", [], [dtypes.float32])
2077    self._testGraphElements([a, b])
2078
2079  @test_util.run_deprecated_v1
2080  def testSparseTensor(self):
2081    g0 = ops.Graph()
2082    a = g0.create_op("A", [], [dtypes.float32])
2083    b = g0.create_op("B", [], [dtypes.float32])
2084    sparse = sparse_tensor.SparseTensor(
2085        _apply_op(g0, "Int64Output", [], [dtypes.int64]),
2086        _apply_op(g0, "FloatOutput", [], [dtypes.float32]),
2087        _apply_op(g0, "Int64Output", [], [dtypes.int64]))
2088    self._testGraphElements([a, sparse, b])
2089
2090  @test_util.run_deprecated_v1
2091  def testVariable(self):
2092    g0 = ops.Graph()
2093    with g0.as_default():
2094      variable = variables.Variable([1.0])
2095    a = g0.create_op("A", [], [dtypes.float32])
2096    b = g0.create_op("B", [], [dtypes.float32])
2097    self._testGraphElements([a, variable, b])
2098
2099
2100class InitScopeTest(test_util.TensorFlowTestCase):
2101
2102  def testClearsControlDependencies(self):
2103    g = ops.Graph()
2104    a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2105    a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2106    a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2107    a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2108
2109    with g.as_default():
2110      with g.control_dependencies([a_1]):
2111        with g.control_dependencies([a_2]):
2112          with ops.init_scope():
2113            with g.control_dependencies([a_3]):
2114              with g.control_dependencies([a_4]):
2115                # deps [a_3, a_4]
2116                b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2117              # deps = [a_3]
2118              b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2119            # deps back to None
2120            b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2121          # deps back to [a_1, a_2]
2122          b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2123        # deps back to [a_1]
2124        b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2125        with ops.init_scope():
2126          # deps are None again
2127          b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2128
2129    self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
2130    self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
2131    self.assertItemsEqual([], b_none.op.control_inputs)
2132    self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
2133    self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
2134    self.assertItemsEqual([], b_none2.op.control_inputs)
2135
2136  def testLiftsOpsFromFunctions(self):
2137    g0 = ops.Graph()
2138    g1 = ops.Graph()
2139    g1._building_function = True  # pylint: disable=protected-access
2140    g2 = ops.Graph()
2141    g2._building_function = True  # pylint: disable=protected-access
2142
2143    with g0.as_default():
2144      with g1.as_default():
2145        with g2.as_default():
2146          with ops.init_scope():
2147            _ = constant_op.constant(1.0)
2148
2149    self.assertEqual(len(g2.get_operations()), 0)
2150    self.assertEqual(len(g1.get_operations()), 0)
2151    self.assertEqual(len(g0.get_operations()), 1)
2152
2153  def testPreservesDevices(self):
2154    g0 = ops.Graph()
2155    with g0.as_default(), ops.device("CPU:0"):
2156      g1 = ops.Graph()
2157      g1._building_function = True  # pylint: disable=protected-access
2158      with g1.as_default():
2159        with ops.device("GPU:0"):
2160          with ops.init_scope():
2161            # init_scope should preserve device set under `g1`.
2162            on_gpu = constant_op.constant(1.0)
2163            self.assertEqual(on_gpu.device, "/device:GPU:0")
2164          still_on_gpu = constant_op.constant(1.0)
2165          self.assertEqual(still_on_gpu.device, "/device:GPU:0")
2166        blank = constant_op.constant(1.0)
2167        self.assertEqual(blank.device, "")
2168        with ops.init_scope():
2169          now_on_cpu = constant_op.constant(1.0)
2170          self.assertEqual(now_on_cpu.device, "/device:CPU:0")
2171      on_cpu = constant_op.constant(1.0)
2172      self.assertEqual(on_cpu.device, "/device:CPU:0")
2173
2174  def testComposes(self):
2175    g0 = ops.Graph()
2176    g1 = ops.Graph()
2177    g1._building_function = True  # pylint: disable=protected-access
2178    g2 = ops.Graph()
2179    g2._building_function = True  # pylint: disable=protected-access
2180    g3 = ops.Graph()
2181    g3._building_function = False  # pylint: disable=protected-access
2182
2183    with g0.as_default():
2184      with g1.as_default():
2185        with ops.init_scope():
2186          # This op should be lifted into g0.
2187          _ = constant_op.constant(1.0)
2188          self.assertIs(g0, ops.get_default_graph())
2189          self.assertEqual(len(g2.get_operations()), 0)
2190          self.assertEqual(len(g1.get_operations()), 0)
2191          self.assertEqual(len(g0.get_operations()), 1)
2192        with g2.as_default():
2193          with ops.init_scope():
2194            # This op should be lifted into g0.
2195            _ = constant_op.constant(1.0)
2196            self.assertIs(g0, ops.get_default_graph())
2197            with g3.as_default():
2198              with ops.init_scope():
2199                # This op should be lifted into g3, because g3 is not building a
2200                # function.
2201                _ = constant_op.constant(1.0)
2202                self.assertIs(g3, ops.get_default_graph())
2203
2204    self.assertEqual(len(g3.get_operations()), 1)
2205    self.assertEqual(len(g2.get_operations()), 0)
2206    self.assertEqual(len(g1.get_operations()), 0)
2207    self.assertEqual(len(g0.get_operations()), 2)
2208
2209  def testEscapesToEagerContext(self):
2210    g = ops.Graph()
2211    g._building_function = True  # pylint: disable=protected-access
2212    with context.eager_mode():
2213      with context.graph_mode():
2214        with g.as_default():
2215          with ops.init_scope():
2216            # Because g is building a function, init_scope should
2217            # escape out to the eager context.
2218            self.assertTrue(context.executing_eagerly())
2219          # g should be reinstated as the default graph, and the
2220          # graph context should be re-entered.
2221          self.assertIs(g, ops.get_default_graph())
2222          self.assertFalse(context.executing_eagerly())
2223
2224  def testStaysInEagerWhenOnlyEagerContextActive(self):
2225    with context.eager_mode():
2226      with ops.init_scope():
2227        self.assertTrue(context.eager_mode())
2228      self.assertTrue(context.eager_mode())
2229
2230  def testEscapesDefunWhenInEagerMode(self):
2231
2232    def function_with_variables():
2233      with ops.init_scope():
2234        self.v = resource_variable_ops.ResourceVariable(3)
2235      return self.v.assign_add(1)
2236
2237    with context.eager_mode():
2238      # Each invocation of function_with_variables recreates a variable.
2239      self.assertEqual(4, int(function_with_variables()))
2240      self.assertEqual(4, int(function_with_variables()))
2241
2242      compiled = eager_function.defun(function_with_variables)
2243      # The init_scope in function_with_variables lifts the variable out
2244      # of the graph function constructed by defun; hence,
2245      # compiled now appears to be stateful.
2246      self.assertEqual(4, int(compiled()))
2247      self.assertEqual(5, int(compiled()))
2248
2249  def testEscapesDefunWhenInGraphMode(self):
2250    def function_with_variables(name):
2251      with ops.init_scope():
2252        _ = variable_scope.get_variable(name, shape=(1,))
2253
2254    g = ops.Graph()
2255    with g.as_default():
2256      with self.cached_session():
2257        # First ensure that graphs that are not building functions are
2258        # not escaped.
2259        function_with_variables("foo")
2260        with self.assertRaisesRegexp(ValueError,
2261                                     r"Variable foo already exists.*"):
2262          # This will fail because reuse is not set to True.
2263          function_with_variables("foo")
2264
2265        compiled = eager_function.defun(function_with_variables)
2266        compiled("bar")
2267        self.assertEqual(
2268            len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
2269
2270        # The second call to `compiled` should not create variables: the
2271        # init_scope has lifted the variable creation code out of the defun.
2272        compiled("bar")
2273        self.assertEqual(
2274            len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
2275
2276  def testEscapesNestedDefun(self):
2277
2278    def inner_function():
2279      with ops.init_scope():
2280        self.v = resource_variable_ops.ResourceVariable(1)
2281      return self.v.assign_add(2)
2282
2283    def outer_function(inner=None):
2284      with ops.init_scope():
2285        self.v0 = resource_variable_ops.ResourceVariable(0)
2286      return self.v0.assign_add(1) + inner()
2287
2288    with context.eager_mode():
2289      # Each invocation of outer_function recreates variables.
2290      self.assertEqual(4, int(outer_function(inner=inner_function)))
2291      self.assertEqual(4, int(outer_function(inner=inner_function)))
2292
2293      compiled_inner = eager_function.defun(inner_function)
2294      compiled_outer = eager_function.defun(outer_function)
2295      # The init_scope lifts variables out of the graph functions
2296      # constructed by defun; hence, compiled_outer should now appear to be
2297      # stateful.
2298      self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
2299      self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
2300
2301  @test_util.run_v1_only("b/120545219")
2302  def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
2303    with context.graph_mode():
2304      ops.reset_default_graph()
2305      # This doesn't push anything onto the graph stack, but it does
2306      # set the stack's global graph.
2307      global_graph = ops.get_default_graph()
2308      fn_graph = ops.Graph()
2309
2310      # pylint: disable=protected-access
2311      fn_graph._building_function = True
2312      self.assertEqual(len(ops._default_graph_stack.stack), 0)
2313      with fn_graph.as_default():
2314        self.assertEqual(len(ops._default_graph_stack.stack), 1)
2315        with ops.init_scope():
2316          self.assertGreater(len(ops._default_graph_stack.stack), 1)
2317          dummy = constant_op.constant(1.0)
2318        self.assertEqual(len(ops._default_graph_stack.stack), 1)
2319      # Note that the global graph is _not_ on the graph stack.
2320      self.assertEqual(len(ops._default_graph_stack.stack), 0)
2321      # Ensure that `dummy` was added to the global graph.
2322      self.assertEqual(global_graph, dummy.graph)
2323      # pylint: enable=protected-access
2324
2325  def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self):
2326    with context.graph_mode():
2327      # pylint: disable=protected-access
2328      self.assertEqual(len(ops._default_graph_stack.stack), 0)
2329      with ops.init_scope():
2330        self.assertGreater(len(ops._default_graph_stack.stack), 0)
2331      self.assertEqual(len(ops._default_graph_stack.stack), 0)
2332      # pylint: enable=protected-access
2333
2334  def testPreservesNameScopeInGraphConstruction(self):
2335    with ops.Graph().as_default():
2336      function_graph = ops.Graph()
2337      with function_graph.as_default():
2338        with ops.name_scope("inner"), ops.init_scope():
2339          self.assertEqual(ops.get_name_scope(), "inner")
2340      self.assertEqual(ops.get_name_scope(), "")
2341
2342  def testEnteringGraphFromEagerIsSticky(self):
2343    with context.eager_mode():
2344      g = ops.Graph()
2345      with g.as_default():
2346        with ops.init_scope():
2347          self.assertFalse(context.executing_eagerly())
2348          self.assertEqual(g, ops.get_default_graph())
2349
2350  def testMixGraphEager(self):
2351    with context.eager_mode():
2352      c = constant_op.constant(1.0)
2353      with ops.Graph().as_default():
2354        with self.assertRaisesRegexp(
2355            RuntimeError, "Attempting to capture an EagerTensor"):
2356          math_ops.add(c, c)
2357        c2 = constant_op.constant(2.0)
2358      with self.assertRaisesRegexp(
2359          TypeError, "Graph tensors"):
2360        math_ops.add(c2, c2)
2361
2362  def testPreservesNameScopeInEagerExecution(self):
2363    with context.eager_mode():
2364      def foo():
2365        with ops.name_scope("inner"), ops.init_scope():
2366          if context.executing_eagerly():
2367            # A trailing slash is always appended when eager execution is
2368            # enabled.
2369            self.assertEqual(context.context().scope_name, "inner/")
2370          else:
2371            self.assertEqual(ops.get_name_scope(), "inner")
2372
2373      foo()
2374      self.assertEqual(ops.get_name_scope(), "")
2375      foo_compiled = eager_function.defun(foo)
2376      foo_compiled()
2377      self.assertEqual(ops.get_name_scope(), "")
2378
2379  def testExecutingEagerlyOutsideFunctions(self):
2380
2381    @eager_function.defun
2382    def f():
2383      return ops.executing_eagerly_outside_functions()
2384
2385    with context.eager_mode():
2386      self.assertTrue(ops.executing_eagerly_outside_functions())
2387      self.assertTrue(f())
2388      g = ops.Graph()
2389      with g.as_default():
2390        self.assertFalse(ops.executing_eagerly_outside_functions())
2391
2392
2393class GraphTest(test_util.TensorFlowTestCase):
2394
2395  def setUp(self):
2396    ops.reset_default_graph()
2397
2398  def _AssertDefault(self, expected):
2399    self.assertIs(expected, ops.get_default_graph())
2400
2401  def testResetDefaultGraphNesting(self):
2402    g0 = ops.Graph()
2403    with self.assertRaises(AssertionError):
2404      with g0.as_default():
2405        ops.reset_default_graph()
2406
2407  def testGraphContextManagerCancelsEager(self):
2408    with context.eager_mode():
2409      with ops.Graph().as_default():
2410        self.assertFalse(context.executing_eagerly())
2411
2412  def testGraphContextManager(self):
2413    g0 = ops.Graph()
2414    with g0.as_default() as g1:
2415      self.assertIs(g0, g1)
2416
2417  def testDefaultGraph(self):
2418    orig = ops.get_default_graph()
2419    self.assertFalse(ops.has_default_graph())
2420    self._AssertDefault(orig)
2421    g0 = ops.Graph()
2422    self.assertFalse(ops.has_default_graph())
2423    self._AssertDefault(orig)
2424    context_manager_0 = g0.as_default()
2425    self.assertFalse(ops.has_default_graph())
2426    self._AssertDefault(orig)
2427    with context_manager_0 as g0:
2428      self._AssertDefault(g0)
2429      with ops.Graph().as_default() as g1:
2430        self.assertTrue(ops.has_default_graph())
2431        self._AssertDefault(g1)
2432      self._AssertDefault(g0)
2433    self._AssertDefault(orig)
2434    self.assertFalse(ops.has_default_graph())
2435
2436  def testPreventFeeding(self):
2437    g = ops.Graph()
2438    a = constant_op.constant(2.0)
2439    self.assertTrue(g.is_feedable(a))
2440    g.prevent_feeding(a)
2441    self.assertFalse(g.is_feedable(a))
2442
2443  @test_util.run_deprecated_v1
2444  def testPreventFetching(self):
2445    g = ops.Graph()
2446    a = constant_op.constant(2.0)
2447    self.assertTrue(g.is_fetchable(a))
2448    g.prevent_fetching(a.op)
2449    self.assertFalse(g.is_fetchable(a))
2450
2451  def testAsGraphElementConversions(self):
2452
2453    class ConvertibleObj(object):
2454
2455      def _as_graph_element(self):
2456        return "FloatOutput:0"
2457
2458    class NonConvertibleObj(object):
2459
2460      pass
2461
2462    g = ops.Graph()
2463    a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
2464    self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
2465    with self.assertRaises(TypeError):
2466      g.as_graph_element(NonConvertibleObj())
2467
2468  # Regression test against creating custom __del__ functions in classes
2469  # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
2470  # cycles that require calling a __del__ method, because the __del__ method can
2471  # theoretically increase the object's refcount to "save" it from gc, and any
2472  # already-deleted objects in the cycle would have be to restored.)
2473  def testGarbageCollected(self):
2474    # Create a graph we can delete and a weak reference to monitor if it's gc'd
2475    g = ops.Graph()
2476    g_ref = weakref.ref(g)
2477    # Create some ops
2478    with g.as_default():
2479      a = constant_op.constant(2.0)
2480      b = constant_op.constant(3.0)
2481      c = math_ops.add(a, b)
2482    # Create a session we can delete
2483    with session.Session(graph=g) as sess:
2484      self.evaluate(c)
2485    # Delete all references and trigger gc
2486    del g
2487    del a
2488    del b
2489    del c
2490    del sess
2491    gc.collect()
2492    self.assertIsNone(g_ref())
2493
2494  def testRunnableAfterInvalidShape(self):
2495    with ops.Graph().as_default():
2496      with self.assertRaises(ValueError):
2497        math_ops.add([1, 2], [1, 2, 3])
2498      a = constant_op.constant(1)
2499      with session.Session() as sess:
2500        self.evaluate(a)
2501
2502  def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
2503    g = ops.Graph()
2504    with g.as_default():
2505      with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
2506        with self.assertRaises(ValueError):
2507          test_ops.kernel_label_required(1)
2508      a = constant_op.constant(1)
2509      with session.Session() as sess:
2510        self.evaluate(a)
2511
2512
2513class AttrScopeTest(test_util.TensorFlowTestCase):
2514
2515  def _get_test_attrs(self):
2516    x = control_flow_ops.no_op()
2517    try:
2518      a = compat.as_text(x.get_attr("_A"))
2519    except ValueError:
2520      a = None
2521    try:
2522      b = compat.as_text(x.get_attr("_B"))
2523    except ValueError:
2524      b = None
2525    return (a, b)
2526
2527  @test_util.run_deprecated_v1
2528  def testNoLabel(self):
2529    with self.cached_session():
2530      self.assertAllEqual((None, None), self._get_test_attrs())
2531
2532  @test_util.run_deprecated_v1
2533  def testLabelMap(self):
2534    with self.cached_session() as sess:
2535      a1 = self._get_test_attrs()
2536      with sess.graph._attr_scope({
2537          "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
2538      }):
2539        a2 = self._get_test_attrs()
2540        with sess.graph._attr_scope({
2541            "_A": None,
2542            "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar"))
2543        }):
2544          a3 = self._get_test_attrs()
2545          with sess.graph._attr_scope({
2546              "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz"))
2547          }):
2548            a4 = self._get_test_attrs()
2549          a5 = self._get_test_attrs()
2550        a6 = self._get_test_attrs()
2551      a7 = self._get_test_attrs()
2552
2553      self.assertAllEqual((None, None), a1)
2554      self.assertAllEqual(("foo", None), a2)
2555      self.assertAllEqual((None, "bar"), a3)
2556      self.assertAllEqual(("baz", "bar"), a4)
2557      self.assertAllEqual((None, "bar"), a5)
2558      self.assertAllEqual(("foo", None), a6)
2559      self.assertAllEqual((None, None), a7)
2560
2561
2562ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
2563
2564
2565class KernelLabelTest(test_util.TensorFlowTestCase):
2566
2567  @test_util.run_deprecated_v1
2568  def testNoLabel(self):
2569    with self.cached_session():
2570      self.assertAllEqual(b"My label is: default",
2571                          test_ops.kernel_label().eval())
2572
2573  @test_util.run_deprecated_v1
2574  def testLabelMap(self):
2575    with self.cached_session() as sess:
2576      default_1 = test_ops.kernel_label()
2577      # pylint: disable=protected-access
2578      with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
2579        overload_1_1 = test_ops.kernel_label()
2580        with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
2581          overload_2 = test_ops.kernel_label()
2582          with sess.graph._kernel_label_map({"KernelLabel": ""}):
2583            default_2 = test_ops.kernel_label()
2584        overload_1_2 = test_ops.kernel_label()
2585      # pylint: enable=protected-access
2586      default_3 = test_ops.kernel_label()
2587
2588      self.assertAllEqual(b"My label is: default", self.evaluate(default_1))
2589      self.assertAllEqual(b"My label is: default", self.evaluate(default_2))
2590      self.assertAllEqual(b"My label is: default", self.evaluate(default_3))
2591      self.assertAllEqual(b"My label is: overload_1",
2592                          self.evaluate(overload_1_1))
2593      self.assertAllEqual(b"My label is: overload_1",
2594                          self.evaluate(overload_1_2))
2595      self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2))
2596
2597
2598class AsGraphDefTest(test_util.TensorFlowTestCase):
2599
2600  def testGraphDefVersion(self):
2601    """Test that the graphdef version is plumbed through to kernels."""
2602    with ops.Graph().as_default() as g:
2603      version = g.graph_def_versions.producer
2604      with self.session(graph=g):
2605        v = test_ops.graph_def_version().eval()
2606        self.assertEqual(version, v)
2607
2608  def testAddShapes(self):
2609    with ops.Graph().as_default() as g:
2610      t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [],
2611                                     [dtypes.float32] * 5)
2612      t1.set_shape(None)
2613      t2.set_shape([])
2614      t3.set_shape([None])
2615      t4.set_shape([43, 37])
2616      t5.set_shape([43, None])
2617
2618      b = constant_op.constant(1.0)  # pylint: disable=unused-variable
2619
2620      gd = g.as_graph_def(add_shapes=True)
2621      self.assertProtoEqualsVersion("""
2622      node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
2623        attr {
2624          key: "_output_shapes"
2625          value {
2626            list {
2627              shape { unknown_rank: true }
2628              shape { }
2629              shape { dim { size: -1 } }
2630              shape { dim { size: 43 } dim { size: 37 } }
2631              shape { dim { size: 43 } dim { size: -1 } }
2632            }
2633          }
2634        }
2635      }
2636    node { name: "Const" op: "Const"
2637      attr {
2638        key: "_output_shapes"
2639        value {
2640          list {
2641            shape { }
2642          }
2643        }
2644      }
2645      attr {
2646        key: "dtype"
2647        value { type: DT_FLOAT }
2648      }
2649      attr {
2650        key: "value"
2651        value {
2652          tensor {
2653            dtype: DT_FLOAT
2654            tensor_shape { }
2655         float_val: 1.0  } } } }
2656      """, gd)
2657
2658
2659@ops.RegisterStatistics("a", "flops")
2660def _calc_a_forward_flops(unused_graph, unused_node):
2661  return ops.OpStats("flops", 20)
2662
2663
2664class StatisticsTest(test_util.TensorFlowTestCase):
2665
2666  def testRegisteredNode(self):
2667    graph = ops.Graph()
2668    node = ops._NodeDef("a", "an_a")
2669    flops = ops.get_stats_for_node_def(graph, node, "flops")
2670    self.assertEqual(20, flops.value)
2671    missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
2672    self.assertEqual(None, missing_stat.value)
2673
2674  def testUnregisteredNode(self):
2675    graph = ops.Graph()
2676    node = ops._NodeDef("b", "a_b")
2677    weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
2678    self.assertEqual(None, weight_params.value)
2679
2680  def testAccumulateStatistics(self):
2681    flops_total = ops.OpStats("flops")
2682    self.assertEqual(None, flops_total.value)
2683    second_flops = ops.OpStats("flops", 3)
2684    flops_total += second_flops
2685    self.assertEqual(3, flops_total.value)
2686
2687
2688class DeviceStackTest(test_util.TensorFlowTestCase):
2689
2690  @test_util.run_deprecated_v1
2691  def testBasicDeviceAssignmentMetadata(self):
2692
2693    def device_func(unused_op):
2694      return "/cpu:*"
2695
2696    const_zero = constant_op.constant([0.0], name="zero")
2697    with ops.device("/cpu"):
2698      const_one = constant_op.constant([1.0], name="one")
2699      with ops.device("/cpu:0"):
2700        const_two = constant_op.constant([2.0], name="two")
2701    with ops.device(device_func):
2702      const_three = constant_op.constant(3.0, name="three")
2703
2704    self.assertEqual(0, len(const_zero.op._device_assignments))
2705
2706    one_list = const_one.op._device_assignments
2707    self.assertEqual(1, len(one_list))
2708    self.assertEqual("/cpu", one_list[0].obj)
2709    self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
2710
2711    two_list = const_two.op._device_assignments
2712    self.assertEqual(2, len(two_list))
2713    devices = [t.obj for t in two_list]
2714    self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
2715
2716    three_list = const_three.op._device_assignments
2717    self.assertEqual(1, len(three_list))
2718    func_description = three_list[0].obj
2719    expected_regex = r"device_func<.*ops_test.py, [0-9]+"
2720    self.assertRegexpMatches(func_description, expected_regex)
2721
2722  @test_util.run_deprecated_v1
2723  def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
2724
2725    with ops.device("/cpu"):
2726      const_one = constant_op.constant([1.0], name="one")
2727    with ops.get_default_graph().device("/cpu"):
2728      const_two = constant_op.constant([2.0], name="two")
2729
2730    one_metadata = const_one.op._device_assignments[0]
2731    two_metadata = const_two.op._device_assignments[0]
2732
2733    # Verify both types of device assignment return the right stack info.
2734    self.assertRegexpMatches("ops_test.py",
2735                             os.path.basename(one_metadata.filename))
2736    self.assertEqual(one_metadata.filename, two_metadata.filename)
2737    self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
2738
2739
2740class ColocationGroupTest(test_util.TensorFlowTestCase):
2741
2742  @test_util.run_deprecated_v1
2743  def testBasic(self):
2744    a = constant_op.constant([2.0], name="a")
2745    with ops.colocate_with(a.op):
2746      b = constant_op.constant(3.0)
2747    c = constant_op.constant(4.0)
2748    self.assertEqual([b"loc:@a"], a.op.colocation_groups())
2749    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
2750    with self.assertRaises(ValueError):
2751      c.op.get_attr("_class")
2752
2753  @test_util.run_deprecated_v1
2754  def testBasicColocationMetadata(self):
2755    const_two = constant_op.constant([2.0], name="two")
2756    with ops.colocate_with(const_two.op):
2757      const_three = constant_op.constant(3.0, name="three")
2758    locations_dict = const_three.op._colocation_dict
2759    self.assertIn("two", locations_dict)
2760    metadata = locations_dict["two"]
2761    self.assertIsNone(metadata.obj)
2762    # Check that this test's filename is recorded as the file containing the
2763    # colocation statement.
2764    self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
2765
2766  @test_util.run_deprecated_v1
2767  def testColocationDeviceInteraction(self):
2768    with ops.device("/cpu:0"):
2769      with ops.device("/device:GPU:0"):
2770        a = constant_op.constant([2.0], name="a")
2771      with ops.colocate_with(a.op):
2772        # 'b' is created in the scope of /cpu:0, but it is
2773        # colocated with 'a', which is on '/device:GPU:0'.  colocate_with
2774        # overrides devices because it is a stronger constraint.
2775        b = constant_op.constant(3.0)
2776    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
2777    self.assertEqual(a.op.device, b.op.device)
2778
2779  @test_util.run_deprecated_v1
2780  def testColocationCanonicalization(self):
2781    with ops.device("/device:GPU:0"):
2782      _ = constant_op.constant(2.0)
2783    with ops.device(lambda op: "/device:GPU:0"):
2784      b = constant_op.constant(3.0)
2785    with ops.get_default_graph().colocate_with(b):
2786      with ops.device("/device:GPU:0"):
2787        c = constant_op.constant(4.0)
2788
2789    # A's device will be /device:GPU:0
2790    # B's device will be /device:GPU:0
2791    # C's device will be /device:GPU:0 because it
2792    # inherits B's device name, after canonicalizing the names.
2793    self.assertEqual(b.op.device, c.op.device)
2794
2795  @test_util.run_deprecated_v1
2796  def testLocationOverrides(self):
2797    with ops.device("/cpu:0"):
2798      with ops.device("/device:GPU:0"):
2799        a = constant_op.constant([2.0], name="a")
2800        # Note that this colocation is "redundant", since we are
2801        # within the scope of "/device:GPU:0".  However, we would like to
2802        # preserve in the GraphDef that these two ops should be
2803        # colocated in a portable way.
2804        with ops.colocate_with(a.op):
2805          b = constant_op.constant(3.0)
2806        c = constant_op.constant(4.0)
2807      d = constant_op.constant(5.0)
2808
2809    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
2810    self.assertEqual("/device:GPU:0", a.op.device)
2811    self.assertEqual(a.op.device, b.op.device)
2812
2813    # Test that device function stack is restored.
2814    self.assertEqual("/device:GPU:0", c.op.device)
2815    self.assertEqual("/device:CPU:0", d.op.device)
2816
2817  @test_util.run_deprecated_v1
2818  def testNestedColocateWith(self):
2819    a = constant_op.constant([2.0], name="a")
2820    with ops.colocate_with(a.op):
2821      b = constant_op.constant(3.0)
2822      with ops.colocate_with(b.op):
2823        c = constant_op.constant(4.0)
2824    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
2825    self.assertEqual([b"loc:@a"], c.op.colocation_groups())
2826
2827  @test_util.run_deprecated_v1
2828  def testMultiColocationGroups(self):
2829    a = constant_op.constant([2.0], name="a")
2830    b = constant_op.constant(3.0, name="b")
2831    with ops.colocate_with(a.op):
2832      with ops.colocate_with(b.op):
2833        c = constant_op.constant(4.0)
2834    self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
2835
2836  @test_util.run_deprecated_v1
2837  def testColocationIgnoreStack(self):
2838    a = constant_op.constant([2.0], name="a")
2839    b = constant_op.constant(3.0, name="b")
2840    with ops.colocate_with(a.op):
2841      with ops.colocate_with(b.op, ignore_existing=True):
2842        c = constant_op.constant(4.0)
2843    self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
2844
2845  @test_util.run_deprecated_v1
2846  def testColocateWithReset(self):
2847    a = constant_op.constant([2.0], name="a")
2848    with ops.colocate_with(a.op):
2849      b = constant_op.constant(3.0, name="b")
2850      with ops.colocate_with(None, ignore_existing=True):
2851        c = constant_op.constant(4.0, name="c")
2852    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
2853    self.assertEqual([b"loc:@c"], c.op.colocation_groups())
2854
2855  @test_util.run_deprecated_v1
2856  def testColocateWithInitialNoneThenNested(self):
2857    a = constant_op.constant([2.0], name="a")
2858    with ops.colocate_with(a.op):
2859      with ops.colocate_with(None, ignore_existing=True):
2860        b = constant_op.constant(3.0, name="b")
2861        with ops.colocate_with(b.op):
2862          c = constant_op.constant(4.0, name="c")
2863    self.assertEqual([b"loc:@b"], b.op.colocation_groups())
2864    self.assertEqual([b"loc:@b"], c.op.colocation_groups())
2865
2866  @test_util.run_deprecated_v1
2867  def testColocateVariables(self):
2868    a = variables.Variable([2.0], name="a")
2869    with ops.colocate_with(a.op):
2870      b = variables.Variable([3.0], name="b")
2871    self.assertEqual([b"loc:@a"], b.op.colocation_groups())
2872
2873
2874class DeprecatedTest(test_util.TensorFlowTestCase):
2875
2876  def testSuccess(self):
2877    with ops.Graph().as_default() as g:
2878      test_util.set_producer_version(g, 7)
2879      old = test_ops.old()
2880      with self.session(graph=g):
2881        old.run()
2882
2883  def _error(self):
2884    return ((r"Op Old is not available in GraphDef version %d\. "
2885             r"It has been removed in version 8\. For reasons\.") %
2886            versions.GRAPH_DEF_VERSION)
2887
2888  def testGraphConstructionFail(self):
2889    with ops.Graph().as_default():
2890      with self.assertRaisesRegexp(NotImplementedError, self._error()):
2891        test_ops.old()
2892
2893
2894class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
2895
2896  def testSuccess(self):
2897    op = ops.Operation(
2898        ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
2899    t = op.outputs[0]
2900    self.assertTrue(ops.is_dense_tensor_like(t))
2901
2902    v = variables.Variable([17])
2903    self.assertTrue(ops.is_dense_tensor_like(v))
2904
2905  class BadClassNoName(object):
2906    pass
2907
2908  class BadClassBadName(object):
2909
2910    def name(self):
2911      pass
2912
2913  class BadClassNoDtype(object):
2914
2915    @property
2916    def name(self):
2917      pass
2918
2919  class BadClassBadDtype(object):
2920
2921    @property
2922    def name(self):
2923      pass
2924
2925    def dtype(self):
2926      pass
2927
2928  def testBadClass(self):
2929    with self.assertRaisesRegexp(TypeError, "`name`"):
2930      ops.register_dense_tensor_like_type(
2931          DenseTensorLikeTypeTest.BadClassNoName)
2932    with self.assertRaisesRegexp(TypeError, "`name`"):
2933      ops.register_dense_tensor_like_type(
2934          DenseTensorLikeTypeTest.BadClassBadName)
2935    with self.assertRaisesRegexp(TypeError, "`dtype`"):
2936      ops.register_dense_tensor_like_type(
2937          DenseTensorLikeTypeTest.BadClassNoDtype)
2938    with self.assertRaisesRegexp(TypeError, "`dtype`"):
2939      ops.register_dense_tensor_like_type(
2940          DenseTensorLikeTypeTest.BadClassBadDtype)
2941
2942
2943class NameScopeTest(test_util.TensorFlowTestCase):
2944
2945  def testStripAndPrependScope(self):
2946    strs = [
2947        "hidden1/hidden1/weights",  # Same prefix. Should strip.
2948        "hidden1///hidden1/weights",  # Extra "/". Should strip.
2949        "^hidden1/hidden1/weights",  # Same prefix. Should strip.
2950        "loc:@hidden1/hidden1/weights",  # Same prefix. Should strip.
2951        "hhidden1/hidden1/weights",  # Different prefix. Should keep.
2952        "hidden1"
2953    ]  # Not a prefix. Should keep.
2954    expected_striped = [
2955        "hidden1/weights", "hidden1/weights", "^hidden1/weights",
2956        "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1"
2957    ]
2958    expected_prepended = [
2959        "hidden2/hidden1/weights", "hidden2/hidden1/weights",
2960        "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights",
2961        "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1"
2962    ]
2963    name_scope_to_strip = "hidden1"
2964    name_scope_to_add = "hidden2"
2965    for es, ep, s in zip(expected_striped, expected_prepended, strs):
2966      striped = ops.strip_name_scope(s, name_scope_to_strip)
2967      self.assertEqual(es, striped)
2968      self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
2969
2970  def testGetNameScope(self):
2971    with ops.Graph().as_default() as g:
2972      with ops.name_scope("scope1"):
2973        with ops.name_scope("scope2"):
2974          with ops.name_scope("scope3"):
2975            self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
2976          self.assertEqual("scope1/scope2", g.get_name_scope())
2977        self.assertEqual("scope1", g.get_name_scope())
2978      self.assertEqual("", g.get_name_scope())
2979
2980  def testTwoGraphs(self):
2981
2982    def f():
2983      g1 = ops.Graph()
2984      g2 = ops.Graph()
2985      with g1.as_default():
2986        with g2.as_default():
2987          with ops.name_scope("_"):
2988            pass
2989
2990    self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f)
2991
2992
2993class TracebackTest(test_util.TensorFlowTestCase):
2994
2995  @test_util.run_deprecated_v1
2996  def testTracebackWithStartLines(self):
2997    with self.cached_session() as sess:
2998      a = constant_op.constant(2.0)
2999      sess.run(
3000          a,
3001          options=config_pb2.RunOptions(
3002              trace_level=config_pb2.RunOptions.FULL_TRACE))
3003      self.assertTrue(sess.graph.get_operations())
3004
3005      # Tests that traceback_with_start_lines is the same as traceback
3006      # but includes one more element at the end.
3007      for op in sess.graph.get_operations():
3008        self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines))
3009        for frame, frame_with_start_line in zip(
3010            op.traceback, op.traceback_with_start_lines):
3011          self.assertEquals(5, len(frame_with_start_line))
3012          self.assertEquals(frame, frame_with_start_line[:-1])
3013
3014
3015class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
3016
3017  @test_util.run_v1_only("b/120545219")
3018  def testBadArgumentsToEnableEagerExecution(self):
3019    with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"):
3020      ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
3021    with self.assertRaisesRegexp(ValueError, "device_policy must be one of"):
3022      c = config_pb2.ConfigProto()
3023      ops.enable_eager_execution(c, c)
3024    with self.assertRaisesRegexp(ValueError, "execution_mode must be one of"):
3025      c = config_pb2.ConfigProto()
3026      ops.enable_eager_execution(c, execution_mode=c)
3027
3028
3029if __name__ == "__main__":
3030  googletest.main()
3031