• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import functools
22import itertools
23from multiprocessing.pool import ThreadPool
24import sys
25import weakref
26
27from absl.testing import parameterized
28import numpy
29
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.protobuf import rewriter_config_pb2
32from tensorflow.python import keras
33from tensorflow.python.eager import context
34from tensorflow.python.eager import def_function
35from tensorflow.python.eager import function
36from tensorflow.python.framework import constant_op
37from tensorflow.python.framework import dtypes
38from tensorflow.python.framework import errors
39from tensorflow.python.framework import function as tf_function
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import random_seed
42from tensorflow.python.framework import tensor_shape
43from tensorflow.python.framework import tensor_spec
44from tensorflow.python.framework import test_ops
45from tensorflow.python.framework import test_util
46from tensorflow.python.keras.engine import training as keras_training
47from tensorflow.python.layers import convolutional
48from tensorflow.python.ops import array_ops
49from tensorflow.python.ops import check_ops
50from tensorflow.python.ops import clip_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import gen_functional_ops
53from tensorflow.python.ops import gen_random_ops
54from tensorflow.python.ops import gen_resource_variable_ops
55from tensorflow.python.ops import init_ops
56from tensorflow.python.ops import list_ops
57from tensorflow.python.ops import math_ops
58from tensorflow.python.ops import random_ops
59from tensorflow.python.ops import resource_variable_ops
60from tensorflow.python.ops import variable_scope
61from tensorflow.python.ops import variables
62from tensorflow.python.platform import test
63from tensorflow.python.training import training_ops
64from tensorflow.python.util import compat
65from tensorflow.python.util import nest
66from tensorflow.python.util import tf_inspect
67
68
69def total_function_cache(defined):
70  # pylint: disable=protected-access
71  return (set(defined._function_cache.primary)
72          | set(defined._function_cache.arg_relaxed))
73  # pylint: enable=protected-access
74
75
76class MiniModel(keras_training.Model):
77  """Minimal model for mnist.
78
79  Useful for testing and debugging on slow TPU simulators.
80  """
81
82  def __init__(self):
83    super(MiniModel, self).__init__(name='')
84    self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
85                                 bias_initializer='ones')
86
87  def call(self, inputs, training=True):
88    return self.fc(inputs)
89
90
91class DefunnedMiniModel(MiniModel):
92
93  @function.defun
94  def call(self, inputs, training=True):
95    return super(DefunnedMiniModel, self).call(inputs, training=training)
96
97
98class FunctionTest(test.TestCase, parameterized.TestCase):
99
100  def testBasic(self):
101    # TODO(b/121134877): Remove the autograph override.
102    matmul = def_function.function(math_ops.matmul, autograph=False)
103    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
104    sq = matmul(t, t, transpose_a=True)
105    sq2 = matmul(sq, t, transpose_a=True)
106    self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
107    self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])
108
109  def testVariable(self):
110    v1 = variables.Variable(1.0)
111    add = def_function.function(lambda x, v: x + v1 + v)
112    v2 = variables.Variable(1.0)
113    x = constant_op.constant(1.0)
114    r = add(x, v2)
115    self.assertEqual(3.0, self.evaluate(r))
116
117  def testExternalControlDependency(self):
118    with ops.Graph().as_default(), self.test_session():
119      v = variables.Variable(1.0)
120      v.initializer.run()
121
122      op = v.assign_add(1.0)
123
124      @function.defun
125      def f():
126        with ops.control_dependencies([op]):
127          return 1.0
128
129      self.evaluate(f())
130      self.assertAllEqual(self.evaluate(v), 2.0)
131
132  def testInputShapeFunctionRelaxation(self):
133    unknown_dim = [False]
134
135    @function.defun
136    def func(a):
137      if a._shape_tuple()[0] is None:
138        unknown_dim[0] = True
139      return a + 1
140
141    func(constant_op.constant([]))
142    self.assertFalse(unknown_dim[0])
143    self.assertLen(total_function_cache(func), 1)
144
145    func(constant_op.constant([1.0]))
146    self.assertFalse(unknown_dim[0])
147    self.assertLen(total_function_cache(func), 2)
148
149    func(constant_op.constant([1.0, 2.0]))
150    self.assertTrue(unknown_dim[0])
151    self.assertLen(total_function_cache(func), 2)
152
153  def testNestedInputShapeFunctionRelaxation(self):
154    unknown_dim = [False]
155
156    @function.defun
157    def func(a_, b_=None):
158      del a_  # Only used to check which cache is used.
159      self.assertEqual(b_[0]._shape_tuple(), ())
160      if b_[1]._shape_tuple()[0] is None:
161        unknown_dim[0] = True
162      return b_[0] + 1
163
164    a = 'hi'
165    b0 = constant_op.constant(1.0)
166    func(a, b_=[b0, constant_op.constant([])])
167    self.assertFalse(unknown_dim[0])
168    self.assertLen(total_function_cache(func), 1)
169
170    func(a, b_=[b0, constant_op.constant([1.0])])
171    self.assertFalse(unknown_dim[0])
172    self.assertLen(total_function_cache(func), 2)
173
174    func(a, b_=[b0, constant_op.constant([1.0, 1.0])])
175    self.assertTrue(unknown_dim[0])
176    self.assertLen(total_function_cache(func), 2)
177
178    unknown_dim[0] = False
179
180    # Now do the same except with a new a which is not a tensor; this should
181    # change the cache key.
182    a = 'bye'
183    func(a, b_=[b0, constant_op.constant([])])
184    self.assertFalse(unknown_dim[0])
185    self.assertLen(total_function_cache(func), 3)
186
187    # Since we already marked a cache miss for a function with the same
188    # non-input signatures, here we will immediately start relaxing shapes.
189    func(a, b_=[b0, constant_op.constant([1.0])])
190    self.assertTrue(unknown_dim[0])
191    self.assertLen(total_function_cache(func), 3)
192
193  def testFunctionRelaxationLosesInnerDimWithKerasLayer(self):
194    layer = keras.layers.Dense(1)
195    fn = def_function.function()(layer)
196
197    with self.captureWritesToStream(sys.stderr) as printed:
198      fn(array_ops.ones((3, 2)))
199      self.assertNotIn('ValueError', printed.contents())
200    with self.captureWritesToStream(sys.stderr) as printed:
201      # Use batch size 2 to trigger a second cache miss on the shape.
202      fn(array_ops.ones((2, 2)))
203      self.assertNotIn('ValueError', printed.contents())
204
205    # Shape relaxation passes TensorShape([None, None]), which causes layer
206    # matmul to fail, due to incompatible dims.  What would have been a graph
207    # build time error (layer would complain about the inner dim being 4).
208    with self.captureWritesToStream(sys.stderr) as printed:
209      with self.assertRaisesRegexp(errors.InvalidArgumentError, r'MatMul'):
210        fn(array_ops.ones((3, 4)))
211
212  def testNestedShapeFunctionRelaxation(self):
213
214    got_shape = [None]
215
216    # The inner function will go through shape relaxation because the shapes it
217    # receives will be [1], [2], [3], ...
218    @def_function.function
219    def bar(x_shape):
220      got_shape[0] = x_shape._shape_tuple()
221      return x_shape
222
223    # The outer function will not go through shape relaxation because the shapes
224    # it receives will be [1], [[1]], [[[1]]], ...
225    @def_function.function
226    def foo(ones):
227      return bar(array_ops.shape(ones))
228
229    for rank in range(1, 6):
230      x_shape = self.evaluate(foo(array_ops.ones([1] * rank)))
231      self.assertAllEqual(x_shape, [1] * rank)
232      if rank < 3:
233        self.assertEqual(got_shape[0], (rank,))
234      else:
235        self.assertEqual(got_shape[0], (None,))
236
237  def testNoHash(self):
238
239    @def_function.function()
240    def f(_):
241      return 1.0
242
243    with self.assertRaisesRegexp(TypeError, 'set'):
244      f(set([]))
245
246  def testFuncName(self):
247
248    @function.defun_with_attributes(attributes={'func_name': 'multiply'})
249    def add(x, y):
250      _ = x * y
251      return x + y
252
253    @function.defun
254    def add_2(x, y):
255      _ = x * y
256      return x + y
257
258    self.assertEqual(add._name, 'multiply')
259    self.assertEqual(add_2._name, 'add_2')
260
261  def testBasicGraphMode(self):
262    # TODO(b/121134877): Remove the autograph override.
263    matmul = def_function.function(math_ops.matmul, autograph=False)
264
265    @def_function.function
266    def sq(a):
267      return matmul(a, a)
268
269    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
270    out = sq(t)
271    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
272
273  def testNestedInputsGraphMode(self):
274    # TODO(b/121134877): Remove the autograph override.
275    matmul = def_function.function(math_ops.matmul, autograph=False)
276
277    pair = collections.namedtuple('pair', ['a', 'b'])
278
279    @def_function.function
280    def a_times_b(inputs):
281      return matmul(inputs.a['a'], inputs.b['b'])
282
283    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
284
285    out = a_times_b(pair({'a': t}, {'b': t}))
286    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
287
288  def testNestedOutputsGraphMode(self):
289    # TODO(b/121134877): Remove the autograph override.
290    matmul = def_function.function(math_ops.matmul, autograph=False)
291
292    pair = collections.namedtuple('pair', ['a', 'b'])
293
294    @def_function.function()
295    def pairs_mul(pair_a, pair_b):
296      return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b))
297
298    a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]])
299    b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]])
300
301    out = pairs_mul(pair(a, b), pair(b, a))
302    expected = pair(math_ops.matmul(a, b).numpy(),
303                    math_ops.matmul(b, a).numpy())
304    self.assertAllClose(out, expected)
305
306  def testGraphEagerIsolation(self):
307
308    @function.defun
309    def f():
310      self.v = variables.Variable(1.0)
311      return self.v.read_value()
312
313    self.assertAllEqual(f(), 1.0)
314
315    with ops.Graph().as_default():
316      self.assertEqual(f().shape, ())
317
318  def testBasicGraphFunction(self):
319    # TODO(b/121134877): Remove the autograph override.
320    matmul = def_function.function(math_ops.matmul, autograph=False)
321
322    @def_function.function
323    def sq(a):
324      return matmul(a, a)
325
326    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
327
328    sq_op = sq.get_concrete_function(t)
329    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
330    out = sq_op(t)
331    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
332
333  def testInputSpecGraphFunction(self):
334    # TODO(b/121134877): Remove the autograph override.
335    matmul = def_function.function(math_ops.matmul, autograph=False)
336
337    @def_function.function
338    def sq(a):
339      return matmul(a, a)
340
341    sq_op = sq.get_concrete_function(
342        tensor_spec.TensorSpec((None, None), dtypes.float32))
343    self.assertEqual([None, None], sq_op.output_shapes.as_list())
344
345    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
346    out1 = sq_op(t1)
347    self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
348
349    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
350    out2 = sq_op(t2)
351    self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
352
353  def testNestedInputSpecGraphFunction(self):
354    # TODO(b/121134877): Remove the autograph override.
355    matmul = def_function.function(math_ops.matmul, autograph=False)
356
357    @def_function.function
358    def sq(mats):
359      ((a, b),) = mats
360      return matmul(a, b)
361
362    with self.assertRaisesRegexp(ValueError, "two arguments named 'mats'"):
363      sq.get_concrete_function(
364          [(tensor_spec.TensorSpec((None, None), dtypes.float32),
365            tensor_spec.TensorSpec((None, None), dtypes.float32))])
366    sq_op = sq.get_concrete_function(
367        [(tensor_spec.TensorSpec((None, None), dtypes.float32,
368                                 name='first_mat'),
369          tensor_spec.TensorSpec((None, None), dtypes.float32,
370                                 name='second_mat'))])
371    self.assertEqual([None, None], sq_op.output_shapes.as_list())
372
373    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
374    t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
375    with self.assertRaisesRegexp(
376        TypeError, 'bound to Tensors within nested structures'):
377      sq_op(t1, t2)
378    out = sq_op(first_mat=t1, second_mat=t2)
379    self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
380
381  def testExecutingStatelessDefunConcurrently(self):
382
383    @def_function.function
384    def stateless(x):
385      return math_ops.multiply(2.0, x)
386
387    pool = ThreadPool()
388    inputs = [constant_op.constant(1.0 * x) for x in range(100)]
389    outputs = [float(out) for out in pool.map(stateless, inputs)]
390    expected = [float(2.0 * x) for x in inputs]
391    self.assertSequenceEqual(outputs, expected)
392
393  def testExecutingManyStatelessDefunsConcurrently(self):
394
395    @def_function.function
396    def stateless(x):
397      del x
398      return math_ops.multiply(2.0, 2.0)
399
400    pool = ThreadPool()
401    # `pool.map` below instantiates 100 functions, one for each object.
402    outputs = [
403        float(out)
404        for out in pool.map(stateless, [object() for _ in range(100)])
405    ]
406    expected = [4.0] * 100
407    self.assertSequenceEqual(outputs, expected)
408
409  def testExecutingStatefulDefunConcurrently(self):
410
411    v = resource_variable_ops.ResourceVariable(1.0)
412
413    @def_function.function
414    def stateful(x):
415      v.assign(x)
416
417    pool = ThreadPool()
418    inputs = [constant_op.constant(0.0)] * 100
419    pool.map(stateful, inputs)
420    self.assertEqual(float(v.read_value()), 0.0)
421
422  def testExecutingManyStatefulDefunsConcurrently(self):
423
424    v = resource_variable_ops.ResourceVariable(1.0)
425
426    @def_function.function
427    def stateful(x):
428      del x
429      return v.assign(0.0)
430
431    pool = ThreadPool()
432    # `pool.map` below instantiates 100 functions, one for each object.
433    pool.map(stateful, [object() for _ in range(100)])
434    self.assertEqual(float(v.read_value()), 0.0)
435
436  def disabled_testRandomSeed(self):
437
438    @def_function.function
439    def f():
440      return random_ops.random_normal(())
441
442    random_seed.set_random_seed(1)
443    x = f()
444    self.assertNotEqual(x, f())
445    random_seed.set_random_seed(1)
446    self.assertAllEqual(f(), x)
447
448  def testNestedInputsGraphFunction(self):
449    # TODO(b/121134877): Remove the autograph override.
450    matmul = def_function.function(math_ops.matmul, autograph=False)
451
452    pair = collections.namedtuple('pair', ['a', 'b'])
453
454    @def_function.function
455    def a_times_b(inputs):
456      return matmul(inputs.a['a'], inputs.b['b'])
457
458    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
459    sq_op = a_times_b.get_concrete_function(
460        pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')),
461             dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b'))))
462    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
463    out = sq_op(a=t, b=t)
464    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
465
466  def testNestedOutputGraphFunction(self):
467    # TODO(b/121134877): Remove the autograph override.
468    matmul = def_function.function(math_ops.matmul, autograph=False)
469
470    @def_function.function
471    def sq(a):
472      return (matmul(a, a), {'b': constant_op.constant(1.0)})
473
474    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
475
476    sq_op = sq.get_concrete_function(t)
477    self.assertEqual(sq_op.output_shapes,
478                     (tensor_shape.TensorShape([2, 2]),
479                      {'b': tensor_shape.TensorShape([])}))
480    self.assertEqual(sq_op.output_dtypes,
481                     (dtypes.float32, {'b': dtypes.float32}))
482    (a, b) = sq_op(t)
483    self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
484    self.assertAllEqual(b['b'].numpy(), 1.0)
485
486  def testGraphFunctionNoneOutput(self):
487    @def_function.function
488    def fn(unused_a, unused_b):
489      return None
490
491    x = constant_op.constant(1)
492    fn_op = fn.get_concrete_function(x, x)
493    self.assertEqual(fn_op.output_dtypes, None)
494    self.assertEqual(fn_op.output_shapes, None)
495    self.assertAllEqual(fn_op(x, x), None)
496
497  def testDefunNumpyArraysConvertedToTensors(self):
498
499    def f(x):
500      self.assertIsInstance(x, ops.Tensor)
501      return x
502
503    x = random_ops.random_uniform([2, 2]).numpy()
504    defined = function.defun(f)
505    defined(x)
506    self.assertLen(total_function_cache(defined), 1)
507
508    x = random_ops.random_uniform([2, 2]).numpy()
509    defined(x)
510    # A NumPy array with different values but the same shape and dtype
511    # shouldn't trigger another function definition.
512    self.assertLen(total_function_cache(defined), 1)
513
514    # Test that the numpy array is properly an argument to the graph function.
515    self.assertEqual(1., defined(numpy.ones([])).numpy())
516    self.assertEqual(0., defined(numpy.zeros([])).numpy())
517    self.assertEqual(1., defined(array_ops.ones([])).numpy())
518    self.assertEqual(0., defined(array_ops.zeros([])).numpy())
519
520  def testDefunCapturedInt32(self):
521    x = constant_op.constant(1, dtype=dtypes.int32)
522
523    @def_function.function
524    def add_int32s():
525      return x + x
526
527    self.assertEqual(2, int(add_int32s()))
528
529  def testDefunReadVariable(self):
530    v = resource_variable_ops.ResourceVariable(1.0)
531
532    @def_function.function
533    def f():
534      return v.read_value()
535
536    self.assertEqual(1.0, float(f()))
537
538  def testDefunAssignAddVariable(self):
539    v = resource_variable_ops.ResourceVariable(1.0)
540    x = constant_op.constant(2.0)
541
542    @def_function.function
543    def test_assign_add():
544      v.assign_add(x)
545      return v.read_value()
546
547    self.assertEqual(3.0, float(test_assign_add()))
548
549  @test_util.run_in_graph_and_eager_modes
550  def testTensorInitializationInFunctionRaisesError(self):
551    error_msg = ('Tensor-typed variable initializers must either be '
552                 'wrapped in an init_scope or callable.*')
553
554    @def_function.function
555    def tensor_init():
556      with self.assertRaisesRegexp(ValueError, error_msg):
557        resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
558
559    tensor_init()
560
561  @test_util.run_in_graph_and_eager_modes
562  def testCallableTensorInitializationInFunction(self):
563
564    @def_function.function
565    def tensor_init():
566      self.v = resource_variable_ops.ResourceVariable(
567          lambda: constant_op.constant(2.0))
568      return self.v.read_value()
569
570    value = tensor_init()
571    if not context.executing_eagerly():
572      self.evaluate(variables.global_variables_initializer())
573    self.assertEqual(self.evaluate(value), 2.0)
574
575  @test_util.also_run_as_tf_function
576  def testInitScopeTensorInitializationInFunction(self):
577
578    @def_function.function
579    def tensor_init():
580      with ops.init_scope():
581        const = constant_op.constant(2.0)
582      # Note: this variable bypasses tf.function's variable creation
583      # requirements by bypassing variable_creator_scope by using
584      # ResourceVariable instead of Variable.
585      self.v = resource_variable_ops.ResourceVariable(const)
586      return self.v.read_value()
587
588    value = tensor_init()
589    self.assertAllEqual(value, 2.0)
590
591  @test_util.run_in_graph_and_eager_modes
592  def testGetConcreteFunctionCreatesVariables(self):
593
594    v_holder = []
595
596    @def_function.function
597    def tensor_init():
598      if not v_holder:
599        v_holder.append(variables.Variable(5.))
600      return v_holder[0].read_value()
601
602    concrete = tensor_init.get_concrete_function()
603    self.evaluate(variables.global_variables_initializer())
604    self.assertAllEqual(5., self.evaluate(concrete()))
605    self.assertAllEqual(5., self.evaluate(tensor_init()))
606
607  def testFuncGraphCaptureByValue(self):
608    v = variables.Variable(1.0)
609
610    def trivial_function():
611      return v.read_value()
612
613    graph_function = function.Function(
614        trivial_function, 'test', capture_by_value=True)
615
616    self.assertAllEqual(graph_function(), 1.0)
617    v.assign(2.0)
618    self.assertAllEqual(graph_function(), 1.0)
619
620  def testFuncGraphCaptureByValueNested(self):
621    v = variables.Variable(1.0)
622
623    def trivial_function():
624      return control_flow_ops.cond(
625          array_ops.placeholder_with_default(True, ()),
626          v.read_value, v.read_value)
627
628    graph_function = function.Function(
629        trivial_function, 'test', capture_by_value=True)
630
631    self.assertAllEqual(graph_function(), 1.0)
632    v.assign(2.0)
633    self.assertAllEqual(graph_function(), 1.0)
634
635  def testDefunShapeInferenceWithCapturedResourceVariable(self):
636    v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
637
638    def f():
639      x = constant_op.constant([[1, 2], [3, 4]])
640      out = math_ops.matmul(v, x)
641      self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
642      # We do not return v directly since the tensor conversion function of
643      # ResourceVariable returns the read value and not the resource itself.
644      return v._handle
645
646    compiled = def_function.function(f)
647    var_handle = compiled()
648    self.assertEqual(var_handle.dtype, dtypes.resource)
649    self.assertEqual(var_handle.shape, tensor_shape.scalar())
650    var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
651    self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
652
653  def testShapeInferenceForMoreSpecificInput(self):
654    self.skipTest('b/124219898')
655
656    def f(a):
657      return array_ops.reshape(a, [-1, 3])
658
659    signature = [tensor_spec.TensorSpec(None, dtypes.float32)]
660    compiled = def_function.function(f, input_signature=signature)
661
662    with ops.Graph().as_default():
663      inputs = array_ops.zeros([10, 10, 3])
664      self.assertAllEqual(f(inputs).shape, compiled(inputs).shape)
665
666  def testFuncListAttr(self):
667
668    @function.defun
669    def test_function(val):
670
671      def fn1():
672        return array_ops.ones([10])
673
674      fn2 = lambda: array_ops.ones([10]) * 2
675
676      def fn3(x=2):
677        return array_ops.ones([10]) * x
678      fn3 = functools.partial(fn3, x=3)
679
680      return gen_functional_ops.case(val, [], [dtypes.float32],
681                                     [function.defun(f).get_concrete_function()
682                                      for f in (fn1, fn2, fn3)])
683
684    ones = array_ops.ones([10])
685    self.assertAllEqual([ones], test_function(0))
686    self.assertAllEqual([ones * 2], test_function(1))
687    self.assertAllEqual([ones * 3], test_function(2))
688    self.assertAllEqual([ones * 3], test_function(22))  # default branch
689
690  @test_util.enable_control_flow_v2
691  def testVariableInLoopInFunction(self):
692
693    @function.defun
694    def test_function():
695
696      def loop_test(_):
697        return False
698
699      def loop_body(_):
700        return variable_scope.get_variable('a', shape=())
701
702      return control_flow_ops.while_loop(loop_test, loop_body, [0.0])
703
704    self.assertEqual(test_function().shape, [])
705
706  def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
707    with context.graph_mode():
708      v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
709
710      def f():
711        x = constant_op.constant([[1, 2], [3, 4]])
712        out = math_ops.matmul(v, x)
713        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
714        # We do not return v directly since the tensor conversion function of
715        # ResourceVariable returns the read value and not the resource itself.
716        return v._handle
717
718      compiled = def_function.function(f)
719      var_handle = compiled()
720      self.assertEqual(var_handle.dtype, dtypes.resource)
721      self.assertEqual(var_handle.shape, tensor_shape.scalar())
722      var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
723      self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
724
725  def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
726    with context.graph_mode():
727      v = variables.Variable([[1, 2], [3, 4]])
728
729      def f():
730        x = constant_op.constant([[1, 2], [3, 4]])
731        out = math_ops.matmul(v, x)
732        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
733
734      # Check that shape inference works while creating the defun
735      compiled = def_function.function(f)
736      compiled()
737
738  def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
739    with context.graph_mode():
740      tensor_list = list_ops.empty_tensor_list(
741          element_dtype=dtypes.float32,
742          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
743      tensor_list = list_ops.tensor_list_push_back(tensor_list,
744                                                   constant_op.constant(1.0))
745      tensor_list = list_ops.tensor_list_push_back(tensor_list,
746                                                   constant_op.constant(2.0))
747
748      def f():
749        tl, value = list_ops.tensor_list_pop_back(
750            tensor_list, element_dtype=dtypes.float32)
751        self.assertEqual(value.shape, tensor_shape.scalar())
752        return tl
753
754      compiled = def_function.function(f)
755      output_tensor_list = compiled()
756      _, value = list_ops.tensor_list_pop_back(
757          output_tensor_list, element_dtype=dtypes.float32)
758      self.assertEqual(value.shape, tensor_shape.scalar())
759
760  @test_util.run_in_graph_and_eager_modes
761  def testDefunForcesResourceVariables(self):
762
763    def variable_creator():
764      self.v = variables.Variable(0.0)
765      return self.v.read_value()
766
767    self.v = None
768    defined = function.defun(variable_creator)
769    defined()  # Create the variable.
770    self.assertIsInstance(
771        self.v, resource_variable_ops.ResourceVariable)
772
773  def testRunMetadata(self):
774
775    @def_function.function
776    def f(x):
777      return x * x
778
779    with ops.device('cpu:0'):
780      context.enable_run_metadata()
781      f(constant_op.constant(1.0))
782    run_metadata = context.export_run_metadata()
783    context.disable_run_metadata()
784    step_stats = run_metadata.step_stats
785    self.assertNotEmpty(step_stats.dev_stats)
786    cpu_stats = step_stats.dev_stats[0]
787    self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
788                     cpu_stats.device)
789    # Testing for at least 2 because the function call should generate at most
790    # one entry in the step_stats; the ops inside function can generate
791    # arbitrarily many (placeholders, return identities, etc, might be included
792    # or not in the future, so shouldn't be tested for exactly.
793    self.assertGreaterEqual(len(cpu_stats.node_stats), 2)
794    self.assertLen(run_metadata.partition_graphs, 1)
795
796  def testGraphModeCaptureVariable(self):
797    with context.graph_mode(), self.cached_session():
798
799      class HasAVar(object):
800
801        def __init__(self):
802          self.v = resource_variable_ops.ResourceVariable(1.0)
803
804        def call(self):
805          return self.v * 2
806
807      o = HasAVar()
808      self.evaluate(variables.global_variables_initializer())
809      call = def_function.function(o.call)
810      op = call()
811      self.assertAllEqual(self.evaluate(op), 2.0)
812
813  def testGraphModeManyFunctions(self):
814    with ops.Graph().as_default(), self.cached_session():
815
816      @def_function.function
817      def f(x):
818        return x * x
819
820      @def_function.function
821      def g(x):
822        return f(x) + 1
823
824      self.assertAllEqual(g(constant_op.constant(2.0)).eval(), 5.0)
825
826  def testDict(self):
827
828    @def_function.function
829    def f(x):
830      return {'name': x + 1}
831
832    self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0)
833
834  def testTensorConversionWithDefun(self):
835
836    @def_function.function
837    def f(x):
838      return math_ops.add(x, constant_op.constant(3))
839
840    self.assertAllEqual(5, f(constant_op.constant(2)))
841
842  def testTensorConversionCall(self):
843
844    @def_function.function
845    def f(x):
846      return math_ops.add(x, constant_op.constant(3))
847
848    @def_function.function
849    def g(x):
850      return f(f(x))
851
852    self.assertAllEqual(8, g(constant_op.constant(2)))
853
854  def testCallShape(self):
855
856    @def_function.function
857    def f(x):
858      return x + 1
859
860    @def_function.function
861    def g(x):
862      x = f(x)
863      self.assertEqual(x.shape.as_list(), [])
864      return None
865
866    g(constant_op.constant(1.0))
867
868  def testNestedDefunWithNoOutputAndTapedInput(self):
869    three = resource_variable_ops.ResourceVariable(3.0, name='v')
870
871    @def_function.function
872    def f(x):
873      # This function intentionally takes a taped variable as input,
874      # but does not return any values
875      math_ops.add(x, three)
876
877    @def_function.function
878    def g(x):
879      y = math_ops.add(x, three)
880      f(y)
881
882    g(three)
883
884  def testGatherResourceWithDefun(self):
885    with ops.device('cpu:0'):
886      v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
887
888    def sum_gather():
889      return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
890
891    defined = def_function.function(sum_gather)
892    self.assertAllEqual(sum_gather(), defined())
893
894  def testReturningIndexedSlicesWithDefun(self):
895
896    def validate(indexed_slice):
897      @def_function.function
898      def f():
899        return indexed_slice
900
901      output = f()
902      self.assertIsInstance(output, ops.IndexedSlices)
903      self.assertAllEqual(indexed_slice.values, output.values)
904      self.assertAllEqual(indexed_slice.indices, output.indices)
905      self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)
906
907      self.assertEqual(
908          f.get_concrete_function().output_shapes,
909          indexed_slice.values.shape)
910
911    arg = ops.IndexedSlices(
912        values=constant_op.constant([1, 2]),
913        indices=constant_op.constant([0, 1]),
914        dense_shape=constant_op.constant([2]))
915    validate(arg)
916
917    arg = ops.IndexedSlices(
918        values=constant_op.constant([1, 2]),
919        indices=constant_op.constant([0, 1]),
920        dense_shape=None)
921    validate(arg)
922
923  def testIndexedSliceAsArgumentWithDefun(self):
924
925    @def_function.function
926    def f(indexed_slice):
927      return indexed_slice
928
929    def validate(arg):
930      output = f(arg)
931      self.assertIsInstance(output, ops.IndexedSlices)
932      self.assertAllEqual(arg.values, output.values)
933      self.assertAllEqual(arg.indices, output.indices)
934      self.assertAllEqual(arg.dense_shape, output.dense_shape)
935
936    indexed_slice = ops.IndexedSlices(
937        values=constant_op.constant([1]),
938        indices=constant_op.constant([0]),
939        dense_shape=constant_op.constant([1]))
940    validate(indexed_slice)
941
942    # Test that `f` works even when `dense_shape` is None.
943    indexed_slice = ops.IndexedSlices(
944        values=constant_op.constant([1]),
945        indices=constant_op.constant([0]),
946        dense_shape=None)
947    validate(indexed_slice)
948
949  def testFunctionOnDevice(self):
950    if not context.context().num_gpus():
951      self.skipTest('No GPUs found')
952
953    x = constant_op.constant([1.]).gpu()
954    # TODO(b/121134877): Remove the autograph override.
955    f = def_function.function(math_ops.add, autograph=False)
956    y = f(x, x).cpu()
957    self.assertAllEqual(y, [2.])
958
959  @test_util.run_in_graph_and_eager_modes
960  def testFunctionWithResourcesOnDifferentDevices(self):
961    if not context.context().num_gpus():
962      self.skipTest('No GPUs found.')
963
964    with ops.device('/cpu:0'):
965      v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
966
967    with ops.device('/gpu:0'):
968      v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
969
970    def sum_gather():
971      cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
972      gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
973      return cpu_result, gpu_result
974
975    defined = function.defun(sum_gather)
976    if not context.executing_eagerly():
977      self.evaluate(variables.global_variables_initializer())
978    expected = self.evaluate(sum_gather())
979    self.assertAllEqual(expected, self.evaluate(defined()))
980
981  @test_util.run_in_graph_and_eager_modes
982  def testOpInFunctionWithConflictingResourceInputs(self):
983    if not context.context().num_gpus():
984      self.skipTest('No GPUs found.')
985
986    with ops.device('/cpu:0'):
987      v_cpu = resource_variable_ops.ResourceVariable(
988          [0.0, 1.0, 2.0], name='cpu')
989      v_also_cpu = resource_variable_ops.ResourceVariable(
990          [0.0, 1.0, 2.0], name='also_cpu')
991
992    with ops.device('/gpu:0'):
993      v_gpu = resource_variable_ops.ResourceVariable(
994          [0.0, 1.0, 2.0], name='gpu')
995
996    @def_function.function
997    def resource_apply_adam():
998      training_ops.resource_apply_adam(
999          v_cpu.handle,
1000          v_gpu.handle,
1001          v_also_cpu.handle,
1002          1.0,  # beta1_power
1003          1.0,  # beta2_power
1004          1.0,  # learning_rate
1005          1.0,  # beta1
1006          1.0,  # beta2
1007          1.0,  # epsilon,
1008          [1.0, 1.0, 1.0],  # grad
1009          False)  # use_locking
1010      return None
1011
1012    with self.assertRaisesRegexp(
1013        errors.InvalidArgumentError,
1014        'Cannot place the graph because a reference or resource edge connects '
1015        'colocation groups with incompatible assigned devices'):
1016      if not context.executing_eagerly():
1017        self.evaluate(variables.global_variables_initializer())
1018      self.evaluate(resource_apply_adam())
1019
1020  def testFunctionHandlesInputsOnDifferentDevices(self):
1021    if not context.context().num_gpus():
1022      self.skipTest('No GPUs found')
1023
1024    # The Reshape op requires the shape tensor to be placed in host memory.
1025    # TODO(b/121134877): Remove the autograph override.
1026    reshape = def_function.function(array_ops.reshape, autograph=False)
1027    value = constant_op.constant([1., 2.]).gpu()
1028    shape = constant_op.constant([2, 1])
1029    reshaped = reshape(value, shape).cpu()
1030    self.assertAllEqual(reshaped, [[1], [2]])
1031
1032  def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self):
1033    if not context.context().num_gpus():
1034      self.skipTest('No GPUs found')
1035
1036    # The Reshape op requires the shape tensor to be placed in host memory.
1037    # TODO(b/121134877): Remove the autograph override.
1038    reshape = def_function.function(array_ops.reshape, autograph=False)
1039    value = constant_op.constant([1., 2.])
1040    shape = constant_op.constant([2, 1]).gpu()
1041    reshape(value, shape)  # No error is raised
1042
1043  def testNoneOutput(self):
1044
1045    @def_function.function
1046    def my_function(_):
1047      return None
1048
1049    self.assertAllEqual(my_function(1), None)
1050
1051  def testNestedFunctions(self):
1052    # TensorFlow function (which is what would be used in TensorFlow graph
1053    # construction).
1054    @tf_function.Defun(dtypes.int32, dtypes.int32)
1055    def add(a, b):
1056      return math_ops.add(a, b)
1057
1058    @def_function.function
1059    def add_one(x):
1060      return add(x, 1)
1061
1062    self.assertAllEqual(3, add_one(constant_op.constant(2)))
1063
1064  def testVariableCaptureInNestedFunctions(self):
1065    v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)
1066
1067    @def_function.function
1068    def inner_read():
1069      return v.read_value()
1070
1071    @def_function.function
1072    def outer():
1073      return inner_read()
1074
1075    self.assertEqual(1, int(outer()))
1076
1077  def testReturnCapturedEagerTensor(self):
1078    t = constant_op.constant(1)
1079
1080    @def_function.function
1081    def read():
1082      return t
1083
1084    self.assertEqual(1, int(read()))
1085
1086  def testReturnCapturedGraphTensor(self):
1087    with context.graph_mode(), self.cached_session():
1088      t = constant_op.constant(1)
1089
1090      @def_function.function
1091      def read():
1092        return t
1093
1094      self.assertEqual(1, int(self.evaluate(read())))
1095
1096  def testSequenceInputs(self):
1097    # TODO(b/121134877): Remove the autograph override.
1098    clip_by_global_norm = def_function.function(
1099        clip_ops.clip_by_global_norm, autograph=False)
1100    t_list = [constant_op.constant(1.0), constant_op.constant(2.0)]
1101    clipped_list, global_norm = clip_by_global_norm(t_list,
1102                                                    constant_op.constant(.2))
1103    for t in clipped_list:
1104      self.assertIsInstance(t, ops.Tensor)
1105    self.assertIsInstance(global_norm, ops.Tensor)
1106
1107  def testNestedSequenceInputs(self):
1108
1109    def my_op(inputs):
1110      a, b, c = inputs
1111      e, f = b
1112      g, h = e
1113      return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c
1114
1115    my_eager_op = def_function.function(my_op)
1116    ret = my_eager_op([
1117        constant_op.constant(1), [(constant_op.constant(2),
1118                                   constant_op.constant(3)),
1119                                  constant_op.constant(4)],
1120        constant_op.constant(5)
1121    ])
1122    self.assertLen(ret, 2)
1123    self.assertAllEqual(ret[0][0], 2)
1124    self.assertAllEqual(ret[0][1][0][0], 8)
1125    self.assertAllEqual(ret[0][1][0][1], 4)
1126    self.assertIsInstance(ret[0][1][0], tuple)
1127    self.assertAllEqual(ret[0][1][1], 6)
1128    self.assertAllEqual(ret[0][2], 10)
1129    self.assertAllEqual(ret[1], 15)
1130
1131  def testVariableNamesRespectNameScopesWithDefun(self):
1132    @def_function.function
1133    def create_variable():
1134      with ops.name_scope('foo'):
1135        v = resource_variable_ops.ResourceVariable(0.0, name='bar')
1136      self.assertEqual(v.name, 'foo/bar:0')
1137
1138    create_variable()
1139
1140  def testVariableNamesRespectNameScopesWithDefunInGraph(self):
1141    with context.graph_mode():
1142      @def_function.function
1143      def create_variable():
1144        with ops.name_scope('foo'):
1145          v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
1146        self.assertEqual(v.name, 'foo/bar:0')
1147
1148      with ops.get_default_graph().as_default():
1149        create_variable()
1150
1151  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
1152  def testLayerInDefun(self):
1153    conv = convolutional.Conv2D(
1154        filters=1,
1155        kernel_size=2,
1156        kernel_initializer=init_ops.ones_initializer(),
1157        bias_initializer=init_ops.zeros_initializer())
1158
1159    @function.defun
1160    def model(x):
1161      return conv(x)
1162
1163    x = array_ops.ones([1, 2, 2, 1])
1164    y = model(x)
1165
1166    if not context.executing_eagerly():
1167      self.evaluate(variables.global_variables_initializer())
1168
1169    self.assertAllClose([[[[4.0]]]], self.evaluate(y))
1170
1171  # Variable lifting is somewhat different between defun/tf.function, so testing
1172  # device placement on both makes sense.
1173  @parameterized.named_parameters(
1174      dict(testcase_name='Defun',
1175           function_decorator=function.defun),
1176      dict(testcase_name='DefFunction',
1177           function_decorator=def_function.function))
1178  @test_util.run_in_graph_and_eager_modes
1179  def testVariablesPlacedOnOutsideDevice(self, function_decorator):
1180
1181    class _Obj(object):
1182
1183      def __init__(self):
1184        self.v = None
1185
1186      @function_decorator
1187      def f(self):
1188        if self.v is None:
1189          self.v = variables.Variable(1.)
1190        return self.v + 1.
1191
1192    has_device = _Obj()
1193    with ops.device('cpu:0'):
1194      has_device.f()
1195    self.assertIn('CPU', has_device.v.device)
1196
1197  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
1198  def testDefunKerasModelCall(self):
1199    model = MiniModel()
1200    model.call = function.defun(model.call)
1201
1202    x = array_ops.ones([1, 2])
1203    y = model(x)
1204
1205    if not context.executing_eagerly():
1206      self.evaluate(variables.global_variables_initializer())
1207
1208    self.assertAllEqual([[3.0]], self.evaluate(y))
1209
1210    # Break the reference cycle between the MiniModel and the defun:
1211    # `MiniModel` --(through its `call` method)--> `Function`
1212    # `Function` --(instancemethod on `MiniModel`)--> `MiniModel`
1213    del model.call
1214
1215  # Note: The ConfigProto below unfortunately only configures graph
1216  # construction. Eager's configuration is controlled in `__main__`.
1217  @test_util.run_in_graph_and_eager_modes(
1218      config=config_pb2.ConfigProto(device_count={'CPU': 4}))
1219  @test_util.run_v1_only('b/120545219')
1220  def testDeviceAnnotationsRespected(self):
1221
1222    def multi_device_fn():
1223      with ops.device('/cpu:0'):
1224        s0 = test_ops.device_placement_op()
1225      with ops.device('/cpu:1'):
1226        s1 = test_ops.device_placement_op()
1227      with ops.device('/cpu:2'):
1228        s2 = test_ops.device_placement_op()
1229      s3 = test_ops.device_placement_op()
1230      return s0, s1, s2, s3
1231
1232    defined = function.defun(multi_device_fn)
1233    outputs = self.evaluate(defined())
1234    self.assertLen(total_function_cache(defined), 1)
1235    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1236    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1237    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1238
1239    with ops.device('/cpu:3'):
1240      outputs = self.evaluate(defined())
1241    # All function definitions are agnostic to call site devices.
1242    self.assertLen(total_function_cache(defined), 1)
1243    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1244    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1245    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1246    self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
1247
1248    with ops.device('/cpu:0'):
1249      outputs = self.evaluate(defined())
1250    self.assertLen(total_function_cache(defined), 1)
1251    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1252    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1253    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1254    self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
1255
1256  @test_util.run_in_graph_and_eager_modes(
1257      config=config_pb2.ConfigProto(device_count={'CPU': 2}))
1258  @test_util.run_v1_only('b/120545219')
1259  def testCallingGraphFunctionOnDifferentDevice(self):
1260
1261    def func():
1262      return constant_op.constant(0)
1263
1264    defined = def_function.function(func)
1265    with ops.device('cpu:0'):
1266      cpu_graph_function = defined.get_concrete_function()
1267
1268    with ops.device('cpu:0'):
1269      self.assertEqual(
1270          self.evaluate(cpu_graph_function()), self.evaluate(func()))
1271
1272    with ops.device('cpu:1'):
1273      self.assertEqual(0., self.evaluate(cpu_graph_function()))
1274
1275    with ops.device(None):
1276      self.assertEqual(0., self.evaluate(cpu_graph_function()))
1277
1278    default_graph_function = defined.get_concrete_function()
1279    self.assertEqual(
1280        self.evaluate(default_graph_function()), self.evaluate(func()))
1281
1282    with ops.device('cpu:1'):
1283      self.assertEqual(0., self.evaluate(default_graph_function()))
1284
1285  @test_util.run_in_graph_and_eager_modes
1286  def testColocateWithRespected(self):
1287    # TODO(b/113291792): Use multiple CPUs instead of a GPU.
1288    if not context.context().num_gpus():
1289      self.skipTest('No GPUs found.')
1290
1291    with ops.device('cpu:0'):
1292      x = constant_op.constant(1.0)
1293
1294    with ops.device('gpu:0'):
1295      y = constant_op.constant(1.0)
1296
1297    @def_function.function
1298    def foo():
1299      return test_ops.device_placement_op()
1300
1301    with ops.colocate_with(x):
1302      self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
1303
1304    with ops.colocate_with(y):
1305      self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
1306
1307  def testVariablesAreTracked(self):
1308    v = resource_variable_ops.ResourceVariable(1.0)
1309
1310    def foo(x):
1311      return v * x
1312
1313    defined = def_function.function(foo)
1314
1315    x = constant_op.constant([1.0])
1316    self.assertEqual(1., self.evaluate(defined(x)))
1317    v.assign(2.)
1318
1319    x = constant_op.constant([1.0, 2.0])
1320    self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
1321
1322  def testCacheObjectHashCollisions(self):
1323
1324    class Foo(object):
1325
1326      def __hash__(self):
1327        return 42
1328
1329    def func(foo):
1330      del foo
1331      return
1332
1333    defined = function.defun(func)
1334    defined(Foo())
1335    self.assertLen(total_function_cache(defined), 1)
1336
1337    defined(Foo())
1338    self.assertLen(total_function_cache(defined), 2)
1339
1340  def testCacheTensorDtypeCollision(self):
1341
1342    def func(t):
1343      return t + t
1344
1345    defined = function.defun(func)
1346    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1347    defined(t)
1348    self.assertLen(total_function_cache(defined), 1)
1349
1350    t = constant_op.constant([[1.0]], dtype=dtypes.complex128)
1351    defined(t)
1352    self.assertLen(total_function_cache(defined), 2)
1353
1354  def testCacheTensorShapeCollision(self):
1355
1356    def func(t):
1357      return t + t
1358
1359    defined = function.defun(func)
1360    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1361    defined(t)
1362    self.assertLen(total_function_cache(defined), 1)
1363
1364    t = constant_op.constant([1.0], dtype=dtypes.complex64)
1365    defined(t)
1366    self.assertLen(total_function_cache(defined), 2)
1367
1368  def testCacheTensorShapeDtypeCollision(self):
1369
1370    def func(t):
1371      return t + t
1372
1373    defined = function.defun(func)
1374    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1375    defined(t)
1376    self.assertLen(total_function_cache(defined), 1)
1377
1378    t = constant_op.constant([1.0], dtype=dtypes.complex128)
1379    defined(t)
1380    self.assertLen(total_function_cache(defined), 2)
1381
1382  def testCacheTensorUnknownShapesCollision(self):
1383
1384    def func(t):
1385      return t + t
1386
1387    with context.graph_mode(), self.cached_session():
1388      defined = function.defun(func)
1389
1390      p = array_ops.placeholder(dtype=dtypes.float32, shape=[])
1391      defined(p)
1392      self.assertLen(total_function_cache(defined), 1)
1393
1394      p = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
1395      defined(p)
1396      self.assertLen(total_function_cache(defined), 2)
1397
1398      p = array_ops.placeholder(dtype=dtypes.float32, shape=[2])
1399      defined(p)
1400      # Gradual shape relaxation is performed; and the common shape between
1401      # [1] and [2] is one containing unknown dimensions.
1402      self.assertLen(total_function_cache(defined), 2)
1403
1404      # pylint: disable=protected-access
1405      self.assertLen(defined._function_cache.arg_relaxed_shapes, 1)
1406      relaxed_shapes = (
1407          list(defined._function_cache.arg_relaxed_shapes.values())[0])
1408      self.assertEqual(len(relaxed_shapes), 1)
1409      relaxed_shape = relaxed_shapes[0]
1410      # pylint: enable=protected-access
1411      self.assertEqual(relaxed_shape.rank, 1)
1412      self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None)
1413
1414      t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32)
1415      defined(t)
1416      # Shape (3,) matches the relaxed shape TensorShape([None])
1417      self.assertLen(total_function_cache(defined), 2)
1418
1419  def testPythonFunctionWithDefaultArgs(self):
1420
1421    def func(foo, bar=1, baz=2):
1422      del foo
1423      del bar
1424      del baz
1425      return
1426
1427    defined = function.defun(func)
1428    defined(0, baz=20)
1429
1430    def cache_keys():
1431      """Sanitizes cache keys of non-input metadata."""
1432      return tuple(key[0] for key in total_function_cache(defined))
1433
1434    # `True` corresponds to the fact that we're executing eagerly
1435    self.assertIn(('URRRu', (0, 1, 20)), cache_keys())
1436
1437    defined(1)  # bar=1, baz=2
1438    self.assertIn(('URRRu', (1, 1, 2)), cache_keys())
1439
1440    # This matches the previous call.
1441    defined(foo=1)
1442    self.assertLen(total_function_cache(defined), 2)
1443
1444    defined(1, 2, 3)
1445    self.assertLen(total_function_cache(defined), 3)
1446    self.assertIn(('URRRu', (1, 2, 3)), cache_keys())
1447
1448    # This matches the previous call.
1449    defined(1, bar=2, baz=3)
1450    self.assertLen(total_function_cache(defined), 3)
1451
1452    # This matches the previous call.
1453    defined(1, baz=3, bar=2)
1454    self.assertLen(total_function_cache(defined), 3)
1455
1456  def testFunctoolsPartialUnwrappedCorrectly(self):
1457
1458    def full_function(a, b, c=3):
1459      return a, b, c
1460
1461    partial = functools.partial(full_function, 1, c=4)
1462    a, b, c = partial(2)
1463
1464    defined = function.defun(partial)
1465    func_a, func_b, func_c = defined(2)
1466    self.assertEqual(func_a.numpy(), a)
1467    self.assertEqual(func_b.numpy(), b)
1468    self.assertEqual(func_c.numpy(), c)
1469
1470  def testInputSignatureWithMatchingInputs(self):
1471
1472    def foo(a):
1473      self.assertEqual(a.shape, (2,))
1474      return a
1475
1476    signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
1477    defined = function.defun(foo, input_signature=signature)
1478    a = array_ops.ones([2])
1479    self.assertAllEqual(a, defined(a))
1480    self.assertLen(total_function_cache(defined), 1)
1481    self.assertAllEqual(a, defined.get_concrete_function()(a))
1482    self.assertAllEqual(a, defined.get_concrete_function(a)(a))
1483    self.assertAllEqual(a, defined.get_concrete_function(
1484        tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a))
1485    self.assertLen(total_function_cache(defined), 1)
1486
1487    def bar(a):
1488      self.assertEqual(a._shape_tuple(), (2, None))
1489      return a
1490
1491    signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)]
1492    defined = function.defun(bar, input_signature=signature)
1493    a = array_ops.ones([2, 1])
1494    out = defined(a)
1495    self.assertLen(total_function_cache(defined), 1)
1496    self.assertAllEqual(out, a)
1497
1498    # Changing the second dimension shouldn't create a new function.
1499    b = array_ops.ones([2, 3])
1500    out = defined(b)
1501    self.assertLen(total_function_cache(defined), 1)
1502    self.assertAllEqual(out, b)
1503
1504  def testInputSignatureWithCompatibleInputs(self):
1505
1506    rank2_spec = tensor_spec.TensorSpec(shape=(None, None),
1507                                        dtype=dtypes.float32)
1508
1509    @function.defun(input_signature=[rank2_spec])
1510    def func(a):
1511      self.assertEqual([None, None], a.shape.as_list())
1512      return array_ops.shape(a)
1513
1514    self.assertAllEqual([3, 1], func([[0], [1.0], [1]]))
1515    self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]])))
1516
1517    with self.assertRaisesRegexp(ValueError, 'incompatible'):
1518      func([0.0, 1.0, 2.0])  # Wrong shape.
1519
1520    with self.assertRaisesRegexp(ValueError, 'incompatible'):
1521      func([['wrong dtype']])
1522
1523  def testNestedInputSignatures(self):
1524
1525    def expected_foo(a, b):
1526      return [a, b]
1527
1528    @function.defun(input_signature=[
1529        [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2,
1530        tensor_spec.TensorSpec((1,), dtypes.float32),
1531    ])
1532    def foo(a, b):
1533      self.assertEqual(a[0]._shape_tuple(), (2, None))
1534      self.assertEqual(a[1]._shape_tuple(), (2, None))
1535      self.assertEqual(b._shape_tuple(), (1,))
1536      return [a, b]
1537
1538    a = array_ops.ones([2, 1])
1539    b = array_ops.ones([1])
1540    expected = expected_foo([a, a], b)
1541    out = foo([a, a], b)
1542    self.assertLen(total_function_cache(foo), 1)
1543    nest.assert_same_structure(out, expected)
1544    self.assertAllEqual(out[0][0], a)
1545    self.assertAllEqual(out[0][1], a)
1546    self.assertAllEqual(out[1], b)
1547
1548    # Changing the unspecified dimensions shouldn't create a new function.
1549    a = array_ops.ones([2, 3])
1550    b = array_ops.ones([2, 5])
1551    c = array_ops.ones([1])
1552    expected = expected_foo([a, b], c)
1553    out = foo([a, b], c)
1554    self.assertLen(total_function_cache(foo), 1)
1555    nest.assert_same_structure(out, expected)
1556    self.assertAllEqual(out[0][0], a)
1557    self.assertAllEqual(out[0][1], b)
1558    self.assertAllEqual(out[1], c)
1559
1560    # Passing compatible inputs should work.
1561    a = a.numpy().tolist()
1562    b = b.numpy().tolist()
1563    c = c.numpy().tolist()
1564    out = foo([a, b], c)
1565    self.assertLen(total_function_cache(foo), 1)
1566    nest.assert_same_structure(out, expected)
1567    self.assertAllEqual(out[0][0], a)
1568    self.assertAllEqual(out[0][1], b)
1569    self.assertAllEqual(out[1], c)
1570
1571  def testNestedInputSignaturesWithDict(self):
1572    def expected_bar(a):
1573      return a
1574
1575    @function.defun(input_signature=[{
1576        'a': tensor_spec.TensorSpec((2, None), dtypes.float32),
1577        'b': tensor_spec.TensorSpec((2, None), dtypes.float32),
1578        'c': tensor_spec.TensorSpec((1,), dtypes.float32)}])
1579    def bar(a):
1580      self.assertEqual(a['a']._shape_tuple(), (2, None))
1581      self.assertEqual(a['b']._shape_tuple(), (2, None))
1582      self.assertEqual(a['c']._shape_tuple(), (1,))
1583      return a
1584
1585    a = array_ops.ones([2, 3])
1586    b = array_ops.ones([1])
1587    inputs = {'a': a, 'b': a, 'c': b}
1588    expected = expected_bar(inputs)
1589    out = bar(inputs)
1590    nest.assert_same_structure(out, expected)
1591    self.assertAllEqual(out['a'], expected['a'])
1592    self.assertAllEqual(out['b'], expected['b'])
1593    self.assertAllEqual(out['c'], expected['c'])
1594
1595    # Passing compatible inputs should work.
1596    a = a.numpy().tolist()
1597    b = b.numpy().tolist()
1598    inputs = {'a': a, 'b': a, 'c': b}
1599    out = bar(inputs)
1600    nest.assert_same_structure(out, expected)
1601    self.assertAllEqual(out['a'], expected['a'])
1602    self.assertAllEqual(out['b'], expected['b'])
1603    self.assertAllEqual(out['c'], expected['c'])
1604
1605  def testInputSignatureMustBeSequenceOfTensorSpecs(self):
1606
1607    def foo(a, b):
1608      del a
1609      del b
1610
1611    # Signatures must consist exclusively of `TensorSpec` objects.
1612    signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
1613    with self.assertRaisesRegexp(TypeError, 'Invalid input_signature.*'):
1614      def_function.function(foo, input_signature=signature)
1615
1616    # Signatures must be either lists or tuples on their outermost levels.
1617    signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
1618    with self.assertRaisesRegexp(TypeError, 'input_signature must be either a '
1619                                 'tuple or a list.*'):
1620      function.defun(foo, input_signature=signature)
1621
1622  @test_util.run_in_graph_and_eager_modes
1623  def testInputsIncompatibleWithSignatureRaisesError(self):
1624
1625    def foo(a):
1626      return a
1627
1628    signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
1629    defined = def_function.function(foo, input_signature=signature)
1630
1631    # Invalid shapes.
1632    with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'):
1633      defined(array_ops.ones([3]))
1634
1635    with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'):
1636      defined(array_ops.ones([2, 1]))
1637
1638    # Wrong number of arguments.
1639    with self.assertRaisesRegexp(TypeError, 'Received 2 argument\(s\)'):
1640      defined(array_ops.ones([2]), array_ops.ones([2]))
1641    with self.assertRaisesRegexp(ValueError,
1642                                 'Structure of Python function inputs.*'):
1643      defined()
1644
1645    with self.assertRaisesRegexp(ValueError,
1646                                 'inputs incompatible with input_signature'):
1647      defined.get_concrete_function(
1648          tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32))
1649
1650  def testInputsIncompatibleWithNestedSignatureRaisesError(self):
1651
1652    def foo(a, b):
1653      return [a, b]
1654
1655    signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2,
1656                 [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2]
1657    defined = function.defun(foo, input_signature=signature)
1658    a = array_ops.ones([1])
1659
1660    with self.assertRaisesRegexp(ValueError,
1661                                 'Structure of Python function inputs.*'):
1662      defined([a, a, a], [a])
1663
1664    with self.assertRaisesRegexp(ValueError,
1665                                 'Structure of Python function inputs.*'):
1666      defined([a], [a, a, a])
1667    defined([a, a], [a, a])
1668
1669  def testUnderspecifiedInputSignature(self):
1670    @function.defun(input_signature=[
1671        tensor_spec.TensorSpec([], dtypes.float32),
1672    ])
1673    def foo(a, training=True):
1674      if training:
1675        return a
1676      else:
1677        return -1.0 * a
1678
1679    x = constant_op.constant(1.0)
1680    with self.assertRaisesRegexp(TypeError, 'only pass arguments'):
1681      foo(x, training=True)
1682
1683    with self.assertRaisesRegexp(TypeError, 'only pass arguments'):
1684      foo(x, training=False)
1685
1686    self.assertAllEqual(x.numpy(), foo(x).numpy())
1687
1688  def testInputSignatureWithPartialFunction(self):
1689    self.skipTest('b/124441704')
1690    def full_function(a, b, c=3.0):
1691      return a, b, c
1692
1693    partial = functools.partial(full_function, 1, c=4)
1694    a, b, c = partial(2.0)
1695    signature = [tensor_spec.TensorSpec([], dtypes.float32)]
1696    defined = function.defun(partial, input_signature=signature)
1697    x = constant_op.constant(2.0)
1698    func_a, func_b, func_c = defined(x)
1699    self.assertEqual(func_a.numpy(), a)
1700    self.assertEqual(func_b.numpy(), b)
1701    self.assertEqual(func_c.numpy(), c)
1702
1703  def testInputSignatureConversionWithDefaultArg(self):
1704
1705    def foo(a, training=True):
1706      if training:
1707        return a
1708      else:
1709        return -1.0 * a
1710
1711    signature = [
1712        tensor_spec.TensorSpec([], dtypes.float32),
1713        tensor_spec.TensorSpec([], dtypes.bool),
1714    ]
1715    defined = def_function.function(foo, input_signature=signature)
1716    a = constant_op.constant(1.0)
1717    self.assertAllEqual(a.numpy(), defined(a))
1718    self.assertAllEqual(a.numpy(), defined(a, training=True))
1719    self.assertAllEqual(-a.numpy(), defined(a, training=False))
1720
1721  def testInputSignatureWithKeywordPositionalArgs(self):
1722
1723    @function.defun(input_signature=[
1724        tensor_spec.TensorSpec([], dtypes.float32),
1725        tensor_spec.TensorSpec([], dtypes.int64)
1726    ])
1727    def foo(flt, integer):
1728      return flt, integer
1729
1730    flt = constant_op.constant(1.0)
1731    integer = constant_op.constant(2, dtypes.int64)
1732
1733    out1, out2 = foo(flt, integer)
1734    self.assertLen(total_function_cache(foo), 1)
1735    self.assertEqual(out1.numpy(), 1.0)
1736    self.assertEqual(out2.numpy(), 2)
1737
1738    out1, out2 = foo(flt=flt, integer=integer)
1739    self.assertLen(total_function_cache(foo), 1)
1740    self.assertEqual(out1.numpy(), 1.0)
1741    self.assertEqual(out2.numpy(), 2)
1742
1743    out1, out2 = foo(integer=integer, flt=flt)
1744    self.assertLen(total_function_cache(foo), 1)
1745    self.assertEqual(out1.numpy(), 1.0)
1746    self.assertEqual(out2.numpy(), 2)
1747
1748    out1, out2 = foo(flt, integer=integer)
1749    self.assertLen(total_function_cache(foo), 1)
1750    self.assertEqual(out1.numpy(), 1.0)
1751    self.assertEqual(out2.numpy(), 2)
1752
1753  def testInputSignatureWithKeywordArgsFails(self):
1754
1755    def foo(a, **kwargs):
1756      del a
1757      del kwargs
1758
1759    with self.assertRaisesRegexp(
1760        ValueError, 'Cannot define a TensorFlow function from a Python '
1761        'function with keyword arguments when input_signature.*'):
1762      function.defun(
1763          foo,
1764          input_signature=[
1765              tensor_spec.TensorSpec([], dtypes.float32),
1766              tensor_spec.TensorSpec([], dtypes.int64)
1767          ])
1768
1769  def testTensorKeywordArguments(self):
1770
1771    def foo(a, b):
1772      del a
1773      return b
1774
1775    defined = function.defun(foo)
1776    a = constant_op.constant(2.0)
1777    b = constant_op.constant([1.0, 2.0])
1778    one = defined(a, b)
1779    self.assertLen(total_function_cache(defined), 1)
1780
1781    two = defined(a=a, b=b)
1782    self.assertLen(total_function_cache(defined), 1)
1783
1784    three = defined(b=b, a=a)
1785    self.assertLen(total_function_cache(defined), 1)
1786
1787    four = defined(a, b=b)
1788    self.assertLen(total_function_cache(defined), 1)
1789
1790    # The next call corresponds to a new input signature, hence
1791    # we expect another function to be defined.
1792    five = defined(b, a)
1793    self.assertLen(total_function_cache(defined), 2)
1794
1795    six = defined(a=b, b=a)
1796    self.assertLen(total_function_cache(defined), 2)
1797
1798    seven = defined(b=a, a=b)
1799    self.assertLen(total_function_cache(defined), 2)
1800
1801    self.assertAllEqual(one, [1.0, 2.0])
1802    self.assertAllEqual(two, [1.0, 2.0])
1803    self.assertAllEqual(three, [1.0, 2.0])
1804    self.assertAllEqual(four, [1.0, 2.0])
1805    self.assertAllEqual(five, 2.0)
1806    self.assertAllEqual(six, 2.0)
1807    self.assertAllEqual(seven, 2.0)
1808
1809  def testDefuningInstanceMethod(self):
1810
1811    integer = constant_op.constant(2, dtypes.int64)
1812
1813    class Foo(object):
1814
1815      def one(self, tensor):
1816        return tensor
1817
1818      @def_function.function
1819      def two(self, tensor, other=integer):
1820        return self.one(tensor), other
1821
1822    foo = Foo()
1823    t = constant_op.constant(1.0)
1824    one, two = foo.two(t)
1825    self.assertEqual(one.numpy(), 1.0)
1826    self.assertEqual(two.numpy(), 2)
1827
1828  def testDefuningInstanceMethodWithDefaultArgument(self):
1829
1830    integer = constant_op.constant(2, dtypes.int64)
1831
1832    class Foo(object):
1833
1834      @def_function.function
1835      def func(self, other=integer):
1836        return other
1837
1838    foo = Foo()
1839    self.assertEqual(foo.func().numpy(), int(integer))
1840
1841  def testPythonCallWithSideEffects(self):
1842    state = []
1843
1844    @def_function.function
1845    def side_effecting_function():
1846      state.append(0)
1847
1848    side_effecting_function()
1849    self.assertAllEqual(state, [0])
1850
1851    # The second invocation should call the graph function, which shouldn't
1852    # trigger the list append.
1853    side_effecting_function()
1854    self.assertAllEqual(state, [0])
1855
1856    # Whereas calling the python function directly should create a side-effect.
1857    side_effecting_function.python_function()
1858    self.assertAllEqual(state, [0, 0])
1859
1860  def testFunctionWithNestedFunctionCallAndSideEffects(self):
1861    v1 = variables.Variable(1.0)
1862    v2 = variables.Variable(1.0)
1863
1864    @def_function.function
1865    def add_one(a):
1866      a.assign_add(1.0)
1867
1868    # Grappler will inline calls to `add_one` into the function body, we check
1869    # that all side-effects were executed.
1870    @def_function.function
1871    def side_effecting_function(a, b):
1872      add_one(a)
1873      add_one(b)
1874      return a + b
1875
1876    result = side_effecting_function(v1, v2)
1877    self.assertEqual(result.numpy(), 4.0)
1878
1879  def testFunctionWithExtraAttributes(self):
1880    @function.defun_with_attributes(attributes={'experimental_1': 'value1',
1881                                                'experimental_2': 2})
1882    def matmul(x, y):
1883      return math_ops.matmul(x, y)
1884
1885    def add(x, y):
1886      return math_ops.add(x, y)
1887    defun_add = function.defun_with_attributes(
1888        add, attributes={'experimental_3': True, 'experimental_4': 1.0})
1889
1890    with context.graph_mode(), self.cached_session():
1891      with ops.get_default_graph().as_default():
1892        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
1893        sq = matmul(t, t)
1894        double = defun_add(t, t)
1895        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
1896        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
1897
1898        graph = ops.get_default_graph()
1899        # pylint: disable=protected-access
1900        self.assertLen(graph._functions, 2)
1901        functions = list(graph._functions.values())
1902        self.assertRegexpMatches(
1903            functions[0].definition.signature.name, '.*matmul.*')
1904        attrs = functions[0].definition.attr
1905        self.assertLen(attrs, 2)
1906        self.assertEqual(attrs['experimental_1'].s, b'value1')
1907        self.assertEqual(attrs['experimental_2'].i, 2)
1908
1909        self.assertRegexpMatches(
1910            functions[1].definition.signature.name, '.*add.*')
1911        attrs = functions[1].definition.attr
1912        self.assertLen(attrs, 2)
1913        self.assertEqual(attrs['experimental_3'].b, True)
1914        self.assertEqual(attrs['experimental_4'].f, 1.0)
1915        # pylint: enable=protected-access
1916
1917  def testFunctionWithInvalidAttribute(self):
1918    @function.defun_with_attributes(attributes={'experimental_1': ['value1']})
1919    def add(x, y):
1920      return math_ops.add(x, y)
1921
1922    with self.assertRaisesRegexp(ValueError,
1923                                 '.*Unsupported attribute type.*'):
1924      with context.graph_mode(), self.cached_session():
1925        with ops.get_default_graph().as_default():
1926          t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
1927          add(t, t)
1928
1929  def testRegisterFunction(self):
1930
1931    @function.defun
1932    def add(x, y):
1933      return math_ops.add(x, y)
1934
1935    def matmul(x, y):
1936      return math_ops.matmul(x, y)
1937    defun_matmul = function.defun(matmul)
1938
1939    with context.graph_mode(), self.cached_session():
1940      with ops.get_default_graph().as_default():
1941        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
1942        function.register(defun_matmul, t, t)
1943        function.register(add, t, t)
1944
1945        graph = ops.get_default_graph()
1946        # pylint: disable=protected-access
1947        self.assertLen(graph._functions, 6)
1948        # two sets of functions, each of them are (inference, forward, backward)
1949        functions = list(graph._functions.values())
1950        captured_function_names = [
1951            f.definition.signature.name for f in functions
1952        ]
1953        expected_func_name_regex = [
1954            '.*inference.*matmul.*',
1955            '.*forward.*matmul.*',
1956            '.*inference.*backward.*matmul.*',
1957            '.*inference.*add.*',
1958            '.*forward.*add.*',
1959            '.*inference.*backward.*add.*',
1960        ]
1961        for i in range(len(functions)):
1962          self.assertRegexpMatches(captured_function_names[i],
1963                                   expected_func_name_regex[i])
1964
1965        # Check the forward and backward function has the correct attributes.
1966        self.assertEqual(
1967            functions[1].definition.attr['backward_function_name'].s,
1968            functions[2].name)
1969        self.assertEqual(
1970            functions[2].definition.attr['forward_function_name'].s,
1971            functions[1].name)
1972
1973        self.assertEqual(
1974            functions[4].definition.attr['backward_function_name'].s,
1975            functions[5].name)
1976        self.assertEqual(
1977            functions[5].definition.attr['forward_function_name'].s,
1978            functions[4].name)
1979
1980        sq = defun_matmul(t, t)
1981        double = add(t, t)
1982        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
1983        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
1984        # Make sure the pre registered function is used, and no other function
1985        # is added.
1986        self.assertLen(graph._functions, 6)
1987        functions = list(graph._functions.values())
1988        for i in range(len(functions)):
1989          self.assertEqual(captured_function_names[i],
1990                           functions[i].definition.signature.name)
1991
1992  @parameterized.named_parameters(
1993      dict(testcase_name='Defun',
1994           function_decorator=function.defun),
1995      dict(testcase_name='DefFunction',
1996           function_decorator=def_function.function))
1997  def testRegisterConcreteFunction(self, function_decorator):
1998    @function_decorator
1999    def py_add(x, y):
2000      return math_ops.add(x, y)
2001
2002    py_add(array_ops.ones([]), array_ops.ones([]))
2003    add = py_add.get_concrete_function(
2004        tensor_spec.TensorSpec(None, dtypes.float32),
2005        tensor_spec.TensorSpec(None, dtypes.float32))
2006
2007    @function_decorator
2008    def py_composite(x, y):
2009      return x, add(x, y)
2010
2011    py_composite(array_ops.ones([]), array_ops.ones([]))
2012    composite = py_composite.get_concrete_function(
2013        tensor_spec.TensorSpec(None, dtypes.float32),
2014        tensor_spec.TensorSpec(None, dtypes.float32))
2015
2016    with context.graph_mode(), self.cached_session():
2017      with ops.get_default_graph().as_default():
2018        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2019        composite.add_to_graph(register_gradient_functions=True)
2020
2021        graph = ops.get_default_graph()
2022        # pylint: disable=protected-access
2023        self.assertLen(graph._functions, 6)
2024        # two sets of functions, each of them are (inference, forward, backward)
2025        functions = list(graph._functions.values())
2026        captured_function_names = [
2027            f.definition.signature.name for f in functions
2028        ]
2029        expected_func_name_regex = [
2030            '.*inference.*py_composite.*',
2031            '.*inference.*py_add.*',
2032            '.*forward.*py_composite.*',
2033            '.*forward.*py_add.*',
2034            '.*inference.*backward.*py_composite.*',
2035            '.*inference.*backward.*py_add.*',
2036        ]
2037        for expected, found in zip(
2038            expected_func_name_regex,
2039            captured_function_names):
2040          self.assertRegexpMatches(found, expected)
2041
2042        composite_t, composite_double = composite(t, t)
2043        double = add(t, t)
2044        self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double))
2045        self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double))
2046        self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t))
2047        # Make sure the pre registered function is used, and no other function
2048        # is added.
2049        self.assertLen(graph._functions, 6)
2050
2051  def testRegisterFunctionWithInputSignature(self):
2052    def matmul(x, y):
2053      return math_ops.matmul(x, y)
2054    defun_matmul = function.defun(
2055        matmul,
2056        input_signature=[
2057            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
2058            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
2059        ])
2060    with context.graph_mode(), self.cached_session():
2061      with ops.get_default_graph().as_default():
2062        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2063        function.register(defun_matmul, t, t)
2064
2065        graph = ops.get_default_graph()
2066        # pylint: disable=protected-access
2067        self.assertLen(graph._functions, 3)
2068
2069        # Test register function with cache, note inputs are ignored.
2070        function.register(defun_matmul)
2071        graph = ops.get_default_graph()
2072        self.assertLen(graph._functions, 3)
2073
2074  def testRegisterFunctionWithCache(self):
2075    def matmul(x, y):
2076      return math_ops.matmul(x, y)
2077    defun_matmul = function.defun(matmul)
2078
2079    with context.graph_mode(), self.cached_session():
2080      with ops.get_default_graph().as_default():
2081        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2082        t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
2083        function.register(defun_matmul, t, t)
2084        function.register(defun_matmul, t2, t2)
2085
2086        graph = ops.get_default_graph()
2087        # Only one function is registered since the input param are in same type
2088        # pylint: disable=protected-access
2089        self.assertLen(graph._functions, 3)
2090
2091  def testCallingFunctionWithDifferentVariables(self):
2092
2093    @function.defun
2094    def foo(v):
2095      v.assign_add(1.0)
2096      return v.read_value()
2097
2098    v = resource_variable_ops.ResourceVariable(0.0)
2099    graph_function = foo.get_concrete_function(v)
2100    self.assertLen(graph_function.inputs, 1)
2101    self.assertEmpty(graph_function.captured_inputs)
2102
2103    self.assertEqual(float(graph_function(v)), 1.0)
2104    self.assertEqual(float(graph_function(v)), 2.0)
2105
2106    w = resource_variable_ops.ResourceVariable(0.0)
2107
2108    @function.defun
2109    def bar(v):
2110      del v
2111      return constant_op.constant(1.0)
2112
2113    graph_function = bar.get_concrete_function(v)
2114    self.assertEqual(float(graph_function(v)), 1.0)
2115    self.assertEqual(float(graph_function(w)), 1.0)
2116
2117  def testCallingFunctionWithNonTensorsFails(self):
2118
2119    @function.defun
2120    def foo(x):
2121      return x
2122
2123    graph_function = foo.get_concrete_function(constant_op.constant(1.0))
2124    with self.assertRaisesRegexp(
2125        ValueError, 'All inputs to `ConcreteFunction`s must be Tensors;.*'):
2126      graph_function('Not a Tensor.')
2127
2128  def testSwapImplementationWithGrapplerPlugin(self):
2129    # Set the min_graph_nodes to -1 since the graph in this test is too small,
2130    # and will be ignored by grappler if don't set this.
2131    rewrites = rewriter_config_pb2.RewriterConfig()
2132    rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON
2133    rewrites.min_graph_nodes = -1
2134    graph_options = config_pb2.GraphOptions(
2135        rewrite_options=rewrites, build_cost_model=1)
2136    config = config_pb2.ConfigProto(graph_options=graph_options)
2137
2138    with context.graph_mode(), self.cached_session(
2139        config=config, graph=ops.Graph(), use_gpu=True):
2140
2141      @function.defun_with_attributes(
2142          attributes={
2143              'api_implements': 'random_boost',
2144              'api_preferred_device': 'CPU'
2145          })
2146      def cpu_boost(x):
2147        return math_ops.add(x, 2.0)
2148
2149      @function.defun_with_attributes(
2150          attributes={
2151              'api_implements': 'random_boost',
2152              'api_preferred_device': 'GPU'
2153          })
2154      def gpu_boost(x):
2155        return math_ops.add(x, 4.0)
2156
2157      x = constant_op.constant(1.0)
2158
2159      function.register(cpu_boost, x)
2160      y = gpu_boost(x)
2161      y_value = self.evaluate(y)
2162
2163      if test.is_gpu_available():
2164        self.assertEqual(y_value, 5.0)
2165      else:
2166        # Grappler fallback to use the CPU impl even called with GPU function.
2167        self.assertEqual(y_value, 3.0)
2168
2169  def testDefunFunctionSeparateGraphs(self):
2170    with context.graph_mode():
2171
2172      @function.defun
2173      def add(x):
2174        return x + 5
2175
2176      @function.defun
2177      def maybe_add(x, should_add):
2178        if should_add:
2179          return add(x)
2180        else:
2181          return x
2182
2183      with ops.Graph().as_default():
2184        x = constant_op.constant(11)
2185        maybe_add(x, True)
2186        self.assertLen(total_function_cache(maybe_add), 1)
2187        self.assertLen(total_function_cache(add), 1)
2188
2189        maybe_add(x, False)
2190        self.assertLen(total_function_cache(maybe_add), 2)
2191        self.assertLen(total_function_cache(add), 1)
2192
2193      with ops.Graph().as_default():
2194        x = constant_op.constant(11)
2195        maybe_add(x, True)
2196        self.assertLen(total_function_cache(maybe_add), 3)
2197        self.assertLen(total_function_cache(add), 2)
2198
2199  def testCacheKeyOverlappingShapes(self):
2200    @function.defun
2201    def defined(t):
2202      return t
2203
2204    defined(array_ops.zeros([12, 1]))
2205    self.assertLen(total_function_cache(defined), 1)
2206
2207    defined(array_ops.zeros([1, 21]))
2208    self.assertLen(total_function_cache(defined), 2)
2209
2210  def testCacheKeyNestedLists(self):
2211    @function.defun
2212    def defined(l):
2213      return l
2214
2215    a = constant_op.constant(1.)
2216    b = constant_op.constant(2.)
2217    c = constant_op.constant(3.)
2218    defined([[a], b, c])
2219    self.assertLen(total_function_cache(defined), 1)
2220
2221    defined([[a, b], c])
2222    self.assertLen(total_function_cache(defined), 2)
2223
2224  def testDecoratedMethod(self):
2225    m = DefunnedMiniModel()
2226    instance_call_one = m.call(array_ops.ones([1, 2]), training=True)
2227    instance_call_two = m.call(
2228        inputs=array_ops.ones([1, 2]), training=True)
2229    class_call = DefunnedMiniModel.call(m, array_ops.ones([1, 2]),
2230                                        training=True)
2231    self.assertAllEqual(instance_call_one, instance_call_two)
2232    self.assertAllEqual(instance_call_one, class_call)
2233
2234  def testDecoratedMethodUniqueFunctionPerInstance(self):
2235    m = DefunnedMiniModel()
2236    n = DefunnedMiniModel()
2237
2238    class_method_one = DefunnedMiniModel.call
2239    class_method_two = DefunnedMiniModel.call
2240
2241    m_method_one = m.call
2242    m_method_two = m.call
2243
2244    n_method_one = n.call
2245    n_method_two = n.call
2246
2247    self.assertEqual(class_method_one, class_method_two)
2248    self.assertEqual(m_method_one, m_method_two)
2249    self.assertEqual(n_method_one, n_method_two)
2250    self.assertNotEqual(m.call, n.call)
2251
2252  def testDecoratedMethodInspect(self):
2253    m = DefunnedMiniModel()
2254    fullargspec = tf_inspect.getfullargspec(m.call)
2255    self.assertIn('training', fullargspec.args)
2256
2257  def testDecoratedMethodGetConcreteFunction(self):
2258    m = DefunnedMiniModel()
2259    instance_call_one = m.call.get_concrete_function(
2260        array_ops.ones([1, 2]), training=False)
2261    instance_call_two = m.call.get_concrete_function(
2262        inputs=array_ops.ones([1, 2]), training=False)
2263    self.assertAllEqual(instance_call_one(array_ops.ones([1, 2])),
2264                        instance_call_two(array_ops.ones([1, 2])))
2265
2266    # Also make sure get_concrete_function works on the class method
2267    DefunnedMiniModel.call.get_concrete_function(
2268        m, array_ops.ones([1, 2]), training=False)
2269    DefunnedMiniModel.call.get_concrete_function(
2270        m, inputs=array_ops.ones([1, 2]), training=True)
2271
2272  def testFunctionModifiesInputList(self):
2273    # Tests on `list` methods that do in place modification, except `list.sort`
2274    # since it cannot even be "defunned" in the first place
2275
2276    def get_list():
2277      return [constant_op.constant(0.), constant_op.constant(1.)]
2278
2279    expected_msg = (
2280        'Function to be traced should not modify structure of input '
2281        'arguments. Check if your function has list and dictionary '
2282        'operations that alter input arguments, '
2283        'such as `list.pop`, `list.append`')
2284
2285    with self.assertRaisesRegexp(ValueError, expected_msg):
2286
2287      @def_function.function
2288      def append(l):
2289        l.append(constant_op.constant(0.))
2290
2291      append(get_list())
2292
2293    with self.assertRaisesRegexp(ValueError, expected_msg):
2294
2295      @def_function.function
2296      def extend(l):
2297        l.extend([constant_op.constant(0.)])
2298
2299      extend(get_list())
2300
2301    with self.assertRaisesRegexp(ValueError, expected_msg):
2302
2303      @def_function.function
2304      def insert(l):
2305        l.insert(0, constant_op.constant(0.))
2306
2307      insert(get_list())
2308
2309    with self.assertRaisesRegexp(ValueError, expected_msg):
2310
2311      @def_function.function
2312      def pop(l):
2313        l.pop()
2314
2315      pop(get_list())
2316
2317    with self.assertRaisesRegexp(ValueError, expected_msg):
2318
2319      @def_function.function
2320      def reverse(l):
2321        l.reverse()
2322
2323      reverse(get_list())
2324
2325    with self.assertRaisesRegexp(ValueError, expected_msg):
2326
2327      @def_function.function
2328      def remove(l):
2329        l.remove(l[0])
2330
2331      remove(get_list())
2332
2333    # `list.clear` is a method that is in Py3 but not Py2
2334    if sys.version.startswith('3'):
2335
2336      with self.assertRaisesRegexp(ValueError, expected_msg):
2337
2338        @def_function.function
2339        def clear(l):
2340          l.clear()
2341
2342        clear(get_list())
2343
2344    # One last test for keyword arguments
2345    with self.assertRaisesRegexp(ValueError, expected_msg):
2346
2347      @def_function.function
2348      def kwdappend(**kwargs):
2349        l = kwargs['l']
2350        l.append(constant_op.constant(0.))
2351
2352      kwdappend(l=get_list())
2353
2354  def testFunctionModifiesInputDict(self):
2355
2356    def get_dict():
2357      return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
2358
2359    expected_msg = (
2360        'Function to be traced should not modify structure of input '
2361        'arguments. Check if your function has list and dictionary '
2362        'operations that alter input arguments, '
2363        'such as `list.pop`, `list.append`')
2364
2365    with self.assertRaisesRegexp(ValueError, expected_msg):
2366
2367      @def_function.function
2368      def clear(m):
2369        m.clear()
2370
2371      clear(get_dict())
2372
2373    with self.assertRaisesRegexp(ValueError, expected_msg):
2374
2375      @def_function.function
2376      def pop(m):
2377        m.pop('t1')
2378
2379      pop(get_dict())
2380
2381    with self.assertRaisesRegexp(ValueError, expected_msg):
2382
2383      @def_function.function
2384      def popitem(m):
2385        m.popitem()
2386
2387      popitem(get_dict())
2388
2389    with self.assertRaisesRegexp(ValueError, expected_msg):
2390
2391      @def_function.function
2392      def update(m):
2393        m.update({'t1': constant_op.constant(3.)})
2394
2395      update(get_dict())
2396
2397    with self.assertRaisesRegexp(ValueError, expected_msg):
2398
2399      @def_function.function
2400      def setdefault(m):
2401        m.setdefault('t3', constant_op.constant(3.))
2402
2403      setdefault(get_dict())
2404
2405  def testFunctionModifiesInputNest(self):
2406    # Test on functions that modify structure of nested input arguments
2407    expected_msg = (
2408        'Function to be traced should not modify structure of input '
2409        'arguments. Check if your function has list and dictionary '
2410        'operations that alter input arguments, '
2411        'such as `list.pop`, `list.append`')
2412
2413    with self.assertRaisesRegexp(ValueError, expected_msg):
2414
2415      @def_function.function
2416      def modify(n):
2417        n[0]['t1'].append(constant_op.constant(1.))
2418
2419      nested_input = [{
2420          't1': [constant_op.constant(0.),
2421                 constant_op.constant(1.)],
2422      },
2423                      constant_op.constant(2.)]
2424
2425      modify(nested_input)
2426
2427    with self.assertRaisesRegexp(ValueError, expected_msg):
2428
2429      # The flat list doesn't change whereas the true structure changes
2430      @def_function.function
2431      def modify_same_flat(n):
2432        n[0].append(n[1].pop(0))
2433
2434      nested_input = [[constant_op.constant(0.)],
2435                      [constant_op.constant(1.),
2436                       constant_op.constant(2.)]]
2437
2438      modify_same_flat(nested_input)
2439
2440  def testDecoratedMethodVariableCleanup(self):
2441    m = DefunnedMiniModel()
2442    m(array_ops.ones([1, 2]))
2443    weak_variables = weakref.WeakSet(m.variables)
2444    self.assertLen(weak_variables, 2)
2445    del m
2446    self.assertEqual([], list(weak_variables))
2447
2448  def testExecutorType(self):
2449    @function.defun
2450    def add_five(x):
2451      return x + 5
2452
2453    self.assertEqual(
2454        5,
2455        add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
2456
2457    with self.assertRaisesRegexp(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'):
2458      with context.function_executor_type('NON_EXISTENT_EXECUTOR'):
2459        add_five(constant_op.constant(0, dtype=dtypes.int32))
2460
2461    for executor_type in ('', 'DEFAULT', None):
2462      with context.function_executor_type(executor_type):
2463        self.assertAllEqual(
2464            5,
2465            add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
2466
2467  @test_util.assert_no_garbage_created
2468  def testReferenceCycles(self):
2469
2470    fn = function.defun(lambda x: 2. * x)
2471
2472    fn(constant_op.constant(4.0))
2473    weak_fn = weakref.ref(fn)
2474    del fn
2475    # Tests that the weak reference we made to the function is now dead, which
2476    # means the object has been deleted. This should be true as long as the
2477    # function itself is not involved in a reference cycle.
2478    self.assertIs(None, weak_fn())
2479
2480  def testFunctionStackInErrorMessage(self):
2481    if context.executing_eagerly():
2482      # TODO(b/122736651): Remove this skipTest once fixed.
2483      self.skipTest('Error interpolation is not working when function is '
2484                    'invoked without PartitionedCallOp.')
2485
2486    @def_function.function()
2487    def fn3(x):
2488      return x + 2
2489
2490    @def_function.function()
2491    def fn2(x):
2492      check_ops.assert_equal(fn3(x), 3)
2493      return 2
2494
2495    @def_function.function()
2496    def fn(x):
2497      return fn2(x)
2498
2499    with self.assertRaises(errors.InvalidArgumentError) as cm:
2500      fn(2)
2501    e = cm.exception
2502    self.assertIn('fn -> fn2', e.message)
2503    self.assertIn('node assert_equal/Assert/Assert (defined at', e.message)
2504    self.assertNotIn('fn3', e.message)
2505
2506  def testFunctionIsNotPinned(self):
2507    """Tests that functions aren't pinned to the CPU by the eager runtime."""
2508    if not context.context().num_gpus():
2509      self.skipTest('No GPUs found.')
2510    seed1, seed2 = 79, 25
2511    shape = constant_op.constant([4, 7])
2512    dtype = dtypes.float32
2513
2514    @def_function.function
2515    def func():
2516      with ops.device('GPU:0'):
2517        return gen_random_ops.random_standard_normal(
2518            shape, dtype=dtype, seed=seed1, seed2=seed2)
2519
2520    with ops.device('GPU:0'):
2521      x = func()
2522      self.assertRegexpMatches(x.device, 'GPU')
2523
2524  @test_util.run_in_graph_and_eager_modes
2525  def testShapeCaching(self):
2526
2527    @function.defun
2528    def func(x):
2529      return array_ops.shape(x)
2530
2531    @function.defun(
2532        input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)])
2533    def calls_func(x):
2534      return func(x)
2535
2536    self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
2537    self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
2538    self.assertAllEqual(
2539        [3, 3],
2540        self.evaluate(calls_func(array_ops.zeros([3, 3]))))
2541
2542  def testLimitedRetracing(self):
2543    trace_count = [0]
2544    @function.defun
2545    def func(x):
2546      trace_count[0] += 1
2547      return x
2548
2549    for _ in range(50):
2550      func(constant_op.constant(3.))
2551      func(constant_op.constant(4.))
2552      func(constant_op.constant([[1., 2.]]))
2553      func(constant_op.constant([[]]))
2554      func(constant_op.constant([[3., 4.], [5., 6.]]))
2555      func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]]))
2556    # Tracing more than twice per input doesn't make sense.
2557    self.assertLess(trace_count[0], 13)
2558
2559
2560class MultiDeviceTest(test.TestCase, parameterized.TestCase):
2561
2562  def testMultiDeviceOutput(self):
2563    """Tests that functions can produce outputs on multiple devices."""
2564    if not context.context().num_gpus():
2565      self.skipTest('No GPUs found.')
2566
2567    @function.defun
2568    def func(a, b, transpose_a):
2569      with ops.device('/device:CPU:0'):
2570        m1 = math_ops.matmul(a, b, transpose_a=transpose_a)
2571      with ops.device('/device:GPU:0'):
2572        m2 = math_ops.matmul(a, b, transpose_a=transpose_a)
2573      return m1, m2
2574
2575    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2576    m1, m2 = func(t, t, transpose_a=True)
2577    self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]])
2578    self.assertRegexpMatches(m1.backing_device, 'CPU')
2579    self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]])
2580    self.assertRegexpMatches(m2.backing_device, 'GPU')
2581
2582  def testEmptyBody(self):
2583    if not context.context().num_gpus():
2584      self.skipTest('No GPUs found.')
2585
2586    @function.defun
2587    def func(a, b):
2588      return b, a
2589
2590    with ops.device('/device:CPU:0'):
2591      a = constant_op.constant(3.0)
2592    with ops.device('/device:GPU:0'):
2593      b = constant_op.constant(5.0)
2594
2595    m1, m2 = func(a, b)
2596    self.assertAllEqual(m1.numpy(), 5.0)
2597    self.assertRegexpMatches(m1.backing_device, 'GPU')
2598    self.assertAllEqual(m2.numpy(), 3.0)
2599    self.assertRegexpMatches(m2.backing_device, 'CPU')
2600
2601  def testMultiDeviceInt32(self):
2602    """Tests that multi-device functions can take and output INT32s.
2603
2604    When an INT32 device tensor is fed into a function, it is copied to CPU
2605    by the eager runtime. The function sees all INT32 inputs on CPU.
2606
2607    We set allocator attribute 'on_host' for INT32 outputs. They can be
2608    partitioned into the GPU component function, but will be allocated on
2609    CPU nevertheless.
2610
2611    There is experimental support for `ints_on_device` in
2612    FunctionLibraryRuntime now. We can try that.
2613
2614    """
2615    if not context.context().num_gpus():
2616      self.skipTest('No GPUs found.')
2617
2618    with ops.device('/device:CPU:0'):
2619      int_cpu = constant_op.constant(3, dtype=dtypes.int32)
2620      resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32)
2621    with ops.device('/device:GPU:0'):
2622      int_gpu = constant_op.constant(7, dtype=dtypes.int32)
2623
2624    @function.defun
2625    def func(int_cpu, resource, int_gpu):
2626      with ops.device('/device:CPU:0'):
2627        m1 = int_cpu * resource + int_gpu
2628      with ops.device('/device:GPU:0'):
2629        # This computation will happen on GPU but m2 will be copied to CPU.
2630        m2 = int_gpu * resource + int_cpu + 1
2631      return m1, m2
2632
2633    m1, m2 = func(int_cpu, resource, int_gpu)
2634    self.assertAllEqual(m1.numpy(), 22)
2635    self.assertRegexpMatches(m1.backing_device, 'CPU')
2636    self.assertAllEqual(m2.numpy(), 39)
2637    self.assertRegexpMatches(m2.backing_device, 'CPU')
2638
2639    # flip arguments
2640    m1, m2 = func(int_gpu, resource, int_cpu)
2641    self.assertAllEqual(m1.numpy(), 38)
2642    self.assertRegexpMatches(m1.backing_device, 'CPU')
2643    self.assertAllEqual(m2.numpy(), 23)
2644    self.assertRegexpMatches(m2.backing_device, 'CPU')
2645
2646  def testMultiDeviceColocateWith(self):
2647    """Tests that function's outputs respect colocation constraints."""
2648    if not context.context().num_gpus():
2649      self.skipTest('No GPUs found.')
2650
2651    @function.defun
2652    def func(a, b):
2653      with ops.colocate_with(a):
2654        ra = 2 * a
2655      with ops.colocate_with(b):
2656        rb = 3 * b
2657      return ra, rb
2658
2659    devices = ['/device:CPU:0', '/device:GPU:0']
2660    for dev1, dev2 in itertools.product(devices, devices):
2661      with ops.device(dev1):
2662        a = constant_op.constant(1.0)
2663      with ops.device(dev2):
2664        b = constant_op.constant(10.0)
2665
2666      ra, rb = func(a, b)
2667      self.assertEqual(ra.numpy(), 2.0)
2668      self.assertRegexpMatches(ra.backing_device, dev1)
2669      self.assertEqual(rb.numpy(), 30.0)
2670      self.assertRegexpMatches(rb.backing_device, dev2)
2671
2672  def testMultiDeviceResources(self):
2673    if not context.context().num_gpus():
2674      self.skipTest('No GPUs found.')
2675
2676    with ops.device('/device:CPU:0'):
2677      c1 = resource_variable_ops.ResourceVariable(2.0)
2678      c2 = resource_variable_ops.ResourceVariable(7.0)
2679    with ops.device('/device:GPU:0'):
2680      g1 = resource_variable_ops.ResourceVariable(3.0)
2681      g2 = resource_variable_ops.ResourceVariable(5.0)
2682
2683    @function.defun
2684    def func(resource1, resource2):
2685      with ops.device('/device:CPU:0'):
2686        result1 = resource1 * g2
2687      with ops.device('/device:GPU:0'):
2688        result2 = resource2 * c2
2689      return result1, result2
2690
2691    r1, r2 = func(c1, g1)
2692    self.assertEqual(r1.numpy(), 10.0)
2693    self.assertRegexpMatches(r1.backing_device, 'CPU')
2694    self.assertEqual(r2.numpy(), 21.0)
2695    self.assertRegexpMatches(r2.backing_device, 'GPU')
2696
2697    # Call with flipped inputs. Check that we look at resource's
2698    # device and reinstantiates the function when inputs' devices change.
2699    r1, r2 = func(g1, c1)
2700    self.assertEqual(r1.numpy(), 15.0)
2701    self.assertRegexpMatches(r1.backing_device, 'CPU')
2702    self.assertEqual(r2.numpy(), 14.0)
2703    self.assertRegexpMatches(r2.backing_device, 'GPU')
2704
2705  def testOutputResources(self):
2706    if not context.context().num_gpus():
2707      self.skipTest('No GPUs found.')
2708
2709    with ops.device('/device:CPU:0'):
2710      c1 = resource_variable_ops.ResourceVariable(2.0)
2711    with ops.device('/device:GPU:0'):
2712      g1 = resource_variable_ops.ResourceVariable(3.0)
2713
2714    @function.defun
2715    def func(resource1, resource2):
2716      with ops.device('/device:CPU:0'):
2717        result1 = resource1 * 5
2718      with ops.device('/device:GPU:0'):
2719        result2 = resource2 * 7
2720      return result1, resource1.handle, result2, resource2.handle
2721
2722    r1, res1, r2, res2 = func(c1, g1)
2723    self.assertEqual(r1.numpy(), 10.0)
2724    self.assertRegexpMatches(r1.backing_device, 'CPU')
2725    self.assertEqual(r2.numpy(), 21.0)
2726    self.assertRegexpMatches(r2.backing_device, 'GPU')
2727
2728    def check_handle(handle, expected_value):
2729      self.assertRegexpMatches(handle.backing_device, 'CPU')
2730      tensor = gen_resource_variable_ops.read_variable_op(
2731          handle, dtypes.float32)
2732      self.assertEqual(tensor.numpy(), expected_value)
2733
2734    # Check that handles returned from functions are on CPU and an op using
2735    # the resource handle is correctly placed on the device backing the
2736    # resource.
2737    check_handle(res1, 2.0)
2738    check_handle(res2, 3.0)
2739
2740    # Call with flipped inputs to make sure the same the function is
2741    # reinstantiated and eager runtime does not mess up the device assignment
2742    # for ops consuming handles returned from defuns.
2743    r1, res1, r2, res2 = func(g1, c1)
2744    self.assertEqual(r1.numpy(), 15.0)
2745    self.assertRegexpMatches(r1.backing_device, 'CPU')
2746    self.assertEqual(r2.numpy(), 14.0)
2747    self.assertRegexpMatches(r2.backing_device, 'GPU')
2748    check_handle(res1, 3.0)
2749    check_handle(res2, 2.0)
2750
2751  def testComplexInputOutputDevicePattern(self):
2752    """Tests input/output mapping logic in partitioning."""
2753    if not context.context().num_gpus():
2754      self.skipTest('No GPUs found.')
2755
2756    with ops.device('/device:CPU:0'):
2757      rc0 = resource_variable_ops.ResourceVariable(2.0)
2758      rc1 = resource_variable_ops.ResourceVariable(3.0)
2759      cc0 = constant_op.constant(5.0)
2760      cc1 = constant_op.constant(7.0)
2761    with ops.device('/device:GPU:0'):
2762      rg0 = resource_variable_ops.ResourceVariable(11.0)
2763      rg1 = resource_variable_ops.ResourceVariable(13.0)
2764      cg0 = constant_op.constant(17.0)
2765      cg1 = constant_op.constant(19.0)
2766
2767    # Make sure tensors are on expected devices.
2768    for tensor in [cc0, cc1]:
2769      self.assertRegexpMatches(tensor.backing_device, 'CPU:0')
2770    for tensor in [cg0, cg1]:
2771      self.assertRegexpMatches(tensor.backing_device, 'GPU:0')
2772
2773    @function.defun
2774    def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1):
2775      with ops.device('/device:CPU:0'):
2776        m1 = rc0 * cg0
2777      with ops.device('/device:GPU:0'):
2778        m2 = rg0 * cc0
2779
2780      with ops.device('/device:CPU:0'):
2781        r1 = 1000.0 * m2 + rc1 * cg1
2782      with ops.device('/device:GPU:0'):
2783        r2 = 1000.0 * m1 + rg1 * cc1
2784
2785      return r1, r2, m2, m1
2786
2787    r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1)
2788    self.assertRegexpMatches(m1.backing_device, 'CPU')
2789    self.assertRegexpMatches(r1.backing_device, 'CPU')
2790    self.assertRegexpMatches(m2.backing_device, 'GPU')
2791    self.assertRegexpMatches(r2.backing_device, 'GPU')
2792    self.assertEqual(m1.numpy(), 34.0)
2793    self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0)
2794    self.assertEqual(m2.numpy(), 55.0)
2795    self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0)
2796
2797  def testArgumentPrunning(self):
2798    """Tests functions taking unnecessary arguments."""
2799    if not context.context().num_gpus():
2800      self.skipTest('No GPUs found.')
2801
2802    with ops.device('/device:CPU:0'):
2803      c1 = constant_op.constant(5.0)
2804      c2 = constant_op.constant(7.0)
2805
2806    with ops.device('/device:GPU:0'):
2807      g1 = constant_op.constant(11.0)
2808      g2 = constant_op.constant(13.0)
2809      g3 = constant_op.constant(17.0)
2810
2811    @function.defun
2812    def func(g1, g2, c1, g3, c2):  # pylint: disable=unused-argument
2813      # arguments g1 and g2 are unused and can be pruned by grappler.
2814      return c1 * g3 * c2
2815
2816    result = func(g1, g2, c1, g3, c2)
2817    self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0)
2818
2819
2820if __name__ == '__main__':
2821  ops.enable_eager_execution(
2822      config=config_pb2.ConfigProto(device_count={'CPU': 4}))
2823  test.main()
2824