• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import copy
22import functools
23import itertools
24import multiprocessing.pool
25import os
26import sys
27import time
28import weakref
29
30from absl.testing import parameterized
31import numpy
32
33from tensorflow.core.protobuf import config_pb2
34from tensorflow.core.protobuf import rewriter_config_pb2
35from tensorflow.python.autograph.core import ag_ctx
36from tensorflow.python.data.ops import dataset_ops
37from tensorflow.python.data.ops import iterator_ops
38from tensorflow.python.eager import backprop
39from tensorflow.python.eager import cancellation
40from tensorflow.python.eager import context
41from tensorflow.python.eager import def_function
42from tensorflow.python.eager import function
43from tensorflow.python.framework import composite_tensor
44from tensorflow.python.framework import config
45from tensorflow.python.framework import constant_op
46from tensorflow.python.framework import dtypes
47from tensorflow.python.framework import errors
48from tensorflow.python.framework import func_graph
49from tensorflow.python.framework import function as tf_function
50from tensorflow.python.framework import indexed_slices
51from tensorflow.python.framework import ops
52from tensorflow.python.framework import random_seed
53from tensorflow.python.framework import sparse_tensor
54from tensorflow.python.framework import tensor_shape
55from tensorflow.python.framework import tensor_spec
56from tensorflow.python.framework import test_ops
57from tensorflow.python.framework import test_util
58from tensorflow.python.framework import type_spec
59from tensorflow.python.layers import convolutional
60from tensorflow.python.module import module
61from tensorflow.python.ops import array_ops
62from tensorflow.python.ops import check_ops
63from tensorflow.python.ops import clip_ops
64from tensorflow.python.ops import control_flow_ops
65from tensorflow.python.ops import data_flow_ops
66from tensorflow.python.ops import functional_ops
67from tensorflow.python.ops import gen_functional_ops
68from tensorflow.python.ops import gen_random_ops
69from tensorflow.python.ops import gen_resource_variable_ops
70from tensorflow.python.ops import gen_sendrecv_ops
71from tensorflow.python.ops import gradients_impl
72from tensorflow.python.ops import init_ops
73from tensorflow.python.ops import list_ops
74from tensorflow.python.ops import logging_ops
75from tensorflow.python.ops import math_ops
76from tensorflow.python.ops import random_ops
77from tensorflow.python.ops import resource_variable_ops
78from tensorflow.python.ops import string_ops
79from tensorflow.python.ops import variable_scope
80from tensorflow.python.ops import variables
81from tensorflow.python.ops.ragged import ragged_factory_ops
82from tensorflow.python.ops.ragged import ragged_tensor
83from tensorflow.python.ops.structured import structured_tensor
84from tensorflow.python.platform import test
85from tensorflow.python.saved_model.load import load
86from tensorflow.python.saved_model.save import save
87from tensorflow.python.training import training_ops
88from tensorflow.python.util import compat
89from tensorflow.python.util import nest
90from tensorflow.python.util import tf_decorator
91from tensorflow.python.util import tf_inspect
92
93try:
94  import attr  # pylint:disable=g-import-not-at-top
95except ImportError:
96  attr = None
97
98
99def total_function_cache(defined):
100  # pylint: disable=protected-access
101  return (set(defined._function_cache.primary)
102          | set(defined._function_cache.arg_relaxed))
103  # pylint: enable=protected-access
104
105
106def _example_indexed_slices_with_dense_shape():
107  return indexed_slices.IndexedSlices(
108      constant_op.constant([1, 2]), constant_op.constant([0, 1]),
109      constant_op.constant([2]))
110
111
112def _example_indexed_slices_without_dense_shape():
113  return indexed_slices.IndexedSlices(
114      constant_op.constant([1, 2]), constant_op.constant([0, 1]))
115
116
117def _spec_for_value(value):
118  """Returns the (nested) TypeSpec for a value."""
119  if nest.is_sequence(value):
120    return nest.map_structure(_spec_for_value, value)
121  elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
122    return type_spec.type_spec_from_value(value)
123  else:
124    return value
125
126
127# This dummy decorator imitates ordinary decorators utilizing tf_decorator.
128def dummy_tf_decorator(method):
129
130  def wrapper(*args, **kwargs):
131    return method(*args, **kwargs)
132
133  return tf_decorator.make_decorator(method, wrapper)
134
135
136# TODO(mdan): Organize these tests.
137class FunctionTest(test.TestCase, parameterized.TestCase):
138
139  def setUp(self):
140    super(FunctionTest, self).setUp()
141    cpus = config.list_physical_devices('CPU')
142    # Set 4 virtual CPUs
143    config.set_logical_device_configuration(cpus[0], [
144        context.LogicalDeviceConfiguration(),
145        context.LogicalDeviceConfiguration(),
146        context.LogicalDeviceConfiguration(),
147        context.LogicalDeviceConfiguration()
148    ])
149
150  def testBasic(self):
151    matmul = def_function.function(math_ops.matmul)
152    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
153    sq = matmul(t, t, transpose_a=True)
154    sq2 = matmul(sq, t, transpose_a=True)
155    self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
156    self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])
157
158  def testPythonFunctionNotCallable(self):
159    with self.assertRaisesRegex(TypeError, 'is not a callable object'):
160      def_function.function(1)
161
162  def testOnExitCallback(self):
163    values = []
164    def append_1():
165      values.append(1)
166
167    def append_2():
168      values.append(2)
169
170    def g(x):
171      old_values = list(values)
172      ops.add_exit_callback_to_default_func_graph(append_1)
173      self.assertEqual(old_values, values)
174      return x + 1
175
176    tf_g = def_function.function(g)
177
178    def f(x):
179      old_values = list(values)
180      ops.add_exit_callback_to_default_func_graph(append_2)
181      self.assertEqual(old_values, values)
182      return tf_g(x)
183
184    tf_f = def_function.function(f)
185    self.assertEmpty(values)
186    tf_f(constant_op.constant(1.0))
187    self.assertEqual(values, [1, 2])  # Once for g, once for f.
188    tf_f(constant_op.constant([1.0]))  # force a retrace
189    self.assertEqual(values, [1, 2, 1, 2])  # And again.
190
191  def testCannotAddExitCallbackWhenNotInFunctionScope(self):
192    with self.assertRaisesRegex(RuntimeError, 'when not building a function.'):
193      ops.add_exit_callback_to_default_func_graph(lambda: None)
194
195  def testVariable(self):
196    v1 = variables.Variable(1.0)
197    add = def_function.function(lambda x, v: x + v1 + v)
198    v2 = variables.Variable(1.0)
199    x = constant_op.constant(1.0)
200    r = add(x, v2)
201    self.assertEqual(3.0, self.evaluate(r))
202
203  def testVariableOnly(self):
204    v = variables.Variable(1.0)
205    add = def_function.function(lambda x: x.assign_add(1.0))
206    r1 = add(v)
207    self.assertEqual(2.0, self.evaluate(r1))
208    c = constant_op.constant(1.0)
209    with self.assertRaisesRegex(AttributeError, 'no attribute'):
210      add(c)
211
212  @test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.')
213  def testPackedVariable(self):
214    with ops.device('/cpu:0'):
215      v0_0 = resource_variable_ops.ResourceVariable(1.0)
216    with ops.device('/cpu:1'):
217      v0_1 = resource_variable_ops.ResourceVariable(2.0)
218      v1_0 = resource_variable_ops.ResourceVariable(3.0)
219    with ops.device('/cpu:2'):
220      v1_1 = resource_variable_ops.ResourceVariable(4.0)
221
222    packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle])
223    packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle])
224
225    # TODO(b/145922293): use ResourceVariable.assign_add and
226    # ResourceVariable.read_value directly once we support packing multiple
227    # ResourceVariable into one ResourceVariable.
228    @def_function.function
229    def read_var():
230      resource_variable_ops.assign_add_variable_op(
231          packed_var_0, constant_op.constant(5.0))
232      resource_variable_ops.assign_add_variable_op(
233          packed_var_1, constant_op.constant(6.0))
234      with ops.device('/cpu:0'):
235        read0 = resource_variable_ops.read_variable_op(
236            packed_var_0, dtype=dtypes.float32)
237      with ops.device('/cpu:1'):
238        read1 = resource_variable_ops.read_variable_op(
239            packed_var_0, dtype=dtypes.float32)
240        read2 = resource_variable_ops.read_variable_op(
241            packed_var_1, dtype=dtypes.float32)
242      with ops.device('/cpu:2'):
243        read3 = resource_variable_ops.read_variable_op(
244            packed_var_1, dtype=dtypes.float32)
245
246      return read0, read1, read2, read3
247
248    arg_attrs = read_var.get_concrete_function().function_def.arg_attr
249    self.assertLen(arg_attrs, 2)
250    self.assertEqual(arg_attrs[0].attr['_composite_device'].s,
251                     compat.as_bytes(packed_var_0.device))
252    self.assertEqual(arg_attrs[1].attr['_composite_device'].s,
253                     compat.as_bytes(packed_var_1.device))
254
255    self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
256
257  def testImplementsAttributeBasic(self):
258    v = def_function.function(
259        experimental_implements='func')(lambda x, y: x + y)
260    with context.graph_mode(), self.cached_session():
261      a = array_ops.placeholder(dtypes.float32, ())
262      b = array_ops.placeholder(dtypes.float32, ())
263      v(a, b)
264      gradients_impl.gradients(v(a, b), [a, b])
265      fdefs = ops.get_default_graph().as_graph_def().library.function
266      self.assertLen(fdefs, 3)
267      not_present = 0
268      present = 0
269      for f in fdefs:
270        name = f.signature.name
271        if 'forward' in name or 'backward' in name:
272          not_present += 1
273          self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
274        else:
275          present += 1
276          self.assertEqual(f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME].s,
277                           'func'.encode('ascii'), f)
278      self.assertEqual(not_present, 2, fdefs)
279      self.assertEqual(present, 1, fdefs)
280
281  def testImplementsAttributeAssertsOnSideInput(self):
282    with context.graph_mode(), self.cached_session():
283      z = array_ops.zeros(0)
284      v = def_function.function(
285          experimental_implements='func')(lambda x, y: x + y + z)
286      a = array_ops.ones((1.0,))
287      b = array_ops.ones((1.0,))
288      with self.assertRaisesRegex(AssertionError,
289                                  'variables are always captured'):
290        v(a, b)
291      functions = ops.get_default_graph().as_graph_def().library.function
292      self.assertEmpty(functions)
293
294  def testImplementsAttributeWorksWithGradientTape(self):
295    add = lambda x, y: x + y ** 2
296    add = def_function.function(experimental_implements='MyFunc')(add)
297    x = variables.Variable(3.0)
298    y = variables.Variable(2.0)
299
300    with backprop.GradientTape() as tape:
301      g = add(x, y)
302
303    dg_dy, dg_dx = tape.gradient(g, [y, x])
304    self.assertEqual(dg_dy.numpy(), 4.0)
305    self.assertEqual(dg_dx.numpy(), 1.0)
306
307  def testImplementsAttributeWorksOnVariables(self):
308    with context.graph_mode(), self.cached_session():
309      v = def_function.function(
310          experimental_implements='func')(lambda x, y: x + y)
311      a = variables.Variable((1.0,))
312      b = variables.Variable((1.0,))
313      r1 = v(a, b)
314      _ = v(a, a)
315      functions = ops.get_default_graph().as_graph_def().library.function
316      # Verify that we created only one function
317      self.assertLen(functions, 1)
318      # Verify that eval() reads the current values.
319      a.initializer.run()
320      b.initializer.run()
321      self.assertEqual(r1.eval(), 2)
322
323      a.assign_add([1]).eval()
324      self.assertEqual(r1.eval(), 3)
325
326  def testImplementsAttributeWorksOnConstants(self):
327    with context.graph_mode(), self.cached_session():
328      v = def_function.function(
329          experimental_implements='func')(lambda x, y: x + y)
330      a = variables.Variable(1.0)
331      r1 = v(a, 2.)
332      r2 = v(2., a)
333      functions = ops.get_default_graph().as_graph_def().library.function
334      self.assertLen(functions, 1)
335      self.assertLen(functions[0].signature.input_arg, 2)
336      # Verify that eval() reads the current values.
337      a.initializer.run()
338      self.assertEqual(r1.eval(), 3)
339      self.assertEqual(r2.eval(), 3)
340
341  def testImplementsAttributeSpecializes(self):
342    with context.graph_mode(), self.cached_session():
343      v = def_function.function(
344          experimental_implements='func')(lambda x, y: x + y)
345      a = variables.Variable(1.0)
346      r1 = v(a, [2.])
347      r2 = v([2., 2], a)
348      functions = ops.get_default_graph().as_graph_def().library.function
349      self.assertLen(functions, 2)
350      # Ensure that all parameters are still there and haven't been inlined!
351
352      self.assertLen(functions[0].signature.input_arg, 2)
353      self.assertLen(functions[1].signature.input_arg, 2)
354      # Verify that eval() reads the current values.
355      a.initializer.run()
356      numpy.testing.assert_equal(r1.eval(), [3.])
357      numpy.testing.assert_equal(r2.eval(), [3., 3.])
358
359  def testImplementsWorksWithTensorSpec(self):
360    v = def_function.function(
361        experimental_implements='func')(lambda x, y: x + y)
362    v = v.get_concrete_function(
363        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32),
364        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32))
365    x = v(1., 2.)
366    self.assertEqual(x.numpy(), 3.)
367
368  def testImplementsAttributeAsNameAttrList(self):
369    implements_attr = (
370        'name: "embedding_matmul" attr {   key: "key1"   value {     i: 2   } '
371        '} attr {   key: "key2"   value {     b: false   } }')
372    v = def_function.function(
373        experimental_implements=implements_attr)(lambda x, y: x + y)
374    with context.graph_mode(), self.cached_session():
375      a = array_ops.placeholder(dtypes.float32, ())
376      b = array_ops.placeholder(dtypes.float32, ())
377      v(a, b)
378      gradients_impl.gradients(v(a, b), [a, b])
379      fdefs = ops.get_default_graph().as_graph_def().library.function
380      self.assertLen(fdefs, 3)
381      not_present = 0
382      present = 0
383      for f in fdefs:
384        name = f.signature.name
385        if 'forward' in name or 'backward' in name:
386          not_present += 1
387          self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
388        else:
389          present += 1
390          attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME]
391          self.assertIsNotNone(attr_value.func, f)
392          self.assertEqual(attr_value.func.name, 'embedding_matmul')
393          name_attrs = attr_value.func.attr
394          self.assertLen(name_attrs, 2)
395      self.assertEqual(not_present, 2, fdefs)
396      self.assertEqual(present, 1, fdefs)
397
398  def testExternalControlDependency(self):
399    with ops.Graph().as_default(), self.test_session():
400      v = variables.Variable(1.0)
401      v.initializer.run()
402
403      op = v.assign_add(1.0)
404
405      @function.defun
406      def f():
407        with ops.control_dependencies([op]):
408          return 1.0
409
410      self.evaluate(f())
411      self.assertAllEqual(self.evaluate(v), 2.0)
412
413  def testInputShapeFunctionRelaxation(self):
414    unknown_dim = [False]
415
416    @function.defun(experimental_relax_shapes=True)
417    def func(a):
418      if a._shape_tuple()[0] is None:
419        unknown_dim[0] = True
420      return a + 1
421
422    func(constant_op.constant([]))
423    self.assertFalse(unknown_dim[0])
424    self.assertLen(total_function_cache(func), 1)
425
426    func(constant_op.constant([1.0]))
427    self.assertFalse(unknown_dim[0])
428    self.assertLen(total_function_cache(func), 2)
429
430    func(constant_op.constant([1.0, 2.0]))
431    self.assertTrue(unknown_dim[0])
432    self.assertLen(total_function_cache(func), 2)
433
434  def testInputShapeRelaxationOnInstanceMethod(self):
435    # Test that experimental_relax_shapes is passed during
436    # instance method bounding.
437    unknown_dim = [False]
438
439    class Foo(object):
440
441      @def_function.function(experimental_relax_shapes=True)
442      def func(self, a):
443        if a._shape_tuple()[0] is None:
444          unknown_dim[0] = True
445        return a + 1
446
447    foo = Foo()
448    foo.func(constant_op.constant([]))
449    self.assertFalse(unknown_dim[0])
450
451    foo.func(constant_op.constant([1.0]))
452    self.assertFalse(unknown_dim[0])
453
454    foo.func(constant_op.constant([1.0, 2.0]))
455    self.assertTrue(unknown_dim[0])
456
457  def testInputShapeFunctionRelaxationWithRaggedTensors(self):
458    traced_type_spec = [None]
459
460    @def_function.function(experimental_relax_shapes=True)
461    def func(x):
462      traced_type_spec[0] = x._type_spec
463      return x
464
465    def check_trace(x, expected_trace):
466      traced_type_spec[0] = None
467      func(x)
468      self.assertEqual(traced_type_spec[0], expected_trace)
469
470    check_trace(  # Initial call gets traced.
471        ragged_factory_ops.constant([[1], [2, 3, 4]]),
472        ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32))
473    check_trace(  # Input TypeSpec is the same -> no retrace.
474        ragged_factory_ops.constant([[1, 2], [3, 4]]), None)
475    check_trace(  # Even if component tensor shapes change -> no retrace.
476        ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None)
477    check_trace(  # Different TypeSpec shape (nrows): retrace
478        ragged_factory_ops.constant([[1], [2], [3]]),
479        ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32))
480    check_trace(  # Different nrows again: relax & retrace
481        ragged_factory_ops.constant([[1], [2], [3], [4]]),
482        ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32))
483    check_trace(  # Different nrows yet again: not retrace
484        ragged_factory_ops.constant([[1]]), None)
485    check_trace(  # Different ragged_rank: retrace
486        ragged_factory_ops.constant([[[1]]]),
487        ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32))
488    check_trace(  # Different ragged_rank again: retrace & relax
489        ragged_factory_ops.constant([[[1]], [[2]]]),
490        ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32))
491
492  def testInputShapeFunctionRelaxationWithStructuredTensors(self):
493    traced_type_spec = [None]
494
495    @def_function.function(experimental_relax_shapes=True)
496    def func(x):
497      traced_type_spec[0] = x._type_spec
498      return x
499
500    def check_trace(x, expected_trace):
501      traced_type_spec[0] = None
502      func(x)
503      self.assertEqual(traced_type_spec[0], expected_trace)
504
505    # If we have TypeSpecs that differ in ways other than just their shape,
506    # then retrace each time.
507    check_trace(
508        structured_tensor.StructuredTensor.from_pyval({'a': [1]}),
509        structured_tensor.StructuredTensorSpec(
510            [], {'a': tensor_spec.TensorSpec((1,), dtypes.int32)}))
511    check_trace(
512        structured_tensor.StructuredTensor.from_pyval({'b': [1]}),
513        structured_tensor.StructuredTensorSpec(
514            [], {'b': tensor_spec.TensorSpec((1,), dtypes.int32)}))
515    check_trace(
516        structured_tensor.StructuredTensor.from_pyval({'c': [1]}),
517        structured_tensor.StructuredTensorSpec(
518            [], {'c': tensor_spec.TensorSpec((1,), dtypes.int32)}))
519
520    # But if we call again with only shape different, then do relax:
521    check_trace(  # retrace
522        structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}),
523        structured_tensor.StructuredTensorSpec(
524            [], {'a': tensor_spec.TensorSpec((2,), dtypes.int32)}))
525    check_trace(  # relax & retrace
526        structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}),
527        structured_tensor.StructuredTensorSpec(
528            [], {'a': tensor_spec.TensorSpec((None,), dtypes.int32)}))
529    check_trace(  # use relaxed graph
530        structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}),
531        None)
532
533  def testInputShapeFunctionRelaxationWithDatasetIterators(self):
534    # For dataset iterators, the TypeSpec includes type information that's
535    # not derivable from the component tensors.  Make sure that the TypeSpec
536    # shapes get relaxed as appropriate.
537
538    traced_type_spec = [None]
539
540    @def_function.function(experimental_relax_shapes=True)
541    def func(x):
542      traced_type_spec[0] = x._type_spec
543      return x
544
545    def check_trace(x, expected_trace):
546      traced_type_spec[0] = None
547      func(x)
548      self.assertEqual(traced_type_spec[0], expected_trace)
549
550    ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2]))
551    ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2]))
552    ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2]))
553    ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2]))
554    ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1]))
555    check_trace(  # shape=[1, 2]: retrace
556        dataset_ops.make_one_shot_iterator(ds_1_2),
557        iterator_ops.IteratorSpec(
558            tensor_spec.TensorSpec([1, 2], dtypes.float32)))
559    check_trace(  # shape=[1, 2]: no retrace (use the [1, 2] graph)
560        dataset_ops.make_one_shot_iterator(ds_1_2), None)
561    check_trace(  # shape=[2, 2]: retrace
562        dataset_ops.make_one_shot_iterator(ds_2_2),
563        iterator_ops.IteratorSpec(
564            tensor_spec.TensorSpec([2, 2], dtypes.float32)))
565    check_trace(  # shape=[3, 2]: relax to [None, 2] and retrace
566        dataset_ops.make_one_shot_iterator(ds_3_2),
567        iterator_ops.IteratorSpec(
568            tensor_spec.TensorSpec([None, 2], dtypes.float32)))
569    check_trace(  # shape=[4, 2]: no retrace (use the [None, 2] graph)
570        dataset_ops.make_one_shot_iterator(ds_4_2), None)
571    check_trace(  # shape=[2, 1]: relax to [None, None] and retrace
572        dataset_ops.make_one_shot_iterator(ds_2_1),
573        iterator_ops.IteratorSpec(
574            tensor_spec.TensorSpec([None, None], dtypes.float32)))
575
576  def testCapturesVariables(self):
577    a = variables.Variable(1.0, trainable=False)
578    b = variables.Variable(1.0)
579    cc = [None]
580
581    @def_function.function
582    def f():
583      c = cc[0]
584      if c is None:
585        c = cc[0] = variables.Variable(1.)
586      return a + b + c + 1
587
588    cf = f.get_concrete_function()
589    c = cc[0]
590
591    captured_variables = {v.ref() for v in (a, b, c)}
592    trainable_variables = {v.ref() for v in (b, c)}
593    self.assertEqual({v.ref() for v in cf.variables}, captured_variables)
594    self.assertEqual({v.ref() for v in cf.trainable_variables},
595                     trainable_variables)
596    self.assertEqual(cf.variables, cf.graph.variables)
597    self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables)
598
599  def testNestedInputShapeFunctionRelaxation(self):
600    unknown_dim = [False]
601
602    @function.defun(experimental_relax_shapes=True)
603    def func(a_, b_=None):
604      del a_  # Only used to check which cache is used.
605      self.assertEqual(b_[0]._shape_tuple(), ())
606      if b_[1]._shape_tuple()[0] is None:
607        unknown_dim[0] = True
608      return b_[0] + 1
609
610    a = 'hi'
611    b0 = constant_op.constant(1.0)
612    func(a, b_=[b0, constant_op.constant([])])
613    self.assertFalse(unknown_dim[0])
614    self.assertLen(total_function_cache(func), 1)
615
616    func(a, b_=[b0, constant_op.constant([1.0])])
617    self.assertFalse(unknown_dim[0])
618    self.assertLen(total_function_cache(func), 2)
619
620    func(a, b_=[b0, constant_op.constant([1.0, 1.0])])
621    self.assertTrue(unknown_dim[0])
622    self.assertLen(total_function_cache(func), 2)
623
624    unknown_dim[0] = False
625
626    # Now do the same except with a new a which is not a tensor; this should
627    # change the cache key.
628    a = 'bye'
629    func(a, b_=[b0, constant_op.constant([])])
630    self.assertFalse(unknown_dim[0])
631    self.assertLen(total_function_cache(func), 3)
632
633    # Since we already marked a cache miss for a function with the same
634    # non-input signatures, here we will immediately start relaxing shapes.
635    func(a, b_=[b0, constant_op.constant([1.0])])
636    self.assertTrue(unknown_dim[0])
637    self.assertLen(total_function_cache(func), 3)
638
639  def testNestedShapeFunctionRelaxation(self):
640
641    got_shape = [None]
642
643    # The inner function will go through shape relaxation because the shapes it
644    # receives will be [1], [2], [3], ...
645    @def_function.function(experimental_relax_shapes=True)
646    def bar(x_shape):
647      got_shape[0] = x_shape._shape_tuple()
648      return x_shape
649
650    # The outer function will not go through shape relaxation because the shapes
651    # it receives will be [1], [[1]], [[[1]]], ...
652    @def_function.function(experimental_relax_shapes=True)
653    def foo(ones):
654      return bar(array_ops.shape(ones))
655
656    for rank in range(1, 6):
657      x_shape = self.evaluate(foo(array_ops.ones([1] * rank)))
658      self.assertAllEqual(x_shape, [1] * rank)
659      if rank < 3:
660        self.assertEqual(got_shape[0], (rank,))
661      else:
662        self.assertEqual(got_shape[0], (None,))
663
664  def testNoHash(self):
665
666    @def_function.function()
667    def f(_):
668      return 1.0
669
670    with self.assertRaisesRegex(ValueError, r'got.*set'):
671      f(set([]))
672
673  def testFuncName(self):
674
675    @function.defun_with_attributes(attributes={'func_name': 'multiply'})
676    def add(x, y):
677      _ = x * y
678      return x + y
679
680    @function.defun
681    def add_2(x, y):
682      _ = x * y
683      return x + y
684
685    self.assertEqual(add._name, 'multiply')
686    self.assertEqual(add_2._name, 'add_2')
687
688  def testBasicGraphMode(self):
689    matmul = def_function.function(math_ops.matmul)
690
691    @def_function.function
692    def sq(a):
693      return matmul(a, a)
694
695    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
696    out = sq(t)
697    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
698
699  def testNestedInputsGraphMode(self):
700    matmul = def_function.function(math_ops.matmul)
701
702    pair = collections.namedtuple('pair', ['a', 'b'])
703
704    @def_function.function
705    def a_times_b(inputs):
706      return matmul(inputs.a['a'], inputs.b['b'])
707
708    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
709
710    out = a_times_b(pair({'a': t}, {'b': t}))
711    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
712
713  def testNestedOutputsGraphMode(self):
714    matmul = def_function.function(math_ops.matmul)
715
716    pair = collections.namedtuple('pair', ['a', 'b'])
717
718    @def_function.function()
719    def pairs_mul(pair_a, pair_b):
720      return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b))
721
722    a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]])
723    b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]])
724
725    out = pairs_mul(pair(a, b), pair(b, a))
726    expected = pair(math_ops.matmul(a, b).numpy(),
727                    math_ops.matmul(b, a).numpy())
728    self.assertAllClose(out, expected)
729
730  @parameterized.named_parameters(
731      dict(testcase_name='Defun',
732           function_decorator=function.defun),
733      dict(testcase_name='DefFunction',
734           function_decorator=def_function.function))
735  def testNestedFunctionGraphNotOutOfDate(self, function_decorator):
736    @function_decorator
737    def f():
738      return constant_op.constant(1.)
739
740    class _Model(object):
741
742      @function_decorator
743      def g(self):
744        self.f = f.get_concrete_function()
745
746    model = _Model()
747    model.g()
748    concrete = model.f
749    weak_g_graph = weakref.ref(model.g.get_concrete_function().graph)
750    self.assertIs(weak_g_graph(), concrete.graph.outer_graph)
751    weak_g = weakref.ref(model.g)
752    del model
753    self.assertIsNone(weak_g())
754    self.assertIsNone(weak_g_graph())
755    self.assertIsNotNone(concrete.graph.outer_graph)
756    self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph)
757
758  def testGraphEagerIsolation(self):
759
760    @function.defun
761    def f():
762      self.v = variables.Variable(1.0)
763      return self.v.read_value()
764
765    self.assertAllEqual(f(), 1.0)
766
767    with ops.Graph().as_default():
768      self.assertEqual(f().shape, ())
769
770  def testBasicGraphFunction(self):
771    matmul = def_function.function(math_ops.matmul)
772
773    @def_function.function
774    def sq(a):
775      return matmul(a, a)
776
777    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
778
779    sq_op = sq.get_concrete_function(t)
780    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
781    out = sq_op(t)
782    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
783
784  def testGetConcreteFunctionThreadSafety(self):
785
786    @def_function.function
787    def sq():
788      t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
789      return math_ops.matmul(t, t)
790
791    concrete_functions = []
792
793    def thread_func(_):
794      cf = sq.get_concrete_function()
795      concrete_functions.append(cf)
796
797    num_threads = 100
798    pool = multiprocessing.pool.ThreadPool(num_threads)
799    _ = pool.map(thread_func, list(range(num_threads)))
800
801    self.assertLen(set(concrete_functions), 1)
802
803  def testGetConcreteFunctionThreadSafetyWithArgs(self):
804    @def_function.function
805    def add_100(*args):
806      return math_ops.add_n(args)
807
808    p = multiprocessing.pool.ThreadPool(2)
809    args = (constant_op.constant(1.),) * 100
810    f1, f2 = p.map(add_100.get_concrete_function, [args] * 2)
811    # I see about len(args) + max(0, len(args) - 3) arguments expected.
812    f1(*args)
813    del f2
814
815  def testInputSpecGraphFunction(self):
816    matmul = def_function.function(math_ops.matmul)
817
818    @def_function.function
819    def sq(a):
820      return matmul(a, a)
821
822    sq_op = sq.get_concrete_function(
823        tensor_spec.TensorSpec((None, None), dtypes.float32))
824    self.assertEqual([None, None], sq_op.output_shapes.as_list())
825
826    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
827    out1 = sq_op(t1)
828    self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
829
830    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
831    out2 = sq_op(t2)
832    self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
833
834  def testNestedInputSpecGraphFunction(self):
835    matmul = def_function.function(math_ops.matmul)
836
837    @def_function.function
838    def sq(mats):
839      ((a, b),) = mats
840      return matmul(a, b)
841
842    sq_op_autonamed = sq.get_concrete_function(
843        [(tensor_spec.TensorSpec((None, None), dtypes.float32),
844          tensor_spec.TensorSpec((None, None), dtypes.float32))])
845    self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list())
846
847    sq_op = sq.get_concrete_function(
848        [(tensor_spec.TensorSpec((None, None), dtypes.float32,
849                                 name='first_mat'),
850          tensor_spec.TensorSpec((None, None), dtypes.float32,
851                                 name='second_mat'))])
852    self.assertEqual([None, None], sq_op.output_shapes.as_list())
853
854    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
855    t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
856    out = sq_op(first_mat=t1, second_mat=t2)
857    self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
858    self.assertAllEqual(sq_op_autonamed(t1, t2),
859                        math_ops.matmul(t1, t2).numpy())
860
861  def testExecutingStatelessDefunConcurrently(self):
862
863    @def_function.function
864    def stateless(x):
865      return math_ops.multiply(2.0, x)
866
867    pool = multiprocessing.pool.ThreadPool()
868    inputs = [constant_op.constant(1.0 * x) for x in range(100)]
869    outputs = [float(out) for out in pool.map(stateless, inputs)]
870    expected = [float(2.0 * x) for x in inputs]
871    self.assertSequenceEqual(outputs, expected)
872
873  def testExecutingManyStatelessDefunsConcurrently(self):
874
875    @def_function.function
876    def stateless(x):
877      del x
878      return math_ops.multiply(2.0, 2.0)
879
880    pool = multiprocessing.pool.ThreadPool()
881    # `pool.map` below instantiates 100 functions, one for each object.
882    objects = [object() for _ in range(100)]
883    outputs = [float(out) for out in pool.map(stateless, objects)]
884    expected = [4.0] * 100
885    self.assertSequenceEqual(outputs, expected)
886
887  @test_util.disable_tfrt('b/169431085: This test is flaky on tfrt')
888  def testExecutingStatefulDefunConcurrently(self):
889
890    v = resource_variable_ops.ResourceVariable(1.0)
891
892    @def_function.function
893    def stateful(x):
894      v.assign(x)
895
896    pool = multiprocessing.pool.ThreadPool()
897    inputs = [constant_op.constant(0.0)] * 100
898    pool.map(stateful, inputs)
899    self.assertEqual(float(v.read_value()), 0.0)
900
901  def testExecutingManyStatefulDefunsConcurrently(self):
902
903    v = resource_variable_ops.ResourceVariable(1.0)
904
905    @def_function.function
906    def stateful(x):
907      del x
908      return v.assign(0.0)
909
910    pool = multiprocessing.pool.ThreadPool()
911    # `pool.map` below instantiates 100 functions, one for each object.
912    pool.map(stateful, [object() for _ in range(100)])
913    self.assertEqual(float(v.read_value()), 0.0)
914
915  def testShareRendezvous(self):
916
917    # Disable grappler from inlining the functions. Note we run the send & recv
918    # in graph mode since with eager mode the function should automatically be
919    # inlined.
920    context.context().set_optimizer_experimental_options(
921        {'disable_meta_optimizer': True})
922
923    cpu = '/device:CPU:0'
924
925    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
926
927    @def_function.function
928    def send():
929      x = constant_op.constant(1)
930      gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu)
931      return x
932
933    send._shared_rendezvous = True  # pylint: disable=protected-access
934
935    @def_function.function(input_signature=signature)
936    def send_body(n):
937      send()
938      return n - 1
939
940    @def_function.function
941    def recv():
942      return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu)
943
944    recv._shared_rendezvous = True  # pylint: disable=protected-access
945
946    @def_function.function(input_signature=signature)
947    def recv_body(n):
948      recv()
949      return n - 1
950
951    @def_function.function(input_signature=signature)
952    def cond(n):
953      return n > 0
954
955    # Instead of calling the send & recv functions directly we want to call them
956    # through a functional while to ensure the rendezvous is shared across the
957    # while boundary.
958    @def_function.function
959    def fn(n):
960      functional_ops.While([n], cond.get_concrete_function(),
961                           send_body.get_concrete_function())
962      return functional_ops.While([n], cond.get_concrete_function(),
963                                  recv_body.get_concrete_function())
964
965    # Use a graph context since functions will not be automatically inlined
966    with context.graph_mode(), self.cached_session():
967      self.evaluate(fn(2))
968
969  def disabled_testRandomSeed(self):
970
971    @def_function.function
972    def f():
973      return random_ops.random_normal(())
974
975    random_seed.set_random_seed(1)
976    x = f()
977    self.assertNotEqual(x, f())
978    random_seed.set_random_seed(1)
979    self.assertAllEqual(f(), x)
980
981  def testNestedInputsGraphFunction(self):
982    matmul = def_function.function(math_ops.matmul)
983
984    pair = collections.namedtuple('pair', ['a', 'b'])
985
986    @def_function.function
987    def a_times_b(inputs):
988      return matmul(inputs.a['a'], inputs.b['b'])
989
990    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
991    sq_op = a_times_b.get_concrete_function(
992        pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')),
993             dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b'))))
994    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
995    out = sq_op(a=t, b=t)
996    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
997
998  def testNestedOutputGraphFunction(self):
999    matmul = def_function.function(math_ops.matmul)
1000
1001    @def_function.function
1002    def sq(a):
1003      return (matmul(a, a), {'b': constant_op.constant(1.0)})
1004
1005    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
1006
1007    sq_op = sq.get_concrete_function(t)
1008    self.assertEqual(sq_op.output_shapes,
1009                     (tensor_shape.TensorShape([2, 2]),
1010                      {'b': tensor_shape.TensorShape([])}))
1011    self.assertEqual(sq_op.output_dtypes,
1012                     (dtypes.float32, {'b': dtypes.float32}))
1013    (a, b) = sq_op(t)
1014    self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
1015    self.assertAllEqual(b['b'].numpy(), 1.0)
1016
1017  def testGraphFunctionNoneOutput(self):
1018    @def_function.function
1019    def fn(unused_a, unused_b):
1020      return None
1021
1022    x = constant_op.constant(1)
1023    fn_op = fn.get_concrete_function(x, x)
1024    self.assertEqual(fn_op.output_dtypes, None)
1025    self.assertEqual(fn_op.output_shapes, None)
1026    self.assertAllEqual(fn_op(x, x), None)
1027
1028  def testDefunNumpyArraysConvertedToTensors(self):
1029
1030    def f(x):
1031      self.assertIsInstance(x, ops.Tensor)
1032      return x
1033
1034    x = random_ops.random_uniform([2, 2]).numpy()
1035    defined = function.defun(f)
1036    defined(x)
1037    self.assertLen(total_function_cache(defined), 1)
1038
1039    x = random_ops.random_uniform([2, 2]).numpy()
1040    defined(x)
1041    # A NumPy array with different values but the same shape and dtype
1042    # shouldn't trigger another function definition.
1043    self.assertLen(total_function_cache(defined), 1)
1044
1045    np_ones = numpy.ones([], numpy.float32)
1046    np_zeros = numpy.zeros([], numpy.float32)
1047    tf_ones = array_ops.ones([])
1048    tf_zeros = array_ops.zeros([])
1049
1050    # Test that the numpy array is properly an argument to the graph function.
1051    self.assertEqual(1., defined(np_ones).numpy())
1052    self.assertLen(total_function_cache(defined), 2)
1053    self.assertEqual(0., defined(np_zeros).numpy())
1054    self.assertEqual(1., defined(tf_ones).numpy())
1055    self.assertEqual(0., defined(tf_zeros).numpy())
1056    self.assertLen(total_function_cache(defined), 2)
1057
1058    # Test that mutable inputs are supported.
1059    mutable = numpy.ones([], numpy.float32)
1060    self.assertEqual(1., defined(mutable).numpy())
1061    mutable.fill(0)
1062    self.assertEqual(0., defined(mutable).numpy())
1063
1064    class MyNdarray(numpy.ndarray):
1065      pass
1066
1067    # Test that the subclasses of ndarray are converted too.
1068    self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
1069    self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
1070
1071    # We should not have triggered any re-tracing of the python function.
1072    self.assertLen(total_function_cache(defined), 2)
1073
1074  def testNumpyDtypeInputSupported(self):
1075    @function.defun
1076    def f(x, dtype):
1077      return constant_op.constant(dtype(x))
1078
1079    self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1))
1080    self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2))
1081    self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1))
1082    self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2))
1083
1084  def testDefunNumpyArraysConvertedToTensorsInKwargs(self):
1085
1086    def f(**kwargs):
1087      x = kwargs.pop('x')
1088      self.assertIsInstance(x, ops.Tensor)
1089      return x
1090
1091    x = random_ops.random_uniform([2, 2]).numpy()
1092    defined = function.defun(f)
1093    defined(x=x)
1094    self.assertLen(total_function_cache(defined), 1)
1095
1096    x = random_ops.random_uniform([2, 2]).numpy()
1097    defined(x=x)
1098    # A NumPy array with different values but the same shape and dtype
1099    # shouldn't trigger another function definition.
1100    self.assertLen(total_function_cache(defined), 1)
1101
1102    # Test that the numpy array is properly an argument to the graph function.
1103    self.assertEqual(1., defined(x=numpy.ones([])).numpy())
1104    self.assertEqual(0., defined(x=numpy.zeros([])).numpy())
1105    self.assertEqual(1., defined(x=array_ops.ones([])).numpy())
1106    self.assertEqual(0., defined(x=array_ops.zeros([])).numpy())
1107
1108  def testDefunCapturedInt32(self):
1109    x = constant_op.constant(1, dtype=dtypes.int32)
1110
1111    @def_function.function
1112    def add_int32s():
1113      return x + x
1114
1115    self.assertEqual(2, int(add_int32s()))
1116
1117  def testDefunReadVariable(self):
1118    v = resource_variable_ops.ResourceVariable(1.0)
1119
1120    @def_function.function
1121    def f():
1122      return v.read_value()
1123
1124    self.assertEqual(1.0, float(f()))
1125
1126  def testDefunAssignAddVariable(self):
1127    v = resource_variable_ops.ResourceVariable(1.0)
1128    x = constant_op.constant(2.0)
1129
1130    @def_function.function
1131    def test_assign_add():
1132      v.assign_add(x)
1133      return v.read_value()
1134
1135    self.assertEqual(3.0, float(test_assign_add()))
1136
1137  @test_util.run_in_graph_and_eager_modes
1138  def testTensorInitializationInFunctionRaisesError(self):
1139
1140    @def_function.function
1141    def tensor_init():
1142      with self.assertRaisesRegex(ValueError, 'could not be lifted out'):
1143        resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
1144
1145    tensor_init()
1146
1147  @test_util.run_in_graph_and_eager_modes
1148  def testCallableTensorInitializationInFunction(self):
1149
1150    @def_function.function
1151    def tensor_init():
1152      self.v = resource_variable_ops.ResourceVariable(
1153          lambda: constant_op.constant(2.0))
1154      return self.v.read_value()
1155
1156    value = tensor_init()
1157    if not context.executing_eagerly():
1158      self.evaluate(variables.global_variables_initializer())
1159    self.assertEqual(self.evaluate(value), 2.0)
1160
1161  @test_util.also_run_as_tf_function
1162  def testInitScopeTensorInitializationInFunction(self):
1163
1164    @def_function.function
1165    def tensor_init():
1166      with ops.init_scope():
1167        const = constant_op.constant(2.0)
1168      # Note: this variable bypasses tf.function's variable creation
1169      # requirements by bypassing variable_creator_scope by using
1170      # ResourceVariable instead of Variable.
1171      self.v = resource_variable_ops.ResourceVariable(const)
1172      return self.v.read_value()
1173
1174    value = tensor_init()
1175    self.assertAllEqual(value, 2.0)
1176
1177  @test_util.run_in_graph_and_eager_modes
1178  def testGetConcreteFunctionCreatesVariables(self):
1179
1180    v_holder = []
1181
1182    @def_function.function
1183    def tensor_init():
1184      if not v_holder:
1185        v_holder.append(variables.Variable(5.))
1186      return v_holder[0].read_value()
1187
1188    concrete = tensor_init.get_concrete_function()
1189    self.evaluate(variables.global_variables_initializer())
1190    self.assertAllEqual(5., self.evaluate(concrete()))
1191    self.assertAllEqual(5., self.evaluate(tensor_init()))
1192
1193  def testFuncGraphCaptureByValue(self):
1194    v = variables.Variable(1.0)
1195
1196    def trivial_function():
1197      return v.read_value()
1198
1199    graph_function = function.Function(
1200        trivial_function, 'test', capture_by_value=True)
1201
1202    self.assertAllEqual(graph_function(), 1.0)
1203    v.assign(2.0)
1204    self.assertAllEqual(graph_function(), 1.0)
1205
1206  def testFuncGraphCaptureByValueNested(self):
1207    v = variables.Variable(1.0)
1208
1209    def trivial_function():
1210      return control_flow_ops.cond(
1211          array_ops.placeholder_with_default(True, ()),
1212          v.read_value, v.read_value)
1213
1214    graph_function = function.Function(
1215        trivial_function, 'test', capture_by_value=True)
1216
1217    self.assertAllEqual(graph_function(), 1.0)
1218    v.assign(2.0)
1219    self.assertAllEqual(graph_function(), 1.0)
1220
1221  def testDefunShapeInferenceWithCapturedResourceVariable(self):
1222    v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
1223
1224    def f():
1225      x = constant_op.constant([[1, 2], [3, 4]])
1226      out = math_ops.matmul(v, x)
1227      self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
1228      # We do not return v directly since the tensor conversion function of
1229      # ResourceVariable returns the read value and not the resource itself.
1230      return v._handle
1231
1232    compiled = def_function.function(f)
1233    var_handle = compiled()
1234    self.assertEqual(var_handle.dtype, dtypes.resource)
1235    self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
1236    var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
1237    self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
1238
1239  def testShapeInferenceForMoreSpecificInput(self):
1240
1241    def f(a):
1242      return array_ops.reshape(a, [-1, 3])
1243
1244    signature = [tensor_spec.TensorSpec(None, dtypes.float32)]
1245    compiled = def_function.function(f, input_signature=signature)
1246
1247    @def_function.function
1248    def use_f():
1249      inputs = array_ops.zeros([10, 10, 3])
1250      self.assertAllEqual(f(inputs).shape, compiled(inputs).shape)
1251
1252    use_f()
1253
1254  def testFuncListAttr(self):
1255
1256    @function.defun
1257    def test_function(val):
1258
1259      def fn1():
1260        return array_ops.ones([10])
1261
1262      fn2 = lambda: array_ops.ones([10]) * 2
1263
1264      def fn3(x=3):
1265        return array_ops.ones([10]) * x
1266      fn4 = functools.partial(fn3, x=4)
1267      fn5 = functools.partial(fn3, 5)
1268
1269      return gen_functional_ops.case(val, [], [dtypes.float32],
1270                                     [function.defun(f).get_concrete_function()
1271                                      for f in (fn1, fn2, fn3, fn4, fn5)])
1272
1273    ones = array_ops.ones([10])
1274    self.assertAllEqual([ones], test_function(0))
1275    self.assertAllEqual([ones * 2], test_function(1))
1276    self.assertAllEqual([ones * 3], test_function(2))
1277    self.assertAllEqual([ones * 4], test_function(3))
1278    self.assertAllEqual([ones * 5], test_function(4))
1279    self.assertAllEqual([ones * 5], test_function(22))  # default branch
1280
1281  @test_util.enable_control_flow_v2
1282  def testVariableInLoopInFunction(self):
1283
1284    @function.defun
1285    def test_function():
1286
1287      def loop_test(_):
1288        return False
1289
1290      def loop_body(_):
1291        return variable_scope.get_variable('a', shape=())
1292
1293      return control_flow_ops.while_loop(loop_test, loop_body, [0.0])
1294
1295    self.assertEqual(test_function().shape, [])
1296
1297  def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
1298    with context.graph_mode():
1299      v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
1300
1301      def f():
1302        x = constant_op.constant([[1, 2], [3, 4]])
1303        out = math_ops.matmul(v, x)
1304        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
1305        # We do not return v directly since the tensor conversion function of
1306        # ResourceVariable returns the read value and not the resource itself.
1307        return v._handle
1308
1309      compiled = def_function.function(f)
1310      var_handle = compiled()
1311      self.assertEqual(var_handle.dtype, dtypes.resource)
1312      self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
1313      var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
1314      self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
1315
1316  def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
1317    with context.graph_mode():
1318      v = variables.Variable([[1, 2], [3, 4]])
1319
1320      def f():
1321        x = constant_op.constant([[1, 2], [3, 4]])
1322        out = math_ops.matmul(v, x)
1323        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
1324
1325      # Check that shape inference works while creating the defun
1326      compiled = def_function.function(f)
1327      compiled()
1328
1329  def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
1330    with context.graph_mode():
1331      tensor_list = list_ops.empty_tensor_list(
1332          element_dtype=dtypes.float32,
1333          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1334      tensor_list = list_ops.tensor_list_push_back(tensor_list,
1335                                                   constant_op.constant(1.0))
1336      tensor_list = list_ops.tensor_list_push_back(tensor_list,
1337                                                   constant_op.constant(2.0))
1338
1339      def f():
1340        tl, value = list_ops.tensor_list_pop_back(
1341            tensor_list, element_dtype=dtypes.float32)
1342        self.assertEqual(value.shape, tensor_shape.TensorShape([]))
1343        return tl
1344
1345      compiled = def_function.function(f)
1346      output_tensor_list = compiled()
1347      _, value = list_ops.tensor_list_pop_back(
1348          output_tensor_list, element_dtype=dtypes.float32)
1349      self.assertEqual(value.shape, tensor_shape.TensorShape([]))
1350
1351  @test_util.run_in_graph_and_eager_modes
1352  def testDefunForcesResourceVariables(self):
1353
1354    def variable_creator():
1355      self.v = variables.Variable(0.0)
1356      return self.v.read_value()
1357
1358    self.v = None
1359    defined = function.defun(variable_creator)
1360    defined()  # Create the variable.
1361    self.assertIsInstance(
1362        self.v, resource_variable_ops.ResourceVariable)
1363
1364  def testRunMetadata(self):
1365
1366    @def_function.function
1367    def f(x):
1368      return x * x
1369
1370    with ops.device('cpu:0'):
1371      context.enable_run_metadata()
1372      f(constant_op.constant(1.0))
1373    run_metadata = context.export_run_metadata()
1374    context.disable_run_metadata()
1375    self.assertLen(run_metadata.partition_graphs, 1)
1376
1377  def testGraphModeCaptureVariable(self):
1378    with context.graph_mode(), self.cached_session():
1379
1380      class HasAVar(object):
1381
1382        def __init__(self):
1383          self.v = resource_variable_ops.ResourceVariable(1.0)
1384
1385        def call(self):
1386          return self.v * 2
1387
1388      o = HasAVar()
1389      self.evaluate(variables.global_variables_initializer())
1390      call = def_function.function(o.call)
1391      op = call()
1392      self.assertAllEqual(self.evaluate(op), 2.0)
1393
1394  def testGraphModeManyFunctions(self):
1395    with ops.Graph().as_default(), self.cached_session():
1396
1397      @def_function.function
1398      def f(x):
1399        return x * x
1400
1401      @def_function.function
1402      def g(x):
1403        return f(x) + 1
1404
1405      self.assertAllEqual(g(constant_op.constant(2.0)), 5.0)
1406
1407  def testDict(self):
1408
1409    @def_function.function
1410    def f(x):
1411      return {'name': x + 1}
1412
1413    self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0)
1414
1415  def testWeakrefInputsRejected(self):
1416
1417    @def_function.function
1418    def f(x):
1419      return x
1420
1421    class Dummy:
1422      pass
1423    o = Dummy()
1424    wr = weakref.ref(o)
1425
1426    with self.assertRaisesRegex(ValueError, 'weakref'):
1427      f(wr)
1428
1429  def testTensorConversionWithDefun(self):
1430
1431    @def_function.function
1432    def f(x):
1433      return math_ops.add(x, constant_op.constant(3))
1434
1435    self.assertAllEqual(5, f(constant_op.constant(2)))
1436
1437  def testTensorConversionCall(self):
1438
1439    @def_function.function
1440    def f(x):
1441      return math_ops.add(x, constant_op.constant(3))
1442
1443    @def_function.function
1444    def g(x):
1445      return f(f(x))
1446
1447    self.assertAllEqual(8, g(constant_op.constant(2)))
1448
1449  def testCallShape(self):
1450
1451    @def_function.function
1452    def f(x):
1453      return x + 1
1454
1455    @def_function.function
1456    def g(x):
1457      x = f(x)
1458      self.assertEqual(x.shape.as_list(), [])
1459      return None
1460
1461    g(constant_op.constant(1.0))
1462
1463  def testNestedDefunWithNoOutputAndTapedInput(self):
1464    three = resource_variable_ops.ResourceVariable(3.0, name='v')
1465
1466    @def_function.function
1467    def f(x):
1468      # This function intentionally takes a taped variable as input,
1469      # but does not return any values
1470      math_ops.add(x, three)
1471
1472    @def_function.function
1473    def g(x):
1474      y = math_ops.add(x, three)
1475      f(y)
1476
1477    g(three)
1478
1479  def testGatherResourceWithDefun(self):
1480    with ops.device('cpu:0'):
1481      v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
1482
1483    def sum_gather():
1484      return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
1485
1486    defined = def_function.function(sum_gather)
1487    self.assertAllEqual(sum_gather(), defined())
1488
1489  @parameterized.named_parameters([
1490      ('IndexedSlicesWithDenseShape',
1491       _example_indexed_slices_with_dense_shape,),
1492      ('IndexedSlicesWithoutDenseShape',
1493       _example_indexed_slices_without_dense_shape,),
1494      ('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths,
1495       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
1496      ('RaggedTensorRaggedRank2',
1497       ragged_tensor.RaggedTensor.from_nested_row_lengths,
1498       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
1499      ('SparseTensor', sparse_tensor.SparseTensor,
1500       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
1501  ])  # pyformat: disable
1502  def testReturnCompositeTensorWithDefun(self,
1503                                         factory_fn,
1504                                         factory_kwargs={},
1505                                         input_signature=None):
1506    input_ct = factory_fn(**factory_kwargs)
1507
1508    @def_function.function(input_signature=input_signature)
1509    def f():
1510      return input_ct
1511
1512    output_ct = f()
1513    self.assertIsInstance(output_ct, type(input_ct))
1514    nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
1515
1516    input_flat = nest.flatten(input_ct, expand_composites=True)
1517    output_flat = nest.flatten(output_ct, expand_composites=True)
1518    for (input_component, output_component) in zip(input_flat, output_flat):
1519      self.assertAllEqual(input_component, output_component)
1520
1521  @parameterized.named_parameters([
1522      ('IndexedSlicesWithDenseShape',
1523       _example_indexed_slices_with_dense_shape,),
1524      ('IndexedSlicesWithoutDenseShape',
1525       _example_indexed_slices_without_dense_shape,),
1526      ('RaggedTensorRaggedRank1',
1527       ragged_tensor.RaggedTensor.from_row_lengths,
1528       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
1529      ('RaggedTensorRaggedRank2',
1530       ragged_tensor.RaggedTensor.from_nested_row_lengths,
1531       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
1532      ('SparseTensor',
1533       sparse_tensor.SparseTensor,
1534       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
1535      ('RaggedTensorRaggedRank1WithSignature',
1536       ragged_tensor.RaggedTensor.from_row_lengths,
1537       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
1538       [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]),
1539      ('RaggedTensorRaggedRank2WithSignature',
1540       ragged_tensor.RaggedTensor.from_nested_row_lengths,
1541       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
1542       [ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]),
1543      ('SparseTensorWithSignature',
1544       sparse_tensor.SparseTensor,
1545       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
1546       [sparse_tensor.SparseTensorSpec([None], dtypes.int32)]),
1547  ])  # pyformat: disable
1548  def testCompositeAsArgumentTensorWithDefun(self,
1549                                             factory_fn,
1550                                             factory_kwargs={},
1551                                             input_signature=None):
1552    input_ct = factory_fn(**factory_kwargs)
1553
1554    @def_function.function(input_signature=input_signature)
1555    def f(x):
1556      return x
1557
1558    output_ct = f(input_ct)
1559    self.assertIsInstance(output_ct, type(input_ct))
1560    nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
1561
1562    input_flat = nest.flatten(input_ct, expand_composites=True)
1563    output_flat = nest.flatten(output_ct, expand_composites=True)
1564    for (input_component, output_component) in zip(input_flat, output_flat):
1565      self.assertAllEqual(input_component, output_component)
1566
1567  def testTracedCompositeDiscardsShapeInfo(self):
1568    # SparseTensorSpec intentionally excludes info about the number of elements
1569    # that are in a sparse tensor (which is recorded as st.indices.shape[0] and
1570    # st.values.shape[0]).  Similarly, RaggedTensorSpec intentionally excludes
1571    # info about the total number of values in a RaggedTensor (stored as
1572    # rt.values.shape[0]).  This test checks that the placeholders created by
1573    # tf.function() properly mask this shape info.
1574    @def_function.function
1575    def f(rt, st):
1576      self.assertEqual(st.indices.shape.as_list()[:1], [None])
1577      self.assertEqual(st.values.shape.as_list(), [None])
1578      return (rt, st)
1579
1580    rt = ragged_factory_ops.constant([[1, 2], [3]])
1581    st = sparse_tensor.SparseTensor([[0]], [0], [10])
1582    f(rt, st)
1583
1584  @test_util.run_gpu_only
1585  def testFunctionOnDevice(self):
1586    x = constant_op.constant([1.]).gpu()
1587    f = def_function.function(math_ops.add)
1588    y = f(x, x).cpu()
1589    self.assertAllEqual(y, [2.])
1590
1591  @test_util.run_gpu_only
1592  @test_util.run_in_graph_and_eager_modes
1593  def testFunctionWithResourcesOnDifferentDevices(self):
1594    with ops.device('/cpu:0'):
1595      v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
1596
1597    with ops.device('/gpu:0'):
1598      v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
1599
1600    def sum_gather():
1601      cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
1602      gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
1603      return cpu_result, gpu_result
1604
1605    defined = function.defun(sum_gather)
1606    if not context.executing_eagerly():
1607      self.evaluate(variables.global_variables_initializer())
1608    expected = self.evaluate(sum_gather())
1609    self.assertAllEqual(expected, self.evaluate(defined()))
1610
1611  @test_util.run_gpu_only
1612  @test_util.run_in_graph_and_eager_modes
1613  def testOpInFunctionWithConflictingResourceInputs(self):
1614    with ops.device('/cpu:0'):
1615      v_cpu = resource_variable_ops.ResourceVariable(
1616          [0.0, 1.0, 2.0], name='cpu')
1617      v_also_cpu = resource_variable_ops.ResourceVariable(
1618          [0.0, 1.0, 2.0], name='also_cpu')
1619
1620    with ops.device('/gpu:0'):
1621      v_gpu = resource_variable_ops.ResourceVariable(
1622          [0.0, 1.0, 2.0], name='gpu')
1623
1624    @def_function.function
1625    def resource_apply_adam():
1626      training_ops.resource_apply_adam(
1627          v_cpu.handle,
1628          v_gpu.handle,
1629          v_also_cpu.handle,
1630          1.0,  # beta1_power
1631          1.0,  # beta2_power
1632          1.0,  # learning_rate
1633          1.0,  # beta1
1634          1.0,  # beta2
1635          1.0,  # epsilon,
1636          [1.0, 1.0, 1.0],  # grad
1637          False)  # use_locking
1638      return None
1639
1640    with self.assertRaisesRegex(
1641        errors.InvalidArgumentError,
1642        'Cannot place the graph because a reference or resource edge connects '
1643        'colocation groups with incompatible assigned devices'):
1644      if not context.executing_eagerly():
1645        self.evaluate(variables.global_variables_initializer())
1646      self.evaluate(resource_apply_adam())
1647
1648  @test_util.run_gpu_only
1649  def testFunctionHandlesInputsOnDifferentDevices(self):
1650    # The Reshape op requires the shape tensor to be placed in host memory.
1651    reshape = def_function.function(array_ops.reshape)
1652    value = constant_op.constant([1., 2.]).gpu()
1653    shape = constant_op.constant([2, 1])
1654    reshaped = reshape(value, shape).cpu()
1655    self.assertAllEqual(reshaped, [[1], [2]])
1656
1657  @test_util.run_gpu_only
1658  def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self):
1659    # The Reshape op requires the shape tensor to be placed in host memory.
1660    reshape = def_function.function(array_ops.reshape)
1661    value = constant_op.constant([1., 2.])
1662    shape = constant_op.constant([2, 1]).gpu()
1663    reshape(value, shape)  # No error is raised
1664
1665  def testNoneOutput(self):
1666
1667    @def_function.function
1668    def my_function(_):
1669      return None
1670
1671    self.assertAllEqual(my_function(1), None)
1672
1673  def testNestedFunctions(self):
1674    # TensorFlow function (which is what would be used in TensorFlow graph
1675    # construction).
1676    @tf_function.Defun(dtypes.int32, dtypes.int32)
1677    def add(a, b):
1678      return math_ops.add(a, b)
1679
1680    @def_function.function
1681    def add_one(x):
1682      return add(x, 1)
1683
1684    self.assertAllEqual(3, add_one(constant_op.constant(2)))
1685
1686  def testVariableCaptureInNestedFunctions(self):
1687    v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)
1688
1689    @def_function.function
1690    def inner_read():
1691      return v.read_value()
1692
1693    @def_function.function
1694    def outer():
1695      return inner_read()
1696
1697    self.assertEqual(1, int(outer()))
1698
1699  def testReturnCapturedEagerTensor(self):
1700    t = constant_op.constant(1)
1701
1702    @def_function.function
1703    def read():
1704      return t
1705
1706    self.assertEqual(1, int(read()))
1707
1708  def testReturnCapturedGraphTensor(self):
1709    with context.graph_mode(), self.cached_session():
1710      t = constant_op.constant(1)
1711
1712      @def_function.function
1713      def read():
1714        return t
1715
1716      self.assertEqual(1, int(self.evaluate(read())))
1717
1718  def testSequenceInputs(self):
1719    clip_by_global_norm = def_function.function(clip_ops.clip_by_global_norm)
1720    t_list = [constant_op.constant(1.0), constant_op.constant(2.0)]
1721    clipped_list, global_norm = clip_by_global_norm(t_list,
1722                                                    constant_op.constant(.2))
1723    for t in clipped_list:
1724      self.assertIsInstance(t, ops.Tensor)
1725    self.assertIsInstance(global_norm, ops.Tensor)
1726
1727  def testNestedSequenceInputs(self):
1728
1729    def my_op(inputs):
1730      a, b, c = inputs
1731      e, f = b
1732      g, h = e
1733      return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c
1734
1735    my_eager_op = def_function.function(my_op)
1736    ret = my_eager_op([
1737        constant_op.constant(1), [(constant_op.constant(2),
1738                                   constant_op.constant(3)),
1739                                  constant_op.constant(4)],
1740        constant_op.constant(5)
1741    ])
1742    self.assertLen(ret, 2)
1743    self.assertAllEqual(ret[0][0], 2)
1744    self.assertAllEqual(ret[0][1][0][0], 8)
1745    self.assertAllEqual(ret[0][1][0][1], 4)
1746    self.assertIsInstance(ret[0][1][0], tuple)
1747    self.assertAllEqual(ret[0][1][1], 6)
1748    self.assertAllEqual(ret[0][2], 10)
1749    self.assertAllEqual(ret[1], 15)
1750
1751  def testVariableNamesRespectNameScopesWithDefun(self):
1752    @def_function.function
1753    def create_variable():
1754      with ops.name_scope('foo', skip_on_eager=False):
1755        v = resource_variable_ops.ResourceVariable(0.0, name='bar')
1756      self.assertEqual(v.name, 'foo/bar:0')
1757
1758    create_variable()
1759
1760  def testVariableNamesRespectNameScopesWithDefunInGraph(self):
1761    with context.graph_mode():
1762      @def_function.function
1763      def create_variable():
1764        with ops.name_scope('foo', skip_on_eager=False):
1765          v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
1766        self.assertEqual(v.name, 'foo/bar:0')
1767
1768      with ops.get_default_graph().as_default():
1769        create_variable()
1770
1771  @test_util.assert_no_new_pyobjects_executing_eagerly
1772  def testCallOptionsMemory(self):
1773
1774    @function.defun
1775    def model(x):
1776      return x + constant_op.constant(1.)
1777
1778    # This happens with a lot of option toggles, e.g. soft device placement
1779    context.context().function_call_options = None
1780    model(constant_op.constant(2.))
1781
1782  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
1783  def testLayerInDefun(self):
1784    conv = convolutional.Conv2D(
1785        filters=1,
1786        kernel_size=2,
1787        kernel_initializer=init_ops.ones_initializer(),
1788        bias_initializer=init_ops.zeros_initializer())
1789
1790    @function.defun
1791    def model(x):
1792      return conv(x)
1793
1794    x = array_ops.ones([1, 2, 2, 1])
1795    y = model(x)
1796
1797    if not context.executing_eagerly():
1798      self.evaluate(variables.global_variables_initializer())
1799
1800    self.assertAllClose([[[[4.0]]]], self.evaluate(y))
1801
1802  # Variable lifting is somewhat different between defun/tf.function, so testing
1803  # device placement on both makes sense.
1804  @parameterized.named_parameters(
1805      dict(testcase_name='Defun',
1806           function_decorator=function.defun),
1807      dict(testcase_name='DefFunction',
1808           function_decorator=def_function.function))
1809  @test_util.run_in_graph_and_eager_modes
1810  def testVariablesPlacedOnOutsideDevice(self, function_decorator):
1811
1812    class _Obj(object):
1813
1814      def __init__(self):
1815        self.v = None
1816
1817      @function_decorator
1818      def f(self):
1819        if self.v is None:
1820          self.v = variables.Variable(1.)
1821        return self.v + 1.
1822
1823    has_device = _Obj()
1824    with ops.device('cpu:0'):
1825      has_device.f()
1826    self.assertIn('CPU', has_device.v.device)
1827
1828  @test_util.run_in_graph_and_eager_modes
1829  def testMultipleDeviceCheck(self):
1830
1831    def f():
1832      with ops.device('cpu'):
1833        return test_ops.device_placement_op()
1834
1835    func = function.defun(f)
1836    with ops.device('cpu:0'):
1837      output = self.evaluate(func())
1838      self.assertIn(compat.as_bytes('CPU:0'), output)
1839
1840  @test_util.run_in_graph_and_eager_modes
1841  def testDeviceAnnotationsRespected(self):
1842
1843    def multi_device_fn():
1844      with ops.device('/cpu:0'):
1845        s0 = test_ops.device_placement_op()
1846      with ops.device('/cpu:1'):
1847        s1 = test_ops.device_placement_op()
1848      with ops.device('/cpu:2'):
1849        s2 = test_ops.device_placement_op()
1850      s3 = test_ops.device_placement_op()
1851      return s0, s1, s2, s3
1852
1853    defined = function.defun(multi_device_fn)
1854    outputs = self.evaluate(defined())
1855    self.assertLen(total_function_cache(defined), 1)
1856    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1857    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1858    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1859
1860    with ops.device('/cpu:3'):
1861      outputs = self.evaluate(defined())
1862    # All function definitions are agnostic to call site devices.
1863    self.assertLen(total_function_cache(defined), 1)
1864    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1865    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1866    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1867    self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
1868
1869    with ops.device('/cpu:0'):
1870      outputs = self.evaluate(defined())
1871    self.assertLen(total_function_cache(defined), 1)
1872    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1873    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1874    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1875    self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
1876
1877  @test_util.run_in_graph_and_eager_modes
1878  def testCallingGraphFunctionOnDifferentDevice(self):
1879
1880    def func():
1881      return constant_op.constant(0)
1882
1883    defined = def_function.function(func)
1884    with ops.device('cpu:0'):
1885      cpu_graph_function = defined.get_concrete_function()
1886
1887    with ops.device('cpu:0'):
1888      self.assertEqual(
1889          self.evaluate(cpu_graph_function()), self.evaluate(func()))
1890
1891    with ops.device('cpu:1'):
1892      self.assertEqual(0., self.evaluate(cpu_graph_function()))
1893
1894    with ops.device(None):
1895      self.assertEqual(0., self.evaluate(cpu_graph_function()))
1896
1897    default_graph_function = defined.get_concrete_function()
1898    self.assertEqual(
1899        self.evaluate(default_graph_function()), self.evaluate(func()))
1900
1901    with ops.device('cpu:1'):
1902      self.assertEqual(0., self.evaluate(default_graph_function()))
1903
1904  @test_util.run_gpu_only
1905  @test_util.run_in_graph_and_eager_modes
1906  def testColocateWithRespected(self):
1907    # TODO(b/113291792): Use multiple CPUs instead of a GPU.
1908    with ops.device('cpu:0'):
1909      x = array_ops.identity(1.0)
1910
1911    with ops.device('gpu:0'):
1912      y = array_ops.identity(1.0)
1913
1914    @def_function.function
1915    def foo():
1916      return test_ops.device_placement_op()
1917
1918    with ops.colocate_with(x):
1919      self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
1920
1921    with ops.colocate_with(y):
1922      self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
1923
1924  def testVariablesAreTracked(self):
1925    v = resource_variable_ops.ResourceVariable(1.0)
1926
1927    def foo(x):
1928      return v * x
1929
1930    defined = def_function.function(foo)
1931
1932    x = constant_op.constant([1.0])
1933    self.assertEqual(1., self.evaluate(defined(x)))
1934    v.assign(2.)
1935
1936    x = constant_op.constant([1.0, 2.0])
1937    self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
1938
1939  def testCacheObjectHashCollisions(self):
1940
1941    class Foo(object):
1942
1943      def __hash__(self):
1944        return 42
1945
1946    def func(foo):
1947      del foo
1948      return
1949
1950    defined = function.defun(func)
1951    defined(Foo())
1952    self.assertLen(total_function_cache(defined), 1)
1953
1954    defined(Foo())
1955    self.assertLen(total_function_cache(defined), 2)
1956
1957  def testCacheTensorDtypeCollision(self):
1958
1959    def func(t):
1960      return t + t
1961
1962    defined = function.defun(func)
1963    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1964    defined(t)
1965    self.assertLen(total_function_cache(defined), 1)
1966
1967    t = constant_op.constant([[1.0]], dtype=dtypes.complex128)
1968    defined(t)
1969    self.assertLen(total_function_cache(defined), 2)
1970
1971  def testCacheTensorShapeCollision(self):
1972
1973    def func(t):
1974      return t + t
1975
1976    defined = function.defun(func)
1977    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1978    defined(t)
1979    self.assertLen(total_function_cache(defined), 1)
1980
1981    t = constant_op.constant([1.0], dtype=dtypes.complex64)
1982    defined(t)
1983    self.assertLen(total_function_cache(defined), 2)
1984
1985  def testCacheTensorShapeDtypeCollision(self):
1986
1987    def func(t):
1988      return t + t
1989
1990    defined = function.defun(func)
1991    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1992    defined(t)
1993    self.assertLen(total_function_cache(defined), 1)
1994
1995    t = constant_op.constant([1.0], dtype=dtypes.complex128)
1996    defined(t)
1997    self.assertLen(total_function_cache(defined), 2)
1998
1999  def testCacheTensorUnknownShapesCollisionRelaxedShapes(self):
2000
2001    def func(t):
2002      return t + t
2003
2004    with context.graph_mode(), self.cached_session():
2005      defined = function.defun(func, experimental_relax_shapes=True)
2006
2007      p = array_ops.placeholder(dtype=dtypes.float32, shape=[])
2008      defined(p)
2009      self.assertLen(total_function_cache(defined), 1)
2010
2011      p = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
2012      defined(p)
2013      self.assertLen(total_function_cache(defined), 2)
2014
2015      p = array_ops.placeholder(dtype=dtypes.float32, shape=[2])
2016      defined(p)
2017      # Gradual shape relaxation is performed; and the common shape between
2018      # [1] and [2] is one containing unknown dimensions.
2019      self.assertLen(total_function_cache(defined), 2)
2020
2021      # pylint: disable=protected-access
2022      self.assertLen(defined._function_cache.arg_relaxed_specs, 1)
2023      relaxed_specs = (
2024          list(defined._function_cache.arg_relaxed_specs.values())[0])
2025      self.assertLen(relaxed_specs, 1)
2026      relaxed_shape = relaxed_specs[0].shape
2027      # pylint: enable=protected-access
2028      self.assertEqual(relaxed_shape.rank, 1)
2029      self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None)
2030
2031      t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32)
2032      defined(t)
2033      # Shape (3,) matches the relaxed shape TensorShape([None])
2034      self.assertLen(total_function_cache(defined), 2)
2035
2036  def testPythonFunctionWithDefaultArgs(self):
2037
2038    def func(foo, bar=1, baz=2):
2039      del foo
2040      del bar
2041      del baz
2042      return
2043
2044    defined = function.defun(func)
2045    defined(0, baz=20)
2046    self.assertLen(total_function_cache(defined), 1)
2047
2048    defined(1)  # bar=1, baz=2
2049    self.assertLen(total_function_cache(defined), 2)
2050
2051    # This matches the previous call.
2052    defined(foo=1)
2053    self.assertLen(total_function_cache(defined), 2)
2054
2055    defined(1, 2, 3)
2056    self.assertLen(total_function_cache(defined), 3)
2057
2058    # This matches the previous call.
2059    defined(1, bar=2, baz=3)
2060    self.assertLen(total_function_cache(defined), 3)
2061
2062    # This matches the previous call.
2063    defined(1, baz=3, bar=2)
2064    self.assertLen(total_function_cache(defined), 3)
2065
2066  def testDatasetIteratorCaching(self):
2067    def func(it1, it2):
2068      next(it1)
2069      next(it2)
2070      return 0
2071
2072    defined = function.defun(func)
2073
2074    d = dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3])
2075    it1 = iter(d)
2076    it2 = iter(d)
2077    _ = defined(it1, it2)  # The two iterators are different
2078    self.assertLen(total_function_cache(defined), 1)
2079
2080    it3 = iter(d)
2081    it4 = iter(d)
2082    _ = defined(it3, it4)  # The two iterators are different, should not retrace
2083    self.assertLen(total_function_cache(defined), 1)
2084
2085    it5 = iter(d)
2086    _ = defined(it5, it5)  # The two iterators are the same, should retrace
2087    self.assertLen(total_function_cache(defined), 2)
2088
2089  def testFunctoolsPartialUnwrappedCorrectly(self):
2090
2091    def full_function(a, b, c=3):
2092      return a, b, c
2093
2094    partial = functools.partial(full_function, 1, c=4)
2095    a, b, c = partial(2)
2096
2097    defined = function.defun(partial)
2098    func_a, func_b, func_c = defined(2)
2099    self.assertEqual(func_a.numpy(), a)
2100    self.assertEqual(func_b.numpy(), b)
2101    self.assertEqual(func_c.numpy(), c)
2102
2103  def testInputSignatureWithMatchingInputs(self):
2104
2105    def foo(a):
2106      self.assertEqual(a.shape, (2,))
2107      return a
2108
2109    signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
2110    defined = function.defun(foo, input_signature=signature)
2111    a = array_ops.ones([2])
2112    self.assertAllEqual(a, defined(a))
2113    self.assertLen(total_function_cache(defined), 1)
2114    self.assertAllEqual(a, defined.get_concrete_function()(a))
2115    self.assertAllEqual(a, defined.get_concrete_function(a)(a))
2116    self.assertAllEqual(a, defined.get_concrete_function(
2117        tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a))
2118    self.assertLen(total_function_cache(defined), 1)
2119
2120    def bar(a):
2121      self.assertEqual(a._shape_tuple(), (2, None))
2122      return a
2123
2124    signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)]
2125    defined = function.defun(bar, input_signature=signature)
2126    a = array_ops.ones([2, 1])
2127    out = defined(a)
2128    self.assertLen(total_function_cache(defined), 1)
2129    self.assertAllEqual(out, a)
2130
2131    # Changing the second dimension shouldn't create a new function.
2132    b = array_ops.ones([2, 3])
2133    out = defined(b)
2134    self.assertLen(total_function_cache(defined), 1)
2135    self.assertAllEqual(out, b)
2136
2137  def testInputSignatureWithDictInPositionalArgs(self):
2138
2139    @function.defun
2140    def f(*_args, **_kwargs):
2141      return None
2142
2143    f(1, x=2)
2144    self.assertLen(total_function_cache(f), 1)
2145    f(1, x=2)
2146    self.assertLen(total_function_cache(f), 1)
2147    f(1, {'x': 2})
2148    self.assertLen(total_function_cache(f), 2)
2149
2150  def testInputSignatureWithCompatibleInputs(self):
2151
2152    rank2_spec = tensor_spec.TensorSpec(shape=(None, None),
2153                                        dtype=dtypes.float32)
2154
2155    @function.defun(input_signature=[rank2_spec])
2156    def func(a):
2157      self.assertEqual([None, None], a.shape.as_list())
2158      return array_ops.shape(a)
2159
2160    self.assertAllEqual([3, 1], func([[0], [1.0], [1]]))
2161    self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]])))
2162
2163    with self.assertRaisesRegex(ValueError, 'incompatible'):
2164      func([0.0, 1.0, 2.0])  # Wrong shape.
2165
2166    with self.assertRaisesRegex(ValueError, 'incompatible'):
2167      func([['wrong dtype']])
2168
2169  def testNoKeywordOnlyArgumentsWithInputSignature(self):
2170    if sys.version_info[0] < 3:
2171      self.skipTest('keyword_only arguments only exist in Python 3.')
2172
2173    func = eval('lambda x, *, y: x')  # pylint: disable=eval-used
2174    signature = [tensor_spec.TensorSpec(None, dtypes.int32)]
2175    with self.assertRaisesRegex(
2176        ValueError, 'Cannot define a TensorFlow function from a Python '
2177        'function with keyword-only arguments when input_signature is '
2178        'provided.'):
2179      def_function.function(func, signature)
2180
2181  def testNestedInputSignatures(self):
2182
2183    def expected_foo(a, b):
2184      return [a, b]
2185
2186    @function.defun(input_signature=[
2187        [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2,
2188        tensor_spec.TensorSpec((1,), dtypes.float32),
2189    ])
2190    def foo(a, b):
2191      self.assertEqual(a[0]._shape_tuple(), (2, None))
2192      self.assertEqual(a[1]._shape_tuple(), (2, None))
2193      self.assertEqual(b._shape_tuple(), (1,))
2194      return [a, b]
2195
2196    a = array_ops.ones([2, 1])
2197    b = array_ops.ones([1])
2198    expected = expected_foo([a, a], b)
2199    out = foo([a, a], b)
2200    self.assertLen(total_function_cache(foo), 1)
2201    nest.assert_same_structure(out, expected)
2202    self.assertAllEqual(out[0][0], a)
2203    self.assertAllEqual(out[0][1], a)
2204    self.assertAllEqual(out[1], b)
2205
2206    # Changing the unspecified dimensions shouldn't create a new function.
2207    a = array_ops.ones([2, 3])
2208    b = array_ops.ones([2, 5])
2209    c = array_ops.ones([1])
2210    expected = expected_foo([a, b], c)
2211    out = foo([a, b], c)
2212    self.assertLen(total_function_cache(foo), 1)
2213    nest.assert_same_structure(out, expected)
2214    self.assertAllEqual(out[0][0], a)
2215    self.assertAllEqual(out[0][1], b)
2216    self.assertAllEqual(out[1], c)
2217
2218    # Passing compatible inputs should work.
2219    a = a.numpy().tolist()
2220    b = b.numpy().tolist()
2221    c = c.numpy().tolist()
2222    out = foo([a, b], c)
2223    self.assertLen(total_function_cache(foo), 1)
2224    nest.assert_same_structure(out, expected)
2225    self.assertAllEqual(out[0][0], a)
2226    self.assertAllEqual(out[0][1], b)
2227    self.assertAllEqual(out[1], c)
2228
2229  def testNestedInputSignaturesWithDict(self):
2230    def expected_bar(a):
2231      return a
2232
2233    @function.defun(input_signature=[{
2234        'a': tensor_spec.TensorSpec((2, None), dtypes.float32),
2235        'b': tensor_spec.TensorSpec((2, None), dtypes.float32),
2236        'c': tensor_spec.TensorSpec((1,), dtypes.float32)}])
2237    def bar(a):
2238      self.assertEqual(a['a']._shape_tuple(), (2, None))
2239      self.assertEqual(a['b']._shape_tuple(), (2, None))
2240      self.assertEqual(a['c']._shape_tuple(), (1,))
2241      return a
2242
2243    a = array_ops.ones([2, 3])
2244    b = array_ops.ones([1])
2245    inputs = {'a': a, 'b': a, 'c': b}
2246    expected = expected_bar(inputs)
2247    out = bar(inputs)
2248    nest.assert_same_structure(out, expected)
2249    self.assertAllEqual(out['a'], expected['a'])
2250    self.assertAllEqual(out['b'], expected['b'])
2251    self.assertAllEqual(out['c'], expected['c'])
2252
2253    # Passing compatible inputs should work.
2254    a = a.numpy().tolist()
2255    b = b.numpy().tolist()
2256    inputs = {'a': a, 'b': a, 'c': b}
2257    out = bar(inputs)
2258    nest.assert_same_structure(out, expected)
2259    self.assertAllEqual(out['a'], expected['a'])
2260    self.assertAllEqual(out['b'], expected['b'])
2261    self.assertAllEqual(out['c'], expected['c'])
2262
2263  def testInputSignatureMustBeSequenceOfTensorSpecs(self):
2264
2265    def foo(a, b):
2266      del a
2267      del b
2268
2269    # Signatures must consist exclusively of `TensorSpec` objects.
2270    signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
2271    with self.assertRaisesRegex(TypeError, 'input_signature.*nested sequence'):
2272      def_function.function(foo, input_signature=signature)
2273
2274    # Signatures must be either lists or tuples on their outermost levels.
2275    signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
2276    with self.assertRaisesRegex(
2277        TypeError, 'input_signature must be either a '
2278        'tuple or a list.*'):
2279      function.defun(foo, input_signature=signature)
2280
2281  @test_util.run_in_graph_and_eager_modes
2282  def testInputsIncompatibleWithSignatureRaisesError(self):
2283
2284    def foo(a):
2285      return a
2286
2287    signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
2288    defined = def_function.function(foo, input_signature=signature)
2289
2290    # Invalid shapes.
2291    with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
2292      defined(array_ops.ones([3]))
2293
2294    with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
2295      defined(array_ops.ones([2, 1]))
2296
2297    # Wrong number of arguments.
2298    with self.assertRaisesRegex(TypeError, 'specifies 1 .* got 2'):
2299      defined(array_ops.ones([2]), array_ops.ones([2]))
2300    with self.assertRaisesRegex(ValueError,
2301                                'Structure of Python function inputs.*'):
2302      defined()
2303
2304    with self.assertRaisesRegex(ValueError,
2305                                'inputs incompatible with input_signature'):
2306      defined.get_concrete_function(
2307          tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32))
2308
2309  def testMismatchedConcreteSignatureRaisesError(self):
2310
2311    @def_function.function
2312    def run_test():
2313      @def_function.function
2314      def f(x):
2315        return x
2316
2317      with self.assertRaisesRegex(
2318          TypeError, 'ConcreteFunction .* was constructed .* but was called'):
2319        f.get_concrete_function(1)(constant_op.constant(1))
2320
2321      with self.assertRaisesRegex(TypeError, r'f\(x\) expected .* but got .*'):
2322        f.get_concrete_function(constant_op.constant(1))(1)
2323
2324      with self.assertRaisesRegex(
2325          TypeError, 'ConcreteFunction .* was constructed .* but was called'):
2326        f.get_concrete_function(1)(2)
2327
2328    run_test()
2329
2330  def testInputsIncompatibleWithNestedSignatureRaisesError(self):
2331
2332    def foo(a, b):
2333      return [a, b]
2334
2335    signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2,
2336                 [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2]
2337    defined = function.defun(foo, input_signature=signature)
2338    a = array_ops.ones([1])
2339
2340    with self.assertRaisesRegex(ValueError,
2341                                'Structure of Python function inputs.*'):
2342      defined([a, a, a], [a])
2343
2344    with self.assertRaisesRegex(ValueError,
2345                                'Structure of Python function inputs.*'):
2346      defined([a], [a, a, a])
2347    defined([a, a], [a, a])
2348
2349  def testUnderspecifiedInputSignature(self):
2350    @function.defun(input_signature=[
2351        tensor_spec.TensorSpec([], dtypes.float32),
2352    ])
2353    def foo(a, training=True):
2354      if training:
2355        return a
2356      else:
2357        return -1.0 * a
2358
2359    x = constant_op.constant(1.0)
2360    with self.assertRaisesRegex(
2361        TypeError, 'got keyword argument `training` '
2362        'that was not included in input_signature'):
2363      foo(x, training=True)
2364
2365    with self.assertRaisesRegex(
2366        TypeError, 'got keyword argument `training` '
2367        'that was not included in input_signature'):
2368      foo(x, training=False)
2369
2370    self.assertAllEqual(x.numpy(), foo(x).numpy())
2371
2372  def testInputSignatureWithPartialFunction(self):
2373    def full_function(a, b, c=3.0):
2374      return a, b, c
2375
2376    partial = functools.partial(full_function, 1, c=4)
2377    a, b, c = partial(2.0)
2378    signature = [tensor_spec.TensorSpec([], dtypes.float32)]
2379    defined = function.defun(partial, input_signature=signature)
2380    x = constant_op.constant(2.0)
2381    func_a, func_b, func_c = defined(x)
2382    self.assertEqual(func_a.numpy(), a)
2383    self.assertEqual(func_b.numpy(), b)
2384    self.assertEqual(func_c.numpy(), c)
2385
2386  def testInputSignatureConversionWithDefaultArg(self):
2387
2388    def foo(a, training=True):
2389      if training:
2390        return a
2391      else:
2392        return -1.0 * a
2393
2394    signature = [
2395        tensor_spec.TensorSpec([], dtypes.float32),
2396        tensor_spec.TensorSpec([], dtypes.bool),
2397    ]
2398    defined = def_function.function(foo, input_signature=signature)
2399    a = constant_op.constant(1.0)
2400    self.assertAllEqual(a.numpy(), defined(a))
2401    self.assertAllEqual(a.numpy(), defined(a, training=True))
2402    self.assertAllEqual(-a.numpy(), defined(a, training=False))
2403
2404  def testInputSignatureWithKeywordPositionalArgs(self):
2405
2406    @function.defun(input_signature=[
2407        tensor_spec.TensorSpec([], dtypes.float32),
2408        tensor_spec.TensorSpec([], dtypes.int64)
2409    ])
2410    def foo(flt, integer):
2411      return flt, integer
2412
2413    flt = constant_op.constant(1.0)
2414    integer = constant_op.constant(2, dtypes.int64)
2415
2416    out1, out2 = foo(flt, integer)
2417    self.assertLen(total_function_cache(foo), 1)
2418    self.assertEqual(out1.numpy(), 1.0)
2419    self.assertEqual(out2.numpy(), 2)
2420
2421    out1, out2 = foo(flt=flt, integer=integer)
2422    self.assertLen(total_function_cache(foo), 1)
2423    self.assertEqual(out1.numpy(), 1.0)
2424    self.assertEqual(out2.numpy(), 2)
2425
2426    out1, out2 = foo(integer=integer, flt=flt)
2427    self.assertLen(total_function_cache(foo), 1)
2428    self.assertEqual(out1.numpy(), 1.0)
2429    self.assertEqual(out2.numpy(), 2)
2430
2431    out1, out2 = foo(flt, integer=integer)
2432    self.assertLen(total_function_cache(foo), 1)
2433    self.assertEqual(out1.numpy(), 1.0)
2434    self.assertEqual(out2.numpy(), 2)
2435
2436  def testInputSignatureWithKeywordArgs(self):
2437    def foo(a, b, **kwargs):
2438      del kwargs
2439      return a, b
2440
2441    x = function.defun(
2442        foo,
2443        input_signature=[
2444            tensor_spec.TensorSpec([], dtypes.float32),
2445            tensor_spec.TensorSpec([], dtypes.int32)
2446        ]).get_concrete_function()
2447    result = x(constant_op.constant(5.0), constant_op.constant(5))
2448    self.assertAllEqual(result, [5.0, 5])
2449
2450  def testInputSignatureWithCompositeTensors(self):
2451    def f(rt):
2452      self.assertEqual(rt.values.shape.as_list(), [None])
2453      self.assertEqual(rt.row_splits.shape.as_list(), [4])
2454      return rt
2455
2456    signature = [ragged_tensor.RaggedTensorSpec(
2457        shape=[3, None], dtype=dtypes.int32)]
2458    defined = function.defun(f, input_signature=signature)
2459    rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]])
2460    out1 = defined(rt1)
2461    self.assertLen(total_function_cache(defined), 1)
2462    self.assertAllEqual(out1.values, rt1.values)
2463    self.assertAllEqual(out1.row_splits, rt1.row_splits)
2464
2465    # Changing the row lengths shouldn't create a new function.
2466    rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]])
2467    out2 = defined(rt2)
2468    self.assertLen(total_function_cache(defined), 1)
2469    self.assertAllEqual(out2.values, rt2.values)
2470    self.assertAllEqual(out2.row_splits, rt2.row_splits)
2471
2472    # Different number of rows
2473    rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]])
2474    with self.assertRaisesRegex(ValueError, 'incompatible'):
2475      defined(rt3)
2476
2477    # Different dtype
2478    rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]])
2479    with self.assertRaisesRegex(ValueError, 'Structure .* does not match'):
2480      defined(rt4)
2481
2482    # Different rank
2483    rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]])
2484    with self.assertRaisesRegex(ValueError, 'does not match'):
2485      defined(rt5)
2486
2487  def testInputSignatureWithVariableArgs(self):
2488
2489    def f(v):
2490      v.assign_add(1)
2491
2492    signature = [
2493        resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
2494    ]
2495    defined = function.defun(f, input_signature=signature)
2496
2497    v1 = variables.Variable(0)
2498    v2 = variables.Variable(0)
2499
2500    defined(v1)
2501    self.assertEqual(v1.numpy(), 1)
2502    self.assertEqual(v2.numpy(), 0)
2503
2504    defined(v=v2)
2505    self.assertEqual(v1.numpy(), 1)
2506    self.assertEqual(v2.numpy(), 1)
2507
2508  def testTensorKeywordArguments(self):
2509
2510    def foo(a, b):
2511      del a
2512      return b
2513
2514    defined = function.defun(foo)
2515    a = constant_op.constant(2.0)
2516    b = constant_op.constant([1.0, 2.0])
2517    one = defined(a, b)
2518    self.assertLen(total_function_cache(defined), 1)
2519
2520    two = defined(a=a, b=b)
2521    self.assertLen(total_function_cache(defined), 1)
2522
2523    three = defined(b=b, a=a)
2524    self.assertLen(total_function_cache(defined), 1)
2525
2526    four = defined(a, b=b)
2527    self.assertLen(total_function_cache(defined), 1)
2528
2529    # The next call corresponds to a new input signature, hence
2530    # we expect another function to be defined.
2531    five = defined(b, a)
2532    self.assertLen(total_function_cache(defined), 2)
2533
2534    six = defined(a=b, b=a)
2535    self.assertLen(total_function_cache(defined), 2)
2536
2537    seven = defined(b=a, a=b)
2538    self.assertLen(total_function_cache(defined), 2)
2539
2540    self.assertAllEqual(one, [1.0, 2.0])
2541    self.assertAllEqual(two, [1.0, 2.0])
2542    self.assertAllEqual(three, [1.0, 2.0])
2543    self.assertAllEqual(four, [1.0, 2.0])
2544    self.assertAllEqual(five, 2.0)
2545    self.assertAllEqual(six, 2.0)
2546    self.assertAllEqual(seven, 2.0)
2547
2548  def testDefuningInstanceMethod(self):
2549
2550    integer = constant_op.constant(2, dtypes.int64)
2551
2552    class Foo(object):
2553
2554      def one(self, tensor):
2555        return tensor
2556
2557      @def_function.function
2558      def two(self, tensor, other=integer):
2559        return self.one(tensor), other
2560
2561    foo = Foo()
2562    t = constant_op.constant(1.0)
2563    one, two = foo.two(t)
2564    self.assertEqual(one.numpy(), 1.0)
2565    self.assertEqual(two.numpy(), 2)
2566
2567  def testDefuningInstanceMethodWithDefaultArgument(self):
2568
2569    integer = constant_op.constant(2, dtypes.int64)
2570
2571    class Foo(object):
2572
2573      @def_function.function
2574      def func(self, other=integer):
2575        return other
2576
2577    foo = Foo()
2578    self.assertEqual(foo.func().numpy(), int(integer))
2579
2580  def testPythonCallWithSideEffects(self):
2581    state = []
2582
2583    @def_function.function
2584    def side_effecting_function():
2585      state.append(0)
2586
2587    side_effecting_function()
2588    self.assertAllEqual(state, [0])
2589
2590    # The second invocation should call the graph function, which shouldn't
2591    # trigger the list append.
2592    side_effecting_function()
2593    self.assertAllEqual(state, [0])
2594
2595    # Whereas calling the python function directly should create a side-effect.
2596    side_effecting_function.python_function()
2597    self.assertAllEqual(state, [0, 0])
2598
2599  def testFunctionWithNestedFunctionCallAndSideEffects(self):
2600    v1 = variables.Variable(1.0)
2601    v2 = variables.Variable(1.0)
2602
2603    @def_function.function
2604    def add_one(a):
2605      a.assign_add(1.0)
2606
2607    # Grappler will inline calls to `add_one` into the function body, we check
2608    # that all side-effects were executed.
2609    @def_function.function
2610    def side_effecting_function(a, b):
2611      add_one(a)
2612      add_one(b)
2613      return a + b
2614
2615    result = side_effecting_function(v1, v2)
2616    self.assertEqual(result.numpy(), 4.0)
2617
2618  def testFunctionWithExtraAttributes(self):
2619    @function.defun_with_attributes(attributes={'experimental_1': 'value1',
2620                                                'experimental_2': 2})
2621    def matmul(x, y):
2622      return math_ops.matmul(x, y)
2623
2624    def add(x, y):
2625      return math_ops.add(x, y)
2626    defun_add = function.defun_with_attributes(
2627        add, attributes={'experimental_3': True, 'experimental_4': 1.0})
2628
2629    with context.graph_mode(), self.cached_session():
2630      with ops.get_default_graph().as_default():
2631        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2632        sq = matmul(t, t)
2633        double = defun_add(t, t)
2634        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
2635        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
2636
2637        graph = ops.get_default_graph()
2638        # pylint: disable=protected-access
2639        self.assertLen(graph._functions, 2)
2640        functions = list(graph._functions.values())
2641        self.assertRegex(functions[0].definition.signature.name, '.*matmul.*')
2642        attrs = functions[0].definition.attr
2643        self.assertLen(attrs, 2)
2644        self.assertEqual(attrs['experimental_1'].s, b'value1')
2645        self.assertEqual(attrs['experimental_2'].i, 2)
2646
2647        self.assertRegex(functions[1].definition.signature.name, '.*add.*')
2648        attrs = functions[1].definition.attr
2649        self.assertLen(attrs, 2)
2650        self.assertEqual(attrs['experimental_3'].b, True)
2651        self.assertEqual(attrs['experimental_4'].f, 1.0)
2652        # pylint: enable=protected-access
2653
2654  def testFunctionWithInvalidAttribute(self):
2655    @function.defun_with_attributes(attributes={'experimental_1': ['value1']})
2656    def add(x, y):
2657      return math_ops.add(x, y)
2658
2659    with self.assertRaisesRegex(ValueError,
2660                                'Attribute experimental_1 must be .* Got .*'):
2661      with context.graph_mode(), self.cached_session():
2662        with ops.get_default_graph().as_default():
2663          t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2664          add(t, t)
2665
2666  def testRegisterFunction(self):
2667
2668    @function.defun
2669    def add(x, y):
2670      return math_ops.add(x, y)
2671
2672    def matmul(x, y):
2673      return math_ops.matmul(x, y)
2674    defun_matmul = function.defun(matmul)
2675
2676    with context.graph_mode(), self.cached_session():
2677      with ops.get_default_graph().as_default():
2678        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2679        function.register(defun_matmul, t, t)
2680        function.register(add, t, t)
2681
2682        graph = ops.get_default_graph()
2683        # pylint: disable=protected-access
2684        self.assertLen(graph._functions, 6)
2685        # two sets of functions, each of them are (inference, forward, backward)
2686        functions = list(graph._functions.values())
2687        captured_function_names = [
2688            f.definition.signature.name for f in functions
2689        ]
2690        expected_func_name_regex = [
2691            '.*inference.*matmul.*',
2692            '.*forward.*matmul.*',
2693            '.*inference.*backward.*matmul.*',
2694            '.*inference.*add.*',
2695            '.*forward.*add.*',
2696            '.*inference.*backward.*add.*',
2697        ]
2698        for i in range(len(functions)):
2699          self.assertRegex(captured_function_names[i],
2700                           expected_func_name_regex[i])
2701
2702        # Check the forward and backward function has the correct attributes.
2703        self.assertEqual(
2704            functions[1].definition.attr['backward_function_name'].s,
2705            functions[2].name)
2706        self.assertEqual(
2707            functions[2].definition.attr['forward_function_name'].s,
2708            functions[1].name)
2709
2710        self.assertEqual(
2711            functions[4].definition.attr['backward_function_name'].s,
2712            functions[5].name)
2713        self.assertEqual(
2714            functions[5].definition.attr['forward_function_name'].s,
2715            functions[4].name)
2716
2717        sq = defun_matmul(t, t)
2718        double = add(t, t)
2719        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
2720        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
2721        # Make sure the pre registered function is used, and no other function
2722        # is added.
2723        self.assertLen(graph._functions, 6)
2724        functions = list(graph._functions.values())
2725        for i in range(len(functions)):
2726          self.assertEqual(captured_function_names[i],
2727                           functions[i].definition.signature.name)
2728
2729  @parameterized.named_parameters(
2730      dict(testcase_name='Defun',
2731           function_decorator=function.defun),
2732      dict(testcase_name='DefFunction',
2733           function_decorator=def_function.function))
2734  def testRegisterConcreteFunction(self, function_decorator):
2735    @function_decorator
2736    def py_add(x, y):
2737      return math_ops.add(x, y)
2738
2739    py_add(array_ops.ones([]), array_ops.ones([]))
2740    add = py_add.get_concrete_function(
2741        tensor_spec.TensorSpec(None, dtypes.float32),
2742        tensor_spec.TensorSpec(None, dtypes.float32))
2743
2744    @function_decorator
2745    def py_composite(x, y):
2746      return x, add(x, y)
2747
2748    py_composite(array_ops.ones([]), array_ops.ones([]))
2749    composite = py_composite.get_concrete_function(
2750        tensor_spec.TensorSpec(None, dtypes.float32),
2751        tensor_spec.TensorSpec(None, dtypes.float32))
2752
2753    with context.graph_mode(), self.cached_session():
2754      with ops.get_default_graph().as_default():
2755        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2756        composite.add_to_graph()
2757        composite.add_gradient_functions_to_graph()
2758
2759        graph = ops.get_default_graph()
2760        # pylint: disable=protected-access
2761        self.assertLen(graph._functions, 6)
2762        # two sets of functions, each of them are (inference, forward, backward)
2763        functions = list(graph._functions.values())
2764        captured_function_names = [
2765            f.definition.signature.name for f in functions
2766        ]
2767        expected_func_name_regex = [
2768            '.*inference.*py_composite.*',
2769            '.*inference.*py_add.*',
2770            '.*forward.*py_composite.*',
2771            '.*forward.*py_add.*',
2772            '.*inference.*backward.*py_composite.*',
2773            '.*inference.*backward.*py_add.*',
2774        ]
2775        for expected, found in zip(
2776            expected_func_name_regex,
2777            captured_function_names):
2778          self.assertRegex(found, expected)
2779
2780        composite_t, composite_double = composite(t, t)
2781        double = add(t, t)
2782        self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double))
2783        self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double))
2784        self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t))
2785        # Make sure the pre registered function is used, and no other function
2786        # is added.
2787        self.assertLen(graph._functions, 6)
2788
2789  @parameterized.named_parameters(
2790      dict(testcase_name='Defun',
2791           function_decorator=function.defun),
2792      dict(testcase_name='DefFunction',
2793           function_decorator=def_function.function))
2794  def testEagerCaptures(self, function_decorator):
2795    with context.eager_mode():
2796      large_tensor = array_ops.ones(shape=(256,))
2797      self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD)
2798
2799      small_tensor = array_ops.ones(shape=(4,))
2800      self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD)
2801
2802      v = resource_variable_ops.ResourceVariable(0.0)
2803
2804    for captured, op_type in [(large_tensor, 'Placeholder'),
2805                              (small_tensor, 'Const'), (v, 'Placeholder')]:
2806      @function_decorator
2807      def test_fn():
2808        return captured + 1  # pylint: disable=cell-var-from-loop
2809
2810      g = test_fn.get_concrete_function().graph
2811      internal_captures = g.internal_captures
2812      self.assertLen(internal_captures, 1)
2813      self.assertEqual(internal_captures[0].op.type, op_type)
2814
2815  def testRegisterFunctionWithInputSignature(self):
2816    def matmul(x, y):
2817      return math_ops.matmul(x, y)
2818    defun_matmul = function.defun(
2819        matmul,
2820        input_signature=[
2821            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
2822            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
2823        ])
2824    with context.graph_mode(), self.cached_session():
2825      with ops.get_default_graph().as_default():
2826        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2827        function.register(defun_matmul, t, t)
2828
2829        graph = ops.get_default_graph()
2830        # pylint: disable=protected-access
2831        self.assertLen(graph._functions, 3)
2832
2833        # Test register function with cache, note inputs are ignored.
2834        function.register(defun_matmul)
2835        graph = ops.get_default_graph()
2836        self.assertLen(graph._functions, 3)
2837
2838  def testRegisterFunctionWithCache(self):
2839    def matmul(x, y):
2840      return math_ops.matmul(x, y)
2841    defun_matmul = function.defun(matmul)
2842
2843    with context.graph_mode(), self.cached_session():
2844      with ops.get_default_graph().as_default():
2845        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2846        t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
2847        function.register(defun_matmul, t, t)
2848        function.register(defun_matmul, t2, t2)
2849
2850        graph = ops.get_default_graph()
2851        # Only one function is registered since the input param are in same type
2852        # pylint: disable=protected-access
2853        self.assertLen(graph._functions, 3)
2854
2855  def testCallingFunctionWithDifferentVariables(self):
2856
2857    @function.defun
2858    def foo(v):
2859      v.assign_add(1.0)
2860      return v.read_value()
2861
2862    v = resource_variable_ops.ResourceVariable(0.0)
2863    graph_function = foo.get_concrete_function(v)
2864    self.assertLen(graph_function.inputs, 1)
2865    self.assertEmpty(graph_function.captured_inputs)
2866
2867    self.assertEqual(float(graph_function(v)), 1.0)
2868    self.assertEqual(float(graph_function(v)), 2.0)
2869
2870    w = resource_variable_ops.ResourceVariable(0.0)
2871
2872    @function.defun
2873    def bar(v):
2874      del v
2875      return constant_op.constant(1.0)
2876
2877    graph_function = bar.get_concrete_function(v)
2878    self.assertEqual(float(graph_function(v)), 1.0)
2879    self.assertEqual(float(graph_function(w)), 1.0)
2880
2881  def testCallingFunctionWithNonTensorsFails(self):
2882
2883    @function.defun
2884    def foo(x):
2885      return x
2886
2887    graph_function = foo.get_concrete_function(constant_op.constant(1.0))
2888    with self.assertRaises((TypeError, ValueError)):
2889      graph_function('Not a Tensor.')
2890
2891  def testSwapImplementationWithGrapplerPlugin(self):
2892    # Set the min_graph_nodes to -1 since the graph in this test is too small,
2893    # and will be ignored by grappler if don't set this.
2894    rewrites = rewriter_config_pb2.RewriterConfig()
2895    rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON
2896    rewrites.min_graph_nodes = -1
2897    graph_options = config_pb2.GraphOptions(
2898        rewrite_options=rewrites, build_cost_model=1)
2899    config_proto = config_pb2.ConfigProto(graph_options=graph_options)
2900
2901    with context.graph_mode(), self.cached_session(
2902        config=config_proto, graph=ops.Graph(), use_gpu=True):
2903
2904      @function.defun_with_attributes(
2905          attributes={
2906              'api_implements': 'random_boost',
2907              'api_preferred_device': 'CPU'
2908          })
2909      def cpu_boost(x):
2910        return math_ops.add(x, 2.0)
2911
2912      @function.defun_with_attributes(
2913          attributes={
2914              'api_implements': 'random_boost',
2915              'api_preferred_device': 'GPU'
2916          })
2917      def gpu_boost(x):
2918        return math_ops.add(x, 4.0)
2919
2920      x = constant_op.constant(1.0)
2921
2922      function.register(cpu_boost, x)
2923      y = gpu_boost(x)
2924      y_value = self.evaluate(y)
2925
2926      if test.is_gpu_available():
2927        self.assertEqual(y_value, 5.0)
2928      else:
2929        # Grappler fallback to use the CPU impl even called with GPU function.
2930        self.assertEqual(y_value, 3.0)
2931
2932  @test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior '
2933                          'equivalent to implementation_selector for function')
2934  def testSwapImplementationInEager(self):
2935    if not context.executing_eagerly():
2936      self.skipTest('eager only')
2937
2938    # testSharedRendezvous sets the disable_meta_optimizer flag to True
2939    # if that subtest runs before this one, then having that set to True
2940    # will cause this subtest to fail. To avoid that scenario, explicitly
2941    # set the disable_meta_optimizer flag to false here
2942    context.context().set_optimizer_experimental_options({
2943        'min_graph_nodes': -1,
2944        'implementation_selector': True,
2945        'disable_meta_optimizer': False
2946    })
2947
2948    @function.defun_with_attributes(
2949        attributes={'api_implements': 'foo',
2950                    'api_preferred_device': 'CPU'})
2951    def on_cpu(x):
2952      return x + 2
2953
2954    @function.defun_with_attributes(
2955        attributes={'api_implements': 'foo',
2956                    'api_preferred_device': 'GPU'})
2957    def on_gpu(x):
2958      return x + 4
2959
2960    @function.defun
2961    def run_on_cpu(t):
2962      function.register(on_cpu, t)
2963      with ops.device('CPU:0'):
2964        return on_gpu(t)
2965
2966    # Expect to run the on_cpu branch, regardless whether gpu is available.
2967    self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3)
2968
2969  def testDefunFunctionSeparateGraphs(self):
2970    with context.graph_mode():
2971
2972      @function.defun
2973      def add(x):
2974        return x + 5
2975
2976      @function.defun
2977      def maybe_add(x, should_add):
2978        if should_add:
2979          return add(x)
2980        else:
2981          return x
2982
2983      with ops.Graph().as_default():
2984        x = constant_op.constant(11)
2985        maybe_add(x, True)
2986        self.assertLen(total_function_cache(maybe_add), 1)
2987        self.assertLen(total_function_cache(add), 1)
2988
2989        maybe_add(x, False)
2990        self.assertLen(total_function_cache(maybe_add), 2)
2991        self.assertLen(total_function_cache(add), 1)
2992
2993      with ops.Graph().as_default():
2994        x = constant_op.constant(11)
2995        maybe_add(x, True)
2996        self.assertLen(total_function_cache(maybe_add), 3)
2997        self.assertLen(total_function_cache(add), 2)
2998
2999  def testCacheKeyOverlappingShapes(self):
3000    @function.defun
3001    def defined(t):
3002      return t
3003
3004    defined(array_ops.zeros([12, 1]))
3005    self.assertLen(total_function_cache(defined), 1)
3006
3007    defined(array_ops.zeros([1, 21]))
3008    self.assertLen(total_function_cache(defined), 2)
3009
3010  def testCacheKeyNestedLists(self):
3011    @function.defun
3012    def defined(l):
3013      return l
3014
3015    a = constant_op.constant(1.)
3016    b = constant_op.constant(2.)
3017    c = constant_op.constant(3.)
3018    defined([[a], b, c])
3019    self.assertLen(total_function_cache(defined), 1)
3020
3021    defined([[a, b], c])
3022    self.assertLen(total_function_cache(defined), 2)
3023
3024  def testCacheKeyAttrsClass(self):
3025    if attr is None:
3026      self.skipTest('attr module is unavailable.')
3027
3028    @attr.s
3029    class TestClass(object):
3030      a = attr.ib()
3031      b = attr.ib()
3032
3033    @function.defun
3034    def defined(l):
3035      return l
3036
3037    defined(
3038        TestClass(
3039            constant_op.constant(1.),
3040            [constant_op.constant(2.),
3041             constant_op.constant(3.)]))
3042    self.assertLen(total_function_cache(defined), 1)
3043    defined(
3044        TestClass(
3045            constant_op.constant(1.),
3046            [constant_op.constant(2.),
3047             constant_op.constant(3.)]))
3048    self.assertLen(total_function_cache(defined), 1)
3049
3050    defined(
3051        TestClass([constant_op.constant(1.),
3052                   constant_op.constant(2.)], constant_op.constant(3.)))
3053    self.assertLen(total_function_cache(defined), 2)
3054
3055  def testCacheKeyVariables(self):
3056    @function.defun
3057    def defined(a, b, c):
3058      return a + b + c
3059
3060    x = resource_variable_ops.ResourceVariable(0.0)
3061    y = resource_variable_ops.ResourceVariable(0.0)
3062    z = resource_variable_ops.ResourceVariable(0.0)
3063
3064    # If tensor equality is not enabled, we always get a cache miss if the
3065    # function is called with different variables. With equality enabled we
3066    # should only get a miss if the aliasing changed.
3067    defined(x, y, z)
3068    self.assertLen(total_function_cache(defined), 1)
3069    defined(x, y, z)
3070    self.assertLen(total_function_cache(defined), 1)
3071
3072    # Re-arranging arguments causes cache miss
3073    defined(z, y, x)
3074    self.assertLen(total_function_cache(defined), 2)
3075    defined(z, y, x)
3076    self.assertLen(total_function_cache(defined), 2)
3077
3078    # Aliasing causes cache miss
3079    defined(x, x, z)
3080    self.assertLen(total_function_cache(defined), 3)
3081    defined(x, x, z)
3082    self.assertLen(total_function_cache(defined), 3)
3083
3084    # Re-arranging arguments causes cache miss
3085    defined(y, y, z)
3086    self.assertLen(total_function_cache(defined), 4)
3087    defined(y, y, z)
3088    self.assertLen(total_function_cache(defined), 4)
3089
3090    # Different alias positions causes cache miss
3091    defined(z, y, y)
3092    self.assertLen(total_function_cache(defined), 5)
3093    defined(z, y, y)
3094    self.assertLen(total_function_cache(defined), 5)
3095
3096    x_copy = copy.deepcopy(x)
3097
3098    # Deep copy causes cache miss
3099    defined(x_copy, y, z)
3100    self.assertLen(total_function_cache(defined), 6)
3101    defined(x_copy, y, z)
3102    self.assertLen(total_function_cache(defined), 6)
3103
3104  def testVariableRetracing(self):
3105    v1 = variables.Variable(1.)
3106    v2 = variables.Variable(1.)
3107    v3 = copy.deepcopy(variables.Variable(1.))
3108
3109    var_dict = {id(v1): constant_op.constant(1),
3110                id(v2): constant_op.constant(2),
3111                id(v3): constant_op.constant(3)}
3112
3113    @function.defun
3114    def lookup_tensor(v):
3115      return var_dict[id(v)]
3116
3117    self.assertEqual(1, lookup_tensor(v1).numpy())
3118    self.assertEqual(2, lookup_tensor(v2).numpy())
3119    self.assertEqual(3, lookup_tensor(v3).numpy())
3120
3121  def testDecoratedMethodInspect(self):
3122
3123    class DefunnedMiniModel(object):
3124
3125      @function.defun
3126      def call(self, inputs, training=True):
3127        pass
3128
3129    m = DefunnedMiniModel()
3130    fullargspec = tf_inspect.getfullargspec(m.call)
3131    self.assertIn('training', fullargspec.args)
3132
3133  def testFunctionModifiesInputList(self):
3134    # Tests on `list` methods that do in place modification, except `list.sort`
3135    # since it cannot even be "defunned" in the first place
3136
3137    def get_list():
3138      return [constant_op.constant(0.), constant_op.constant(1.)]
3139
3140    expected_msg = '.*() should not modify'
3141
3142    with self.assertRaisesRegex(ValueError, expected_msg):
3143
3144      @def_function.function
3145      def append(l):
3146        l.append(constant_op.constant(0.))
3147
3148      append(get_list())
3149
3150    with self.assertRaisesRegex(ValueError, expected_msg):
3151
3152      @def_function.function
3153      def extend(l):
3154        l.extend([constant_op.constant(0.)])
3155
3156      extend(get_list())
3157
3158    with self.assertRaisesRegex(ValueError, expected_msg):
3159
3160      @def_function.function
3161      def insert(l):
3162        l.insert(0, constant_op.constant(0.))
3163
3164      insert(get_list())
3165
3166    with self.assertRaisesRegex(ValueError, expected_msg):
3167
3168      @def_function.function
3169      def pop(l):
3170        l.pop()
3171
3172      pop(get_list())
3173
3174    with self.assertRaisesRegex(ValueError, expected_msg):
3175
3176      @def_function.function
3177      def reverse(l):
3178        l.reverse()
3179
3180      reverse(get_list())
3181
3182    with self.assertRaisesRegex(ValueError, expected_msg):
3183
3184      @def_function.function
3185      def remove(l):
3186        l.remove(l[0])
3187
3188      remove(get_list())
3189
3190    # `list.clear` is a method that is in Py3 but not Py2
3191    if sys.version.startswith('3'):
3192
3193      with self.assertRaisesRegex(ValueError, expected_msg):
3194
3195        @def_function.function
3196        def clear(l):
3197          l.clear()
3198
3199        clear(get_list())
3200
3201    # One last test for keyword arguments
3202    with self.assertRaisesRegex(ValueError, expected_msg):
3203
3204      @def_function.function
3205      def kwdappend(**kwargs):
3206        l = kwargs['l']
3207        l.append(constant_op.constant(0.))
3208
3209      kwdappend(l=get_list())
3210
3211  def testFunctionModifiesInputDict(self):
3212
3213    def get_dict():
3214      return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
3215
3216    expected_msg = '.* should not modify'
3217
3218    with self.assertRaisesRegex(ValueError, expected_msg):
3219
3220      @def_function.function
3221      def clear(m):
3222        m.clear()
3223
3224      clear(get_dict())
3225
3226    with self.assertRaisesRegex(ValueError, expected_msg):
3227
3228      @def_function.function
3229      def pop(m):
3230        m.pop('t1')
3231
3232      pop(get_dict())
3233
3234    with self.assertRaisesRegex(ValueError, expected_msg):
3235
3236      @def_function.function
3237      def popitem(m):
3238        m.popitem()
3239
3240      popitem(get_dict())
3241
3242    with self.assertRaisesRegex(ValueError, expected_msg):
3243
3244      @def_function.function
3245      def update(m):
3246        m.update({'t1': constant_op.constant(3.)})
3247
3248      update(get_dict())
3249
3250    with self.assertRaisesRegex(ValueError, expected_msg):
3251
3252      @def_function.function
3253      def setdefault(m):
3254        m.setdefault('t3', constant_op.constant(3.))
3255
3256      setdefault(get_dict())
3257
3258  def testFunctionModifiesInputNest(self):
3259    with self.assertRaisesRegex(ValueError, 'modify.* should not modify'):
3260
3261      @def_function.function
3262      def modify(n):
3263        n[0]['t1'].append(constant_op.constant(1.))
3264
3265      nested_input = [{
3266          't1': [constant_op.constant(0.),
3267                 constant_op.constant(1.)],
3268      },
3269                      constant_op.constant(2.)]
3270
3271      modify(nested_input)
3272
3273    with self.assertRaisesRegex(ValueError,
3274                                'modify_same_flat.* should not modify'):
3275
3276      # The flat list doesn't change whereas the true structure changes
3277      @def_function.function
3278      def modify_same_flat(n):
3279        n[0].append(n[1].pop(0))
3280
3281      nested_input = [[constant_op.constant(0.)],
3282                      [constant_op.constant(1.),
3283                       constant_op.constant(2.)]]
3284
3285      modify_same_flat(nested_input)
3286
3287  @test_util.disable_tfrt('b/173429686')
3288  def testExecutorType(self):
3289    @function.defun
3290    def add_five(x):
3291      return x + 5
3292
3293    self.assertEqual(
3294        5,
3295        add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
3296
3297    with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'):
3298      with context.function_executor_type('NON_EXISTENT_EXECUTOR'):
3299        add_five(constant_op.constant(0, dtype=dtypes.int32))
3300
3301    for executor_type in ('', 'DEFAULT', None):
3302      with context.function_executor_type(executor_type):
3303        self.assertAllEqual(
3304            5,
3305            add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
3306
3307  @test_util.assert_no_garbage_created
3308  def testReferenceCycles(self):
3309
3310    fn = function.defun(lambda x: 2. * x)
3311
3312    fn(constant_op.constant(4.0))
3313    weak_fn = weakref.ref(fn)
3314    del fn
3315    # Tests that the weak reference we made to the function is now dead, which
3316    # means the object has been deleted. This should be true as long as the
3317    # function itself is not involved in a reference cycle.
3318    self.assertIs(None, weak_fn())
3319
3320  def testFunctionStackInErrorMessage(self):
3321    if context.executing_eagerly():
3322      # TODO(b/122736651): Remove this skipTest once fixed.
3323      self.skipTest('Error interpolation is not working when function is '
3324                    'invoked without PartitionedCallOp.')
3325
3326    @def_function.function()
3327    def fn3(x):
3328      return x + 2
3329
3330    @def_function.function()
3331    def fn2(x):
3332      check_ops.assert_equal(fn3(x), 3)
3333      return 2
3334
3335    @def_function.function()
3336    def fn(x):
3337      return fn2(x)
3338
3339    with self.assertRaises(errors.InvalidArgumentError) as cm:
3340      fn(2)
3341    e = cm.exception
3342    self.assertIn('fn -> fn2', e.message)
3343    self.assertIn('node assert_equal/Assert/Assert (defined at', e.message)
3344    self.assertNotIn('fn3', e.message)
3345
3346  @test_util.run_gpu_only
3347  def testFunctionIsNotPinned(self):
3348    """Tests that functions aren't pinned to the CPU by the eager runtime."""
3349    seed1, seed2 = 79, 25
3350    shape = constant_op.constant([4, 7])
3351    dtype = dtypes.float32
3352
3353    @def_function.function
3354    def func():
3355      with ops.device('GPU:0'):
3356        return gen_random_ops.random_standard_normal(
3357            shape, dtype=dtype, seed=seed1, seed2=seed2)
3358
3359    with ops.device('GPU:0'):
3360      x = func()
3361      self.assertRegex(x.device, 'GPU')
3362
3363  @test_util.run_in_graph_and_eager_modes
3364  def testShapeCaching(self):
3365
3366    @function.defun
3367    def func(x):
3368      return array_ops.shape(x)
3369
3370    @function.defun(
3371        input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)])
3372    def calls_func(x):
3373      return func(x)
3374
3375    self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
3376    self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
3377    self.assertAllEqual(
3378        [3, 3],
3379        self.evaluate(calls_func(array_ops.zeros([3, 3]))))
3380
3381  def testLimitedRetracing(self):
3382    trace_count = [0]
3383    @function.defun
3384    def func(x):
3385      trace_count[0] += 1
3386      return x
3387
3388    for _ in range(50):
3389      func(constant_op.constant(3.))
3390      func(constant_op.constant(4.))
3391      func(constant_op.constant([[1., 2.]]))
3392      func(constant_op.constant([[]]))
3393      func(constant_op.constant([[3., 4.], [5., 6.]]))
3394      func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]]))
3395    # Tracing more than twice per input doesn't make sense.
3396    self.assertLess(trace_count[0], 13)
3397
3398  def testLimitedRetracingWithCompositeTensors(self):
3399    trace_count = [0]
3400
3401    @def_function.function
3402    def f(x):
3403      trace_count[0] += 1
3404      return x
3405
3406    for i in range(10):
3407      f(ragged_factory_ops.constant([[1, 2], [i]]))
3408      f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]]))
3409      f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]))
3410      self.assertEqual(trace_count[0], 3)
3411
3412  def test_concrete_function_shape_mismatch(self):
3413
3414    @def_function.function
3415    def f(argument_name):
3416      return argument_name + 1.
3417
3418    f_concrete = f.get_concrete_function(constant_op.constant([1.]))
3419
3420    # Calling a function from eager doesn't do any shape checking above what
3421    # kernels do while executing.
3422    self.assertAllEqual(
3423        [2., 3.],
3424        f_concrete(constant_op.constant([1., 2.])).numpy())
3425
3426    @def_function.function
3427    def g():
3428      f_concrete(constant_op.constant([1., 2.]))
3429
3430    with self.assertRaisesRegex(ValueError, 'argument_name'):
3431      g()
3432
3433  @test_util.run_in_graph_and_eager_modes
3434  def test_shape_inference_with_symbolic_shapes(self):
3435
3436    @def_function.function
3437    def _uses_symbolic_shapes(w, x, y):
3438      x = array_ops.identity(x, name='name_collision')
3439      x = array_ops.transpose(x, [1, 0, 2])
3440      x_batch = array_ops.shape(x)[0]
3441      y_batch = array_ops.shape(y)[0]
3442      y *= w
3443      n = y_batch // x_batch
3444      return array_ops.reshape(y, [n, x_batch, -1])
3445
3446    conc = _uses_symbolic_shapes.get_concrete_function(
3447        tensor_spec.TensorSpec(None, dtypes.float32),
3448        tensor_spec.TensorSpec(None, dtypes.float32),
3449        tensor_spec.TensorSpec(None, dtypes.float32))
3450
3451    @def_function.function
3452    def _call_concrete():
3453      c = constant_op.constant(1.)
3454      array_ops.identity(c, name='name_collision')
3455      output1 = conc(array_ops.ones([2]),
3456                     array_ops.ones([5, 4, 2]),
3457                     array_ops.ones([20, 2]))
3458      self.assertEqual([5, 4, 2], output1.shape)
3459      output2 = conc(array_ops.ones([3]),
3460                     array_ops.ones([5, 4, 3]),
3461                     array_ops.ones([40, 3]))
3462      self.assertEqual([10, 4, 3], output2.shape)
3463      return output1, output2
3464
3465    output1, output2 = _call_concrete()
3466    self.assertEqual((5, 4, 2), self.evaluate(output1).shape)
3467    self.assertEqual((10, 4, 3), self.evaluate(output2).shape)
3468
3469  def testAutoGraphContext(self):
3470
3471    @def_function.function
3472    def test_fn():
3473      self.assertEqual(
3474          ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED)
3475
3476    prev_status = ag_ctx.control_status_ctx().status
3477    test_fn()
3478    self.assertEqual(ag_ctx.control_status_ctx().status, prev_status)
3479
3480  @test_util.disable_tfrt('b/170435618')
3481  def testCancelBeforeFunctionExecution(self):
3482    if not context.executing_eagerly():
3483      self.skipTest('eager only')
3484
3485    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
3486
3487    @def_function.function
3488    def f():
3489      return q.dequeue()
3490
3491    c_mgr = cancellation.CancellationManager()
3492    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
3493
3494    c_mgr.start_cancel()
3495    with self.assertRaises(errors.CancelledError):
3496      cancelable_func()
3497
3498  @test_util.disable_tfrt('b/170435618')
3499  def testCancelBlockedFunctionExecution(self):
3500    if not context.executing_eagerly():
3501      self.skipTest('eager only')
3502
3503    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
3504
3505    @def_function.function
3506    def f():
3507      return q.dequeue()
3508
3509    c_mgr = cancellation.CancellationManager()
3510    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
3511
3512    def cancel_thread():
3513      time.sleep(0.5)
3514      c_mgr.start_cancel()
3515
3516    t = self.checkedThread(cancel_thread)
3517    t.start()
3518    with self.assertRaises(errors.CancelledError):
3519      cancelable_func()
3520    t.join()
3521
3522  @test_util.disable_tfrt('b/170435618')
3523  def testCancelAfterFunctionExecution(self):
3524    if not context.executing_eagerly():
3525      self.skipTest('eager only')
3526
3527    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
3528    q.enqueue(37)
3529
3530    @def_function.function
3531    def f():
3532      return q.dequeue()
3533
3534    c_mgr = cancellation.CancellationManager()
3535    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
3536
3537    self.assertAllEqual(37, cancelable_func().numpy())
3538
3539    # Cancellation after the function executes is a no-op.
3540    c_mgr.start_cancel()
3541
3542  def testAddFunctionCallback(self):
3543    functions = []
3544    def function_callback(f, name, graph, inputs, outputs):
3545      del name, graph, inputs, outputs
3546      functions.append(f)
3547
3548    @def_function.function
3549    def plus_one(x):
3550      return x + 1
3551
3552    try:
3553      function.add_function_callback(function_callback)
3554      x_float32 = numpy.array(3.0, dtype=numpy.float32)
3555      self.assertAllClose(plus_one(x_float32), 4.0)
3556      self.assertLen(functions, 1)
3557      # Function is already created. Executing it again should not invoke the
3558      # function callback.
3559      self.assertAllClose(plus_one(x_float32), 4.0)
3560      self.assertLen(functions, 1)
3561      # Signature change leads to a new Function being built.
3562      x_float64 = numpy.array(3.0, dtype=numpy.float64)
3563      self.assertAllClose(plus_one(x_float64), 4.0)
3564      self.assertLen(functions, 2)
3565    finally:
3566      function.clear_function_callbacks()
3567
3568  def testFunctionCallbackAddOps(self):
3569    file_name = os.path.join(self.get_temp_dir(), 'test')
3570
3571    def function_callback(f, name, graph, inputs, outputs):
3572      del f, name, inputs
3573
3574      with graph.as_default():
3575        printer = logging_ops.print_v2(
3576            'hello',
3577            output_stream='file://' + file_name
3578        )
3579        outputs[0].op._add_control_input(printer)
3580
3581    @def_function.function
3582    def plus_one(x):
3583      return x + 1
3584
3585    self.addCleanup(function.clear_function_callbacks)
3586    function.add_function_callback(function_callback)
3587    x_float32 = numpy.array(3.0, dtype=numpy.float32)
3588
3589    self.assertAllClose(plus_one(x_float32), 4.0)
3590
3591    with open(file_name, 'r') as f:
3592      self.assertEqual(f.read().strip(), 'hello')
3593
3594  def testRemoveFunctionCallback(self):
3595    functions_1 = []
3596    def function_callback_1(f, name, graph, inputs, outputs):
3597      del name, graph, inputs, outputs
3598      functions_1.append(f)
3599
3600    functions_2 = []
3601    def function_callback_2(f, name, graph, inputs, outputs):
3602      del name, graph, inputs, outputs
3603      functions_2.append(f)
3604
3605    @def_function.function
3606    def plus_one(x):
3607      return x + 1
3608
3609    try:
3610      function.add_function_callback(function_callback_1)
3611      function.add_function_callback(function_callback_2)
3612      self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float32)), 4.0)
3613      self.assertLen(functions_1, 1)
3614      self.assertLen(functions_2, 1)
3615      function.remove_function_callback(function_callback_1)
3616      # The 1st callback should not be invokved after remove_function_callback()
3617      # is called.
3618      self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float64)), 4.0)
3619      self.assertLen(functions_1, 1)
3620      self.assertLen(functions_2, 2)
3621    finally:
3622      function.clear_function_callbacks()
3623
3624  def testClearFunctionCallbacks(self):
3625    function.add_function_callback(lambda f: None)
3626    function.add_function_callback(lambda f: None)
3627    self.assertLen(function._function_callbacks, 2)
3628    function.clear_function_callbacks()
3629    self.assertEmpty(function._function_callbacks)  # pylint:disable=protected-access
3630
3631  @test_util.run_in_graph_and_eager_modes
3632  def testConcreteFunctionWithNestedTensorInputs(self):
3633
3634    @def_function.function
3635    def f(x, y):
3636      return (x['a'] + x['b'], y[0] + y[1])
3637
3638    a = constant_op.constant(1000)
3639    b = constant_op.constant(200)
3640    c = constant_op.constant(30)
3641    d = {'a': a, 'b': b}
3642    e = (c, 4)
3643
3644    # Test different argument signatures when constructing the concrete func.
3645    for cf in [
3646        f.get_concrete_function(d, e),
3647        f.get_concrete_function(d, y=e),
3648        f.get_concrete_function(y=e, x=d),
3649        f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
3650        f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
3651        f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
3652    ]:
3653      # Test different calling conventions when calling the concrete func.
3654      for output in [
3655          cf(d, e),  # structured signature
3656          cf(d, y=e),  # structured signature w/ kwarg
3657          cf(y=e, x=d),  # structured signature w/ 2 kwargs
3658          cf(a, b, c),  # flat signature
3659          cf(x=a, x_1=b, y=c)  # flat signature w/ kwargs
3660      ]:
3661        self.assertIsInstance(output, tuple)
3662        self.assertLen(output, 2)
3663        self.assertAllEqual(output[0], 1200)
3664        self.assertAllEqual(output[1], 34)
3665
3666  @test_util.run_in_graph_and_eager_modes
3667  def testConcreteFunctionWithNestedNonTensorInputs(self):
3668
3669    @def_function.function
3670    def f(x, y):
3671      return (x['a'] + x['b'], y[0] + y[1])
3672
3673    a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
3674    b = (50, 3)
3675
3676    for cf in [  # argument y is bound to non-Tensor value (50, 3).
3677        f.get_concrete_function(a, b),
3678        f.get_concrete_function(a, y=b),
3679        f.get_concrete_function(x=a, y=b)
3680    ]:
3681      for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
3682        self.assertAllEqual(output[0] + output[1], 1253)
3683
3684  @test_util.run_in_graph_and_eager_modes
3685  def testConcreteFunctionWithNonTensorStringInputs(self):
3686
3687    @def_function.function
3688    def f(x, y):
3689      return string_ops.string_join([x, y])
3690
3691    a = constant_op.constant('a')
3692    b = 'b'
3693
3694    cf = f.get_concrete_function(a, b)
3695    for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
3696      self.assertAllEqual(output, b'ab')
3697
3698  @test_util.run_in_graph_and_eager_modes
3699  def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
3700
3701    @def_function.function
3702    def f(x, y):
3703      return (x['a'] + x['b'], y[0] + y[1])
3704
3705    a = {'a': 3000, 'b': 200, 'c': 9000}
3706    b = (constant_op.constant(30), 4)
3707
3708    for cf in [  # argument x is bound to non-tensor value `a`
3709        f.get_concrete_function(a, b),
3710        f.get_concrete_function(a, y=b),
3711        f.get_concrete_function(x=a, y=b)
3712    ]:
3713      for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]:
3714        self.assertAllEqual(output[0] + output[1], 3234)
3715
3716  @test_util.run_in_graph_and_eager_modes
3717  def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
3718
3719    @def_function.function
3720    def f(x, y):
3721      return (x['a'] + x['b'], y[0] + y[1])
3722
3723    a = {'a': 5000, 'b': 500}
3724    b = (50, 5)
3725
3726    cf = f.get_concrete_function(a, b)
3727    for output in [cf(), cf(a), cf(y=b)]:
3728      self.assertAllEqual(output[0] + output[1], 5555)
3729
3730  @test_util.run_in_graph_and_eager_modes
3731  def testConcreteFunctionMethodWithVarargs(self):
3732    float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
3733
3734    class MyModel(module.Module):
3735
3736      @def_function.function(input_signature=[float32_scalar, float32_scalar])
3737      def add(self, *arg):
3738        return math_ops.add(*arg)
3739
3740    m = MyModel()
3741    cf = m.add.get_concrete_function()
3742    cf(-12.0, 3.0)
3743
3744  @test_util.run_in_graph_and_eager_modes
3745  def testConcreteFunctionStructuredSignatureKeywordOrder(self):
3746    # Check that keyword-only arguments are sorted appropriately, so that they
3747    # feed the right tensor into each input.
3748    @def_function.function
3749    def g(**kwargs):
3750      return string_ops.reduce_join(
3751          string_ops.reduce_join(
3752              ops.convert_to_tensor(sorted(kwargs.items())),
3753              axis=1,
3754              separator='='),
3755          axis=0,
3756          separator=', ')
3757
3758    s = constant_op.constant('s')
3759    g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
3760    self.assertAllEqual(
3761        g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
3762        b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
3763    self.assertAllEqual(
3764        g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
3765        b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
3766    self.assertAllEqual(
3767        g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
3768        b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
3769
3770  # pylint: disable=g-long-lambda
3771  @parameterized.named_parameters([
3772      dict(
3773          testcase_name='MissingArg',
3774          conc_args=lambda: (1, constant_op.constant(2)),
3775          call_args=lambda: (1,),
3776          error=r'func\(x, y\) missing required arguments: y'),
3777      dict(
3778          testcase_name='MissingVararg',
3779          conc_args=lambda: (1, 2, constant_op.constant(1.0)),
3780          call_args=lambda: (1, 2),
3781          error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'),
3782      dict(
3783          testcase_name='ExtraPositionalArg',
3784          conc_args=lambda: (1, 2),
3785          call_args=lambda: (1, 2, 3),
3786          error=r'func\(x, y\) takes 2 .* got 3'),
3787      dict(
3788          testcase_name='MissingKeywordOnlyArg',
3789          conc_args=lambda: (1, 2),
3790          conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
3791          call_args=lambda: (1, 2),
3792          error=r'func\(x, y, \*, c\) missing required arguments: c'),
3793      dict(
3794          testcase_name='ExtraKeywordArg',
3795          conc_args=lambda: (1, 2),
3796          call_args=lambda: (1, 2),
3797          call_kwargs=lambda: {'c': constant_op.constant(1.0)},
3798          error=r'func\(x, y\) got unexpected keyword arguments: c'),
3799      dict(
3800          testcase_name='ExpectedRaggedGotNest',
3801          conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
3802          call_args=lambda: ({
3803              'a': constant_op.constant([1, 2, 3])
3804          },),
3805          error=r'func\(x, y\): argument x had incorrect type\n'
3806          r'  expected: RaggedTensor\n'
3807          r"       got: {'a': (Eager)?Tensor}"),
3808      dict(
3809          testcase_name='WrongRaggedRank',
3810          conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
3811          call_args=lambda: (ragged_factory_ops.constant([[[1]]]),),
3812          error=r'func\(x, y\): argument x had incorrect type\n'),
3813      dict(
3814          testcase_name='WrongRaggedDType',
3815          conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
3816          call_args=lambda: (ragged_factory_ops.constant([[1.0]]),),
3817          error=r'func\(x, y\): argument x had incorrect type\n'),
3818      dict(
3819          testcase_name='ExpectedDictGotTensor',
3820          conc_args=lambda: ({
3821              'a': constant_op.constant(1),
3822              'b': constant_op.constant(1)
3823          },),
3824          call_args=lambda: (constant_op.constant(1),),
3825          error=r'func\(x, y\): argument x had incorrect type\n'),
3826      dict(
3827          testcase_name='ExpectedTupleGotTensor',
3828          conc_args=lambda:
3829          ((constant_op.constant(1), constant_op.constant(2)),),
3830          call_args=lambda: (constant_op.constant(1),),
3831          error=r'func\(x, y\): argument x had incorrect type\n'),
3832      dict(
3833          testcase_name='WrongDType',
3834          conc_args=lambda: (constant_op.constant(1),),
3835          call_args=lambda: (constant_op.constant(1.0),),
3836          exception=(ValueError, errors.InvalidArgumentError,
3837                     # on xla_gpu, we get InternalError instead.
3838                     errors.InternalError)),
3839      dict(
3840          testcase_name='ExpectedTensorGotInt',
3841          conc_args=lambda: (constant_op.constant(1),),
3842          call_args=lambda: (5,),
3843          error=r'func\(x, y\) expected a Tensor in x, but got int value 5'),
3844      dict(
3845          testcase_name='ExpectedIntGotDifferentInt',
3846          conc_args=lambda: (5,),
3847          call_args=lambda: (8,),
3848          error=r'ConcreteFunction func\(x, y\) was constructed with int '
3849          r'value 5 in x, but was called with int value 8'),
3850      dict(
3851          testcase_name='ExpectedIntGotTensor',
3852          conc_args=lambda: (5,),
3853          call_args=lambda: (constant_op.constant(6),),
3854          error=r'ConcreteFunction func\(x, y\) was constructed with int '
3855          'value 5 in x, but was called with (Eager)?Tensor value .*'),
3856      dict(
3857          testcase_name='TwoValuesForArgument',
3858          conc_args=lambda: (1, 2),
3859          call_args=lambda: (1, 2),
3860          call_kwargs=lambda: {'x': 3},
3861          error=r"func\(x, y\) got two values for 'x'"),
3862  ])
3863  # pylint: enable=g-long-lambda
3864  @test_util.run_in_graph_and_eager_modes
3865  def testConcreteFunctionStructuredSignatureError(self,
3866                                                   conc_args=(),
3867                                                   conc_kwargs=None,
3868                                                   call_args=(),
3869                                                   call_kwargs=None,
3870                                                   error='.*',
3871                                                   exception=TypeError):
3872    """Tests for errors in the structrued signature.
3873
3874    Args:
3875      conc_args: Positional arguments used for get_concrete_function.
3876      conc_kwargs: Keyword arguments used for get_concrete_function.
3877      call_args: Positional arguments used to call the function.
3878      call_kwargs: Keyword arguments used to call the function.
3879      error: Expected exception message.
3880      exception: Expected exception type.
3881    """
3882    conc_args = conc_args() if callable(conc_args) else conc_args
3883    conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
3884    call_args = call_args() if callable(call_args) else call_args
3885    call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
3886    self.assertIsInstance(conc_args, tuple)
3887    self.assertIsInstance(call_args, tuple)
3888    self.assertIsInstance(conc_kwargs, dict)
3889    self.assertIsInstance(call_kwargs, dict)
3890
3891    @def_function.function
3892    def func(x, y=5, *varargs, **kwargs):  # pylint: disable=keyword-arg-before-vararg
3893      del y, varargs, kwargs
3894      return x
3895
3896    conc = func.get_concrete_function(*conc_args, **conc_kwargs)
3897    with self.assertRaisesRegex(exception, error):
3898      self.evaluate(conc(*call_args, **call_kwargs))
3899
3900  # pylint: disable=g-long-lambda
3901  @parameterized.named_parameters([
3902      dict(
3903          testcase_name='MissingArg',
3904          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3905          call_args=lambda: (constant_op.constant(1),),
3906          error=r'func\(x, y\) missing required arguments: y'),
3907      dict(
3908          testcase_name='TwoValuesForArg',
3909          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3910          call_args=lambda: (constant_op.constant(1),),
3911          call_kwargs=lambda: {
3912              'x': constant_op.constant(1),
3913              'y': constant_op.constant(1)
3914          },
3915          error=r"func\(x, y\) got two values for 'x'"),
3916      dict(
3917          testcase_name='ExtraPositionalArg',
3918          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3919          call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
3920                             constant_op.constant(3)),
3921          error=r'func\(x, y\) takes 2 .* got 3'),
3922      dict(
3923          testcase_name='UnexpectedKeywordArg',
3924          conc_args=lambda: (constant_op.constant(1),),
3925          call_args=lambda: (constant_op.constant(1),),
3926          call_kwargs=lambda: {'c': constant_op.constant(1)},
3927          error=r'func\(x\) got unexpected keyword arguments: c'),
3928      dict(
3929          testcase_name='MissingVararg',
3930          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
3931                             constant_op.constant(3)),
3932          call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3933          error=r'func\(x, y, varargs_0\) missing required '
3934          r'arguments: varargs_0'),
3935      dict(
3936          testcase_name='MissingKeywordArg',
3937          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3938          conc_kwargs=lambda: {'c': constant_op.constant(1)},
3939          call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3940          error=r'func\(x, y, c\) missing required arguments: c'),
3941      dict(
3942          testcase_name='ExpectedTensorGotInt',
3943          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3944          call_args=lambda: (5, constant_op.constant(2)),
3945          error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
3946          r'a Tensor; got int \(5\)'),
3947      dict(
3948          testcase_name='WrongDType',
3949          conc_args=lambda: (constant_op.constant(1),),
3950          call_args=lambda: (constant_op.constant(1.0),),
3951          exception=(ValueError, errors.InvalidArgumentError,
3952                     # on xla_gpu, we get InternalError instead.
3953                     errors.InternalError)),
3954      dict(
3955          testcase_name='MissingKeywordArgNestPiece',
3956          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3957          conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
3958          call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3959          call_kwargs=lambda: {'c': constant_op.constant(1)},
3960          error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
3961  ])
3962  # pylint: enable=g-long-lambda
3963  @test_util.run_in_graph_and_eager_modes
3964  def testConcreteFunctionFlatSignatureError(self,
3965                                             conc_args=(),
3966                                             conc_kwargs=None,
3967                                             call_args=(),
3968                                             call_kwargs=None,
3969                                             error='.*',
3970                                             exception=TypeError):
3971    """Tests for errors in the flat signature.
3972
3973    Args:
3974      conc_args: Positional arguments used for get_concrete_function.
3975      conc_kwargs: Keyword arguments used for get_concrete_function.
3976      call_args: Positional arguments used to call the function.
3977      call_kwargs: Keyword arguments used to call the function.
3978      error: Expected exception message.
3979      exception: Expected exception type.
3980    """
3981    conc_args = conc_args() if callable(conc_args) else conc_args
3982    conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
3983    call_args = call_args() if callable(call_args) else call_args
3984    call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
3985    self.assertIsInstance(conc_args, tuple)
3986    self.assertIsInstance(call_args, tuple)
3987    self.assertIsInstance(conc_kwargs, dict)
3988    self.assertIsInstance(call_kwargs, dict)
3989
3990    @def_function.function
3991    def func(x, y=5, *varargs, **kwargs):  # pylint: disable=keyword-arg-before-vararg
3992      del y, varargs, kwargs
3993      return x
3994
3995    conc = func.get_concrete_function(*conc_args, **conc_kwargs)
3996
3997    # Remove _function_spec, to disable the structured signature.
3998    conc._set_function_spec(None)  # pylint: disable=protected-access
3999
4000    with self.assertRaisesRegex(exception, error):
4001      self.evaluate(conc(*call_args, **call_kwargs))
4002
4003  @test_util.run_in_graph_and_eager_modes
4004  def testConcreteFunctionAmbiguousSignature(self):
4005    # When both the flat & structured signatures are applicable, but they
4006    # give different results, we use the structured signature.  Note: we expect
4007    # this to be extremely rare.
4008    @def_function.function
4009    def f(x, y):
4010      return x * 10 + y
4011
4012    conc = f.get_concrete_function(
4013        x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'),
4014        y=tensor_spec.TensorSpec(None, dtypes.int32, name='x'))
4015
4016    result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
4017    self.assertAllEqual(result, 56)
4018
4019  def testPrettyPrintedSignature(self):
4020
4021    @def_function.function
4022    def func(x, kangaroo=None, octopus=7):
4023      del octopus, kangaroo
4024      return x
4025
4026    scalar = constant_op.constant(5)
4027    vector = constant_op.constant([10, 10, 20])
4028    ragged = ragged_factory_ops.constant([[10, 20], [40]])
4029
4030    c1 = func.get_concrete_function(scalar, vector)
4031    c1_summary = r'func\(x, kangaroo, octopus=7\)'
4032    c1_details = (r'  Args:\n'
4033                  r'    x: int32 Tensor, shape=\(\)\n'
4034                  r'    kangaroo: int32 Tensor, shape=\(3,\)\n'
4035                  r'  Returns:\n'
4036                  r'    int32 Tensor, shape=\(\)')
4037    self.assertRegex(c1.pretty_printed_signature(verbose=False), c1_summary)
4038    self.assertRegex(
4039        c1.pretty_printed_signature(verbose=True),
4040        c1_summary + '\n' + c1_details)
4041    self.assertRegex(
4042        repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>')
4043    self.assertRegex(
4044        str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details))
4045
4046    c2 = func.get_concrete_function(scalar, ragged, 3)
4047    c2_summary = r'func\(x, kangaroo, octopus=3\)'
4048    c2_details = (r'  Args:\n'
4049                  r'    x: int32 Tensor, shape=\(\)\n'
4050                  r'    kangaroo: RaggedTensorSpec\(.*\)\n'
4051                  r'  Returns:\n'
4052                  r'    int32 Tensor, shape=\(\)')
4053    self.assertRegex(c2.pretty_printed_signature(),
4054                     c2_summary + '\n' + c2_details)
4055
4056    c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]})
4057    c3_summary = r'func\(x, kangaroo=None, octopus=7\)'
4058    c3_details = (r'  Args:\n'
4059                  r"    x: {'a': <1>, 'b': \[<2>, <3>\]}\n"
4060                  r'      <1>: int32 Tensor, shape=\(\)\n'
4061                  r'      <2>: RaggedTensorSpec\(.*\)\n'
4062                  r'      <3>: RaggedTensorSpec\(.*\)\n'
4063                  r'  Returns:\n'
4064                  r"    {'a': <1>, 'b': \[<2>, <3>\]}\n"
4065                  r'      <1>: int32 Tensor, shape=\(\)\n'
4066                  r'      <2>: RaggedTensorSpec\(.*\)\n'
4067                  r'      <3>: RaggedTensorSpec\(.*\)')
4068
4069    # python 3.5 does not gurantee deterministic iteration of dict contents
4070    # which can lead mismatch on pretty_printed_signature output for "Args"
4071    if sys.version_info >= (3, 6):
4072      self.assertRegex(c3.pretty_printed_signature(),
4073                       c3_summary + '\n' + c3_details)
4074
4075    # pylint: disable=keyword-arg-before-vararg
4076    @def_function.function
4077    def func2(x, y=3, *args, **kwargs):
4078      return (x, y, args, kwargs)
4079
4080    c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar)
4081    c4_summary = 'func2(x, y=4, <arg3>=5, *, a)'
4082    self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary)
4083
4084    c5 = func2.get_concrete_function(8, vector)
4085    c5_summary = 'func2(x=8, y)'
4086    self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary)
4087
4088  def testPrettyPrintedExplicitSignatureWithKeywordArg(self):  # b/159639913
4089
4090    @def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
4091    def fn(a, b=1):
4092      return a + b
4093
4094    concrete_fn = fn.get_concrete_function()
4095    self.assertEqual(concrete_fn.pretty_printed_signature(False), 'fn(a)')
4096    self.assertEqual(
4097        concrete_fn.pretty_printed_signature(True), 'fn(a)\n'
4098        '  Args:\n'
4099        '    a: float32 Tensor, shape=<unknown>\n'
4100        '  Returns:\n'
4101        '    float32 Tensor, shape=<unknown>')
4102
4103  def testPrettyPrintedSignatureLoadedNamedTuple(self):
4104    Point = collections.namedtuple('Point', ['x', 'y'])
4105
4106    @def_function.function
4107    def fn(b, a):  # pylint: disable=unused-argument
4108      return 1.
4109
4110    b = Point(
4111        x=constant_op.constant(1., dtype=dtypes.float32),
4112        y=constant_op.constant(1., dtype=dtypes.float32))
4113    a = Point(
4114        x=constant_op.constant(1, dtype=dtypes.int32),
4115        y=constant_op.constant(1, dtype=dtypes.int32))
4116
4117    mod = module.Module()
4118    f = fn.get_concrete_function(b, a)
4119    save(mod, '/tmp/f', signatures=f)
4120    loaded = load('/tmp/f')
4121
4122    printed = loaded.signatures['serving_default'].pretty_printed_signature()
4123    self.assertIn('a: int32 Tensor, shape=()', printed)
4124    self.assertIn('a_1: int32 Tensor, shape=()', printed)
4125    self.assertIn('b: float32 Tensor, shape=()', printed)
4126    self.assertIn('b_1: float32 Tensor, shape=()', printed)
4127
4128  @test_util.run_in_graph_and_eager_modes
4129  def testIndexedSlicesAsGradientsForConcreteFunctions(self):
4130
4131    @def_function.function
4132    def summing_rnn(inputs):
4133      return math_ops.reduce_sum(inputs, axis=1)
4134
4135    @def_function.function
4136    def gradients(inputs):
4137      with backprop.GradientTape() as tape:
4138        tape.watch(inputs)
4139        hidden = summing_rnn(inputs)
4140        hidden = array_ops.gather(hidden, constant_op.constant([0]))
4141        loss = math_ops.reduce_mean(hidden)
4142      return tape.gradient(loss, inputs)
4143
4144    gradients(constant_op.constant([[[1.0], [2.0]]]))  # No error is raised
4145
4146  def testFollowTypeHintsTraceBasic(self):
4147    trace_count = [0]
4148
4149    def func(x: ops.Tensor):
4150      trace_count[0] += 1
4151      return x
4152
4153    enabled = def_function.function(func, experimental_follow_type_hints=True)
4154    disabled = def_function.function(func, experimental_follow_type_hints=False)
4155
4156    enabled(1)  # Initial call gets traced
4157    enabled(2)
4158    enabled(3)
4159    self.assertEqual(trace_count[0], 1)
4160
4161    trace_count = [0]
4162    disabled(1)
4163    disabled(2)  # Retrace
4164    disabled(3)  # Retrace
4165    self.assertEqual(trace_count[0], 3)
4166
4167  def testFollowTypeHintsTraceWithArgs(self):
4168    trace_count = [0]
4169
4170    def func(*args: ops.Tensor):
4171      trace_count[0] += 1
4172      return args
4173
4174    enabled = def_function.function(func, experimental_follow_type_hints=True)
4175    disabled = def_function.function(func, experimental_follow_type_hints=False)
4176
4177    args = (
4178        'abc',
4179        'def',
4180    ) * 20
4181    args2 = (
4182        'def',
4183        'abc',
4184    ) * 20
4185
4186    enabled(args)
4187    enabled(args2)
4188    self.assertEqual(trace_count[0], 1)
4189
4190    trace_count = [0]
4191    disabled(args)
4192    disabled(args2)  # Retrace
4193    self.assertEqual(trace_count[0], 2)
4194
4195  def testFollowTypeHintsTraceWithKwargs(self):
4196    trace_count = [0]
4197
4198    def func(t: ops.Tensor, **kwargs: ops.Tensor):
4199      del kwargs
4200      trace_count[0] += 1
4201      return t
4202
4203    enabled = def_function.function(func, experimental_follow_type_hints=True)
4204    disabled = def_function.function(func, experimental_follow_type_hints=False)
4205
4206    enabled(1, x=1, y=1.0, z='one')
4207    enabled(2, x=2, y=2.0, z='two')
4208    self.assertEqual(trace_count[0], 1)
4209
4210    trace_count = [0]
4211    disabled(1, x=1, y=1.0, z='one')
4212    disabled(2, x=2, y=2.0, z='two')  # Retrace
4213    self.assertEqual(trace_count[0], 2)
4214
4215  def testFollowTypeHintsTraceWithMultipleInputTypes(self):
4216    trace_count = [0]
4217
4218    def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor):
4219      del args, kwargs
4220      trace_count[0] += 1
4221      return t
4222
4223    enabled = def_function.function(func, experimental_follow_type_hints=True)
4224    disabled = def_function.function(func, experimental_follow_type_hints=False)
4225
4226    enabled(1, constant_op.constant(1), 'str', x=4.0)
4227    enabled(2, constant_op.constant(2), 'str2', x=5.0)
4228    self.assertEqual(trace_count[0], 1)
4229
4230    trace_count = [0]
4231    disabled(1, constant_op.constant(1), 'str', x=4.0)
4232    disabled(2, constant_op.constant(2), 'str2', x=5.0)  # Retrace
4233    self.assertEqual(trace_count[0], 2)
4234
4235  def testFollowTypeHintsTraceWithOnlyArgNamed(self):
4236    trace_count = [0]
4237
4238    def func(t: ops.Tensor, i: int = 1, **kwargs):  # pylint: disable=bad-whitespace
4239      del i, kwargs
4240      trace_count[0] += 1
4241      return t
4242
4243    enabled = def_function.function(func, experimental_follow_type_hints=True)
4244
4245    enabled(1, 3, x=4.0, y='str')
4246    enabled(2, 4, x=4.0, y='str')  # Retrace
4247    self.assertEqual(trace_count[0], 2)
4248
4249  def testFollowTypeHintsTraceWithNotAllNamed(self):
4250    trace_count = [0]
4251
4252    def func(x, y: ops.Tensor, z: int):
4253      del y, z
4254      trace_count[0] += 1
4255      return x
4256
4257    enabled = def_function.function(func, experimental_follow_type_hints=True)
4258
4259    enabled(1, 2, 3)
4260    enabled(1, 20, 3)  # No retrace - change in ops.Tensor typed arg
4261    enabled(2, 2, 3)  # Retrace - change in untyped arg
4262    enabled(2, 2, 4)  # Retrace - change in typed arg
4263    self.assertEqual(trace_count[0], 3)
4264
4265  def testFollowTypeHintsTraceWithOnlyArgsNamed(self):
4266    trace_count = [0]
4267
4268    def func(x, y, *args: ops.Tensor):
4269      del y, args
4270      trace_count[0] += 1
4271      return x
4272
4273    enabled = def_function.function(func, experimental_follow_type_hints=True)
4274
4275    enabled(1, 20, 3, 4, 5, 6)
4276    enabled(1, 20, 3, 4, 5, 60)  # No retrace - change in *args
4277    enabled(1, 30, 7, 8, 9, 10)  # Retrace - change in args
4278    self.assertEqual(trace_count[0], 2)
4279
4280  def testFollowTypeHintsTraceWithOnlyKwargsNamed(self):
4281    trace_count = [0]
4282
4283    def func(x, y, *args, **kwargs: ops.Tensor):
4284      del y, args, kwargs
4285      trace_count[0] += 1
4286      return x
4287
4288    enabled = def_function.function(func, experimental_follow_type_hints=True)
4289
4290    enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)
4291    enabled(
4292        1, 2, 3, 4, 5, 6, a=1.5, b=2.5,
4293        c=3.5)  # No retrace - change in **kwargs
4294    enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)  # Retrace - change in args
4295    enabled(
4296        1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0)  # Retrace - change in *args
4297    self.assertEqual(trace_count[0], 3)
4298
4299  def testFollowTypeHintsTraceWithArgsEquals(self):
4300    trace_count = [0]
4301
4302    def func(
4303        x: ops.Tensor = 0,  # pylint:disable=bad-whitespace
4304        y: int = 1,  # pylint:disable=bad-whitespace
4305        **kwargs: ops.Tensor):
4306      del y, kwargs
4307      trace_count[0] += 1
4308      return x
4309
4310    enabled = def_function.function(func, experimental_follow_type_hints=True)
4311
4312    enabled(x=1, y=2, z=3)
4313    enabled(x=1, y=3, z=3)  # Retrace - change in args
4314    enabled(x=2, y=2, z=4)  # No retrace - change in args and **kwargs
4315    enabled(x=2, y=2, z=4, u=5)  # Retrace - change in **kwargs
4316    self.assertEqual(trace_count[0], 3)
4317
4318  def testFollowTypeHintsWithTensorSpec(self):
4319    def func(x: ops.Tensor, y):
4320      return x + y
4321    v = def_function.function(experimental_follow_type_hints=True)(func)
4322    v = v.get_concrete_function(
4323        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 3)
4324    x = v(constant_op.constant(1.), 3)
4325    self.assertEqual(x.numpy(), 4.)
4326
4327  def testFollowTypeHintsTraceWithKwArgsAndNoVarKws(self):
4328    trace_count = [0]
4329
4330    def func(a: int, b: ops.Tensor,
4331             x: ops.Tensor = 0, y: int = 1):
4332      del a, b, y
4333      trace_count[0] += 1
4334      return x
4335
4336    enabled = def_function.function(func, experimental_follow_type_hints=True)
4337
4338    enabled(0, 0, x=1, y=2)
4339    enabled(0, 0, x=2, y=2,)  # No retrace, since only tensor changed
4340    self.assertEqual(trace_count[0], 1)
4341
4342    # Pass args as keyword args.
4343    enabled(a=0, b=0, x=2, y=2,)  # No retrace, args are the same
4344    self.assertEqual(trace_count[0], 1)
4345
4346    enabled(a=1, b=0, x=2, y=2,)  # Retrace, since non-tensor arg changed
4347    self.assertEqual(trace_count[0], 2)
4348
4349    enabled(a=1, b=2, x=2, y=2)  # No retrace, since only tensor changed
4350    self.assertEqual(trace_count[0], 2)
4351
4352    trace_count[0] = 0
4353    disabled = def_function.function(func, experimental_follow_type_hints=False)
4354    disabled(0, 0, x=1, y=2)
4355    disabled(0, 0, x=2, y=2,)  # Retrace
4356    self.assertEqual(trace_count[0], 2)
4357
4358  def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
4359    trace_count = [0]
4360
4361    def func(x, y, **kwargs: ops.Tensor):
4362      del y, kwargs
4363      trace_count[0] += 1
4364      return x
4365
4366    enabled = def_function.function(func, experimental_follow_type_hints=True)
4367
4368    enabled(x=1, y=2, z=3)
4369    enabled(x=1, y=3, z=3)  # Retrace
4370    enabled(x=1, y=2, z=4)  # No retrace
4371    enabled(x=2, y=2, z=4)  # Retrace
4372    enabled(x=2, y=2, z=4, u=5)  # Retrace
4373    self.assertEqual(trace_count[0], 4)
4374
4375  def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self):
4376    trace_count = [0]
4377
4378    def func(x: ops.Tensor, y: int, **kwargs):
4379      del y, kwargs
4380      trace_count[0] += 1
4381      return x
4382
4383    enabled = def_function.function(func, experimental_follow_type_hints=True)
4384
4385    enabled(x=1, y=2, z=3)
4386    enabled(x=1, y=3, z=3)  # Retrace
4387    enabled(x=1, y=2, z=4)  # Retrace
4388    enabled(x=2, y=2, z=3)  # No retrace
4389    enabled(x=2, y=2, z=4, u=5)  # Retrace
4390    self.assertEqual(trace_count[0], 4)
4391
4392  def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self):
4393    trace_count = [0]
4394
4395    def func(*, a: ops.Tensor = None, b=1):  # pylint: disable=bad-whitespace
4396      del b
4397      trace_count[0] += 1
4398      return a
4399
4400    enabled = def_function.function(func, experimental_follow_type_hints=True)
4401
4402    enabled(a=1, b=2)
4403    enabled(a=2, b=2)  # No retrace
4404    enabled(a=1, b=1)  # Retrace
4405    self.assertEqual(trace_count[0], 2)
4406
4407  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self):
4408    trace_count = [0]
4409
4410    def func(arg: ops.Tensor, *args, kwonly, **kwargs):
4411      del args, kwonly, kwargs
4412      trace_count[0] += 1
4413      return arg
4414
4415    enabled = def_function.function(func, experimental_follow_type_hints=True)
4416
4417    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4418    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4419    enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4420    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4421    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # Retrace
4422    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # Retrace
4423    self.assertEqual(trace_count[0], 4)
4424
4425  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self):
4426    trace_count = [0]
4427
4428    def func(arg, *args: ops.Tensor, kwonly, **kwargs):
4429      del args, kwonly, kwargs
4430      trace_count[0] += 1
4431      return arg
4432
4433    enabled = def_function.function(func, experimental_follow_type_hints=True)
4434
4435    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4436    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4437    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4438    enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4439    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # Retrace
4440    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # Retrace
4441    self.assertEqual(trace_count[0], 4)
4442
4443  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self):
4444    trace_count = [0]
4445
4446    def func(arg, *args, kwonly: ops.Tensor, **kwargs):
4447      del args, kwonly, kwargs
4448      trace_count[0] += 1
4449      return arg
4450
4451    enabled = def_function.function(func, experimental_follow_type_hints=True)
4452
4453    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4454    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4455    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4456    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # No retrace
4457    enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7)  # No retrace
4458    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # Retrace
4459    self.assertEqual(trace_count[0], 4)
4460
4461  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self):
4462    trace_count = [0]
4463
4464    def func(arg, *args, kwonly, **kwargs: ops.Tensor):
4465      del args, kwonly, kwargs
4466      trace_count[0] += 1
4467      return arg
4468
4469    enabled = def_function.function(func, experimental_follow_type_hints=True)
4470
4471    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4472    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4473    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4474    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # Retrace
4475    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # No retrace
4476    enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700)  # No retrace
4477    self.assertEqual(trace_count[0], 4)
4478
4479  def testWithExtraWrapper(self):
4480
4481    class Foo(module.Module):
4482
4483      def __init__(self):
4484        super().__init__()
4485        self.var = None
4486
4487      @def_function.function
4488      @dummy_tf_decorator
4489      def add(self, x, y, z=1):
4490        if self.var is None:
4491          return x + y + z
4492
4493    foo = Foo()
4494    self.assertEqual(foo.add(2, 3).numpy(), 6)
4495
4496  @parameterized.parameters([(def_function.function, dummy_tf_decorator),
4497                             (dummy_tf_decorator, def_function.function),
4498                             (def_function.function, def_function.function)])
4499  def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2):
4500
4501    class Foo(module.Module):
4502
4503      def __init__(self):
4504        super().__init__()
4505        self.var = None
4506
4507      @decorator1
4508      @decorator2
4509      def add1(self, x, y):
4510        if self.var is None:
4511          return x + y
4512
4513    foo = Foo()
4514    with self.assertRaisesRegex(TypeError, 'got two values'):
4515      foo.add1(2, x=3)  # pylint: disable=redundant-keyword-arg,no-value-for-parameter
4516
4517  def testWithExtraWrapperMissingArgs(self):
4518
4519    class Foo(module.Module):
4520
4521      def __init__(self):
4522        super().__init__()
4523        self.var = None
4524
4525      @def_function.function
4526      @dummy_tf_decorator
4527      def add1(self, x, y):
4528        if self.var is None:
4529          return x + y
4530
4531      @def_function.function
4532      @dummy_tf_decorator
4533      def add2(self, x, y):
4534        if self.var is None:
4535          return x + y
4536
4537      @def_function.function
4538      @def_function.function
4539      def add3(self, x, y):
4540        if self.var is None:
4541          return x + y
4542
4543    foo = Foo()
4544    with self.assertRaisesRegex(
4545        TypeError, 'missing 1 required positional argument: \'y\''):
4546      foo.add1(2)  # pylint: disable=no-value-for-parameter
4547
4548    with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
4549      foo.add1(y=2)  # pylint: disable=no-value-for-parameter
4550
4551    with self.assertRaisesRegex(
4552        TypeError, 'missing 1 required positional argument: \'y\''):
4553      foo.add2(2)  # pylint: disable=no-value-for-parameter
4554
4555    with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
4556      foo.add2(y=2)  # pylint: disable=no-value-for-parameter
4557
4558    with self.assertRaisesRegex(
4559        TypeError, 'missing 1 required positional argument: \'y\''):
4560      foo.add3(2)  # pylint: disable=no-value-for-parameter
4561
4562    with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
4563      foo.add3(y=2)  # pylint: disable=no-value-for-parameter
4564
4565  def testMissingArgsTfFunctionedMethod(self):
4566
4567    class A(object):
4568
4569      def func(self, position_arg1, position_arg2):
4570        return position_arg1, position_arg2
4571
4572      @def_function.function
4573      def decorated_method(self, position_arg1, position_arg2):
4574        return position_arg1, position_arg2
4575
4576    a_instance = A()
4577    tf_method_pos = def_function.function(a_instance.func)
4578    with self.assertRaisesRegex(
4579        TypeError, '.* missing 1 required argument: position_arg1'):
4580      tf_method_pos(position_arg2='foo')
4581
4582    # tf.function-decorated instance methods need to be tested because of
4583    # the __get__ method implementation.
4584    tf_func_decorated_method = def_function.function(
4585        a_instance.decorated_method)
4586    tf_func_decorated_method(position_arg1='foo', position_arg2='bar')
4587    with self.assertRaisesRegex(
4588        TypeError, '.* missing 1 required argument: position_arg1'):
4589      tf_func_decorated_method(position_arg2='bar')
4590
4591  def testMissingArgsTfFunctionedObject(self):
4592
4593    class A(object):
4594
4595      def __call__(self, position_arg1, position_arg2):
4596        return position_arg1, position_arg2
4597
4598    a_instance = A()
4599
4600    # A tf.function-decorated callable object needs to be tested because of
4601    # the special inspect results.
4602    tf_func_obj = def_function.function(a_instance)
4603    tf_func_obj(position_arg1=1, position_arg2=2)
4604    with self.assertRaisesRegex(
4605        TypeError, '.* missing 1 required argument: position_arg1'):
4606      tf_func_obj(position_arg2='bar')
4607
4608  def testMissingArgsTfFunctionedFunctions(self):
4609
4610    def func_pos(position_arg1, position_arg2):
4611      return position_arg1, position_arg2
4612
4613    def func_with_default(position_arg, named_arg=None):
4614      return position_arg, named_arg
4615
4616    def func_pos_3args(position_arg1, position_arg2, position_arg3):
4617      return position_arg1, position_arg2, position_arg3
4618
4619    tf_func_pos = def_function.function(func_pos)
4620    with self.assertRaisesRegex(
4621        TypeError, '.* missing 1 required argument: position_arg1'):
4622      tf_func_pos(position_arg2='foo')
4623
4624    tf_func_with_default = def_function.function(func_with_default)
4625    tf_func_with_default(position_arg='bar')
4626    with self.assertRaisesRegex(TypeError,
4627                                '.* missing 1 required argument: position_arg'):
4628      tf_func_with_default(named_arg='foo')
4629
4630    tf_func_pos_3args = def_function.function(func_pos_3args)
4631    with self.assertRaisesRegex(
4632        TypeError,
4633        '.* missing required arguments: position_arg1, position_arg3'):
4634      tf_func_pos_3args(position_arg2='foo')
4635
4636  def testShapeInferencePropagateConstNestedStack(self):
4637
4638    @def_function.function(input_signature=[
4639        tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
4640        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4641    ])
4642    def f(x, s):
4643      old_shape = array_ops.shape(x)
4644      new_shape = array_ops.stack([old_shape[0], s], axis=0)
4645      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4646      return y
4647
4648    @def_function.function(input_signature=[
4649        tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
4650    ])
4651    def g(x):
4652      y = f(x, s=5)
4653      assert y.shape.as_list() == [3, 5], y.shape.as_list()
4654      return y
4655
4656    self.assertAllEqual(
4657        g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
4658
4659  def testShapeInferencePropagateConstNestedUnstackStack(self):
4660
4661    @def_function.function(input_signature=[
4662        tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
4663        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4664    ])
4665    def f(x, s):
4666      s0, _ = array_ops.unstack(array_ops.shape(x), axis=0)
4667      new_shape = array_ops.stack([s0, s], axis=0)
4668      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4669      return y
4670
4671    @def_function.function(input_signature=[
4672        tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
4673    ])
4674    def g(x):
4675      y = f(x, s=5)
4676      assert y.shape.as_list() == [3, 5], y.shape.as_list()
4677      return y
4678
4679    self.assertAllEqual(
4680        g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
4681
4682  def testShapeInferencePropagateConstNestedConcat(self):
4683
4684    @def_function.function(input_signature=[
4685        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4686        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4687        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4688    ])
4689    def f(d1, d2, d3):
4690      new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
4691      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4692      return y
4693
4694    @def_function.function()
4695    def g():
4696      y = f(1, 2, 3)
4697      assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
4698      return y
4699
4700    self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
4701
4702  def testShapeInferencePropagateConstDoubleNested(self):
4703
4704    @def_function.function(input_signature=[
4705        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4706        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4707        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4708    ])
4709    def f(d1, d2, d3):
4710      new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
4711      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4712      return y
4713
4714    @def_function.function()
4715    def g():
4716      y = def_function.function(f)(1, 2, 3)
4717      assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
4718      return y
4719
4720    self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
4721
4722  @test_util.run_v2_only
4723  def testControlDependencyAfterInline(self):
4724    v = variables.Variable(0.)
4725
4726    @def_function.function
4727    def assign():
4728      return v.assign(1.)
4729
4730    @def_function.function
4731    def assign_add():
4732      return v.assign_add(1.)
4733
4734    @def_function.function
4735    def f():
4736      check_ops.assert_equal_v2(assign(), 1.)
4737      check_ops.assert_equal_v2(assign_add(), 2.)
4738
4739    # We don't have a way to inspect the inlined graph in Python, so we run it
4740    # multiple times to have more confidence the dependency is correct.
4741    for _ in range(30):
4742      f()
4743
4744  @test_util.run_v2_only
4745  def testReadInFuncWriteOutside(self):
4746    # Run many times since we are testing for a potential race condition.
4747    for _ in range(30):
4748      # pylint: disable=cell-var-from-loop
4749      v = variables.Variable(1.)
4750
4751      @def_function.function
4752      def add_one():
4753        return v + 1.
4754
4755      @def_function.function
4756      def get_v_plus_one():
4757        v_plus_one = add_one()
4758        v.assign_add(2.0)
4759        return v_plus_one
4760
4761      self.assertAllEqual(get_v_plus_one(), 2.0)
4762
4763
4764class MultiDeviceTest(test.TestCase, parameterized.TestCase):
4765
4766  @test_util.run_gpu_only
4767  def testMultiDeviceOutput(self):
4768    """Tests that functions can produce outputs on multiple devices."""
4769    @function.defun
4770    def func(a, b, transpose_a):
4771      with ops.device('/device:CPU:0'):
4772        m1 = math_ops.matmul(a, b, transpose_a=transpose_a)
4773      with ops.device('/device:GPU:0'):
4774        m2 = math_ops.matmul(a, b, transpose_a=transpose_a)
4775      return m1, m2
4776
4777    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
4778    m1, m2 = func(t, t, transpose_a=True)
4779    self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]])
4780    self.assertRegex(m1.backing_device, 'CPU')
4781    self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]])
4782    self.assertRegex(m2.backing_device, 'GPU')
4783
4784  @test_util.run_gpu_only
4785  def testEmptyBody(self):
4786    @function.defun
4787    def func(a, b):
4788      return b, a
4789
4790    with ops.device('/device:CPU:0'):
4791      a = array_ops.identity(3.0)
4792    with ops.device('/device:GPU:0'):
4793      b = array_ops.identity(5.0)
4794
4795    m1, m2 = func(a, b)
4796    self.assertAllEqual(m1.numpy(), 5.0)
4797    self.assertRegex(m1.backing_device, 'GPU')
4798    self.assertAllEqual(m2.numpy(), 3.0)
4799    self.assertRegex(m2.backing_device, 'CPU')
4800
4801  @test_util.run_gpu_only
4802  def testMultiDeviceInt32(self):
4803    """Tests that multi-device functions can take and output INT32s.
4804
4805    When an INT32 device tensor is fed into a function, it is copied to CPU
4806    by the eager runtime. The function sees all INT32 inputs on CPU.
4807
4808    We set allocator attribute 'on_host' for INT32 outputs. They can be
4809    partitioned into the GPU component function, but will be allocated on
4810    CPU nevertheless.
4811
4812    There is experimental support for `ints_on_device` in
4813    FunctionLibraryRuntime now. We can try that.
4814
4815    """
4816    with ops.device('/device:CPU:0'):
4817      int_cpu = constant_op.constant(3, dtype=dtypes.int32)
4818      resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32)
4819    with ops.device('/device:GPU:0'):
4820      int_gpu = constant_op.constant(7, dtype=dtypes.int32)
4821
4822    @function.defun
4823    def func(int_cpu, resource, int_gpu):
4824      with ops.device('/device:CPU:0'):
4825        m1 = int_cpu * resource + int_gpu
4826      with ops.device('/device:GPU:0'):
4827        # This computation will happen on GPU but m2 will be copied to CPU.
4828        m2 = int_gpu * resource + int_cpu + 1
4829      return m1, m2
4830
4831    m1, m2 = func(int_cpu, resource, int_gpu)
4832    self.assertAllEqual(m1.numpy(), 22)
4833    self.assertRegex(m1.backing_device, 'CPU')
4834    self.assertAllEqual(m2.numpy(), 39)
4835    self.assertRegex(m2.backing_device, 'CPU')
4836
4837    # flip arguments
4838    m1, m2 = func(int_gpu, resource, int_cpu)
4839    self.assertAllEqual(m1.numpy(), 38)
4840    self.assertRegex(m1.backing_device, 'CPU')
4841    self.assertAllEqual(m2.numpy(), 23)
4842    self.assertRegex(m2.backing_device, 'CPU')
4843
4844  @test_util.run_gpu_only
4845  def testMultiDeviceColocateWith(self):
4846    """Tests that function's outputs respect colocation constraints."""
4847    @function.defun
4848    def func(a, b):
4849      with ops.colocate_with(a):
4850        ra = 2 * a
4851      with ops.colocate_with(b):
4852        rb = 3 * b
4853      return ra, rb
4854
4855    devices = ['/device:CPU:0', '/device:GPU:0']
4856    for dev1, dev2 in itertools.product(devices, devices):
4857      with ops.device(dev1):
4858        a = array_ops.identity(1.0)
4859      with ops.device(dev2):
4860        b = array_ops.identity(10.0)
4861
4862      ra, rb = func(a, b)
4863      self.assertEqual(ra.numpy(), 2.0)
4864      self.assertRegex(ra.backing_device, dev1)
4865      self.assertEqual(rb.numpy(), 30.0)
4866      self.assertRegex(rb.backing_device, dev2)
4867
4868  @test_util.run_gpu_only
4869  def testMultiDeviceResources(self):
4870    with ops.device('/device:CPU:0'):
4871      c1 = resource_variable_ops.ResourceVariable(2.0)
4872      c2 = resource_variable_ops.ResourceVariable(7.0)
4873    with ops.device('/device:GPU:0'):
4874      g1 = resource_variable_ops.ResourceVariable(3.0)
4875      g2 = resource_variable_ops.ResourceVariable(5.0)
4876
4877    @function.defun
4878    def func(resource1, resource2):
4879      with ops.device('/device:CPU:0'):
4880        result1 = resource1 * g2
4881      with ops.device('/device:GPU:0'):
4882        result2 = resource2 * c2
4883      return result1, result2
4884
4885    r1, r2 = func(c1, g1)
4886    self.assertEqual(r1.numpy(), 10.0)
4887    self.assertRegex(r1.backing_device, 'CPU')
4888    self.assertEqual(r2.numpy(), 21.0)
4889    self.assertRegex(r2.backing_device, 'GPU')
4890
4891    # Call with flipped inputs. Check that we look at resource's
4892    # device and reinstantiates the function when inputs' devices change.
4893    r1, r2 = func(g1, c1)
4894    self.assertEqual(r1.numpy(), 15.0)
4895    self.assertRegex(r1.backing_device, 'CPU')
4896    self.assertEqual(r2.numpy(), 14.0)
4897    self.assertRegex(r2.backing_device, 'GPU')
4898
4899  @test_util.run_gpu_only
4900  def testOutputResources(self):
4901    with ops.device('/device:CPU:0'):
4902      c1 = resource_variable_ops.ResourceVariable(2.0)
4903    with ops.device('/device:GPU:0'):
4904      g1 = resource_variable_ops.ResourceVariable(3.0)
4905
4906    @function.defun
4907    def func(resource1, resource2):
4908      with ops.device('/device:CPU:0'):
4909        result1 = resource1 * 5
4910      with ops.device('/device:GPU:0'):
4911        result2 = resource2 * 7
4912      return result1, resource1.handle, result2, resource2.handle
4913
4914    r1, res1, r2, res2 = func(c1, g1)
4915    self.assertEqual(r1.numpy(), 10.0)
4916    self.assertRegex(r1.backing_device, 'CPU')
4917    self.assertEqual(r2.numpy(), 21.0)
4918    self.assertRegex(r2.backing_device, 'GPU')
4919
4920    def check_handle(handle, expected_value):
4921      self.assertRegex(handle.backing_device, 'CPU')
4922      tensor = gen_resource_variable_ops.read_variable_op(
4923          handle, dtypes.float32)
4924      self.assertEqual(tensor.numpy(), expected_value)
4925
4926    # Check that handles returned from functions are on CPU and an op using
4927    # the resource handle is correctly placed on the device backing the
4928    # resource.
4929    check_handle(res1, 2.0)
4930    check_handle(res2, 3.0)
4931
4932    # Call with flipped inputs to make sure the same the function is
4933    # reinstantiated and eager runtime does not mess up the device assignment
4934    # for ops consuming handles returned from defuns.
4935    r1, res1, r2, res2 = func(g1, c1)
4936    self.assertEqual(r1.numpy(), 15.0)
4937    self.assertRegex(r1.backing_device, 'CPU')
4938    self.assertEqual(r2.numpy(), 14.0)
4939    self.assertRegex(r2.backing_device, 'GPU')
4940    check_handle(res1, 3.0)
4941    check_handle(res2, 2.0)
4942
4943  @test_util.run_gpu_only
4944  def testPassResourceThroughNestedFunctionCall(self):
4945    """Test passing GPU resource to noinline function call placed on CPU.
4946
4947    PartitionedCallOp must not enforce any particular device assignment for the
4948    resource output. Inner function marked as `_nospecialize`, so Grappler would
4949    not prune unused function output.
4950    """
4951
4952    with ops.device('/device:GPU:0'):
4953      g1 = resource_variable_ops.ResourceVariable(3.0)
4954
4955    @function.defun_with_attributes(attributes={
4956        '_noinline': True,
4957        '_nospecialize': True
4958    })
4959    def inner(resource1):
4960      return resource1 * 2, resource1.handle
4961
4962    @function.defun
4963    def outer(resource1):
4964      with ops.device('/device:CPU:0'):
4965        r1, _ = inner(resource1)
4966      return r1
4967
4968    r1 = outer(g1)
4969
4970    self.assertEqual(r1.numpy(), 6.0)
4971    self.assertRegex(r1.backing_device, 'CPU')
4972
4973  @test_util.run_gpu_only
4974  def testReturnResourceFromNestedFunctionCall(self):
4975    """Test returning GPU resource from noinline function call placed on CPU.
4976
4977    When inferring output devices for the return value, do not set a device for
4978    returns of DT_RESOURCE data type based on the device assignment of the node
4979    that produced that resource. As an example function call placed on CPU can
4980    return resources on GPU.
4981    """
4982
4983    with ops.device('/device:GPU:0'):
4984      g1 = resource_variable_ops.ResourceVariable(3.0)
4985
4986    @function.defun_with_attributes(attributes={
4987        '_noinline': True
4988    })
4989    def inner(resource1):
4990      resource1.assign_add(2.0)
4991      return resource1 * 2, resource1.handle
4992
4993    @function.defun
4994    def outer(resource1):
4995      with ops.device('/device:CPU:0'):
4996        r1, res1 = inner(resource1)
4997      return r1, res1
4998
4999    r1, res1 = outer(g1)
5000
5001    self.assertEqual(r1.numpy(), 10.0)
5002    self.assertRegex(r1.backing_device, 'CPU')
5003
5004    def check_handle(handle, expected_value):
5005      self.assertRegex(handle.backing_device, 'CPU')
5006      tensor = gen_resource_variable_ops.read_variable_op(
5007          handle, dtypes.float32)
5008      self.assertEqual(tensor.numpy(), expected_value)
5009
5010    # Check that handles returned from functions are on CPU and an op using
5011    # the resource handle is correctly placed on the device backing the
5012    # resource.
5013    check_handle(res1, 5.0)
5014
5015  @test_util.run_gpu_only
5016  def testComplexInputOutputDevicePattern(self):
5017    """Tests input/output mapping logic in partitioning."""
5018    with ops.device('/device:CPU:0'):
5019      rc0 = resource_variable_ops.ResourceVariable(2.0)
5020      rc1 = resource_variable_ops.ResourceVariable(3.0)
5021      cc0 = array_ops.identity(5.0)
5022      cc1 = array_ops.identity(7.0)
5023    with ops.device('/device:GPU:0'):
5024      rg0 = resource_variable_ops.ResourceVariable(11.0)
5025      rg1 = resource_variable_ops.ResourceVariable(13.0)
5026      cg0 = array_ops.identity(17.0)
5027      cg1 = array_ops.identity(19.0)
5028
5029    # Make sure tensors are on expected devices.
5030    for tensor in [cc0, cc1]:
5031      self.assertRegex(tensor.backing_device, 'CPU:0')
5032    for tensor in [cg0, cg1]:
5033      self.assertRegex(tensor.backing_device, 'GPU:0')
5034
5035    @function.defun
5036    def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1):
5037      with ops.device('/device:CPU:0'):
5038        m1 = rc0 * cg0
5039      with ops.device('/device:GPU:0'):
5040        m2 = rg0 * cc0
5041
5042      with ops.device('/device:CPU:0'):
5043        r1 = 1000.0 * m2 + rc1 * cg1
5044      with ops.device('/device:GPU:0'):
5045        r2 = 1000.0 * m1 + rg1 * cc1
5046
5047      return r1, r2, m2, m1
5048
5049    r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1)
5050    self.assertRegex(m1.backing_device, 'CPU')
5051    self.assertRegex(r1.backing_device, 'CPU')
5052    self.assertRegex(m2.backing_device, 'GPU')
5053    self.assertRegex(r2.backing_device, 'GPU')
5054    self.assertEqual(m1.numpy(), 34.0)
5055    self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0)
5056    self.assertEqual(m2.numpy(), 55.0)
5057    self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0)
5058
5059  @test_util.run_gpu_only
5060  def testArgumentPruning(self):
5061    """Tests functions taking unnecessary arguments."""
5062    with ops.device('/device:CPU:0'):
5063      c1 = constant_op.constant(5.0)
5064      c2 = constant_op.constant(7.0)
5065
5066    with ops.device('/device:GPU:0'):
5067      g1 = constant_op.constant(11.0)
5068      g2 = constant_op.constant(13.0)
5069      g3 = constant_op.constant(17.0)
5070
5071    @function.defun
5072    def func(g1, g2, c1, g3, c2):  # pylint: disable=unused-argument
5073      # arguments g1 and g2 are unused and can be pruned by grappler.
5074      return c1 * g3 * c2
5075
5076    result = func(g1, g2, c1, g3, c2)
5077    self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0)
5078
5079  def testNestedCallWatchedVariables(self):
5080
5081    v = variables.Variable(4.)
5082
5083    @def_function.function
5084    def f():
5085      return v ** 2.
5086
5087    with backprop.GradientTape() as tape:
5088      f()
5089
5090    self.assertEqual((v,), tape.watched_variables())
5091
5092    @def_function.function
5093    def g():
5094      return f()
5095
5096    with backprop.GradientTape() as tape:
5097      g()
5098
5099    self.assertEqual((v,), tape.watched_variables())
5100
5101    # f() can rely on the variable being read during its trace. g() checks that
5102    # variables from a function which knows about them are recorded on the
5103    # tape. h() tests that functions forward knowledge of variables to callers.
5104
5105    @def_function.function
5106    def h():
5107      return g()
5108
5109    with backprop.GradientTape() as tape:
5110      h()
5111
5112    self.assertEqual((v,), tape.watched_variables())
5113
5114  def testDeferredCapture(self):
5115    value = 1.0
5116
5117    @def_function.function
5118    def lazy_capture(x):
5119      y = ops.get_default_graph().capture_call_time_value(
5120          lambda: value, tensor_spec.TensorSpec(None))
5121      return x + y
5122
5123    self.assertAllEqual(lazy_capture(2.0), 3.0)
5124    # After changing the value of `value` the function call should return a
5125    # different result.
5126    value = 2.0
5127    self.assertAllEqual(lazy_capture(2.0), 4.0)
5128
5129  def testDeferredCaptureWithKey(self):
5130    value0 = 1.0
5131    value1 = 2.0
5132
5133    @def_function.function
5134    def lazy_capture(x):
5135      w = ops.get_default_graph().capture_call_time_value(
5136          lambda: value0, tensor_spec.TensorSpec(None), key=0)
5137      y = ops.get_default_graph().capture_call_time_value(
5138          lambda: value1, tensor_spec.TensorSpec(None), key=1)
5139      def bad_closure():
5140        raise ValueError('Should not run')
5141      z = ops.get_default_graph().capture_call_time_value(
5142          bad_closure, tensor_spec.TensorSpec(None), key=1)
5143      return x + y + w + z
5144
5145    self.assertAllEqual(lazy_capture(2.0), 7.0)
5146    value0 = 2.0
5147    value1 = 3.0
5148    self.assertAllEqual(lazy_capture(2.0), 10.0)
5149
5150  def testDeferredCaptureTypeError(self):
5151    value = constant_op.constant(1.0)
5152
5153    @def_function.function
5154    def lazy_capture(x):
5155      y = ops.get_default_graph().capture_call_time_value(
5156          lambda: value, tensor_spec.TensorSpec(()))
5157      return x + y
5158
5159    self.assertAllEqual(lazy_capture(2.0), 3.0)
5160
5161    # dtype mismatch
5162    value = constant_op.constant(1)
5163    with self.assertRaisesRegex(ValueError, 'Value .* to a tensor with dtype'):
5164      lazy_capture(2.0)
5165
5166    # shape mismatch
5167    value = constant_op.constant([1.0])
5168    with self.assertRaisesRegex(ValueError, 'Value .* shape'):
5169      lazy_capture(2.0)
5170
5171  def testDeferredCaptureReturnNestWithCompositeTensor(self):
5172    i_s = indexed_slices.IndexedSlices(
5173        constant_op.constant([1, 2]),
5174        constant_op.constant([0, 1], dtype=dtypes.int64),
5175        constant_op.constant([2]))
5176    r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])
5177    s_t = sparse_tensor.SparseTensor(
5178        values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20])
5179
5180    @def_function.function
5181    def lazy_capture():
5182      y = ops.get_default_graph().capture_call_time_value(
5183          lambda: {'i': i_s, 't': (r_t, s_t)},
5184          {'i': indexed_slices.IndexedSlicesSpec(
5185              dtype=dtypes.int32, dense_shape_dtype=dtypes.int32),
5186           't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32),
5187                 sparse_tensor.SparseTensorSpec([None], dtypes.int32))})
5188      return y['i'], y['t']
5189
5190    i, (r, s) = lazy_capture()
5191    self.assertAllEqual(i_s.values, i.values)
5192    self.assertAllEqual(i_s.indices, i.indices)
5193    self.assertAllEqual(i_s.dense_shape, i.dense_shape)
5194    self.assertAllEqual(r_t, r)
5195    self.assertAllEqual(s_t.indices, s.indices)
5196    self.assertAllEqual(s_t.values, s.values)
5197    self.assertAllEqual(s_t.dense_shape, s.dense_shape)
5198
5199  def testDeferredCaptureCompositeTensorSpecTypeMismatch(self):
5200    value = indexed_slices.IndexedSlices(
5201        constant_op.constant([1, 2]),
5202        constant_op.constant([0, 1], dtype=dtypes.int64))
5203
5204    @def_function.function
5205    def lazy_capture():
5206      return ops.get_default_graph().capture_call_time_value(
5207          lambda: value,
5208          indexed_slices.IndexedSlicesSpec(dtype=dtypes.int32))
5209
5210    # Type matches spec.
5211    lazy_capture()
5212
5213    # Extra dense shape component.
5214    value = indexed_slices.IndexedSlices(
5215        constant_op.constant([1, 2]),
5216        constant_op.constant([0, 1], dtype=dtypes.int64),
5217        constant_op.constant([2]))
5218    with self.assertRaises(ValueError):
5219      lazy_capture()
5220
5221    # Index dtype mismatch int32 vs. int64.
5222    value = indexed_slices.IndexedSlices(
5223        constant_op.constant([1, 2]),
5224        constant_op.constant([0, 1]))
5225    with self.assertRaises(ValueError):
5226      lazy_capture()
5227
5228  def testFunctoolsLruCache(self):
5229    self.skipTest(
5230        "b/194845243: inspect.getfullargspec doesn't unwrap Python decorators.")
5231
5232    @def_function.function
5233    @functools.lru_cache(maxsize=2)
5234    def f(a):
5235      return 2 * a
5236
5237    self.assertAllEqual(f(1), array_ops.constant(2))
5238
5239if __name__ == '__main__':
5240  ops.enable_eager_execution()
5241  test.main()
5242