• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import copy
22import functools
23import itertools
24import multiprocessing.pool
25import os
26import sys
27import time
28import weakref
29
30from absl.testing import parameterized
31import numpy
32
33from tensorflow.core.protobuf import config_pb2
34from tensorflow.core.protobuf import rewriter_config_pb2
35from tensorflow.python.autograph.core import ag_ctx
36from tensorflow.python.data.ops import dataset_ops
37from tensorflow.python.data.ops import iterator_ops
38from tensorflow.python.eager import backprop
39from tensorflow.python.eager import cancellation
40from tensorflow.python.eager import context
41from tensorflow.python.eager import def_function
42from tensorflow.python.eager import function
43from tensorflow.python.framework import composite_tensor
44from tensorflow.python.framework import config
45from tensorflow.python.framework import constant_op
46from tensorflow.python.framework import dtypes
47from tensorflow.python.framework import errors
48from tensorflow.python.framework import func_graph
49from tensorflow.python.framework import function as tf_function
50from tensorflow.python.framework import indexed_slices
51from tensorflow.python.framework import ops
52from tensorflow.python.framework import random_seed
53from tensorflow.python.framework import sparse_tensor
54from tensorflow.python.framework import tensor_shape
55from tensorflow.python.framework import tensor_spec
56from tensorflow.python.framework import test_ops
57from tensorflow.python.framework import test_util
58from tensorflow.python.framework import type_spec
59from tensorflow.python.layers import convolutional
60from tensorflow.python.module import module
61from tensorflow.python.ops import array_ops
62from tensorflow.python.ops import check_ops
63from tensorflow.python.ops import clip_ops
64from tensorflow.python.ops import control_flow_ops
65from tensorflow.python.ops import data_flow_ops
66from tensorflow.python.ops import functional_ops
67from tensorflow.python.ops import gen_functional_ops
68from tensorflow.python.ops import gen_random_ops
69from tensorflow.python.ops import gen_resource_variable_ops
70from tensorflow.python.ops import gen_sendrecv_ops
71from tensorflow.python.ops import gradients_impl
72from tensorflow.python.ops import init_ops
73from tensorflow.python.ops import list_ops
74from tensorflow.python.ops import logging_ops
75from tensorflow.python.ops import math_ops
76from tensorflow.python.ops import random_ops
77from tensorflow.python.ops import resource_variable_ops
78from tensorflow.python.ops import string_ops
79from tensorflow.python.ops import variable_scope
80from tensorflow.python.ops import variables
81from tensorflow.python.ops.ragged import ragged_factory_ops
82from tensorflow.python.ops.ragged import ragged_tensor
83from tensorflow.python.ops.structured import structured_tensor
84from tensorflow.python.platform import test
85from tensorflow.python.training import training_ops
86from tensorflow.python.util import compat
87from tensorflow.python.util import nest
88from tensorflow.python.util import tf_decorator
89from tensorflow.python.util import tf_inspect
90
91try:
92  import attr  # pylint:disable=g-import-not-at-top
93except ImportError:
94  attr = None
95
96
97def total_function_cache(defined):
98  # pylint: disable=protected-access
99  return (set(defined._function_cache.primary)
100          | set(defined._function_cache.arg_relaxed))
101  # pylint: enable=protected-access
102
103
104def _example_indexed_slices_with_dense_shape():
105  return indexed_slices.IndexedSlices(
106      constant_op.constant([1, 2]), constant_op.constant([0, 1]),
107      constant_op.constant([2]))
108
109
110def _example_indexed_slices_without_dense_shape():
111  return indexed_slices.IndexedSlices(
112      constant_op.constant([1, 2]), constant_op.constant([0, 1]))
113
114
115def _spec_for_value(value):
116  """Returns the (nested) TypeSpec for a value."""
117  if nest.is_sequence(value):
118    return nest.map_structure(_spec_for_value, value)
119  elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
120    return type_spec.type_spec_from_value(value)
121  else:
122    return value
123
124
125# This dummy decorator imitates ordinary decorators utilizing tf_decorator.
126def dummy_tf_decorator(method):
127
128  def wrapper(*args, **kwargs):
129    return method(*args, **kwargs)
130
131  return tf_decorator.make_decorator(method, wrapper)
132
133
134class FunctionTest(test.TestCase, parameterized.TestCase):
135
136  def setUp(self):
137    super(FunctionTest, self).setUp()
138    cpus = config.list_physical_devices('CPU')
139    # Set 4 virtual CPUs
140    config.set_logical_device_configuration(cpus[0], [
141        context.LogicalDeviceConfiguration(),
142        context.LogicalDeviceConfiguration(),
143        context.LogicalDeviceConfiguration(),
144        context.LogicalDeviceConfiguration()
145    ])
146
147  def testBasic(self):
148    matmul = def_function.function(math_ops.matmul)
149    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
150    sq = matmul(t, t, transpose_a=True)
151    sq2 = matmul(sq, t, transpose_a=True)
152    self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
153    self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])
154
155  def testOnExitCallback(self):
156    values = []
157    def append_1():
158      values.append(1)
159
160    def append_2():
161      values.append(2)
162
163    def g(x):
164      old_values = list(values)
165      ops.add_exit_callback_to_default_func_graph(append_1)
166      self.assertEqual(old_values, values)
167      return x + 1
168
169    tf_g = def_function.function(g)
170
171    def f(x):
172      old_values = list(values)
173      ops.add_exit_callback_to_default_func_graph(append_2)
174      self.assertEqual(old_values, values)
175      return tf_g(x)
176
177    tf_f = def_function.function(f)
178    self.assertEmpty(values)
179    tf_f(constant_op.constant(1.0))
180    self.assertEqual(values, [1, 2])  # Once for g, once for f.
181    tf_f(constant_op.constant([1.0]))  # force a retrace
182    self.assertEqual(values, [1, 2, 1, 2])  # And again.
183
184  def testCannotAddExitCallbackWhenNotInFunctionScope(self):
185    with self.assertRaisesRegex(RuntimeError, 'when not building a function.'):
186      ops.add_exit_callback_to_default_func_graph(lambda: None)
187
188  def testVariable(self):
189    v1 = variables.Variable(1.0)
190    add = def_function.function(lambda x, v: x + v1 + v)
191    v2 = variables.Variable(1.0)
192    x = constant_op.constant(1.0)
193    r = add(x, v2)
194    self.assertEqual(3.0, self.evaluate(r))
195
196  def testVariableOnly(self):
197    v = variables.Variable(1.0)
198    add = def_function.function(lambda x: x.assign_add(1.0))
199    r1 = add(v)
200    self.assertEqual(2.0, self.evaluate(r1))
201    c = constant_op.constant(1.0)
202    with self.assertRaisesRegex(AttributeError, 'no attribute'):
203      add(c)
204
205  @test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.')
206  def testPackedVariable(self):
207    with ops.device('/cpu:0'):
208      v0_0 = resource_variable_ops.ResourceVariable(1.0)
209    with ops.device('/cpu:1'):
210      v0_1 = resource_variable_ops.ResourceVariable(2.0)
211      v1_0 = resource_variable_ops.ResourceVariable(3.0)
212    with ops.device('/cpu:2'):
213      v1_1 = resource_variable_ops.ResourceVariable(4.0)
214
215    packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle])
216    packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle])
217
218    # TODO(b/145922293): use ResourceVariable.assign_add and
219    # ResourceVariable.read_value directly once we support packing multiple
220    # ResourceVariable into one ResourceVariable.
221    @def_function.function
222    def read_var():
223      resource_variable_ops.assign_add_variable_op(
224          packed_var_0, constant_op.constant(5.0))
225      resource_variable_ops.assign_add_variable_op(
226          packed_var_1, constant_op.constant(6.0))
227      with ops.device('/cpu:0'):
228        read0 = resource_variable_ops.read_variable_op(
229            packed_var_0, dtype=dtypes.float32)
230      with ops.device('/cpu:1'):
231        read1 = resource_variable_ops.read_variable_op(
232            packed_var_0, dtype=dtypes.float32)
233        read2 = resource_variable_ops.read_variable_op(
234            packed_var_1, dtype=dtypes.float32)
235      with ops.device('/cpu:2'):
236        read3 = resource_variable_ops.read_variable_op(
237            packed_var_1, dtype=dtypes.float32)
238
239      return read0, read1, read2, read3
240
241    arg_attrs = read_var.get_concrete_function().function_def.arg_attr
242    self.assertLen(arg_attrs, 2)
243    self.assertEqual(arg_attrs[0].attr['_composite_device'].s,
244                     compat.as_bytes(packed_var_0.device))
245    self.assertEqual(arg_attrs[1].attr['_composite_device'].s,
246                     compat.as_bytes(packed_var_1.device))
247
248    self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
249
250  def testImplementsAttributeBasic(self):
251    v = def_function.function(
252        experimental_implements='func')(lambda x, y: x + y)
253    with context.graph_mode(), self.cached_session():
254      a = array_ops.placeholder(dtypes.float32, ())
255      b = array_ops.placeholder(dtypes.float32, ())
256      v(a, b)
257      gradients_impl.gradients(v(a, b), [a, b])
258      fdefs = ops.get_default_graph().as_graph_def().library.function
259      self.assertLen(fdefs, 3)
260      not_present = 0
261      present = 0
262      for f in fdefs:
263        name = f.signature.name
264        if 'forward' in name or 'backward' in name:
265          not_present += 1
266          self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
267        else:
268          present += 1
269          self.assertEqual(f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME].s,
270                           'func'.encode('ascii'), f)
271      self.assertEqual(not_present, 2, fdefs)
272      self.assertEqual(present, 1, fdefs)
273
274  def testImplementsAttributeAssertsOnSideInput(self):
275    with context.graph_mode(), self.cached_session():
276      z = array_ops.zeros(0)
277      v = def_function.function(
278          experimental_implements='func')(lambda x, y: x + y + z)
279      a = array_ops.ones((1.0,))
280      b = array_ops.ones((1.0,))
281      with self.assertRaisesRegex(AssertionError,
282                                  'variables are always captured'):
283        v(a, b)
284      functions = ops.get_default_graph().as_graph_def().library.function
285      self.assertEmpty(functions)
286
287  def testImplementsAttributeWorksWithGradientTape(self):
288    add = lambda x, y: x + y ** 2
289    add = def_function.function(experimental_implements='MyFunc')(add)
290    x = variables.Variable(3.0)
291    y = variables.Variable(2.0)
292
293    with backprop.GradientTape() as tape:
294      g = add(x, y)
295
296    dg_dy, dg_dx = tape.gradient(g, [y, x])
297    self.assertEqual(dg_dy.numpy(), 4.0)
298    self.assertEqual(dg_dx.numpy(), 1.0)
299
300  def testImplementsAttributeWorksOnVariables(self):
301    with context.graph_mode(), self.cached_session():
302      v = def_function.function(
303          experimental_implements='func')(lambda x, y: x + y)
304      a = variables.Variable((1.0,))
305      b = variables.Variable((1.0,))
306      r1 = v(a, b)
307      _ = v(a, a)
308      functions = ops.get_default_graph().as_graph_def().library.function
309      # Verify that we created only one function
310      self.assertLen(functions, 1)
311      # Verify that eval() reads the current values.
312      a.initializer.run()
313      b.initializer.run()
314      self.assertEqual(r1.eval(), 2)
315
316      a.assign_add([1]).eval()
317      self.assertEqual(r1.eval(), 3)
318
319  def testImplementsAttributeWorksOnConstants(self):
320    with context.graph_mode(), self.cached_session():
321      v = def_function.function(
322          experimental_implements='func')(lambda x, y: x + y)
323      a = variables.Variable(1.0)
324      r1 = v(a, 2.)
325      r2 = v(2., a)
326      functions = ops.get_default_graph().as_graph_def().library.function
327      self.assertLen(functions, 1)
328      self.assertLen(functions[0].signature.input_arg, 2)
329      # Verify that eval() reads the current values.
330      a.initializer.run()
331      self.assertEqual(r1.eval(), 3)
332      self.assertEqual(r2.eval(), 3)
333
334  def testImplementsAttributeSpecializes(self):
335    with context.graph_mode(), self.cached_session():
336      v = def_function.function(
337          experimental_implements='func')(lambda x, y: x + y)
338      a = variables.Variable(1.0)
339      r1 = v(a, [2.])
340      r2 = v([2., 2], a)
341      functions = ops.get_default_graph().as_graph_def().library.function
342      self.assertLen(functions, 2)
343      # Ensure that all parameters are still there and haven't been inlined!
344
345      self.assertLen(functions[0].signature.input_arg, 2)
346      self.assertLen(functions[1].signature.input_arg, 2)
347      # Verify that eval() reads the current values.
348      a.initializer.run()
349      numpy.testing.assert_equal(r1.eval(), [3.])
350      numpy.testing.assert_equal(r2.eval(), [3., 3.])
351
352  def testImplementsWorksWithTensorSpec(self):
353    v = def_function.function(
354        experimental_implements='func')(lambda x, y: x + y)
355    v = v.get_concrete_function(
356        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32),
357        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32))
358    x = v(1., 2.)
359    self.assertEqual(x.numpy(), 3.)
360
361  def testImplementsAttributeAsNameAttrList(self):
362    implements_attr = (
363        'name: "embedding_matmul" attr {   key: "key1"   value {     i: 2   } '
364        '} attr {   key: "key2"   value {     b: false   } }')
365    v = def_function.function(
366        experimental_implements=implements_attr)(lambda x, y: x + y)
367    with context.graph_mode(), self.cached_session():
368      a = array_ops.placeholder(dtypes.float32, ())
369      b = array_ops.placeholder(dtypes.float32, ())
370      v(a, b)
371      gradients_impl.gradients(v(a, b), [a, b])
372      fdefs = ops.get_default_graph().as_graph_def().library.function
373      self.assertLen(fdefs, 3)
374      not_present = 0
375      present = 0
376      for f in fdefs:
377        name = f.signature.name
378        if 'forward' in name or 'backward' in name:
379          not_present += 1
380          self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
381        else:
382          present += 1
383          attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME]
384          self.assertIsNotNone(attr_value.func, f)
385          self.assertEqual(attr_value.func.name, 'embedding_matmul')
386          name_attrs = attr_value.func.attr
387          self.assertLen(name_attrs, 2)
388      self.assertEqual(not_present, 2, fdefs)
389      self.assertEqual(present, 1, fdefs)
390
391  def testExternalControlDependency(self):
392    with ops.Graph().as_default(), self.test_session():
393      v = variables.Variable(1.0)
394      v.initializer.run()
395
396      op = v.assign_add(1.0)
397
398      @function.defun
399      def f():
400        with ops.control_dependencies([op]):
401          return 1.0
402
403      self.evaluate(f())
404      self.assertAllEqual(self.evaluate(v), 2.0)
405
406  def testInputShapeFunctionRelaxation(self):
407    unknown_dim = [False]
408
409    @function.defun(experimental_relax_shapes=True)
410    def func(a):
411      if a._shape_tuple()[0] is None:
412        unknown_dim[0] = True
413      return a + 1
414
415    func(constant_op.constant([]))
416    self.assertFalse(unknown_dim[0])
417    self.assertLen(total_function_cache(func), 1)
418
419    func(constant_op.constant([1.0]))
420    self.assertFalse(unknown_dim[0])
421    self.assertLen(total_function_cache(func), 2)
422
423    func(constant_op.constant([1.0, 2.0]))
424    self.assertTrue(unknown_dim[0])
425    self.assertLen(total_function_cache(func), 2)
426
427  def testInputShapeRelaxationOnInstanceMethod(self):
428    # Test that experimental_relax_shapes is passed during
429    # instance method bounding.
430    unknown_dim = [False]
431
432    class Foo(object):
433
434      @def_function.function(experimental_relax_shapes=True)
435      def func(self, a):
436        if a._shape_tuple()[0] is None:
437          unknown_dim[0] = True
438        return a + 1
439
440    foo = Foo()
441    foo.func(constant_op.constant([]))
442    self.assertFalse(unknown_dim[0])
443
444    foo.func(constant_op.constant([1.0]))
445    self.assertFalse(unknown_dim[0])
446
447    foo.func(constant_op.constant([1.0, 2.0]))
448    self.assertTrue(unknown_dim[0])
449
450  def testInputShapeFunctionRelaxationWithRaggedTensors(self):
451    traced_type_spec = [None]
452
453    @def_function.function(experimental_relax_shapes=True)
454    def func(x):
455      traced_type_spec[0] = x._type_spec
456      return x
457
458    def check_trace(x, expected_trace):
459      traced_type_spec[0] = None
460      func(x)
461      self.assertEqual(traced_type_spec[0], expected_trace)
462
463    check_trace(  # Initial call gets traced.
464        ragged_factory_ops.constant([[1], [2, 3, 4]]),
465        ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32))
466    check_trace(  # Input TypeSpec is the same -> no retrace.
467        ragged_factory_ops.constant([[1, 2], [3, 4]]), None)
468    check_trace(  # Even if component tensor shapes change -> no retrace.
469        ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None)
470    check_trace(  # Different TypeSpec shape (nrows): retrace
471        ragged_factory_ops.constant([[1], [2], [3]]),
472        ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32))
473    check_trace(  # Different nrows again: relax & retrace
474        ragged_factory_ops.constant([[1], [2], [3], [4]]),
475        ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32))
476    check_trace(  # Different nrows yet again: not retrace
477        ragged_factory_ops.constant([[1]]), None)
478    check_trace(  # Different ragged_rank: retrace
479        ragged_factory_ops.constant([[[1]]]),
480        ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32))
481    check_trace(  # Different ragged_rank again: retrace & relax
482        ragged_factory_ops.constant([[[1]], [[2]]]),
483        ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32))
484
485  def testInputShapeFunctionRelaxationWithStructuredTensors(self):
486    traced_type_spec = [None]
487
488    @def_function.function(experimental_relax_shapes=True)
489    def func(x):
490      traced_type_spec[0] = x._type_spec
491      return x
492
493    def check_trace(x, expected_trace):
494      traced_type_spec[0] = None
495      func(x)
496      self.assertEqual(traced_type_spec[0], expected_trace)
497
498    # If we have TypeSpecs that differ in ways other than just their shape,
499    # then retrace each time.
500    check_trace(
501        structured_tensor.StructuredTensor.from_pyval({'a': [1]}),
502        structured_tensor.StructuredTensorSpec(
503            [], {'a': tensor_spec.TensorSpec((1,), dtypes.int32)}))
504    check_trace(
505        structured_tensor.StructuredTensor.from_pyval({'b': [1]}),
506        structured_tensor.StructuredTensorSpec(
507            [], {'b': tensor_spec.TensorSpec((1,), dtypes.int32)}))
508    check_trace(
509        structured_tensor.StructuredTensor.from_pyval({'c': [1]}),
510        structured_tensor.StructuredTensorSpec(
511            [], {'c': tensor_spec.TensorSpec((1,), dtypes.int32)}))
512
513    # But if we call again with only shape different, then do relax:
514    check_trace(  # retrace
515        structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}),
516        structured_tensor.StructuredTensorSpec(
517            [], {'a': tensor_spec.TensorSpec((2,), dtypes.int32)}))
518    check_trace(  # relax & retrace
519        structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}),
520        structured_tensor.StructuredTensorSpec(
521            [], {'a': tensor_spec.TensorSpec((None,), dtypes.int32)}))
522    check_trace(  # use relaxed graph
523        structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}),
524        None)
525
526  def testInputShapeFunctionRelaxationWithDatasetIterators(self):
527    # For dataset iterators, the TypeSpec includes type information that's
528    # not derivable from the component tensors.  Make sure that the TypeSpec
529    # shapes get relaxed as appropriate.
530
531    traced_type_spec = [None]
532
533    @def_function.function(experimental_relax_shapes=True)
534    def func(x):
535      traced_type_spec[0] = x._type_spec
536      return x
537
538    def check_trace(x, expected_trace):
539      traced_type_spec[0] = None
540      func(x)
541      self.assertEqual(traced_type_spec[0], expected_trace)
542
543    ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2]))
544    ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2]))
545    ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2]))
546    ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2]))
547    ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1]))
548    check_trace(  # shape=[1, 2]: retrace
549        dataset_ops.make_one_shot_iterator(ds_1_2),
550        iterator_ops.IteratorSpec(
551            tensor_spec.TensorSpec([1, 2], dtypes.float32)))
552    check_trace(  # shape=[1, 2]: no retrace (use the [1, 2] graph)
553        dataset_ops.make_one_shot_iterator(ds_1_2), None)
554    check_trace(  # shape=[2, 2]: retrace
555        dataset_ops.make_one_shot_iterator(ds_2_2),
556        iterator_ops.IteratorSpec(
557            tensor_spec.TensorSpec([2, 2], dtypes.float32)))
558    check_trace(  # shape=[3, 2]: relax to [None, 2] and retrace
559        dataset_ops.make_one_shot_iterator(ds_3_2),
560        iterator_ops.IteratorSpec(
561            tensor_spec.TensorSpec([None, 2], dtypes.float32)))
562    check_trace(  # shape=[4, 2]: no retrace (use the [None, 2] graph)
563        dataset_ops.make_one_shot_iterator(ds_4_2), None)
564    check_trace(  # shape=[2, 1]: relax to [None, None] and retrace
565        dataset_ops.make_one_shot_iterator(ds_2_1),
566        iterator_ops.IteratorSpec(
567            tensor_spec.TensorSpec([None, None], dtypes.float32)))
568
569  def testCapturesVariables(self):
570    a = variables.Variable(1.0, trainable=False)
571    b = variables.Variable(1.0)
572    cc = [None]
573
574    @def_function.function
575    def f():
576      c = cc[0]
577      if c is None:
578        c = cc[0] = variables.Variable(1.)
579      return a + b + c + 1
580
581    cf = f.get_concrete_function()
582    c = cc[0]
583
584    captured_variables = {v.ref() for v in (a, b, c)}
585    trainable_variables = {v.ref() for v in (b, c)}
586    self.assertEqual({v.ref() for v in cf.variables}, captured_variables)
587    self.assertEqual({v.ref() for v in cf.trainable_variables},
588                     trainable_variables)
589    self.assertEqual(cf.variables, cf.graph.variables)
590    self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables)
591
592  def testNestedInputShapeFunctionRelaxation(self):
593    unknown_dim = [False]
594
595    @function.defun(experimental_relax_shapes=True)
596    def func(a_, b_=None):
597      del a_  # Only used to check which cache is used.
598      self.assertEqual(b_[0]._shape_tuple(), ())
599      if b_[1]._shape_tuple()[0] is None:
600        unknown_dim[0] = True
601      return b_[0] + 1
602
603    a = 'hi'
604    b0 = constant_op.constant(1.0)
605    func(a, b_=[b0, constant_op.constant([])])
606    self.assertFalse(unknown_dim[0])
607    self.assertLen(total_function_cache(func), 1)
608
609    func(a, b_=[b0, constant_op.constant([1.0])])
610    self.assertFalse(unknown_dim[0])
611    self.assertLen(total_function_cache(func), 2)
612
613    func(a, b_=[b0, constant_op.constant([1.0, 1.0])])
614    self.assertTrue(unknown_dim[0])
615    self.assertLen(total_function_cache(func), 2)
616
617    unknown_dim[0] = False
618
619    # Now do the same except with a new a which is not a tensor; this should
620    # change the cache key.
621    a = 'bye'
622    func(a, b_=[b0, constant_op.constant([])])
623    self.assertFalse(unknown_dim[0])
624    self.assertLen(total_function_cache(func), 3)
625
626    # Since we already marked a cache miss for a function with the same
627    # non-input signatures, here we will immediately start relaxing shapes.
628    func(a, b_=[b0, constant_op.constant([1.0])])
629    self.assertTrue(unknown_dim[0])
630    self.assertLen(total_function_cache(func), 3)
631
632  def testNestedShapeFunctionRelaxation(self):
633
634    got_shape = [None]
635
636    # The inner function will go through shape relaxation because the shapes it
637    # receives will be [1], [2], [3], ...
638    @def_function.function(experimental_relax_shapes=True)
639    def bar(x_shape):
640      got_shape[0] = x_shape._shape_tuple()
641      return x_shape
642
643    # The outer function will not go through shape relaxation because the shapes
644    # it receives will be [1], [[1]], [[[1]]], ...
645    @def_function.function(experimental_relax_shapes=True)
646    def foo(ones):
647      return bar(array_ops.shape(ones))
648
649    for rank in range(1, 6):
650      x_shape = self.evaluate(foo(array_ops.ones([1] * rank)))
651      self.assertAllEqual(x_shape, [1] * rank)
652      if rank < 3:
653        self.assertEqual(got_shape[0], (rank,))
654      else:
655        self.assertEqual(got_shape[0], (None,))
656
657  def testNoHash(self):
658
659    @def_function.function()
660    def f(_):
661      return 1.0
662
663    with self.assertRaisesRegex(ValueError, r'Got type: set'):
664      f(set([]))
665
666  def testFuncName(self):
667
668    @function.defun_with_attributes(attributes={'func_name': 'multiply'})
669    def add(x, y):
670      _ = x * y
671      return x + y
672
673    @function.defun
674    def add_2(x, y):
675      _ = x * y
676      return x + y
677
678    self.assertEqual(add._name, 'multiply')
679    self.assertEqual(add_2._name, 'add_2')
680
681  def testBasicGraphMode(self):
682    matmul = def_function.function(math_ops.matmul)
683
684    @def_function.function
685    def sq(a):
686      return matmul(a, a)
687
688    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
689    out = sq(t)
690    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
691
692  def testNestedInputsGraphMode(self):
693    matmul = def_function.function(math_ops.matmul)
694
695    pair = collections.namedtuple('pair', ['a', 'b'])
696
697    @def_function.function
698    def a_times_b(inputs):
699      return matmul(inputs.a['a'], inputs.b['b'])
700
701    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
702
703    out = a_times_b(pair({'a': t}, {'b': t}))
704    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
705
706  def testNestedOutputsGraphMode(self):
707    matmul = def_function.function(math_ops.matmul)
708
709    pair = collections.namedtuple('pair', ['a', 'b'])
710
711    @def_function.function()
712    def pairs_mul(pair_a, pair_b):
713      return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b))
714
715    a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]])
716    b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]])
717
718    out = pairs_mul(pair(a, b), pair(b, a))
719    expected = pair(math_ops.matmul(a, b).numpy(),
720                    math_ops.matmul(b, a).numpy())
721    self.assertAllClose(out, expected)
722
723  @parameterized.named_parameters(
724      dict(testcase_name='Defun',
725           function_decorator=function.defun),
726      dict(testcase_name='DefFunction',
727           function_decorator=def_function.function))
728  def testNestedFunctionGraphNotOutOfDate(self, function_decorator):
729    @function_decorator
730    def f():
731      return constant_op.constant(1.)
732
733    class _Model(object):
734
735      @function_decorator
736      def g(self):
737        self.f = f.get_concrete_function()
738
739    model = _Model()
740    model.g()
741    concrete = model.f
742    weak_g_graph = weakref.ref(model.g.get_concrete_function().graph)
743    self.assertIs(weak_g_graph(), concrete.graph.outer_graph)
744    weak_g = weakref.ref(model.g)
745    del model
746    self.assertIsNone(weak_g())
747    self.assertIsNone(weak_g_graph())
748    self.assertIsNotNone(concrete.graph.outer_graph)
749    self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph)
750
751  def testGraphEagerIsolation(self):
752
753    @function.defun
754    def f():
755      self.v = variables.Variable(1.0)
756      return self.v.read_value()
757
758    self.assertAllEqual(f(), 1.0)
759
760    with ops.Graph().as_default():
761      self.assertEqual(f().shape, ())
762
763  def testBasicGraphFunction(self):
764    matmul = def_function.function(math_ops.matmul)
765
766    @def_function.function
767    def sq(a):
768      return matmul(a, a)
769
770    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
771
772    sq_op = sq.get_concrete_function(t)
773    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
774    out = sq_op(t)
775    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
776
777  def testGetConcreteFunctionThreadSafety(self):
778
779    @def_function.function
780    def sq():
781      t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
782      return math_ops.matmul(t, t)
783
784    concrete_functions = []
785
786    def thread_func(_):
787      cf = sq.get_concrete_function()
788      concrete_functions.append(cf)
789
790    num_threads = 100
791    pool = multiprocessing.pool.ThreadPool(num_threads)
792    _ = pool.map(thread_func, list(range(num_threads)))
793
794    self.assertLen(set(concrete_functions), 1)
795
796  def testGetConcreteFunctionThreadSafetyWithArgs(self):
797    @def_function.function
798    def add_100(*args):
799      return math_ops.add_n(args)
800
801    p = multiprocessing.pool.ThreadPool(2)
802    args = (constant_op.constant(1.),) * 100
803    f1, f2 = p.map(add_100.get_concrete_function, [args] * 2)
804    # I see about len(args) + max(0, len(args) - 3) arguments expected.
805    f1(*args)
806    del f2
807
808  def testInputSpecGraphFunction(self):
809    matmul = def_function.function(math_ops.matmul)
810
811    @def_function.function
812    def sq(a):
813      return matmul(a, a)
814
815    sq_op = sq.get_concrete_function(
816        tensor_spec.TensorSpec((None, None), dtypes.float32))
817    self.assertEqual([None, None], sq_op.output_shapes.as_list())
818
819    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
820    out1 = sq_op(t1)
821    self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
822
823    t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
824    out2 = sq_op(t2)
825    self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
826
827  def testNestedInputSpecGraphFunction(self):
828    matmul = def_function.function(math_ops.matmul)
829
830    @def_function.function
831    def sq(mats):
832      ((a, b),) = mats
833      return matmul(a, b)
834
835    sq_op_autonamed = sq.get_concrete_function(
836        [(tensor_spec.TensorSpec((None, None), dtypes.float32),
837          tensor_spec.TensorSpec((None, None), dtypes.float32))])
838    self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list())
839
840    sq_op = sq.get_concrete_function(
841        [(tensor_spec.TensorSpec((None, None), dtypes.float32,
842                                 name='first_mat'),
843          tensor_spec.TensorSpec((None, None), dtypes.float32,
844                                 name='second_mat'))])
845    self.assertEqual([None, None], sq_op.output_shapes.as_list())
846
847    t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
848    t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
849    out = sq_op(first_mat=t1, second_mat=t2)
850    self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
851    self.assertAllEqual(sq_op_autonamed(t1, t2),
852                        math_ops.matmul(t1, t2).numpy())
853
854  def testExecutingStatelessDefunConcurrently(self):
855
856    @def_function.function
857    def stateless(x):
858      return math_ops.multiply(2.0, x)
859
860    pool = multiprocessing.pool.ThreadPool()
861    inputs = [constant_op.constant(1.0 * x) for x in range(100)]
862    outputs = [float(out) for out in pool.map(stateless, inputs)]
863    expected = [float(2.0 * x) for x in inputs]
864    self.assertSequenceEqual(outputs, expected)
865
866  def testExecutingManyStatelessDefunsConcurrently(self):
867
868    @def_function.function
869    def stateless(x):
870      del x
871      return math_ops.multiply(2.0, 2.0)
872
873    pool = multiprocessing.pool.ThreadPool()
874    # `pool.map` below instantiates 100 functions, one for each object.
875    objects = [object() for _ in range(100)]
876    outputs = [float(out) for out in pool.map(stateless, objects)]
877    expected = [4.0] * 100
878    self.assertSequenceEqual(outputs, expected)
879
880  @test_util.disable_tfrt('b/169431085: This test is flaky on tfrt')
881  def testExecutingStatefulDefunConcurrently(self):
882
883    v = resource_variable_ops.ResourceVariable(1.0)
884
885    @def_function.function
886    def stateful(x):
887      v.assign(x)
888
889    pool = multiprocessing.pool.ThreadPool()
890    inputs = [constant_op.constant(0.0)] * 100
891    pool.map(stateful, inputs)
892    self.assertEqual(float(v.read_value()), 0.0)
893
894  def testExecutingManyStatefulDefunsConcurrently(self):
895
896    v = resource_variable_ops.ResourceVariable(1.0)
897
898    @def_function.function
899    def stateful(x):
900      del x
901      return v.assign(0.0)
902
903    pool = multiprocessing.pool.ThreadPool()
904    # `pool.map` below instantiates 100 functions, one for each object.
905    pool.map(stateful, [object() for _ in range(100)])
906    self.assertEqual(float(v.read_value()), 0.0)
907
908  def testShareRendezvous(self):
909
910    # Disable grappler from inlining the functions. Note we run the send & recv
911    # in graph mode since with eager mode the function should automatically be
912    # inlined.
913    context.context().set_optimizer_experimental_options(
914        {'disable_meta_optimizer': True})
915
916    cpu = '/device:CPU:0'
917
918    signature = [tensor_spec.TensorSpec([], dtypes.int32)]
919
920    @def_function.function
921    def send():
922      x = constant_op.constant(1)
923      gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu)
924      return x
925
926    send._shared_rendezvous = True  # pylint: disable=protected-access
927
928    @def_function.function(input_signature=signature)
929    def send_body(n):
930      send()
931      return n - 1
932
933    @def_function.function
934    def recv():
935      return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu)
936
937    recv._shared_rendezvous = True  # pylint: disable=protected-access
938
939    @def_function.function(input_signature=signature)
940    def recv_body(n):
941      recv()
942      return n - 1
943
944    @def_function.function(input_signature=signature)
945    def cond(n):
946      return n > 0
947
948    # Instead of calling the send & recv functions directly we want to call them
949    # through a functional while to ensure the rendezvous is shared across the
950    # while boundary.
951    @def_function.function
952    def fn(n):
953      functional_ops.While([n], cond.get_concrete_function(),
954                           send_body.get_concrete_function())
955      return functional_ops.While([n], cond.get_concrete_function(),
956                                  recv_body.get_concrete_function())
957
958    # Use a graph context since functions will not be automatically inlined
959    with context.graph_mode(), self.cached_session():
960      self.evaluate(fn(2))
961
962  def disabled_testRandomSeed(self):
963
964    @def_function.function
965    def f():
966      return random_ops.random_normal(())
967
968    random_seed.set_random_seed(1)
969    x = f()
970    self.assertNotEqual(x, f())
971    random_seed.set_random_seed(1)
972    self.assertAllEqual(f(), x)
973
974  def testNestedInputsGraphFunction(self):
975    matmul = def_function.function(math_ops.matmul)
976
977    pair = collections.namedtuple('pair', ['a', 'b'])
978
979    @def_function.function
980    def a_times_b(inputs):
981      return matmul(inputs.a['a'], inputs.b['b'])
982
983    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
984    sq_op = a_times_b.get_concrete_function(
985        pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')),
986             dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b'))))
987    self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
988    out = sq_op(a=t, b=t)
989    self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
990
991  def testNestedOutputGraphFunction(self):
992    matmul = def_function.function(math_ops.matmul)
993
994    @def_function.function
995    def sq(a):
996      return (matmul(a, a), {'b': constant_op.constant(1.0)})
997
998    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
999
1000    sq_op = sq.get_concrete_function(t)
1001    self.assertEqual(sq_op.output_shapes,
1002                     (tensor_shape.TensorShape([2, 2]),
1003                      {'b': tensor_shape.TensorShape([])}))
1004    self.assertEqual(sq_op.output_dtypes,
1005                     (dtypes.float32, {'b': dtypes.float32}))
1006    (a, b) = sq_op(t)
1007    self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
1008    self.assertAllEqual(b['b'].numpy(), 1.0)
1009
1010  def testGraphFunctionNoneOutput(self):
1011    @def_function.function
1012    def fn(unused_a, unused_b):
1013      return None
1014
1015    x = constant_op.constant(1)
1016    fn_op = fn.get_concrete_function(x, x)
1017    self.assertEqual(fn_op.output_dtypes, None)
1018    self.assertEqual(fn_op.output_shapes, None)
1019    self.assertAllEqual(fn_op(x, x), None)
1020
1021  def testDefunNumpyArraysConvertedToTensors(self):
1022
1023    def f(x):
1024      self.assertIsInstance(x, ops.Tensor)
1025      return x
1026
1027    x = random_ops.random_uniform([2, 2]).numpy()
1028    defined = function.defun(f)
1029    defined(x)
1030    self.assertLen(total_function_cache(defined), 1)
1031
1032    x = random_ops.random_uniform([2, 2]).numpy()
1033    defined(x)
1034    # A NumPy array with different values but the same shape and dtype
1035    # shouldn't trigger another function definition.
1036    self.assertLen(total_function_cache(defined), 1)
1037
1038    np_ones = numpy.ones([], numpy.float32)
1039    np_zeros = numpy.zeros([], numpy.float32)
1040    tf_ones = array_ops.ones([])
1041    tf_zeros = array_ops.zeros([])
1042
1043    # Test that the numpy array is properly an argument to the graph function.
1044    self.assertEqual(1., defined(np_ones).numpy())
1045    self.assertLen(total_function_cache(defined), 2)
1046    self.assertEqual(0., defined(np_zeros).numpy())
1047    self.assertEqual(1., defined(tf_ones).numpy())
1048    self.assertEqual(0., defined(tf_zeros).numpy())
1049    self.assertLen(total_function_cache(defined), 2)
1050
1051    # Test that mutable inputs are supported.
1052    mutable = numpy.ones([], numpy.float32)
1053    self.assertEqual(1., defined(mutable).numpy())
1054    mutable.fill(0)
1055    self.assertEqual(0., defined(mutable).numpy())
1056
1057    class MyNdarray(numpy.ndarray):
1058      pass
1059
1060    # Test that the subclasses of ndarray are converted too.
1061    self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
1062    self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
1063
1064    # We should not have triggered any re-tracing of the python function.
1065    self.assertLen(total_function_cache(defined), 2)
1066
1067  def testNumpyDtypeInputSupported(self):
1068    @function.defun
1069    def f(x, dtype):
1070      return constant_op.constant(dtype(x))
1071
1072    self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1))
1073    self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2))
1074    self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1))
1075    self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2))
1076
1077  def testDefunNumpyArraysConvertedToTensorsInKwargs(self):
1078
1079    def f(**kwargs):
1080      x = kwargs.pop('x')
1081      self.assertIsInstance(x, ops.Tensor)
1082      return x
1083
1084    x = random_ops.random_uniform([2, 2]).numpy()
1085    defined = function.defun(f)
1086    defined(x=x)
1087    self.assertLen(total_function_cache(defined), 1)
1088
1089    x = random_ops.random_uniform([2, 2]).numpy()
1090    defined(x=x)
1091    # A NumPy array with different values but the same shape and dtype
1092    # shouldn't trigger another function definition.
1093    self.assertLen(total_function_cache(defined), 1)
1094
1095    # Test that the numpy array is properly an argument to the graph function.
1096    self.assertEqual(1., defined(x=numpy.ones([])).numpy())
1097    self.assertEqual(0., defined(x=numpy.zeros([])).numpy())
1098    self.assertEqual(1., defined(x=array_ops.ones([])).numpy())
1099    self.assertEqual(0., defined(x=array_ops.zeros([])).numpy())
1100
1101  def testDefunCapturedInt32(self):
1102    x = constant_op.constant(1, dtype=dtypes.int32)
1103
1104    @def_function.function
1105    def add_int32s():
1106      return x + x
1107
1108    self.assertEqual(2, int(add_int32s()))
1109
1110  def testDefunReadVariable(self):
1111    v = resource_variable_ops.ResourceVariable(1.0)
1112
1113    @def_function.function
1114    def f():
1115      return v.read_value()
1116
1117    self.assertEqual(1.0, float(f()))
1118
1119  def testDefunAssignAddVariable(self):
1120    v = resource_variable_ops.ResourceVariable(1.0)
1121    x = constant_op.constant(2.0)
1122
1123    @def_function.function
1124    def test_assign_add():
1125      v.assign_add(x)
1126      return v.read_value()
1127
1128    self.assertEqual(3.0, float(test_assign_add()))
1129
1130  @test_util.run_in_graph_and_eager_modes
1131  def testTensorInitializationInFunctionRaisesError(self):
1132    error_msg = ('Tensor-typed variable initializers must either be '
1133                 'wrapped in an init_scope or callable.*')
1134
1135    @def_function.function
1136    def tensor_init():
1137      with self.assertRaisesRegex(ValueError, error_msg):
1138        resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
1139
1140    tensor_init()
1141
1142  @test_util.run_in_graph_and_eager_modes
1143  def testCallableTensorInitializationInFunction(self):
1144
1145    @def_function.function
1146    def tensor_init():
1147      self.v = resource_variable_ops.ResourceVariable(
1148          lambda: constant_op.constant(2.0))
1149      return self.v.read_value()
1150
1151    value = tensor_init()
1152    if not context.executing_eagerly():
1153      self.evaluate(variables.global_variables_initializer())
1154    self.assertEqual(self.evaluate(value), 2.0)
1155
1156  @test_util.also_run_as_tf_function
1157  def testInitScopeTensorInitializationInFunction(self):
1158
1159    @def_function.function
1160    def tensor_init():
1161      with ops.init_scope():
1162        const = constant_op.constant(2.0)
1163      # Note: this variable bypasses tf.function's variable creation
1164      # requirements by bypassing variable_creator_scope by using
1165      # ResourceVariable instead of Variable.
1166      self.v = resource_variable_ops.ResourceVariable(const)
1167      return self.v.read_value()
1168
1169    value = tensor_init()
1170    self.assertAllEqual(value, 2.0)
1171
1172  @test_util.run_in_graph_and_eager_modes
1173  def testGetConcreteFunctionCreatesVariables(self):
1174
1175    v_holder = []
1176
1177    @def_function.function
1178    def tensor_init():
1179      if not v_holder:
1180        v_holder.append(variables.Variable(5.))
1181      return v_holder[0].read_value()
1182
1183    concrete = tensor_init.get_concrete_function()
1184    self.evaluate(variables.global_variables_initializer())
1185    self.assertAllEqual(5., self.evaluate(concrete()))
1186    self.assertAllEqual(5., self.evaluate(tensor_init()))
1187
1188  def testFuncGraphCaptureByValue(self):
1189    v = variables.Variable(1.0)
1190
1191    def trivial_function():
1192      return v.read_value()
1193
1194    graph_function = function.Function(
1195        trivial_function, 'test', capture_by_value=True)
1196
1197    self.assertAllEqual(graph_function(), 1.0)
1198    v.assign(2.0)
1199    self.assertAllEqual(graph_function(), 1.0)
1200
1201  def testFuncGraphCaptureByValueNested(self):
1202    v = variables.Variable(1.0)
1203
1204    def trivial_function():
1205      return control_flow_ops.cond(
1206          array_ops.placeholder_with_default(True, ()),
1207          v.read_value, v.read_value)
1208
1209    graph_function = function.Function(
1210        trivial_function, 'test', capture_by_value=True)
1211
1212    self.assertAllEqual(graph_function(), 1.0)
1213    v.assign(2.0)
1214    self.assertAllEqual(graph_function(), 1.0)
1215
1216  def testDefunShapeInferenceWithCapturedResourceVariable(self):
1217    v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
1218
1219    def f():
1220      x = constant_op.constant([[1, 2], [3, 4]])
1221      out = math_ops.matmul(v, x)
1222      self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
1223      # We do not return v directly since the tensor conversion function of
1224      # ResourceVariable returns the read value and not the resource itself.
1225      return v._handle
1226
1227    compiled = def_function.function(f)
1228    var_handle = compiled()
1229    self.assertEqual(var_handle.dtype, dtypes.resource)
1230    self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
1231    var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
1232    self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
1233
1234  def testShapeInferenceForMoreSpecificInput(self):
1235
1236    def f(a):
1237      return array_ops.reshape(a, [-1, 3])
1238
1239    signature = [tensor_spec.TensorSpec(None, dtypes.float32)]
1240    compiled = def_function.function(f, input_signature=signature)
1241
1242    @def_function.function
1243    def use_f():
1244      inputs = array_ops.zeros([10, 10, 3])
1245      self.assertAllEqual(f(inputs).shape, compiled(inputs).shape)
1246
1247    use_f()
1248
1249  def testFuncListAttr(self):
1250
1251    @function.defun
1252    def test_function(val):
1253
1254      def fn1():
1255        return array_ops.ones([10])
1256
1257      fn2 = lambda: array_ops.ones([10]) * 2
1258
1259      def fn3(x=3):
1260        return array_ops.ones([10]) * x
1261      fn4 = functools.partial(fn3, x=4)
1262      fn5 = functools.partial(fn3, 5)
1263
1264      return gen_functional_ops.case(val, [], [dtypes.float32],
1265                                     [function.defun(f).get_concrete_function()
1266                                      for f in (fn1, fn2, fn3, fn4, fn5)])
1267
1268    ones = array_ops.ones([10])
1269    self.assertAllEqual([ones], test_function(0))
1270    self.assertAllEqual([ones * 2], test_function(1))
1271    self.assertAllEqual([ones * 3], test_function(2))
1272    self.assertAllEqual([ones * 4], test_function(3))
1273    self.assertAllEqual([ones * 5], test_function(4))
1274    self.assertAllEqual([ones * 5], test_function(22))  # default branch
1275
1276  @test_util.enable_control_flow_v2
1277  def testVariableInLoopInFunction(self):
1278
1279    @function.defun
1280    def test_function():
1281
1282      def loop_test(_):
1283        return False
1284
1285      def loop_body(_):
1286        return variable_scope.get_variable('a', shape=())
1287
1288      return control_flow_ops.while_loop(loop_test, loop_body, [0.0])
1289
1290    self.assertEqual(test_function().shape, [])
1291
1292  def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
1293    with context.graph_mode():
1294      v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
1295
1296      def f():
1297        x = constant_op.constant([[1, 2], [3, 4]])
1298        out = math_ops.matmul(v, x)
1299        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
1300        # We do not return v directly since the tensor conversion function of
1301        # ResourceVariable returns the read value and not the resource itself.
1302        return v._handle
1303
1304      compiled = def_function.function(f)
1305      var_handle = compiled()
1306      self.assertEqual(var_handle.dtype, dtypes.resource)
1307      self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
1308      var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
1309      self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
1310
1311  def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
1312    with context.graph_mode():
1313      v = variables.Variable([[1, 2], [3, 4]])
1314
1315      def f():
1316        x = constant_op.constant([[1, 2], [3, 4]])
1317        out = math_ops.matmul(v, x)
1318        self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
1319
1320      # Check that shape inference works while creating the defun
1321      compiled = def_function.function(f)
1322      compiled()
1323
1324  def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
1325    with context.graph_mode():
1326      tensor_list = list_ops.empty_tensor_list(
1327          element_dtype=dtypes.float32,
1328          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1329      tensor_list = list_ops.tensor_list_push_back(tensor_list,
1330                                                   constant_op.constant(1.0))
1331      tensor_list = list_ops.tensor_list_push_back(tensor_list,
1332                                                   constant_op.constant(2.0))
1333
1334      def f():
1335        tl, value = list_ops.tensor_list_pop_back(
1336            tensor_list, element_dtype=dtypes.float32)
1337        self.assertEqual(value.shape, tensor_shape.TensorShape([]))
1338        return tl
1339
1340      compiled = def_function.function(f)
1341      output_tensor_list = compiled()
1342      _, value = list_ops.tensor_list_pop_back(
1343          output_tensor_list, element_dtype=dtypes.float32)
1344      self.assertEqual(value.shape, tensor_shape.TensorShape([]))
1345
1346  @test_util.run_in_graph_and_eager_modes
1347  def testDefunForcesResourceVariables(self):
1348
1349    def variable_creator():
1350      self.v = variables.Variable(0.0)
1351      return self.v.read_value()
1352
1353    self.v = None
1354    defined = function.defun(variable_creator)
1355    defined()  # Create the variable.
1356    self.assertIsInstance(
1357        self.v, resource_variable_ops.ResourceVariable)
1358
1359  def testRunMetadata(self):
1360
1361    @def_function.function
1362    def f(x):
1363      return x * x
1364
1365    with ops.device('cpu:0'):
1366      context.enable_run_metadata()
1367      f(constant_op.constant(1.0))
1368    run_metadata = context.export_run_metadata()
1369    context.disable_run_metadata()
1370    self.assertLen(run_metadata.partition_graphs, 1)
1371
1372  def testGraphModeCaptureVariable(self):
1373    with context.graph_mode(), self.cached_session():
1374
1375      class HasAVar(object):
1376
1377        def __init__(self):
1378          self.v = resource_variable_ops.ResourceVariable(1.0)
1379
1380        def call(self):
1381          return self.v * 2
1382
1383      o = HasAVar()
1384      self.evaluate(variables.global_variables_initializer())
1385      call = def_function.function(o.call)
1386      op = call()
1387      self.assertAllEqual(self.evaluate(op), 2.0)
1388
1389  def testGraphModeManyFunctions(self):
1390    with ops.Graph().as_default(), self.cached_session():
1391
1392      @def_function.function
1393      def f(x):
1394        return x * x
1395
1396      @def_function.function
1397      def g(x):
1398        return f(x) + 1
1399
1400      self.assertAllEqual(g(constant_op.constant(2.0)), 5.0)
1401
1402  def testDict(self):
1403
1404    @def_function.function
1405    def f(x):
1406      return {'name': x + 1}
1407
1408    self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0)
1409
1410  def testTensorConversionWithDefun(self):
1411
1412    @def_function.function
1413    def f(x):
1414      return math_ops.add(x, constant_op.constant(3))
1415
1416    self.assertAllEqual(5, f(constant_op.constant(2)))
1417
1418  def testTensorConversionCall(self):
1419
1420    @def_function.function
1421    def f(x):
1422      return math_ops.add(x, constant_op.constant(3))
1423
1424    @def_function.function
1425    def g(x):
1426      return f(f(x))
1427
1428    self.assertAllEqual(8, g(constant_op.constant(2)))
1429
1430  def testCallShape(self):
1431
1432    @def_function.function
1433    def f(x):
1434      return x + 1
1435
1436    @def_function.function
1437    def g(x):
1438      x = f(x)
1439      self.assertEqual(x.shape.as_list(), [])
1440      return None
1441
1442    g(constant_op.constant(1.0))
1443
1444  def testNestedDefunWithNoOutputAndTapedInput(self):
1445    three = resource_variable_ops.ResourceVariable(3.0, name='v')
1446
1447    @def_function.function
1448    def f(x):
1449      # This function intentionally takes a taped variable as input,
1450      # but does not return any values
1451      math_ops.add(x, three)
1452
1453    @def_function.function
1454    def g(x):
1455      y = math_ops.add(x, three)
1456      f(y)
1457
1458    g(three)
1459
1460  def testGatherResourceWithDefun(self):
1461    with ops.device('cpu:0'):
1462      v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
1463
1464    def sum_gather():
1465      return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
1466
1467    defined = def_function.function(sum_gather)
1468    self.assertAllEqual(sum_gather(), defined())
1469
1470  @parameterized.named_parameters([
1471      ('IndexedSlicesWithDenseShape',
1472       _example_indexed_slices_with_dense_shape,),
1473      ('IndexedSlicesWithoutDenseShape',
1474       _example_indexed_slices_without_dense_shape,),
1475      ('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths,
1476       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
1477      ('RaggedTensorRaggedRank2',
1478       ragged_tensor.RaggedTensor.from_nested_row_lengths,
1479       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
1480      ('SparseTensor', sparse_tensor.SparseTensor,
1481       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
1482  ])  # pyformat: disable
1483  def testReturnCompositeTensorWithDefun(self,
1484                                         factory_fn,
1485                                         factory_kwargs={},
1486                                         input_signature=None):
1487    input_ct = factory_fn(**factory_kwargs)
1488
1489    @def_function.function(input_signature=input_signature)
1490    def f():
1491      return input_ct
1492
1493    output_ct = f()
1494    self.assertIsInstance(output_ct, type(input_ct))
1495    nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
1496
1497    input_flat = nest.flatten(input_ct, expand_composites=True)
1498    output_flat = nest.flatten(output_ct, expand_composites=True)
1499    for (input_component, output_component) in zip(input_flat, output_flat):
1500      self.assertAllEqual(input_component, output_component)
1501
1502  @parameterized.named_parameters([
1503      ('IndexedSlicesWithDenseShape',
1504       _example_indexed_slices_with_dense_shape,),
1505      ('IndexedSlicesWithoutDenseShape',
1506       _example_indexed_slices_without_dense_shape,),
1507      ('RaggedTensorRaggedRank1',
1508       ragged_tensor.RaggedTensor.from_row_lengths,
1509       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
1510      ('RaggedTensorRaggedRank2',
1511       ragged_tensor.RaggedTensor.from_nested_row_lengths,
1512       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
1513      ('SparseTensor',
1514       sparse_tensor.SparseTensor,
1515       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
1516      ('RaggedTensorRaggedRank1WithSignature',
1517       ragged_tensor.RaggedTensor.from_row_lengths,
1518       {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
1519       [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]),
1520      ('RaggedTensorRaggedRank2WithSignature',
1521       ragged_tensor.RaggedTensor.from_nested_row_lengths,
1522       {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
1523       [ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]),
1524      ('SparseTensorWithSignature',
1525       sparse_tensor.SparseTensor,
1526       {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
1527       [sparse_tensor.SparseTensorSpec([None], dtypes.int32)]),
1528  ])  # pyformat: disable
1529  def testCompositeAsArgumentTensorWithDefun(self,
1530                                             factory_fn,
1531                                             factory_kwargs={},
1532                                             input_signature=None):
1533    input_ct = factory_fn(**factory_kwargs)
1534
1535    @def_function.function(input_signature=input_signature)
1536    def f(x):
1537      return x
1538
1539    output_ct = f(input_ct)
1540    self.assertIsInstance(output_ct, type(input_ct))
1541    nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
1542
1543    input_flat = nest.flatten(input_ct, expand_composites=True)
1544    output_flat = nest.flatten(output_ct, expand_composites=True)
1545    for (input_component, output_component) in zip(input_flat, output_flat):
1546      self.assertAllEqual(input_component, output_component)
1547
1548  def testTracedCompositeDiscardsShapeInfo(self):
1549    # SparseTensorSpec intentionally excludes info about the number of elements
1550    # that are in a sparse tensor (which is recorded as st.indices.shape[0] and
1551    # st.values.shape[0]).  Similarly, RaggedTensorSpec intentionally excludes
1552    # info about the total number of values in a RaggedTensor (stored as
1553    # rt.values.shape[0]).  This test checks that the placeholders created by
1554    # tf.function() properly mask this shape info.
1555    @def_function.function
1556    def f(rt, st):
1557      self.assertEqual(st.indices.shape.as_list()[:1], [None])
1558      self.assertEqual(st.values.shape.as_list(), [None])
1559      return (rt, st)
1560
1561    rt = ragged_factory_ops.constant([[1, 2], [3]])
1562    st = sparse_tensor.SparseTensor([[0]], [0], [10])
1563    f(rt, st)
1564
1565  @test_util.run_gpu_only
1566  def testFunctionOnDevice(self):
1567    x = constant_op.constant([1.]).gpu()
1568    f = def_function.function(math_ops.add)
1569    y = f(x, x).cpu()
1570    self.assertAllEqual(y, [2.])
1571
1572  @test_util.run_gpu_only
1573  @test_util.run_in_graph_and_eager_modes
1574  def testFunctionWithResourcesOnDifferentDevices(self):
1575    with ops.device('/cpu:0'):
1576      v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
1577
1578    with ops.device('/gpu:0'):
1579      v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
1580
1581    def sum_gather():
1582      cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
1583      gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
1584      return cpu_result, gpu_result
1585
1586    defined = function.defun(sum_gather)
1587    if not context.executing_eagerly():
1588      self.evaluate(variables.global_variables_initializer())
1589    expected = self.evaluate(sum_gather())
1590    self.assertAllEqual(expected, self.evaluate(defined()))
1591
1592  @test_util.run_gpu_only
1593  @test_util.run_in_graph_and_eager_modes
1594  def testOpInFunctionWithConflictingResourceInputs(self):
1595    with ops.device('/cpu:0'):
1596      v_cpu = resource_variable_ops.ResourceVariable(
1597          [0.0, 1.0, 2.0], name='cpu')
1598      v_also_cpu = resource_variable_ops.ResourceVariable(
1599          [0.0, 1.0, 2.0], name='also_cpu')
1600
1601    with ops.device('/gpu:0'):
1602      v_gpu = resource_variable_ops.ResourceVariable(
1603          [0.0, 1.0, 2.0], name='gpu')
1604
1605    @def_function.function
1606    def resource_apply_adam():
1607      training_ops.resource_apply_adam(
1608          v_cpu.handle,
1609          v_gpu.handle,
1610          v_also_cpu.handle,
1611          1.0,  # beta1_power
1612          1.0,  # beta2_power
1613          1.0,  # learning_rate
1614          1.0,  # beta1
1615          1.0,  # beta2
1616          1.0,  # epsilon,
1617          [1.0, 1.0, 1.0],  # grad
1618          False)  # use_locking
1619      return None
1620
1621    with self.assertRaisesRegex(
1622        errors.InvalidArgumentError,
1623        'Cannot place the graph because a reference or resource edge connects '
1624        'colocation groups with incompatible assigned devices'):
1625      if not context.executing_eagerly():
1626        self.evaluate(variables.global_variables_initializer())
1627      self.evaluate(resource_apply_adam())
1628
1629  @test_util.run_gpu_only
1630  def testFunctionHandlesInputsOnDifferentDevices(self):
1631    # The Reshape op requires the shape tensor to be placed in host memory.
1632    reshape = def_function.function(array_ops.reshape)
1633    value = constant_op.constant([1., 2.]).gpu()
1634    shape = constant_op.constant([2, 1])
1635    reshaped = reshape(value, shape).cpu()
1636    self.assertAllEqual(reshaped, [[1], [2]])
1637
1638  @test_util.run_gpu_only
1639  def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self):
1640    # The Reshape op requires the shape tensor to be placed in host memory.
1641    reshape = def_function.function(array_ops.reshape)
1642    value = constant_op.constant([1., 2.])
1643    shape = constant_op.constant([2, 1]).gpu()
1644    reshape(value, shape)  # No error is raised
1645
1646  def testNoneOutput(self):
1647
1648    @def_function.function
1649    def my_function(_):
1650      return None
1651
1652    self.assertAllEqual(my_function(1), None)
1653
1654  def testNestedFunctions(self):
1655    # TensorFlow function (which is what would be used in TensorFlow graph
1656    # construction).
1657    @tf_function.Defun(dtypes.int32, dtypes.int32)
1658    def add(a, b):
1659      return math_ops.add(a, b)
1660
1661    @def_function.function
1662    def add_one(x):
1663      return add(x, 1)
1664
1665    self.assertAllEqual(3, add_one(constant_op.constant(2)))
1666
1667  def testVariableCaptureInNestedFunctions(self):
1668    v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)
1669
1670    @def_function.function
1671    def inner_read():
1672      return v.read_value()
1673
1674    @def_function.function
1675    def outer():
1676      return inner_read()
1677
1678    self.assertEqual(1, int(outer()))
1679
1680  def testReturnCapturedEagerTensor(self):
1681    t = constant_op.constant(1)
1682
1683    @def_function.function
1684    def read():
1685      return t
1686
1687    self.assertEqual(1, int(read()))
1688
1689  def testReturnCapturedGraphTensor(self):
1690    with context.graph_mode(), self.cached_session():
1691      t = constant_op.constant(1)
1692
1693      @def_function.function
1694      def read():
1695        return t
1696
1697      self.assertEqual(1, int(self.evaluate(read())))
1698
1699  def testSequenceInputs(self):
1700    clip_by_global_norm = def_function.function(clip_ops.clip_by_global_norm)
1701    t_list = [constant_op.constant(1.0), constant_op.constant(2.0)]
1702    clipped_list, global_norm = clip_by_global_norm(t_list,
1703                                                    constant_op.constant(.2))
1704    for t in clipped_list:
1705      self.assertIsInstance(t, ops.Tensor)
1706    self.assertIsInstance(global_norm, ops.Tensor)
1707
1708  def testNestedSequenceInputs(self):
1709
1710    def my_op(inputs):
1711      a, b, c = inputs
1712      e, f = b
1713      g, h = e
1714      return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c
1715
1716    my_eager_op = def_function.function(my_op)
1717    ret = my_eager_op([
1718        constant_op.constant(1), [(constant_op.constant(2),
1719                                   constant_op.constant(3)),
1720                                  constant_op.constant(4)],
1721        constant_op.constant(5)
1722    ])
1723    self.assertLen(ret, 2)
1724    self.assertAllEqual(ret[0][0], 2)
1725    self.assertAllEqual(ret[0][1][0][0], 8)
1726    self.assertAllEqual(ret[0][1][0][1], 4)
1727    self.assertIsInstance(ret[0][1][0], tuple)
1728    self.assertAllEqual(ret[0][1][1], 6)
1729    self.assertAllEqual(ret[0][2], 10)
1730    self.assertAllEqual(ret[1], 15)
1731
1732  def testVariableNamesRespectNameScopesWithDefun(self):
1733    @def_function.function
1734    def create_variable():
1735      with ops.name_scope('foo', skip_on_eager=False):
1736        v = resource_variable_ops.ResourceVariable(0.0, name='bar')
1737      self.assertEqual(v.name, 'foo/bar:0')
1738
1739    create_variable()
1740
1741  def testVariableNamesRespectNameScopesWithDefunInGraph(self):
1742    with context.graph_mode():
1743      @def_function.function
1744      def create_variable():
1745        with ops.name_scope('foo', skip_on_eager=False):
1746          v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
1747        self.assertEqual(v.name, 'foo/bar:0')
1748
1749      with ops.get_default_graph().as_default():
1750        create_variable()
1751
1752  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
1753  def testLayerInDefun(self):
1754    conv = convolutional.Conv2D(
1755        filters=1,
1756        kernel_size=2,
1757        kernel_initializer=init_ops.ones_initializer(),
1758        bias_initializer=init_ops.zeros_initializer())
1759
1760    @function.defun
1761    def model(x):
1762      return conv(x)
1763
1764    x = array_ops.ones([1, 2, 2, 1])
1765    y = model(x)
1766
1767    if not context.executing_eagerly():
1768      self.evaluate(variables.global_variables_initializer())
1769
1770    self.assertAllClose([[[[4.0]]]], self.evaluate(y))
1771
1772  # Variable lifting is somewhat different between defun/tf.function, so testing
1773  # device placement on both makes sense.
1774  @parameterized.named_parameters(
1775      dict(testcase_name='Defun',
1776           function_decorator=function.defun),
1777      dict(testcase_name='DefFunction',
1778           function_decorator=def_function.function))
1779  @test_util.run_in_graph_and_eager_modes
1780  def testVariablesPlacedOnOutsideDevice(self, function_decorator):
1781
1782    class _Obj(object):
1783
1784      def __init__(self):
1785        self.v = None
1786
1787      @function_decorator
1788      def f(self):
1789        if self.v is None:
1790          self.v = variables.Variable(1.)
1791        return self.v + 1.
1792
1793    has_device = _Obj()
1794    with ops.device('cpu:0'):
1795      has_device.f()
1796    self.assertIn('CPU', has_device.v.device)
1797
1798  @test_util.run_in_graph_and_eager_modes
1799  def testMultipleDeviceCheck(self):
1800
1801    def f():
1802      with ops.device('cpu'):
1803        return test_ops.device_placement_op()
1804
1805    func = function.defun(f)
1806    with ops.device('cpu:0'):
1807      output = self.evaluate(func())
1808      self.assertIn(compat.as_bytes('CPU:0'), output)
1809
1810  @test_util.run_in_graph_and_eager_modes
1811  def testDeviceAnnotationsRespected(self):
1812
1813    def multi_device_fn():
1814      with ops.device('/cpu:0'):
1815        s0 = test_ops.device_placement_op()
1816      with ops.device('/cpu:1'):
1817        s1 = test_ops.device_placement_op()
1818      with ops.device('/cpu:2'):
1819        s2 = test_ops.device_placement_op()
1820      s3 = test_ops.device_placement_op()
1821      return s0, s1, s2, s3
1822
1823    defined = function.defun(multi_device_fn)
1824    outputs = self.evaluate(defined())
1825    self.assertLen(total_function_cache(defined), 1)
1826    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1827    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1828    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1829
1830    with ops.device('/cpu:3'):
1831      outputs = self.evaluate(defined())
1832    # All function definitions are agnostic to call site devices.
1833    self.assertLen(total_function_cache(defined), 1)
1834    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1835    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1836    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1837    self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
1838
1839    with ops.device('/cpu:0'):
1840      outputs = self.evaluate(defined())
1841    self.assertLen(total_function_cache(defined), 1)
1842    self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
1843    self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
1844    self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
1845    self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
1846
1847  @test_util.run_in_graph_and_eager_modes
1848  def testCallingGraphFunctionOnDifferentDevice(self):
1849
1850    def func():
1851      return constant_op.constant(0)
1852
1853    defined = def_function.function(func)
1854    with ops.device('cpu:0'):
1855      cpu_graph_function = defined.get_concrete_function()
1856
1857    with ops.device('cpu:0'):
1858      self.assertEqual(
1859          self.evaluate(cpu_graph_function()), self.evaluate(func()))
1860
1861    with ops.device('cpu:1'):
1862      self.assertEqual(0., self.evaluate(cpu_graph_function()))
1863
1864    with ops.device(None):
1865      self.assertEqual(0., self.evaluate(cpu_graph_function()))
1866
1867    default_graph_function = defined.get_concrete_function()
1868    self.assertEqual(
1869        self.evaluate(default_graph_function()), self.evaluate(func()))
1870
1871    with ops.device('cpu:1'):
1872      self.assertEqual(0., self.evaluate(default_graph_function()))
1873
1874  @test_util.run_gpu_only
1875  @test_util.run_in_graph_and_eager_modes
1876  def testColocateWithRespected(self):
1877    # TODO(b/113291792): Use multiple CPUs instead of a GPU.
1878    with ops.device('cpu:0'):
1879      x = array_ops.identity(1.0)
1880
1881    with ops.device('gpu:0'):
1882      y = array_ops.identity(1.0)
1883
1884    @def_function.function
1885    def foo():
1886      return test_ops.device_placement_op()
1887
1888    with ops.colocate_with(x):
1889      self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
1890
1891    with ops.colocate_with(y):
1892      self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
1893
1894  def testVariablesAreTracked(self):
1895    v = resource_variable_ops.ResourceVariable(1.0)
1896
1897    def foo(x):
1898      return v * x
1899
1900    defined = def_function.function(foo)
1901
1902    x = constant_op.constant([1.0])
1903    self.assertEqual(1., self.evaluate(defined(x)))
1904    v.assign(2.)
1905
1906    x = constant_op.constant([1.0, 2.0])
1907    self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
1908
1909  def testCacheObjectHashCollisions(self):
1910
1911    class Foo(object):
1912
1913      def __hash__(self):
1914        return 42
1915
1916    def func(foo):
1917      del foo
1918      return
1919
1920    defined = function.defun(func)
1921    defined(Foo())
1922    self.assertLen(total_function_cache(defined), 1)
1923
1924    defined(Foo())
1925    self.assertLen(total_function_cache(defined), 2)
1926
1927  def testCacheTensorDtypeCollision(self):
1928
1929    def func(t):
1930      return t + t
1931
1932    defined = function.defun(func)
1933    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1934    defined(t)
1935    self.assertLen(total_function_cache(defined), 1)
1936
1937    t = constant_op.constant([[1.0]], dtype=dtypes.complex128)
1938    defined(t)
1939    self.assertLen(total_function_cache(defined), 2)
1940
1941  def testCacheTensorShapeCollision(self):
1942
1943    def func(t):
1944      return t + t
1945
1946    defined = function.defun(func)
1947    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1948    defined(t)
1949    self.assertLen(total_function_cache(defined), 1)
1950
1951    t = constant_op.constant([1.0], dtype=dtypes.complex64)
1952    defined(t)
1953    self.assertLen(total_function_cache(defined), 2)
1954
1955  def testCacheTensorShapeDtypeCollision(self):
1956
1957    def func(t):
1958      return t + t
1959
1960    defined = function.defun(func)
1961    t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
1962    defined(t)
1963    self.assertLen(total_function_cache(defined), 1)
1964
1965    t = constant_op.constant([1.0], dtype=dtypes.complex128)
1966    defined(t)
1967    self.assertLen(total_function_cache(defined), 2)
1968
1969  def testCacheTensorUnknownShapesCollisionRelaxedShapes(self):
1970
1971    def func(t):
1972      return t + t
1973
1974    with context.graph_mode(), self.cached_session():
1975      defined = function.defun(func, experimental_relax_shapes=True)
1976
1977      p = array_ops.placeholder(dtype=dtypes.float32, shape=[])
1978      defined(p)
1979      self.assertLen(total_function_cache(defined), 1)
1980
1981      p = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
1982      defined(p)
1983      self.assertLen(total_function_cache(defined), 2)
1984
1985      p = array_ops.placeholder(dtype=dtypes.float32, shape=[2])
1986      defined(p)
1987      # Gradual shape relaxation is performed; and the common shape between
1988      # [1] and [2] is one containing unknown dimensions.
1989      self.assertLen(total_function_cache(defined), 2)
1990
1991      # pylint: disable=protected-access
1992      self.assertLen(defined._function_cache.arg_relaxed_specs, 1)
1993      relaxed_specs = (
1994          list(defined._function_cache.arg_relaxed_specs.values())[0])
1995      self.assertLen(relaxed_specs, 1)
1996      relaxed_shape = relaxed_specs[0].shape
1997      # pylint: enable=protected-access
1998      self.assertEqual(relaxed_shape.rank, 1)
1999      self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None)
2000
2001      t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32)
2002      defined(t)
2003      # Shape (3,) matches the relaxed shape TensorShape([None])
2004      self.assertLen(total_function_cache(defined), 2)
2005
2006  def testPythonFunctionWithDefaultArgs(self):
2007
2008    def func(foo, bar=1, baz=2):
2009      del foo
2010      del bar
2011      del baz
2012      return
2013
2014    defined = function.defun(func)
2015    defined(0, baz=20)
2016
2017    def cache_keys():
2018      """Sanitizes cache keys of non-input metadata."""
2019      return tuple(key[0] for key in total_function_cache(defined))
2020
2021    # `True` corresponds to the fact that we're executing eagerly
2022    self.assertIn(('URRRu', (0, 1, 20)), cache_keys())
2023
2024    defined(1)  # bar=1, baz=2
2025    self.assertIn(('URRRu', (1, 1, 2)), cache_keys())
2026
2027    # This matches the previous call.
2028    defined(foo=1)
2029    self.assertLen(total_function_cache(defined), 2)
2030
2031    defined(1, 2, 3)
2032    self.assertLen(total_function_cache(defined), 3)
2033    self.assertIn(('URRRu', (1, 2, 3)), cache_keys())
2034
2035    # This matches the previous call.
2036    defined(1, bar=2, baz=3)
2037    self.assertLen(total_function_cache(defined), 3)
2038
2039    # This matches the previous call.
2040    defined(1, baz=3, bar=2)
2041    self.assertLen(total_function_cache(defined), 3)
2042
2043  def testFunctoolsPartialUnwrappedCorrectly(self):
2044
2045    def full_function(a, b, c=3):
2046      return a, b, c
2047
2048    partial = functools.partial(full_function, 1, c=4)
2049    a, b, c = partial(2)
2050
2051    defined = function.defun(partial)
2052    func_a, func_b, func_c = defined(2)
2053    self.assertEqual(func_a.numpy(), a)
2054    self.assertEqual(func_b.numpy(), b)
2055    self.assertEqual(func_c.numpy(), c)
2056
2057  def testInputSignatureWithMatchingInputs(self):
2058
2059    def foo(a):
2060      self.assertEqual(a.shape, (2,))
2061      return a
2062
2063    signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
2064    defined = function.defun(foo, input_signature=signature)
2065    a = array_ops.ones([2])
2066    self.assertAllEqual(a, defined(a))
2067    self.assertLen(total_function_cache(defined), 1)
2068    self.assertAllEqual(a, defined.get_concrete_function()(a))
2069    self.assertAllEqual(a, defined.get_concrete_function(a)(a))
2070    self.assertAllEqual(a, defined.get_concrete_function(
2071        tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a))
2072    self.assertLen(total_function_cache(defined), 1)
2073
2074    def bar(a):
2075      self.assertEqual(a._shape_tuple(), (2, None))
2076      return a
2077
2078    signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)]
2079    defined = function.defun(bar, input_signature=signature)
2080    a = array_ops.ones([2, 1])
2081    out = defined(a)
2082    self.assertLen(total_function_cache(defined), 1)
2083    self.assertAllEqual(out, a)
2084
2085    # Changing the second dimension shouldn't create a new function.
2086    b = array_ops.ones([2, 3])
2087    out = defined(b)
2088    self.assertLen(total_function_cache(defined), 1)
2089    self.assertAllEqual(out, b)
2090
2091  def testInputSignatureWithCompatibleInputs(self):
2092
2093    rank2_spec = tensor_spec.TensorSpec(shape=(None, None),
2094                                        dtype=dtypes.float32)
2095
2096    @function.defun(input_signature=[rank2_spec])
2097    def func(a):
2098      self.assertEqual([None, None], a.shape.as_list())
2099      return array_ops.shape(a)
2100
2101    self.assertAllEqual([3, 1], func([[0], [1.0], [1]]))
2102    self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]])))
2103
2104    with self.assertRaisesRegex(ValueError, 'incompatible'):
2105      func([0.0, 1.0, 2.0])  # Wrong shape.
2106
2107    with self.assertRaisesRegex(ValueError, 'incompatible'):
2108      func([['wrong dtype']])
2109
2110  def testNoKeywordOnlyArgumentsWithInputSignature(self):
2111    if sys.version_info[0] < 3:
2112      self.skipTest('keyword_only arguments only exist in Python 3.')
2113
2114    func = eval('lambda x, *, y: x')  # pylint: disable=eval-used
2115    signature = [tensor_spec.TensorSpec(None, dtypes.int32)]
2116    with self.assertRaisesRegex(
2117        ValueError, 'Cannot define a TensorFlow function from a Python '
2118        'function with keyword-only arguments when input_signature is '
2119        'provided.'):
2120      def_function.function(func, signature)
2121
2122  def testNestedInputSignatures(self):
2123
2124    def expected_foo(a, b):
2125      return [a, b]
2126
2127    @function.defun(input_signature=[
2128        [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2,
2129        tensor_spec.TensorSpec((1,), dtypes.float32),
2130    ])
2131    def foo(a, b):
2132      self.assertEqual(a[0]._shape_tuple(), (2, None))
2133      self.assertEqual(a[1]._shape_tuple(), (2, None))
2134      self.assertEqual(b._shape_tuple(), (1,))
2135      return [a, b]
2136
2137    a = array_ops.ones([2, 1])
2138    b = array_ops.ones([1])
2139    expected = expected_foo([a, a], b)
2140    out = foo([a, a], b)
2141    self.assertLen(total_function_cache(foo), 1)
2142    nest.assert_same_structure(out, expected)
2143    self.assertAllEqual(out[0][0], a)
2144    self.assertAllEqual(out[0][1], a)
2145    self.assertAllEqual(out[1], b)
2146
2147    # Changing the unspecified dimensions shouldn't create a new function.
2148    a = array_ops.ones([2, 3])
2149    b = array_ops.ones([2, 5])
2150    c = array_ops.ones([1])
2151    expected = expected_foo([a, b], c)
2152    out = foo([a, b], c)
2153    self.assertLen(total_function_cache(foo), 1)
2154    nest.assert_same_structure(out, expected)
2155    self.assertAllEqual(out[0][0], a)
2156    self.assertAllEqual(out[0][1], b)
2157    self.assertAllEqual(out[1], c)
2158
2159    # Passing compatible inputs should work.
2160    a = a.numpy().tolist()
2161    b = b.numpy().tolist()
2162    c = c.numpy().tolist()
2163    out = foo([a, b], c)
2164    self.assertLen(total_function_cache(foo), 1)
2165    nest.assert_same_structure(out, expected)
2166    self.assertAllEqual(out[0][0], a)
2167    self.assertAllEqual(out[0][1], b)
2168    self.assertAllEqual(out[1], c)
2169
2170  def testNestedInputSignaturesWithDict(self):
2171    def expected_bar(a):
2172      return a
2173
2174    @function.defun(input_signature=[{
2175        'a': tensor_spec.TensorSpec((2, None), dtypes.float32),
2176        'b': tensor_spec.TensorSpec((2, None), dtypes.float32),
2177        'c': tensor_spec.TensorSpec((1,), dtypes.float32)}])
2178    def bar(a):
2179      self.assertEqual(a['a']._shape_tuple(), (2, None))
2180      self.assertEqual(a['b']._shape_tuple(), (2, None))
2181      self.assertEqual(a['c']._shape_tuple(), (1,))
2182      return a
2183
2184    a = array_ops.ones([2, 3])
2185    b = array_ops.ones([1])
2186    inputs = {'a': a, 'b': a, 'c': b}
2187    expected = expected_bar(inputs)
2188    out = bar(inputs)
2189    nest.assert_same_structure(out, expected)
2190    self.assertAllEqual(out['a'], expected['a'])
2191    self.assertAllEqual(out['b'], expected['b'])
2192    self.assertAllEqual(out['c'], expected['c'])
2193
2194    # Passing compatible inputs should work.
2195    a = a.numpy().tolist()
2196    b = b.numpy().tolist()
2197    inputs = {'a': a, 'b': a, 'c': b}
2198    out = bar(inputs)
2199    nest.assert_same_structure(out, expected)
2200    self.assertAllEqual(out['a'], expected['a'])
2201    self.assertAllEqual(out['b'], expected['b'])
2202    self.assertAllEqual(out['c'], expected['c'])
2203
2204  def testInputSignatureMustBeSequenceOfTensorSpecs(self):
2205
2206    def foo(a, b):
2207      del a
2208      del b
2209
2210    # Signatures must consist exclusively of `TensorSpec` objects.
2211    signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
2212    with self.assertRaisesRegex(TypeError, 'Invalid input_signature.*'):
2213      def_function.function(foo, input_signature=signature)
2214
2215    # Signatures must be either lists or tuples on their outermost levels.
2216    signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
2217    with self.assertRaisesRegex(
2218        TypeError, 'input_signature must be either a '
2219        'tuple or a list.*'):
2220      function.defun(foo, input_signature=signature)
2221
2222  @test_util.run_in_graph_and_eager_modes
2223  def testInputsIncompatibleWithSignatureRaisesError(self):
2224
2225    def foo(a):
2226      return a
2227
2228    signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
2229    defined = def_function.function(foo, input_signature=signature)
2230
2231    # Invalid shapes.
2232    with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
2233      defined(array_ops.ones([3]))
2234
2235    with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
2236      defined(array_ops.ones([2, 1]))
2237
2238    # Wrong number of arguments.
2239    with self.assertRaisesRegex(
2240        TypeError, r'takes 1 positional arguments \(as specified by the '
2241        r'input_signature\) but 2 were given'):
2242      defined(array_ops.ones([2]), array_ops.ones([2]))
2243    with self.assertRaisesRegex(ValueError,
2244                                'Structure of Python function inputs.*'):
2245      defined()
2246
2247    with self.assertRaisesRegex(ValueError,
2248                                'inputs incompatible with input_signature'):
2249      defined.get_concrete_function(
2250          tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32))
2251
2252  def testInputsIncompatibleWithNestedSignatureRaisesError(self):
2253
2254    def foo(a, b):
2255      return [a, b]
2256
2257    signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2,
2258                 [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2]
2259    defined = function.defun(foo, input_signature=signature)
2260    a = array_ops.ones([1])
2261
2262    with self.assertRaisesRegex(ValueError,
2263                                'Structure of Python function inputs.*'):
2264      defined([a, a, a], [a])
2265
2266    with self.assertRaisesRegex(ValueError,
2267                                'Structure of Python function inputs.*'):
2268      defined([a], [a, a, a])
2269    defined([a, a], [a, a])
2270
2271  def testUnderspecifiedInputSignature(self):
2272    @function.defun(input_signature=[
2273        tensor_spec.TensorSpec([], dtypes.float32),
2274    ])
2275    def foo(a, training=True):
2276      if training:
2277        return a
2278      else:
2279        return -1.0 * a
2280
2281    x = constant_op.constant(1.0)
2282    with self.assertRaisesRegex(
2283        TypeError, 'got keyword argument `training` '
2284        'that was not included in input_signature'):
2285      foo(x, training=True)
2286
2287    with self.assertRaisesRegex(
2288        TypeError, 'got keyword argument `training` '
2289        'that was not included in input_signature'):
2290      foo(x, training=False)
2291
2292    self.assertAllEqual(x.numpy(), foo(x).numpy())
2293
2294  def testInputSignatureWithPartialFunction(self):
2295    def full_function(a, b, c=3.0):
2296      return a, b, c
2297
2298    partial = functools.partial(full_function, 1, c=4)
2299    a, b, c = partial(2.0)
2300    signature = [tensor_spec.TensorSpec([], dtypes.float32)]
2301    defined = function.defun(partial, input_signature=signature)
2302    x = constant_op.constant(2.0)
2303    func_a, func_b, func_c = defined(x)
2304    self.assertEqual(func_a.numpy(), a)
2305    self.assertEqual(func_b.numpy(), b)
2306    self.assertEqual(func_c.numpy(), c)
2307
2308  def testInputSignatureConversionWithDefaultArg(self):
2309
2310    def foo(a, training=True):
2311      if training:
2312        return a
2313      else:
2314        return -1.0 * a
2315
2316    signature = [
2317        tensor_spec.TensorSpec([], dtypes.float32),
2318        tensor_spec.TensorSpec([], dtypes.bool),
2319    ]
2320    defined = def_function.function(foo, input_signature=signature)
2321    a = constant_op.constant(1.0)
2322    self.assertAllEqual(a.numpy(), defined(a))
2323    self.assertAllEqual(a.numpy(), defined(a, training=True))
2324    self.assertAllEqual(-a.numpy(), defined(a, training=False))
2325
2326  def testInputSignatureWithKeywordPositionalArgs(self):
2327
2328    @function.defun(input_signature=[
2329        tensor_spec.TensorSpec([], dtypes.float32),
2330        tensor_spec.TensorSpec([], dtypes.int64)
2331    ])
2332    def foo(flt, integer):
2333      return flt, integer
2334
2335    flt = constant_op.constant(1.0)
2336    integer = constant_op.constant(2, dtypes.int64)
2337
2338    out1, out2 = foo(flt, integer)
2339    self.assertLen(total_function_cache(foo), 1)
2340    self.assertEqual(out1.numpy(), 1.0)
2341    self.assertEqual(out2.numpy(), 2)
2342
2343    out1, out2 = foo(flt=flt, integer=integer)
2344    self.assertLen(total_function_cache(foo), 1)
2345    self.assertEqual(out1.numpy(), 1.0)
2346    self.assertEqual(out2.numpy(), 2)
2347
2348    out1, out2 = foo(integer=integer, flt=flt)
2349    self.assertLen(total_function_cache(foo), 1)
2350    self.assertEqual(out1.numpy(), 1.0)
2351    self.assertEqual(out2.numpy(), 2)
2352
2353    out1, out2 = foo(flt, integer=integer)
2354    self.assertLen(total_function_cache(foo), 1)
2355    self.assertEqual(out1.numpy(), 1.0)
2356    self.assertEqual(out2.numpy(), 2)
2357
2358  def testInputSignatureWithKeywordArgs(self):
2359    def foo(a, b, **kwargs):
2360      del kwargs
2361      return a, b
2362
2363    x = function.defun(
2364        foo,
2365        input_signature=[
2366            tensor_spec.TensorSpec([], dtypes.float32),
2367            tensor_spec.TensorSpec([], dtypes.int32)
2368        ]).get_concrete_function()
2369    result = x(constant_op.constant(5.0), constant_op.constant(5))
2370    self.assertAllEqual(result, [5.0, 5])
2371
2372  def testInputSignatureWithCompositeTensors(self):
2373    def f(rt):
2374      self.assertEqual(rt.values.shape.as_list(), [None])
2375      self.assertEqual(rt.row_splits.shape.as_list(), [4])
2376      return rt
2377
2378    signature = [ragged_tensor.RaggedTensorSpec(
2379        shape=[3, None], dtype=dtypes.int32)]
2380    defined = function.defun(f, input_signature=signature)
2381    rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]])
2382    out1 = defined(rt1)
2383    self.assertLen(total_function_cache(defined), 1)
2384    self.assertAllEqual(out1.values, rt1.values)
2385    self.assertAllEqual(out1.row_splits, rt1.row_splits)
2386
2387    # Changing the row lengths shouldn't create a new function.
2388    rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]])
2389    out2 = defined(rt2)
2390    self.assertLen(total_function_cache(defined), 1)
2391    self.assertAllEqual(out2.values, rt2.values)
2392    self.assertAllEqual(out2.row_splits, rt2.row_splits)
2393
2394    # Different number of rows
2395    rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]])
2396    with self.assertRaisesRegex(ValueError, 'incompatible'):
2397      defined(rt3)
2398
2399    # Different dtype
2400    rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]])
2401    with self.assertRaisesRegex(ValueError, 'Structure .* does not match'):
2402      defined(rt4)
2403
2404    # Different rank
2405    rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]])
2406    with self.assertRaisesRegex(ValueError, 'does not match'):
2407      defined(rt5)
2408
2409  def testInputSignatureWithVariableArgs(self):
2410
2411    def f(v):
2412      v.assign_add(1)
2413
2414    signature = [
2415        resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
2416    ]
2417    defined = function.defun(f, input_signature=signature)
2418
2419    v1 = variables.Variable(0)
2420    v2 = variables.Variable(0)
2421
2422    defined(v1)
2423    self.assertEqual(v1.numpy(), 1)
2424    self.assertEqual(v2.numpy(), 0)
2425
2426    defined(v=v2)
2427    self.assertEqual(v1.numpy(), 1)
2428    self.assertEqual(v2.numpy(), 1)
2429
2430  def testTensorKeywordArguments(self):
2431
2432    def foo(a, b):
2433      del a
2434      return b
2435
2436    defined = function.defun(foo)
2437    a = constant_op.constant(2.0)
2438    b = constant_op.constant([1.0, 2.0])
2439    one = defined(a, b)
2440    self.assertLen(total_function_cache(defined), 1)
2441
2442    two = defined(a=a, b=b)
2443    self.assertLen(total_function_cache(defined), 1)
2444
2445    three = defined(b=b, a=a)
2446    self.assertLen(total_function_cache(defined), 1)
2447
2448    four = defined(a, b=b)
2449    self.assertLen(total_function_cache(defined), 1)
2450
2451    # The next call corresponds to a new input signature, hence
2452    # we expect another function to be defined.
2453    five = defined(b, a)
2454    self.assertLen(total_function_cache(defined), 2)
2455
2456    six = defined(a=b, b=a)
2457    self.assertLen(total_function_cache(defined), 2)
2458
2459    seven = defined(b=a, a=b)
2460    self.assertLen(total_function_cache(defined), 2)
2461
2462    self.assertAllEqual(one, [1.0, 2.0])
2463    self.assertAllEqual(two, [1.0, 2.0])
2464    self.assertAllEqual(three, [1.0, 2.0])
2465    self.assertAllEqual(four, [1.0, 2.0])
2466    self.assertAllEqual(five, 2.0)
2467    self.assertAllEqual(six, 2.0)
2468    self.assertAllEqual(seven, 2.0)
2469
2470  def testDefuningInstanceMethod(self):
2471
2472    integer = constant_op.constant(2, dtypes.int64)
2473
2474    class Foo(object):
2475
2476      def one(self, tensor):
2477        return tensor
2478
2479      @def_function.function
2480      def two(self, tensor, other=integer):
2481        return self.one(tensor), other
2482
2483    foo = Foo()
2484    t = constant_op.constant(1.0)
2485    one, two = foo.two(t)
2486    self.assertEqual(one.numpy(), 1.0)
2487    self.assertEqual(two.numpy(), 2)
2488
2489  def testDefuningInstanceMethodWithDefaultArgument(self):
2490
2491    integer = constant_op.constant(2, dtypes.int64)
2492
2493    class Foo(object):
2494
2495      @def_function.function
2496      def func(self, other=integer):
2497        return other
2498
2499    foo = Foo()
2500    self.assertEqual(foo.func().numpy(), int(integer))
2501
2502  def testPythonCallWithSideEffects(self):
2503    state = []
2504
2505    @def_function.function
2506    def side_effecting_function():
2507      state.append(0)
2508
2509    side_effecting_function()
2510    self.assertAllEqual(state, [0])
2511
2512    # The second invocation should call the graph function, which shouldn't
2513    # trigger the list append.
2514    side_effecting_function()
2515    self.assertAllEqual(state, [0])
2516
2517    # Whereas calling the python function directly should create a side-effect.
2518    side_effecting_function.python_function()
2519    self.assertAllEqual(state, [0, 0])
2520
2521  def testFunctionWithNestedFunctionCallAndSideEffects(self):
2522    v1 = variables.Variable(1.0)
2523    v2 = variables.Variable(1.0)
2524
2525    @def_function.function
2526    def add_one(a):
2527      a.assign_add(1.0)
2528
2529    # Grappler will inline calls to `add_one` into the function body, we check
2530    # that all side-effects were executed.
2531    @def_function.function
2532    def side_effecting_function(a, b):
2533      add_one(a)
2534      add_one(b)
2535      return a + b
2536
2537    result = side_effecting_function(v1, v2)
2538    self.assertEqual(result.numpy(), 4.0)
2539
2540  def testFunctionWithExtraAttributes(self):
2541    @function.defun_with_attributes(attributes={'experimental_1': 'value1',
2542                                                'experimental_2': 2})
2543    def matmul(x, y):
2544      return math_ops.matmul(x, y)
2545
2546    def add(x, y):
2547      return math_ops.add(x, y)
2548    defun_add = function.defun_with_attributes(
2549        add, attributes={'experimental_3': True, 'experimental_4': 1.0})
2550
2551    with context.graph_mode(), self.cached_session():
2552      with ops.get_default_graph().as_default():
2553        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2554        sq = matmul(t, t)
2555        double = defun_add(t, t)
2556        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
2557        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
2558
2559        graph = ops.get_default_graph()
2560        # pylint: disable=protected-access
2561        self.assertLen(graph._functions, 2)
2562        functions = list(graph._functions.values())
2563        self.assertRegex(functions[0].definition.signature.name, '.*matmul.*')
2564        attrs = functions[0].definition.attr
2565        self.assertLen(attrs, 2)
2566        self.assertEqual(attrs['experimental_1'].s, b'value1')
2567        self.assertEqual(attrs['experimental_2'].i, 2)
2568
2569        self.assertRegex(functions[1].definition.signature.name, '.*add.*')
2570        attrs = functions[1].definition.attr
2571        self.assertLen(attrs, 2)
2572        self.assertEqual(attrs['experimental_3'].b, True)
2573        self.assertEqual(attrs['experimental_4'].f, 1.0)
2574        # pylint: enable=protected-access
2575
2576  def testFunctionWithInvalidAttribute(self):
2577    @function.defun_with_attributes(attributes={'experimental_1': ['value1']})
2578    def add(x, y):
2579      return math_ops.add(x, y)
2580
2581    with self.assertRaisesRegex(ValueError, '.*Unsupported attribute type.*'):
2582      with context.graph_mode(), self.cached_session():
2583        with ops.get_default_graph().as_default():
2584          t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2585          add(t, t)
2586
2587  def testRegisterFunction(self):
2588
2589    @function.defun
2590    def add(x, y):
2591      return math_ops.add(x, y)
2592
2593    def matmul(x, y):
2594      return math_ops.matmul(x, y)
2595    defun_matmul = function.defun(matmul)
2596
2597    with context.graph_mode(), self.cached_session():
2598      with ops.get_default_graph().as_default():
2599        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2600        function.register(defun_matmul, t, t)
2601        function.register(add, t, t)
2602
2603        graph = ops.get_default_graph()
2604        # pylint: disable=protected-access
2605        self.assertLen(graph._functions, 6)
2606        # two sets of functions, each of them are (inference, forward, backward)
2607        functions = list(graph._functions.values())
2608        captured_function_names = [
2609            f.definition.signature.name for f in functions
2610        ]
2611        expected_func_name_regex = [
2612            '.*inference.*matmul.*',
2613            '.*forward.*matmul.*',
2614            '.*inference.*backward.*matmul.*',
2615            '.*inference.*add.*',
2616            '.*forward.*add.*',
2617            '.*inference.*backward.*add.*',
2618        ]
2619        for i in range(len(functions)):
2620          self.assertRegex(captured_function_names[i],
2621                           expected_func_name_regex[i])
2622
2623        # Check the forward and backward function has the correct attributes.
2624        self.assertEqual(
2625            functions[1].definition.attr['backward_function_name'].s,
2626            functions[2].name)
2627        self.assertEqual(
2628            functions[2].definition.attr['forward_function_name'].s,
2629            functions[1].name)
2630
2631        self.assertEqual(
2632            functions[4].definition.attr['backward_function_name'].s,
2633            functions[5].name)
2634        self.assertEqual(
2635            functions[5].definition.attr['forward_function_name'].s,
2636            functions[4].name)
2637
2638        sq = defun_matmul(t, t)
2639        double = add(t, t)
2640        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
2641        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
2642        # Make sure the pre registered function is used, and no other function
2643        # is added.
2644        self.assertLen(graph._functions, 6)
2645        functions = list(graph._functions.values())
2646        for i in range(len(functions)):
2647          self.assertEqual(captured_function_names[i],
2648                           functions[i].definition.signature.name)
2649
2650  @parameterized.named_parameters(
2651      dict(testcase_name='Defun',
2652           function_decorator=function.defun),
2653      dict(testcase_name='DefFunction',
2654           function_decorator=def_function.function))
2655  def testRegisterConcreteFunction(self, function_decorator):
2656    @function_decorator
2657    def py_add(x, y):
2658      return math_ops.add(x, y)
2659
2660    py_add(array_ops.ones([]), array_ops.ones([]))
2661    add = py_add.get_concrete_function(
2662        tensor_spec.TensorSpec(None, dtypes.float32),
2663        tensor_spec.TensorSpec(None, dtypes.float32))
2664
2665    @function_decorator
2666    def py_composite(x, y):
2667      return x, add(x, y)
2668
2669    py_composite(array_ops.ones([]), array_ops.ones([]))
2670    composite = py_composite.get_concrete_function(
2671        tensor_spec.TensorSpec(None, dtypes.float32),
2672        tensor_spec.TensorSpec(None, dtypes.float32))
2673
2674    with context.graph_mode(), self.cached_session():
2675      with ops.get_default_graph().as_default():
2676        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2677        composite.add_to_graph()
2678        composite.add_gradient_functions_to_graph()
2679
2680        graph = ops.get_default_graph()
2681        # pylint: disable=protected-access
2682        self.assertLen(graph._functions, 6)
2683        # two sets of functions, each of them are (inference, forward, backward)
2684        functions = list(graph._functions.values())
2685        captured_function_names = [
2686            f.definition.signature.name for f in functions
2687        ]
2688        expected_func_name_regex = [
2689            '.*inference.*py_composite.*',
2690            '.*inference.*py_add.*',
2691            '.*forward.*py_composite.*',
2692            '.*forward.*py_add.*',
2693            '.*inference.*backward.*py_composite.*',
2694            '.*inference.*backward.*py_add.*',
2695        ]
2696        for expected, found in zip(
2697            expected_func_name_regex,
2698            captured_function_names):
2699          self.assertRegex(found, expected)
2700
2701        composite_t, composite_double = composite(t, t)
2702        double = add(t, t)
2703        self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double))
2704        self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double))
2705        self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t))
2706        # Make sure the pre registered function is used, and no other function
2707        # is added.
2708        self.assertLen(graph._functions, 6)
2709
2710  @parameterized.named_parameters(
2711      dict(testcase_name='Defun',
2712           function_decorator=function.defun),
2713      dict(testcase_name='DefFunction',
2714           function_decorator=def_function.function))
2715  def testEagerCaptures(self, function_decorator):
2716    with context.eager_mode():
2717      large_tensor = array_ops.ones(shape=(256,))
2718      self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD)
2719
2720      small_tensor = array_ops.ones(shape=(4,))
2721      self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD)
2722
2723      v = resource_variable_ops.ResourceVariable(0.0)
2724
2725    for captured, op_type in [(large_tensor, 'Placeholder'),
2726                              (small_tensor, 'Const'), (v, 'Placeholder')]:
2727      @function_decorator
2728      def test_fn():
2729        return captured + 1  # pylint: disable=cell-var-from-loop
2730
2731      g = test_fn.get_concrete_function().graph
2732      internal_captures = g.internal_captures
2733      self.assertLen(internal_captures, 1)
2734      self.assertEqual(internal_captures[0].op.type, op_type)
2735
2736  def testRegisterFunctionWithInputSignature(self):
2737    def matmul(x, y):
2738      return math_ops.matmul(x, y)
2739    defun_matmul = function.defun(
2740        matmul,
2741        input_signature=[
2742            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
2743            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
2744        ])
2745    with context.graph_mode(), self.cached_session():
2746      with ops.get_default_graph().as_default():
2747        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2748        function.register(defun_matmul, t, t)
2749
2750        graph = ops.get_default_graph()
2751        # pylint: disable=protected-access
2752        self.assertLen(graph._functions, 3)
2753
2754        # Test register function with cache, note inputs are ignored.
2755        function.register(defun_matmul)
2756        graph = ops.get_default_graph()
2757        self.assertLen(graph._functions, 3)
2758
2759  def testRegisterFunctionWithCache(self):
2760    def matmul(x, y):
2761      return math_ops.matmul(x, y)
2762    defun_matmul = function.defun(matmul)
2763
2764    with context.graph_mode(), self.cached_session():
2765      with ops.get_default_graph().as_default():
2766        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
2767        t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
2768        function.register(defun_matmul, t, t)
2769        function.register(defun_matmul, t2, t2)
2770
2771        graph = ops.get_default_graph()
2772        # Only one function is registered since the input param are in same type
2773        # pylint: disable=protected-access
2774        self.assertLen(graph._functions, 3)
2775
2776  def testCallingFunctionWithDifferentVariables(self):
2777
2778    @function.defun
2779    def foo(v):
2780      v.assign_add(1.0)
2781      return v.read_value()
2782
2783    v = resource_variable_ops.ResourceVariable(0.0)
2784    graph_function = foo.get_concrete_function(v)
2785    self.assertLen(graph_function.inputs, 1)
2786    self.assertEmpty(graph_function.captured_inputs)
2787
2788    self.assertEqual(float(graph_function(v)), 1.0)
2789    self.assertEqual(float(graph_function(v)), 2.0)
2790
2791    w = resource_variable_ops.ResourceVariable(0.0)
2792
2793    @function.defun
2794    def bar(v):
2795      del v
2796      return constant_op.constant(1.0)
2797
2798    graph_function = bar.get_concrete_function(v)
2799    self.assertEqual(float(graph_function(v)), 1.0)
2800    self.assertEqual(float(graph_function(w)), 1.0)
2801
2802  def testCallingFunctionWithNonTensorsFails(self):
2803
2804    @function.defun
2805    def foo(x):
2806      return x
2807
2808    graph_function = foo.get_concrete_function(constant_op.constant(1.0))
2809    with self.assertRaises((TypeError, ValueError)):
2810      graph_function('Not a Tensor.')
2811
2812  def testSwapImplementationWithGrapplerPlugin(self):
2813    # Set the min_graph_nodes to -1 since the graph in this test is too small,
2814    # and will be ignored by grappler if don't set this.
2815    rewrites = rewriter_config_pb2.RewriterConfig()
2816    rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON
2817    rewrites.min_graph_nodes = -1
2818    graph_options = config_pb2.GraphOptions(
2819        rewrite_options=rewrites, build_cost_model=1)
2820    config_proto = config_pb2.ConfigProto(graph_options=graph_options)
2821
2822    with context.graph_mode(), self.cached_session(
2823        config=config_proto, graph=ops.Graph(), use_gpu=True):
2824
2825      @function.defun_with_attributes(
2826          attributes={
2827              'api_implements': 'random_boost',
2828              'api_preferred_device': 'CPU'
2829          })
2830      def cpu_boost(x):
2831        return math_ops.add(x, 2.0)
2832
2833      @function.defun_with_attributes(
2834          attributes={
2835              'api_implements': 'random_boost',
2836              'api_preferred_device': 'GPU'
2837          })
2838      def gpu_boost(x):
2839        return math_ops.add(x, 4.0)
2840
2841      x = constant_op.constant(1.0)
2842
2843      function.register(cpu_boost, x)
2844      y = gpu_boost(x)
2845      y_value = self.evaluate(y)
2846
2847      if test.is_gpu_available():
2848        self.assertEqual(y_value, 5.0)
2849      else:
2850        # Grappler fallback to use the CPU impl even called with GPU function.
2851        self.assertEqual(y_value, 3.0)
2852
2853  @test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior '
2854                          'equivalent to implementation_selector for function')
2855  def testSwapImplementationInEager(self):
2856    if not context.executing_eagerly():
2857      self.skipTest('eager only')
2858
2859    # testSharedRendezvous sets the disable_meta_optimizer flag to True
2860    # if that subtest runs before this one, then having that set to True
2861    # will cause this subtest to fail. To avoid that scenario, explicitly
2862    # set the disable_meta_optimizer flag to false here
2863    context.context().set_optimizer_experimental_options({
2864        'min_graph_nodes': -1,
2865        'implementation_selector': True,
2866        'disable_meta_optimizer': False
2867    })
2868
2869    @function.defun_with_attributes(
2870        attributes={'api_implements': 'foo',
2871                    'api_preferred_device': 'CPU'})
2872    def on_cpu(x):
2873      return x + 2
2874
2875    @function.defun_with_attributes(
2876        attributes={'api_implements': 'foo',
2877                    'api_preferred_device': 'GPU'})
2878    def on_gpu(x):
2879      return x + 4
2880
2881    @function.defun
2882    def run_on_cpu(t):
2883      function.register(on_cpu, t)
2884      with ops.device('CPU:0'):
2885        return on_gpu(t)
2886
2887    # Expect to run the on_cpu branch, regardless whether gpu is available.
2888    self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3)
2889
2890  def testDefunFunctionSeparateGraphs(self):
2891    with context.graph_mode():
2892
2893      @function.defun
2894      def add(x):
2895        return x + 5
2896
2897      @function.defun
2898      def maybe_add(x, should_add):
2899        if should_add:
2900          return add(x)
2901        else:
2902          return x
2903
2904      with ops.Graph().as_default():
2905        x = constant_op.constant(11)
2906        maybe_add(x, True)
2907        self.assertLen(total_function_cache(maybe_add), 1)
2908        self.assertLen(total_function_cache(add), 1)
2909
2910        maybe_add(x, False)
2911        self.assertLen(total_function_cache(maybe_add), 2)
2912        self.assertLen(total_function_cache(add), 1)
2913
2914      with ops.Graph().as_default():
2915        x = constant_op.constant(11)
2916        maybe_add(x, True)
2917        self.assertLen(total_function_cache(maybe_add), 3)
2918        self.assertLen(total_function_cache(add), 2)
2919
2920  def testCacheKeyOverlappingShapes(self):
2921    @function.defun
2922    def defined(t):
2923      return t
2924
2925    defined(array_ops.zeros([12, 1]))
2926    self.assertLen(total_function_cache(defined), 1)
2927
2928    defined(array_ops.zeros([1, 21]))
2929    self.assertLen(total_function_cache(defined), 2)
2930
2931  def testCacheKeyNestedLists(self):
2932    @function.defun
2933    def defined(l):
2934      return l
2935
2936    a = constant_op.constant(1.)
2937    b = constant_op.constant(2.)
2938    c = constant_op.constant(3.)
2939    defined([[a], b, c])
2940    self.assertLen(total_function_cache(defined), 1)
2941
2942    defined([[a, b], c])
2943    self.assertLen(total_function_cache(defined), 2)
2944
2945  def testCacheKeyAttrsClass(self):
2946    if attr is None:
2947      self.skipTest('attr module is unavailable.')
2948
2949    @attr.s
2950    class TestClass(object):
2951      a = attr.ib()
2952      b = attr.ib()
2953
2954    @function.defun
2955    def defined(l):
2956      return l
2957
2958    defined(
2959        TestClass(
2960            constant_op.constant(1.),
2961            [constant_op.constant(2.),
2962             constant_op.constant(3.)]))
2963    self.assertLen(total_function_cache(defined), 1)
2964    defined(
2965        TestClass(
2966            constant_op.constant(1.),
2967            [constant_op.constant(2.),
2968             constant_op.constant(3.)]))
2969    self.assertLen(total_function_cache(defined), 1)
2970
2971    defined(
2972        TestClass([constant_op.constant(1.),
2973                   constant_op.constant(2.)], constant_op.constant(3.)))
2974    self.assertLen(total_function_cache(defined), 2)
2975
2976  def testCacheKeyVariables(self):
2977    @function.defun
2978    def defined(a, b, c):
2979      return a + b + c
2980
2981    x = resource_variable_ops.ResourceVariable(0.0)
2982    y = resource_variable_ops.ResourceVariable(0.0)
2983    z = resource_variable_ops.ResourceVariable(0.0)
2984
2985    # If tensor equality is not enabled, we always get a cache miss if the
2986    # function is called with different variables. With equality enabled we
2987    # should only get a miss if the aliasing changed.
2988    defined(x, y, z)
2989    self.assertLen(total_function_cache(defined), 1)
2990    defined(x, y, z)
2991    self.assertLen(total_function_cache(defined), 1)
2992
2993    # Re-arranging arguments causes cache miss
2994    defined(z, y, x)
2995    self.assertLen(total_function_cache(defined), 2)
2996    defined(z, y, x)
2997    self.assertLen(total_function_cache(defined), 2)
2998
2999    # Aliasing causes cache miss
3000    defined(x, x, z)
3001    self.assertLen(total_function_cache(defined), 3)
3002    defined(x, x, z)
3003    self.assertLen(total_function_cache(defined), 3)
3004
3005    # Re-arranging arguments causes cache miss
3006    defined(y, y, z)
3007    self.assertLen(total_function_cache(defined), 4)
3008    defined(y, y, z)
3009    self.assertLen(total_function_cache(defined), 4)
3010
3011    # Different alias positions causes cache miss
3012    defined(z, y, y)
3013    self.assertLen(total_function_cache(defined), 5)
3014    defined(z, y, y)
3015    self.assertLen(total_function_cache(defined), 5)
3016
3017    x_copy = copy.deepcopy(x)
3018
3019    # Deep copy causes cache miss
3020    defined(x_copy, y, z)
3021    self.assertLen(total_function_cache(defined), 6)
3022    defined(x_copy, y, z)
3023    self.assertLen(total_function_cache(defined), 6)
3024
3025  def testVariableRetracing(self):
3026    v1 = variables.Variable(1.)
3027    v2 = variables.Variable(1.)
3028    v3 = copy.deepcopy(variables.Variable(1.))
3029
3030    var_dict = {id(v1): constant_op.constant(1),
3031                id(v2): constant_op.constant(2),
3032                id(v3): constant_op.constant(3)}
3033
3034    @function.defun
3035    def lookup_tensor(v):
3036      return var_dict[id(v)]
3037
3038    self.assertEqual(1, lookup_tensor(v1).numpy())
3039    self.assertEqual(2, lookup_tensor(v2).numpy())
3040    self.assertEqual(3, lookup_tensor(v3).numpy())
3041
3042  def testDecoratedMethodInspect(self):
3043
3044    class DefunnedMiniModel(object):
3045
3046      @function.defun
3047      def call(self, inputs, training=True):
3048        pass
3049
3050    m = DefunnedMiniModel()
3051    fullargspec = tf_inspect.getfullargspec(m.call)
3052    self.assertIn('training', fullargspec.args)
3053
3054  def testFunctionModifiesInputList(self):
3055    # Tests on `list` methods that do in place modification, except `list.sort`
3056    # since it cannot even be "defunned" in the first place
3057
3058    def get_list():
3059      return [constant_op.constant(0.), constant_op.constant(1.)]
3060
3061    expected_msg = '.*() should not modify'
3062
3063    with self.assertRaisesRegex(ValueError, expected_msg):
3064
3065      @def_function.function
3066      def append(l):
3067        l.append(constant_op.constant(0.))
3068
3069      append(get_list())
3070
3071    with self.assertRaisesRegex(ValueError, expected_msg):
3072
3073      @def_function.function
3074      def extend(l):
3075        l.extend([constant_op.constant(0.)])
3076
3077      extend(get_list())
3078
3079    with self.assertRaisesRegex(ValueError, expected_msg):
3080
3081      @def_function.function
3082      def insert(l):
3083        l.insert(0, constant_op.constant(0.))
3084
3085      insert(get_list())
3086
3087    with self.assertRaisesRegex(ValueError, expected_msg):
3088
3089      @def_function.function
3090      def pop(l):
3091        l.pop()
3092
3093      pop(get_list())
3094
3095    with self.assertRaisesRegex(ValueError, expected_msg):
3096
3097      @def_function.function
3098      def reverse(l):
3099        l.reverse()
3100
3101      reverse(get_list())
3102
3103    with self.assertRaisesRegex(ValueError, expected_msg):
3104
3105      @def_function.function
3106      def remove(l):
3107        l.remove(l[0])
3108
3109      remove(get_list())
3110
3111    # `list.clear` is a method that is in Py3 but not Py2
3112    if sys.version.startswith('3'):
3113
3114      with self.assertRaisesRegex(ValueError, expected_msg):
3115
3116        @def_function.function
3117        def clear(l):
3118          l.clear()
3119
3120        clear(get_list())
3121
3122    # One last test for keyword arguments
3123    with self.assertRaisesRegex(ValueError, expected_msg):
3124
3125      @def_function.function
3126      def kwdappend(**kwargs):
3127        l = kwargs['l']
3128        l.append(constant_op.constant(0.))
3129
3130      kwdappend(l=get_list())
3131
3132  def testFunctionModifiesInputDict(self):
3133
3134    def get_dict():
3135      return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
3136
3137    expected_msg = '.* should not modify'
3138
3139    with self.assertRaisesRegex(ValueError, expected_msg):
3140
3141      @def_function.function
3142      def clear(m):
3143        m.clear()
3144
3145      clear(get_dict())
3146
3147    with self.assertRaisesRegex(ValueError, expected_msg):
3148
3149      @def_function.function
3150      def pop(m):
3151        m.pop('t1')
3152
3153      pop(get_dict())
3154
3155    with self.assertRaisesRegex(ValueError, expected_msg):
3156
3157      @def_function.function
3158      def popitem(m):
3159        m.popitem()
3160
3161      popitem(get_dict())
3162
3163    with self.assertRaisesRegex(ValueError, expected_msg):
3164
3165      @def_function.function
3166      def update(m):
3167        m.update({'t1': constant_op.constant(3.)})
3168
3169      update(get_dict())
3170
3171    with self.assertRaisesRegex(ValueError, expected_msg):
3172
3173      @def_function.function
3174      def setdefault(m):
3175        m.setdefault('t3', constant_op.constant(3.))
3176
3177      setdefault(get_dict())
3178
3179  def testFunctionModifiesInputNest(self):
3180    with self.assertRaisesRegex(ValueError, 'modify.* should not modify'):
3181
3182      @def_function.function
3183      def modify(n):
3184        n[0]['t1'].append(constant_op.constant(1.))
3185
3186      nested_input = [{
3187          't1': [constant_op.constant(0.),
3188                 constant_op.constant(1.)],
3189      },
3190                      constant_op.constant(2.)]
3191
3192      modify(nested_input)
3193
3194    with self.assertRaisesRegex(ValueError,
3195                                'modify_same_flat.* should not modify'):
3196
3197      # The flat list doesn't change whereas the true structure changes
3198      @def_function.function
3199      def modify_same_flat(n):
3200        n[0].append(n[1].pop(0))
3201
3202      nested_input = [[constant_op.constant(0.)],
3203                      [constant_op.constant(1.),
3204                       constant_op.constant(2.)]]
3205
3206      modify_same_flat(nested_input)
3207
3208  @test_util.disable_tfrt('b/173429686')
3209  def testExecutorType(self):
3210    @function.defun
3211    def add_five(x):
3212      return x + 5
3213
3214    self.assertEqual(
3215        5,
3216        add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
3217
3218    with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'):
3219      with context.function_executor_type('NON_EXISTENT_EXECUTOR'):
3220        add_five(constant_op.constant(0, dtype=dtypes.int32))
3221
3222    for executor_type in ('', 'DEFAULT', None):
3223      with context.function_executor_type(executor_type):
3224        self.assertAllEqual(
3225            5,
3226            add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
3227
3228  @test_util.assert_no_garbage_created
3229  def testReferenceCycles(self):
3230
3231    fn = function.defun(lambda x: 2. * x)
3232
3233    fn(constant_op.constant(4.0))
3234    weak_fn = weakref.ref(fn)
3235    del fn
3236    # Tests that the weak reference we made to the function is now dead, which
3237    # means the object has been deleted. This should be true as long as the
3238    # function itself is not involved in a reference cycle.
3239    self.assertIs(None, weak_fn())
3240
3241  def testFunctionStackInErrorMessage(self):
3242    if context.executing_eagerly():
3243      # TODO(b/122736651): Remove this skipTest once fixed.
3244      self.skipTest('Error interpolation is not working when function is '
3245                    'invoked without PartitionedCallOp.')
3246
3247    @def_function.function()
3248    def fn3(x):
3249      return x + 2
3250
3251    @def_function.function()
3252    def fn2(x):
3253      check_ops.assert_equal(fn3(x), 3)
3254      return 2
3255
3256    @def_function.function()
3257    def fn(x):
3258      return fn2(x)
3259
3260    with self.assertRaises(errors.InvalidArgumentError) as cm:
3261      fn(2)
3262    e = cm.exception
3263    self.assertIn('fn -> fn2', e.message)
3264    self.assertIn('node assert_equal/Assert/Assert (defined at', e.message)
3265    self.assertNotIn('fn3', e.message)
3266
3267  @test_util.run_gpu_only
3268  def testFunctionIsNotPinned(self):
3269    """Tests that functions aren't pinned to the CPU by the eager runtime."""
3270    seed1, seed2 = 79, 25
3271    shape = constant_op.constant([4, 7])
3272    dtype = dtypes.float32
3273
3274    @def_function.function
3275    def func():
3276      with ops.device('GPU:0'):
3277        return gen_random_ops.random_standard_normal(
3278            shape, dtype=dtype, seed=seed1, seed2=seed2)
3279
3280    with ops.device('GPU:0'):
3281      x = func()
3282      self.assertRegex(x.device, 'GPU')
3283
3284  @test_util.run_in_graph_and_eager_modes
3285  def testShapeCaching(self):
3286
3287    @function.defun
3288    def func(x):
3289      return array_ops.shape(x)
3290
3291    @function.defun(
3292        input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)])
3293    def calls_func(x):
3294      return func(x)
3295
3296    self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
3297    self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
3298    self.assertAllEqual(
3299        [3, 3],
3300        self.evaluate(calls_func(array_ops.zeros([3, 3]))))
3301
3302  def testLimitedRetracing(self):
3303    trace_count = [0]
3304    @function.defun
3305    def func(x):
3306      trace_count[0] += 1
3307      return x
3308
3309    for _ in range(50):
3310      func(constant_op.constant(3.))
3311      func(constant_op.constant(4.))
3312      func(constant_op.constant([[1., 2.]]))
3313      func(constant_op.constant([[]]))
3314      func(constant_op.constant([[3., 4.], [5., 6.]]))
3315      func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]]))
3316    # Tracing more than twice per input doesn't make sense.
3317    self.assertLess(trace_count[0], 13)
3318
3319  def testLimitedRetracingWithCompositeTensors(self):
3320    trace_count = [0]
3321
3322    @def_function.function
3323    def f(x):
3324      trace_count[0] += 1
3325      return x
3326
3327    for i in range(10):
3328      f(ragged_factory_ops.constant([[1, 2], [i]]))
3329      f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]]))
3330      f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]))
3331      self.assertEqual(trace_count[0], 3)
3332
3333  def test_concrete_function_shape_mismatch(self):
3334
3335    @def_function.function
3336    def f(argument_name):
3337      return argument_name + 1.
3338
3339    f_concrete = f.get_concrete_function(constant_op.constant([1.]))
3340
3341    # Calling a function from eager doesn't do any shape checking above what
3342    # kernels do while executing.
3343    self.assertAllEqual(
3344        [2., 3.],
3345        f_concrete(constant_op.constant([1., 2.])).numpy())
3346
3347    @def_function.function
3348    def g():
3349      f_concrete(constant_op.constant([1., 2.]))
3350
3351    with self.assertRaisesRegex(ValueError, 'argument_name'):
3352      g()
3353
3354  @test_util.run_in_graph_and_eager_modes
3355  def test_shape_inference_with_symbolic_shapes(self):
3356
3357    @def_function.function
3358    def _uses_symbolic_shapes(w, x, y):
3359      x = array_ops.identity(x, name='name_collision')
3360      x = array_ops.transpose(x, [1, 0, 2])
3361      x_batch = array_ops.shape(x)[0]
3362      y_batch = array_ops.shape(y)[0]
3363      y *= w
3364      n = y_batch // x_batch
3365      return array_ops.reshape(y, [n, x_batch, -1])
3366
3367    conc = _uses_symbolic_shapes.get_concrete_function(
3368        tensor_spec.TensorSpec(None, dtypes.float32),
3369        tensor_spec.TensorSpec(None, dtypes.float32),
3370        tensor_spec.TensorSpec(None, dtypes.float32))
3371
3372    @def_function.function
3373    def _call_concrete():
3374      c = constant_op.constant(1.)
3375      array_ops.identity(c, name='name_collision')
3376      output1 = conc(array_ops.ones([2]),
3377                     array_ops.ones([5, 4, 2]),
3378                     array_ops.ones([20, 2]))
3379      self.assertEqual([5, 4, 2], output1.shape)
3380      output2 = conc(array_ops.ones([3]),
3381                     array_ops.ones([5, 4, 3]),
3382                     array_ops.ones([40, 3]))
3383      self.assertEqual([10, 4, 3], output2.shape)
3384      return output1, output2
3385
3386    output1, output2 = _call_concrete()
3387    self.assertEqual((5, 4, 2), self.evaluate(output1).shape)
3388    self.assertEqual((10, 4, 3), self.evaluate(output2).shape)
3389
3390  def testAutoGraphContext(self):
3391
3392    @def_function.function
3393    def test_fn():
3394      self.assertEqual(
3395          ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED)
3396
3397    prev_status = ag_ctx.control_status_ctx().status
3398    test_fn()
3399    self.assertEqual(ag_ctx.control_status_ctx().status, prev_status)
3400
3401  @test_util.disable_tfrt('b/170435618')
3402  def testCancelBeforeFunctionExecution(self):
3403    if not context.executing_eagerly():
3404      self.skipTest('eager only')
3405
3406    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
3407
3408    @def_function.function
3409    def f():
3410      return q.dequeue()
3411
3412    c_mgr = cancellation.CancellationManager()
3413    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
3414
3415    c_mgr.start_cancel()
3416    with self.assertRaises(errors.CancelledError):
3417      cancelable_func()
3418
3419  @test_util.disable_tfrt('b/170435618')
3420  def testCancelBlockedFunctionExecution(self):
3421    if not context.executing_eagerly():
3422      self.skipTest('eager only')
3423
3424    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
3425
3426    @def_function.function
3427    def f():
3428      return q.dequeue()
3429
3430    c_mgr = cancellation.CancellationManager()
3431    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
3432
3433    def cancel_thread():
3434      time.sleep(0.5)
3435      c_mgr.start_cancel()
3436
3437    t = self.checkedThread(cancel_thread)
3438    t.start()
3439    with self.assertRaises(errors.CancelledError):
3440      cancelable_func()
3441    t.join()
3442
3443  @test_util.disable_tfrt('b/170435618')
3444  def testCancelAfterFunctionExecution(self):
3445    if not context.executing_eagerly():
3446      self.skipTest('eager only')
3447
3448    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
3449    q.enqueue(37)
3450
3451    @def_function.function
3452    def f():
3453      return q.dequeue()
3454
3455    c_mgr = cancellation.CancellationManager()
3456    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
3457
3458    self.assertAllEqual(37, cancelable_func().numpy())
3459
3460    # Cancellation after the function executes is a no-op.
3461    c_mgr.start_cancel()
3462
3463  def testAddFunctionCallback(self):
3464    functions = []
3465    def function_callback(f, name, graph, inputs, outputs):
3466      del name, graph, inputs, outputs
3467      functions.append(f)
3468
3469    @def_function.function
3470    def plus_one(x):
3471      return x + 1
3472
3473    try:
3474      function.add_function_callback(function_callback)
3475      x_float32 = numpy.array(3.0, dtype=numpy.float32)
3476      self.assertAllClose(plus_one(x_float32), 4.0)
3477      self.assertLen(functions, 1)
3478      # Function is already created. Executing it again should not invoke the
3479      # function callback.
3480      self.assertAllClose(plus_one(x_float32), 4.0)
3481      self.assertLen(functions, 1)
3482      # Signature change leads to a new Function being built.
3483      x_float64 = numpy.array(3.0, dtype=numpy.float64)
3484      self.assertAllClose(plus_one(x_float64), 4.0)
3485      self.assertLen(functions, 2)
3486    finally:
3487      function.clear_function_callbacks()
3488
3489  def testFunctionCallbackAddOps(self):
3490    file_name = os.path.join(self.get_temp_dir(), 'test')
3491
3492    def function_callback(f, name, graph, inputs, outputs):
3493      del f, name, inputs
3494
3495      with graph.as_default():
3496        printer = logging_ops.print_v2(
3497            'hello',
3498            output_stream='file://' + file_name
3499        )
3500        outputs[0].op._add_control_input(printer)
3501
3502    @def_function.function
3503    def plus_one(x):
3504      return x + 1
3505
3506    self.addCleanup(function.clear_function_callbacks)
3507    function.add_function_callback(function_callback)
3508    x_float32 = numpy.array(3.0, dtype=numpy.float32)
3509
3510    self.assertAllClose(plus_one(x_float32), 4.0)
3511
3512    with open(file_name, 'r') as f:
3513      self.assertEqual(f.read().strip(), 'hello')
3514
3515  def testRemoveFunctionCallback(self):
3516    functions_1 = []
3517    def function_callback_1(f, name, graph, inputs, outputs):
3518      del name, graph, inputs, outputs
3519      functions_1.append(f)
3520
3521    functions_2 = []
3522    def function_callback_2(f, name, graph, inputs, outputs):
3523      del name, graph, inputs, outputs
3524      functions_2.append(f)
3525
3526    @def_function.function
3527    def plus_one(x):
3528      return x + 1
3529
3530    try:
3531      function.add_function_callback(function_callback_1)
3532      function.add_function_callback(function_callback_2)
3533      self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float32)), 4.0)
3534      self.assertLen(functions_1, 1)
3535      self.assertLen(functions_2, 1)
3536      function.remove_function_callback(function_callback_1)
3537      # The 1st callback should not be invokved after remove_function_callback()
3538      # is called.
3539      self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float64)), 4.0)
3540      self.assertLen(functions_1, 1)
3541      self.assertLen(functions_2, 2)
3542    finally:
3543      function.clear_function_callbacks()
3544
3545  def testClearFunctionCallbacks(self):
3546    function.add_function_callback(lambda f: None)
3547    function.add_function_callback(lambda f: None)
3548    self.assertLen(function._function_callbacks, 2)
3549    function.clear_function_callbacks()
3550    self.assertEmpty(function._function_callbacks)  # pylint:disable=protected-access
3551
3552  @test_util.run_in_graph_and_eager_modes
3553  def testConcreteFunctionWithNestedTensorInputs(self):
3554
3555    @def_function.function
3556    def f(x, y):
3557      return (x['a'] + x['b'], y[0] + y[1])
3558
3559    a = constant_op.constant(1000)
3560    b = constant_op.constant(200)
3561    c = constant_op.constant(30)
3562    d = {'a': a, 'b': b}
3563    e = (c, 4)
3564
3565    # Test different argument signatures when constructing the concrete func.
3566    for cf in [
3567        f.get_concrete_function(d, e),
3568        f.get_concrete_function(d, y=e),
3569        f.get_concrete_function(y=e, x=d),
3570        f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
3571        f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
3572        f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
3573    ]:
3574      # Test different calling conventions when calling the concrete func.
3575      for output in [
3576          cf(d, e),  # structured signature
3577          cf(d, y=e),  # structured signature w/ kwarg
3578          cf(y=e, x=d),  # structured signature w/ 2 kwargs
3579          cf(a, b, c),  # flat signature
3580          cf(x=a, x_1=b, y=c)  # flat signature w/ kwargs
3581      ]:
3582        self.assertIsInstance(output, tuple)
3583        self.assertLen(output, 2)
3584        self.assertAllEqual(output[0], 1200)
3585        self.assertAllEqual(output[1], 34)
3586
3587  @test_util.run_in_graph_and_eager_modes
3588  def testConcreteFunctionWithNestedNonTensorInputs(self):
3589
3590    @def_function.function
3591    def f(x, y):
3592      return (x['a'] + x['b'], y[0] + y[1])
3593
3594    a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
3595    b = (50, 3)
3596
3597    for cf in [  # argument y is bound to non-Tensor value (50, 3).
3598        f.get_concrete_function(a, b),
3599        f.get_concrete_function(a, y=b),
3600        f.get_concrete_function(x=a, y=b)
3601    ]:
3602      for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
3603        self.assertAllEqual(output[0] + output[1], 1253)
3604
3605  @test_util.run_in_graph_and_eager_modes
3606  def testConcreteFunctionWithNonTensorStringInputs(self):
3607
3608    @def_function.function
3609    def f(x, y):
3610      return string_ops.string_join([x, y])
3611
3612    a = constant_op.constant('a')
3613    b = 'b'
3614
3615    cf = f.get_concrete_function(a, b)
3616    for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
3617      self.assertAllEqual(output, b'ab')
3618
3619  @test_util.run_in_graph_and_eager_modes
3620  def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
3621
3622    @def_function.function
3623    def f(x, y):
3624      return (x['a'] + x['b'], y[0] + y[1])
3625
3626    a = {'a': 3000, 'b': 200, 'c': 9000}
3627    b = (constant_op.constant(30), 4)
3628
3629    for cf in [  # argument x is bound to non-tensor value `a`
3630        f.get_concrete_function(a, b),
3631        f.get_concrete_function(a, y=b),
3632        f.get_concrete_function(x=a, y=b)
3633    ]:
3634      for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]:
3635        self.assertAllEqual(output[0] + output[1], 3234)
3636
3637  @test_util.run_in_graph_and_eager_modes
3638  def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
3639
3640    @def_function.function
3641    def f(x, y):
3642      return (x['a'] + x['b'], y[0] + y[1])
3643
3644    a = {'a': 5000, 'b': 500}
3645    b = (50, 5)
3646
3647    cf = f.get_concrete_function(a, b)
3648    for output in [cf(), cf(a), cf(y=b)]:
3649      self.assertAllEqual(output[0] + output[1], 5555)
3650
3651  @test_util.run_in_graph_and_eager_modes
3652  def testConcreteFunctionMethodWithVarargs(self):
3653    float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
3654
3655    class MyModel(module.Module):
3656
3657      @def_function.function(input_signature=[float32_scalar, float32_scalar])
3658      def add(self, *arg):
3659        return math_ops.add(*arg)
3660
3661    m = MyModel()
3662    cf = m.add.get_concrete_function()
3663    cf(-12.0, 3.0)
3664
3665  @test_util.run_in_graph_and_eager_modes
3666  def testConcreteFunctionStructuredSignatureKeywordOrder(self):
3667    # Check that keyword-only arguments are sorted appropriately, so that they
3668    # feed the right tensor into each input.
3669    @def_function.function
3670    def g(**kwargs):
3671      return string_ops.reduce_join(
3672          string_ops.reduce_join(
3673              ops.convert_to_tensor(sorted(kwargs.items())),
3674              axis=1,
3675              separator='='),
3676          axis=0,
3677          separator=', ')
3678
3679    s = constant_op.constant('s')
3680    g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
3681    self.assertAllEqual(
3682        g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
3683        b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
3684    self.assertAllEqual(
3685        g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
3686        b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
3687    self.assertAllEqual(
3688        g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
3689        b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
3690
3691  # pylint: disable=g-long-lambda
3692  @parameterized.named_parameters([
3693      dict(
3694          testcase_name='MissingArg',
3695          conc_args=lambda: (1, constant_op.constant(2)),
3696          call_args=lambda: (1,),
3697          error=r'func\(x, y\) missing required arguments: y'),
3698      dict(
3699          testcase_name='MissingVararg',
3700          conc_args=lambda: (1, 2, constant_op.constant(1.0)),
3701          call_args=lambda: (1, 2),
3702          error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'),
3703      dict(
3704          testcase_name='ExtraPositionalArg',
3705          conc_args=lambda: (1, 2),
3706          call_args=lambda: (1, 2, 3),
3707          error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
3708      dict(
3709          testcase_name='MissingKeywordOnlyArg',
3710          conc_args=lambda: (1, 2),
3711          conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
3712          call_args=lambda: (1, 2),
3713          error=r'func\(x, y, \*, c\) missing required arguments: c'),
3714      dict(
3715          testcase_name='ExtraKeywordArg',
3716          conc_args=lambda: (1, 2),
3717          call_args=lambda: (1, 2),
3718          call_kwargs=lambda: {'c': constant_op.constant(1.0)},
3719          error=r'func\(x, y\) got unexpected keyword arguments: c'),
3720      dict(
3721          testcase_name='ExpectedRaggedGotNest',
3722          conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
3723          call_args=lambda: ({
3724              'a': constant_op.constant([1, 2, 3])
3725          },),
3726          error=r'func\(x, y\): argument x had incorrect type\n'
3727          r'  expected: RaggedTensor\n'
3728          r"       got: {'a': (Eager)?Tensor}"),
3729      dict(
3730          testcase_name='WrongRaggedRank',
3731          conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
3732          call_args=lambda: (ragged_factory_ops.constant([[[1]]]),),
3733          error=r'func\(x, y\): argument x had incorrect type\n'),
3734      dict(
3735          testcase_name='WrongRaggedDType',
3736          conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
3737          call_args=lambda: (ragged_factory_ops.constant([[1.0]]),),
3738          error=r'func\(x, y\): argument x had incorrect type\n'),
3739      dict(
3740          testcase_name='ExpectedDictGotTensor',
3741          conc_args=lambda: ({
3742              'a': constant_op.constant(1),
3743              'b': constant_op.constant(1)
3744          },),
3745          call_args=lambda: (constant_op.constant(1),),
3746          error=r'func\(x, y\): argument x had incorrect type\n'),
3747      dict(
3748          testcase_name='ExpectedTupleGotTensor',
3749          conc_args=lambda:
3750          ((constant_op.constant(1), constant_op.constant(2)),),
3751          call_args=lambda: (constant_op.constant(1),),
3752          error=r'func\(x, y\): argument x had incorrect type\n'),
3753      dict(
3754          testcase_name='WrongDType',
3755          conc_args=lambda: (constant_op.constant(1),),
3756          call_args=lambda: (constant_op.constant(1.0),),
3757          exception=(ValueError, errors.InvalidArgumentError,
3758                     # on xla_gpu, we get InternalError instead.
3759                     errors.InternalError)),
3760      dict(
3761          testcase_name='ExpectedTensorGotInt',
3762          conc_args=lambda: (constant_op.constant(1),),
3763          call_args=lambda: (5,),
3764          error=r'func\(x, y\) expected a Tensor in x, but got int value 5'),
3765      dict(
3766          testcase_name='ExpectedIntGotDifferentInt',
3767          conc_args=lambda: (5,),
3768          call_args=lambda: (8,),
3769          error=r'ConcreteFunction func\(x, y\) was constructed with int '
3770          r'value 5 in x, but was called with int value 8'),
3771      dict(
3772          testcase_name='ExpectedIntGotTensor',
3773          conc_args=lambda: (5,),
3774          call_args=lambda: (constant_op.constant(6),),
3775          error=r'ConcreteFunction func\(x, y\) was constructed with int '
3776          'value 5 in x, but was called with (Eager)?Tensor value .*'),
3777      dict(
3778          testcase_name='TwoValuesForArgument',
3779          conc_args=lambda: (1, 2),
3780          call_args=lambda: (1, 2),
3781          call_kwargs=lambda: {'x': 3},
3782          error=r"func\(x, y\) got two values for argument 'x'"),
3783  ])
3784  # pylint: enable=g-long-lambda
3785  @test_util.run_in_graph_and_eager_modes
3786  def testConcreteFunctionStructuredSignatureError(self,
3787                                                   conc_args=(),
3788                                                   conc_kwargs=None,
3789                                                   call_args=(),
3790                                                   call_kwargs=None,
3791                                                   error='.*',
3792                                                   exception=TypeError):
3793    """Tests for errors in the structrued signature.
3794
3795    Args:
3796      conc_args: Positional arguments used for get_concrete_function.
3797      conc_kwargs: Keyword arguments used for get_concrete_function.
3798      call_args: Positional arguments used to call the function.
3799      call_kwargs: Keyword arguments used to call the function.
3800      error: Expected exception message.
3801      exception: Expected exception type.
3802    """
3803    conc_args = conc_args() if callable(conc_args) else conc_args
3804    conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
3805    call_args = call_args() if callable(call_args) else call_args
3806    call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
3807    self.assertIsInstance(conc_args, tuple)
3808    self.assertIsInstance(call_args, tuple)
3809    self.assertIsInstance(conc_kwargs, dict)
3810    self.assertIsInstance(call_kwargs, dict)
3811
3812    @def_function.function
3813    def func(x, y=5, *varargs, **kwargs):  # pylint: disable=keyword-arg-before-vararg
3814      del y, varargs, kwargs
3815      return x
3816
3817    conc = func.get_concrete_function(*conc_args, **conc_kwargs)
3818    with self.assertRaisesRegex(exception, error):
3819      self.evaluate(conc(*call_args, **call_kwargs))
3820
3821  # pylint: disable=g-long-lambda
3822  @parameterized.named_parameters([
3823      dict(
3824          testcase_name='MissingArg',
3825          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3826          call_args=lambda: (constant_op.constant(1),),
3827          error=r'func\(x, y\) missing required arguments: y'),
3828      dict(
3829          testcase_name='TwoValuesForArg',
3830          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3831          call_args=lambda: (constant_op.constant(1),),
3832          call_kwargs=lambda: {
3833              'x': constant_op.constant(1),
3834              'y': constant_op.constant(1)
3835          },
3836          error=r"func\(x, y\) got two values for argument 'x'"),
3837      dict(
3838          testcase_name='ExtraPositionalArg',
3839          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3840          call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
3841                             constant_op.constant(3)),
3842          error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
3843      dict(
3844          testcase_name='UnexpectedKeywordArg',
3845          conc_args=lambda: (constant_op.constant(1),),
3846          call_args=lambda: (constant_op.constant(1),),
3847          call_kwargs=lambda: {'c': constant_op.constant(1)},
3848          error=r'func\(x\) got unexpected keyword arguments: c'),
3849      dict(
3850          testcase_name='MissingVararg',
3851          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
3852                             constant_op.constant(3)),
3853          call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3854          error=r'func\(x, y, varargs_0\) missing required '
3855          r'arguments: varargs_0'),
3856      dict(
3857          testcase_name='MissingKeywordArg',
3858          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3859          conc_kwargs=lambda: {'c': constant_op.constant(1)},
3860          call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3861          error=r'func\(x, y, c\) missing required arguments: c'),
3862      dict(
3863          testcase_name='ExpectedTensorGotInt',
3864          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3865          call_args=lambda: (5, constant_op.constant(2)),
3866          error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
3867          r'a Tensor; got int \(5\)'),
3868      dict(
3869          testcase_name='WrongDType',
3870          conc_args=lambda: (constant_op.constant(1),),
3871          call_args=lambda: (constant_op.constant(1.0),),
3872          exception=(ValueError, errors.InvalidArgumentError,
3873                     # on xla_gpu, we get InternalError instead.
3874                     errors.InternalError)),
3875      dict(
3876          testcase_name='MissingKeywordArgNestPiece',
3877          conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3878          conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
3879          call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
3880          call_kwargs=lambda: {'c': constant_op.constant(1)},
3881          error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
3882  ])
3883  # pylint: enable=g-long-lambda
3884  @test_util.run_in_graph_and_eager_modes
3885  def testConcreteFunctionFlatSignatureError(self,
3886                                             conc_args=(),
3887                                             conc_kwargs=None,
3888                                             call_args=(),
3889                                             call_kwargs=None,
3890                                             error='.*',
3891                                             exception=TypeError):
3892    """Tests for errors in the flat signature.
3893
3894    Args:
3895      conc_args: Positional arguments used for get_concrete_function.
3896      conc_kwargs: Keyword arguments used for get_concrete_function.
3897      call_args: Positional arguments used to call the function.
3898      call_kwargs: Keyword arguments used to call the function.
3899      error: Expected exception message.
3900      exception: Expected exception type.
3901    """
3902    conc_args = conc_args() if callable(conc_args) else conc_args
3903    conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
3904    call_args = call_args() if callable(call_args) else call_args
3905    call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
3906    self.assertIsInstance(conc_args, tuple)
3907    self.assertIsInstance(call_args, tuple)
3908    self.assertIsInstance(conc_kwargs, dict)
3909    self.assertIsInstance(call_kwargs, dict)
3910
3911    @def_function.function
3912    def func(x, y=5, *varargs, **kwargs):  # pylint: disable=keyword-arg-before-vararg
3913      del y, varargs, kwargs
3914      return x
3915
3916    conc = func.get_concrete_function(*conc_args, **conc_kwargs)
3917
3918    # Remove _function_spec, to disable the structured signature.
3919    conc._set_function_spec(None)  # pylint: disable=protected-access
3920
3921    with self.assertRaisesRegex(exception, error):
3922      self.evaluate(conc(*call_args, **call_kwargs))
3923
3924  @test_util.run_in_graph_and_eager_modes
3925  def testConcreteFunctionAmbiguousSignature(self):
3926    # When both the flat & structured signatures are applicable, but they
3927    # give different results, we use the structured signature.  Note: we expect
3928    # this to be extremely rare.
3929    @def_function.function
3930    def f(x, y):
3931      return x * 10 + y
3932
3933    conc = f.get_concrete_function(
3934        x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'),
3935        y=tensor_spec.TensorSpec(None, dtypes.int32, name='x'))
3936
3937    result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
3938    self.assertAllEqual(result, 56)
3939
3940  def testPrettyPrintedSignature(self):
3941
3942    @def_function.function
3943    def func(x, kangaroo=None, octopus=7):
3944      del octopus, kangaroo
3945      return x
3946
3947    scalar = constant_op.constant(5)
3948    vector = constant_op.constant([10, 10, 20])
3949    ragged = ragged_factory_ops.constant([[10, 20], [40]])
3950
3951    c1 = func.get_concrete_function(scalar, vector)
3952    c1_summary = r'func\(x, kangaroo, octopus=7\)'
3953    c1_details = (r'  Args:\n'
3954                  r'    x: int32 Tensor, shape=\(\)\n'
3955                  r'    kangaroo: int32 Tensor, shape=\(3,\)\n'
3956                  r'  Returns:\n'
3957                  r'    int32 Tensor, shape=\(\)')
3958    self.assertRegex(c1.pretty_printed_signature(verbose=False), c1_summary)
3959    self.assertRegex(
3960        c1.pretty_printed_signature(verbose=True),
3961        c1_summary + '\n' + c1_details)
3962    self.assertRegex(
3963        repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>')
3964    self.assertRegex(
3965        str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details))
3966
3967    c2 = func.get_concrete_function(scalar, ragged, 3)
3968    c2_summary = r'func\(x, kangaroo, octopus=3\)'
3969    c2_details = (r'  Args:\n'
3970                  r'    x: int32 Tensor, shape=\(\)\n'
3971                  r'    kangaroo: RaggedTensorSpec\(.*\)\n'
3972                  r'  Returns:\n'
3973                  r'    int32 Tensor, shape=\(\)')
3974    self.assertRegex(c2.pretty_printed_signature(),
3975                     c2_summary + '\n' + c2_details)
3976
3977    c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]})
3978    c3_summary = r'func\(x, kangaroo=None, octopus=7\)'
3979    c3_details = (r'  Args:\n'
3980                  r"    x: {'a': <1>, 'b': \[<2>, <3>\]}\n"
3981                  r'      <1>: int32 Tensor, shape=\(\)\n'
3982                  r'      <2>: RaggedTensorSpec\(.*\)\n'
3983                  r'      <3>: RaggedTensorSpec\(.*\)\n'
3984                  r'  Returns:\n'
3985                  r"    {'a': <1>, 'b': \[<2>, <3>\]}\n"
3986                  r'      <1>: int32 Tensor, shape=\(\)\n'
3987                  r'      <2>: RaggedTensorSpec\(.*\)\n'
3988                  r'      <3>: RaggedTensorSpec\(.*\)')
3989
3990    # python 3.5 does not gurantee deterministic iteration of dict contents
3991    # which can lead mismatch on pretty_printed_signature output for "Args"
3992    if sys.version_info >= (3, 6):
3993      self.assertRegex(c3.pretty_printed_signature(),
3994                       c3_summary + '\n' + c3_details)
3995
3996    # pylint: disable=keyword-arg-before-vararg
3997    @def_function.function
3998    def func2(x, y=3, *args, **kwargs):
3999      return (x, y, args, kwargs)
4000
4001    c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar)
4002    c4_summary = 'func2(x, y=4, <arg3>=5, *, a)'
4003    self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary)
4004
4005    c5 = func2.get_concrete_function(8, vector)
4006    c5_summary = 'func2(x=8, y)'
4007    self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary)
4008
4009  def testPrettyPrintedExplicitSignatureWithKeywordArg(self):  # b/159639913
4010
4011    @def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
4012    def fn(a, b=1):
4013      return a + b
4014
4015    concrete_fn = fn.get_concrete_function()
4016    self.assertEqual(concrete_fn.pretty_printed_signature(False), 'fn(a)')
4017    self.assertEqual(
4018        concrete_fn.pretty_printed_signature(True), 'fn(a)\n'
4019        '  Args:\n'
4020        '    a: float32 Tensor, shape=<unknown>\n'
4021        '  Returns:\n'
4022        '    float32 Tensor, shape=<unknown>')
4023
4024  @test_util.run_in_graph_and_eager_modes
4025  def testIndexedSlicesAsGradientsForConcreteFunctions(self):
4026
4027    @def_function.function
4028    def summing_rnn(inputs):
4029      return math_ops.reduce_sum(inputs, axis=1)
4030
4031    @def_function.function
4032    def gradients(inputs):
4033      with backprop.GradientTape() as tape:
4034        tape.watch(inputs)
4035        hidden = summing_rnn(inputs)
4036        hidden = array_ops.gather(hidden, constant_op.constant([0]))
4037        loss = math_ops.reduce_mean(hidden)
4038      return tape.gradient(loss, inputs)
4039
4040    gradients(constant_op.constant([[[1.0], [2.0]]]))  # No error is raised
4041
4042  def testFollowTypeHintsTraceBasic(self):
4043    trace_count = [0]
4044
4045    def func(x: ops.Tensor):
4046      trace_count[0] += 1
4047      return x
4048
4049    enabled = def_function.function(func, experimental_follow_type_hints=True)
4050    disabled = def_function.function(func, experimental_follow_type_hints=False)
4051
4052    enabled(1)  # Initial call gets traced
4053    enabled(2)
4054    enabled(3)
4055    self.assertEqual(trace_count[0], 1)
4056
4057    trace_count = [0]
4058    disabled(1)
4059    disabled(2)  # Retrace
4060    disabled(3)  # Retrace
4061    self.assertEqual(trace_count[0], 3)
4062
4063  def testFollowTypeHintsTraceWithArgs(self):
4064    trace_count = [0]
4065
4066    def func(*args: ops.Tensor):
4067      trace_count[0] += 1
4068      return args
4069
4070    enabled = def_function.function(func, experimental_follow_type_hints=True)
4071    disabled = def_function.function(func, experimental_follow_type_hints=False)
4072
4073    args = (
4074        'abc',
4075        'def',
4076    ) * 20
4077    args2 = (
4078        'def',
4079        'abc',
4080    ) * 20
4081
4082    enabled(args)
4083    enabled(args2)
4084    self.assertEqual(trace_count[0], 1)
4085
4086    trace_count = [0]
4087    disabled(args)
4088    disabled(args2)  # Retrace
4089    self.assertEqual(trace_count[0], 2)
4090
4091  def testFollowTypeHintsTraceWithKwargs(self):
4092    trace_count = [0]
4093
4094    def func(t: ops.Tensor, **kwargs: ops.Tensor):
4095      del kwargs
4096      trace_count[0] += 1
4097      return t
4098
4099    enabled = def_function.function(func, experimental_follow_type_hints=True)
4100    disabled = def_function.function(func, experimental_follow_type_hints=False)
4101
4102    enabled(1, x=1, y=1.0, z='one')
4103    enabled(2, x=2, y=2.0, z='two')
4104    self.assertEqual(trace_count[0], 1)
4105
4106    trace_count = [0]
4107    disabled(1, x=1, y=1.0, z='one')
4108    disabled(2, x=2, y=2.0, z='two')  # Retrace
4109    self.assertEqual(trace_count[0], 2)
4110
4111  def testFollowTypeHintsTraceWithMultipleInputTypes(self):
4112    trace_count = [0]
4113
4114    def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor):
4115      del args, kwargs
4116      trace_count[0] += 1
4117      return t
4118
4119    enabled = def_function.function(func, experimental_follow_type_hints=True)
4120    disabled = def_function.function(func, experimental_follow_type_hints=False)
4121
4122    enabled(1, constant_op.constant(1), 'str', x=4.0)
4123    enabled(2, constant_op.constant(2), 'str2', x=5.0)
4124    self.assertEqual(trace_count[0], 1)
4125
4126    trace_count = [0]
4127    disabled(1, constant_op.constant(1), 'str', x=4.0)
4128    disabled(2, constant_op.constant(2), 'str2', x=5.0)  # Retrace
4129    self.assertEqual(trace_count[0], 2)
4130
4131  def testFollowTypeHintsTraceWithOnlyArgNamed(self):
4132    trace_count = [0]
4133
4134    def func(t: ops.Tensor, i: int = 1, **kwargs):  # pylint: disable=bad-whitespace
4135      del i, kwargs
4136      trace_count[0] += 1
4137      return t
4138
4139    enabled = def_function.function(func, experimental_follow_type_hints=True)
4140
4141    enabled(1, 3, x=4.0, y='str')
4142    enabled(2, 4, x=4.0, y='str')  # Retrace
4143    self.assertEqual(trace_count[0], 2)
4144
4145  def testFollowTypeHintsTraceWithNotAllNamed(self):
4146    trace_count = [0]
4147
4148    def func(x, y: ops.Tensor, z: int):
4149      del y, z
4150      trace_count[0] += 1
4151      return x
4152
4153    enabled = def_function.function(func, experimental_follow_type_hints=True)
4154
4155    enabled(1, 2, 3)
4156    enabled(1, 20, 3)  # No retrace - change in ops.Tensor typed arg
4157    enabled(2, 2, 3)  # Retrace - change in untyped arg
4158    enabled(2, 2, 4)  # Retrace - change in typed arg
4159    self.assertEqual(trace_count[0], 3)
4160
4161  def testFollowTypeHintsTraceWithOnlyArgsNamed(self):
4162    trace_count = [0]
4163
4164    def func(x, y, *args: ops.Tensor):
4165      del y, args
4166      trace_count[0] += 1
4167      return x
4168
4169    enabled = def_function.function(func, experimental_follow_type_hints=True)
4170
4171    enabled(1, 20, 3, 4, 5, 6)
4172    enabled(1, 20, 3, 4, 5, 60)  # No retrace - change in *args
4173    enabled(1, 30, 7, 8, 9, 10)  # Retrace - change in args
4174    self.assertEqual(trace_count[0], 2)
4175
4176  def testFollowTypeHintsTraceWithOnlyKwargsNamed(self):
4177    trace_count = [0]
4178
4179    def func(x, y, *args, **kwargs: ops.Tensor):
4180      del y, args, kwargs
4181      trace_count[0] += 1
4182      return x
4183
4184    enabled = def_function.function(func, experimental_follow_type_hints=True)
4185
4186    enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)
4187    enabled(
4188        1, 2, 3, 4, 5, 6, a=1.5, b=2.5,
4189        c=3.5)  # No retrace - change in **kwargs
4190    enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)  # Retrace - change in args
4191    enabled(
4192        1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0)  # Retrace - change in *args
4193    self.assertEqual(trace_count[0], 3)
4194
4195  def testFollowTypeHintsTraceWithArgsEquals(self):
4196    trace_count = [0]
4197
4198    def func(
4199        x: ops.Tensor = 0,  # pylint:disable=bad-whitespace
4200        y: int = 1,  # pylint:disable=bad-whitespace
4201        **kwargs: ops.Tensor):
4202      del y, kwargs
4203      trace_count[0] += 1
4204      return x
4205
4206    enabled = def_function.function(func, experimental_follow_type_hints=True)
4207
4208    enabled(x=1, y=2, z=3)
4209    enabled(x=1, y=3, z=3)  # Retrace - change in args
4210    enabled(x=2, y=2, z=4)  # No retrace - change in args and **kwargs
4211    enabled(x=2, y=2, z=4, u=5)  # Retrace - change in **kwargs
4212    self.assertEqual(trace_count[0], 3)
4213
4214  def testFollowTypeHintsWithTensorSpec(self):
4215    def func(x: ops.Tensor, y):
4216      return x + y
4217    v = def_function.function(experimental_follow_type_hints=True)(func)
4218    v = v.get_concrete_function(
4219        tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 3)
4220    x = v(constant_op.constant(1.), 3)
4221    self.assertEqual(x.numpy(), 4.)
4222
4223  def testFollowTypeHintsTraceWithKwArgsAndNoVarKws(self):
4224    trace_count = [0]
4225
4226    def func(a: int, b: ops.Tensor,
4227             x: ops.Tensor = 0, y: int = 1):
4228      del a, b, y
4229      trace_count[0] += 1
4230      return x
4231
4232    enabled = def_function.function(func, experimental_follow_type_hints=True)
4233
4234    enabled(0, 0, x=1, y=2)
4235    enabled(0, 0, x=2, y=2,)  # No retrace, since only tensor changed
4236    self.assertEqual(trace_count[0], 1)
4237
4238    # Pass args as keyword args.
4239    enabled(a=0, b=0, x=2, y=2,)  # No retrace, args are the same
4240    self.assertEqual(trace_count[0], 1)
4241
4242    enabled(a=1, b=0, x=2, y=2,)  # Retrace, since non-tensor arg changed
4243    self.assertEqual(trace_count[0], 2)
4244
4245    enabled(a=1, b=2, x=2, y=2)  # No retrace, since only tensor changed
4246    self.assertEqual(trace_count[0], 2)
4247
4248    trace_count[0] = 0
4249    disabled = def_function.function(func, experimental_follow_type_hints=False)
4250    disabled(0, 0, x=1, y=2)
4251    disabled(0, 0, x=2, y=2,)  # Retrace
4252    self.assertEqual(trace_count[0], 2)
4253
4254  def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
4255    trace_count = [0]
4256
4257    def func(x, y, **kwargs: ops.Tensor):
4258      del y, kwargs
4259      trace_count[0] += 1
4260      return x
4261
4262    enabled = def_function.function(func, experimental_follow_type_hints=True)
4263
4264    enabled(x=1, y=2, z=3)
4265    enabled(x=1, y=3, z=3)  # Retrace
4266    enabled(x=1, y=2, z=4)  # No retrace
4267    enabled(x=2, y=2, z=4)  # Retrace
4268    enabled(x=2, y=2, z=4, u=5)  # Retrace
4269    self.assertEqual(trace_count[0], 4)
4270
4271  def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self):
4272    trace_count = [0]
4273
4274    def func(x: ops.Tensor, y: int, **kwargs):
4275      del y, kwargs
4276      trace_count[0] += 1
4277      return x
4278
4279    enabled = def_function.function(func, experimental_follow_type_hints=True)
4280
4281    enabled(x=1, y=2, z=3)
4282    enabled(x=1, y=3, z=3)  # Retrace
4283    enabled(x=1, y=2, z=4)  # Retrace
4284    enabled(x=2, y=2, z=3)  # No retrace
4285    enabled(x=2, y=2, z=4, u=5)  # Retrace
4286    self.assertEqual(trace_count[0], 4)
4287
4288  def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self):
4289    trace_count = [0]
4290
4291    def func(*, a: ops.Tensor = None, b=1):  # pylint: disable=bad-whitespace
4292      del b
4293      trace_count[0] += 1
4294      return a
4295
4296    enabled = def_function.function(func, experimental_follow_type_hints=True)
4297
4298    enabled(a=1, b=2)
4299    enabled(a=2, b=2)  # No retrace
4300    enabled(a=1, b=1)  # Retrace
4301    self.assertEqual(trace_count[0], 2)
4302
4303  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self):
4304    trace_count = [0]
4305
4306    def func(arg: ops.Tensor, *args, kwonly, **kwargs):
4307      del args, kwonly, kwargs
4308      trace_count[0] += 1
4309      return arg
4310
4311    enabled = def_function.function(func, experimental_follow_type_hints=True)
4312
4313    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4314    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4315    enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4316    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4317    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # Retrace
4318    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # Retrace
4319    self.assertEqual(trace_count[0], 4)
4320
4321  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self):
4322    trace_count = [0]
4323
4324    def func(arg, *args: ops.Tensor, kwonly, **kwargs):
4325      del args, kwonly, kwargs
4326      trace_count[0] += 1
4327      return arg
4328
4329    enabled = def_function.function(func, experimental_follow_type_hints=True)
4330
4331    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4332    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4333    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4334    enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7)  # No retrace
4335    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # Retrace
4336    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # Retrace
4337    self.assertEqual(trace_count[0], 4)
4338
4339  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self):
4340    trace_count = [0]
4341
4342    def func(arg, *args, kwonly: ops.Tensor, **kwargs):
4343      del args, kwonly, kwargs
4344      trace_count[0] += 1
4345      return arg
4346
4347    enabled = def_function.function(func, experimental_follow_type_hints=True)
4348
4349    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4350    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4351    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4352    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # No retrace
4353    enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7)  # No retrace
4354    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # Retrace
4355    self.assertEqual(trace_count[0], 4)
4356
4357  def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self):
4358    trace_count = [0]
4359
4360    def func(arg, *args, kwonly, **kwargs: ops.Tensor):
4361      del args, kwonly, kwargs
4362      trace_count[0] += 1
4363      return arg
4364
4365    enabled = def_function.function(func, experimental_follow_type_hints=True)
4366
4367    enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
4368    enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4369    enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7)  # Retrace
4370    enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7)  # Retrace
4371    enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70)  # No retrace
4372    enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700)  # No retrace
4373    self.assertEqual(trace_count[0], 4)
4374
4375  def testWithExtraWrapper(self):
4376
4377    class Foo(module.Module):
4378
4379      def __init__(self):
4380        super().__init__()
4381        self.var = None
4382
4383      @def_function.function
4384      @dummy_tf_decorator
4385      def add(self, x, y, z=1):
4386        if self.var is None:
4387          return x + y + z
4388
4389    foo = Foo()
4390    self.assertEqual(foo.add(2, 3).numpy(), 6)
4391
4392  @parameterized.parameters([(def_function.function, dummy_tf_decorator),
4393                             (dummy_tf_decorator, def_function.function),
4394                             (def_function.function, def_function.function)])
4395  def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2):
4396
4397    class Foo(module.Module):
4398
4399      def __init__(self):
4400        super().__init__()
4401        self.var = None
4402
4403      @decorator1
4404      @decorator2
4405      def add1(self, x, y):
4406        if self.var is None:
4407          return x + y
4408
4409    foo = Foo()
4410    with self.assertRaisesRegex(TypeError, 'got two values for argument'):
4411      foo.add1(2, x=3)  # pylint: disable=redundant-keyword-arg,no-value-for-parameter
4412
4413  def testWithExtraWrapperMissingArgs(self):
4414
4415    class Foo(module.Module):
4416
4417      def __init__(self):
4418        super().__init__()
4419        self.var = None
4420
4421      @def_function.function
4422      @dummy_tf_decorator
4423      def add1(self, x, y):
4424        if self.var is None:
4425          return x + y
4426
4427      @def_function.function
4428      @dummy_tf_decorator
4429      def add2(self, x, y):
4430        if self.var is None:
4431          return x + y
4432
4433      @def_function.function
4434      @def_function.function
4435      def add3(self, x, y):
4436        if self.var is None:
4437          return x + y
4438
4439    foo = Foo()
4440    with self.assertRaisesRegex(
4441        TypeError, 'missing 1 required positional argument: \'y\''):
4442      foo.add1(2)  # pylint: disable=no-value-for-parameter
4443
4444    with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
4445      foo.add1(y=2)  # pylint: disable=no-value-for-parameter
4446
4447    with self.assertRaisesRegex(
4448        TypeError, 'missing 1 required positional argument: \'y\''):
4449      foo.add2(2)  # pylint: disable=no-value-for-parameter
4450
4451    with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
4452      foo.add2(y=2)  # pylint: disable=no-value-for-parameter
4453
4454    with self.assertRaisesRegex(
4455        TypeError, 'missing 1 required positional argument: \'y\''):
4456      foo.add3(2)  # pylint: disable=no-value-for-parameter
4457
4458    with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'):
4459      foo.add3(y=2)  # pylint: disable=no-value-for-parameter
4460
4461  def testMissingArgsTfFunctionedMethod(self):
4462
4463    class A(object):
4464
4465      def func(self, position_arg1, position_arg2):
4466        return position_arg1, position_arg2
4467
4468      @def_function.function
4469      def decorated_method(self, position_arg1, position_arg2):
4470        return position_arg1, position_arg2
4471
4472    a_instance = A()
4473    tf_method_pos = def_function.function(a_instance.func)
4474    with self.assertRaisesRegex(
4475        TypeError, '.* missing 1 required argument: position_arg1'):
4476      tf_method_pos(position_arg2='foo')
4477
4478    # tf.function-decorated instance methods need to be tested because of
4479    # the __get__ method implementation.
4480    tf_func_decorated_method = def_function.function(
4481        a_instance.decorated_method)
4482    tf_func_decorated_method(position_arg1='foo', position_arg2='bar')
4483    with self.assertRaisesRegex(
4484        TypeError, '.* missing 1 required argument: position_arg1'):
4485      tf_func_decorated_method(position_arg2='bar')
4486
4487  def testMissingArgsTfFunctionedObject(self):
4488
4489    class A(object):
4490
4491      def __call__(self, position_arg1, position_arg2):
4492        return position_arg1, position_arg2
4493
4494    a_instance = A()
4495
4496    # A tf.function-decorated callable object needs to be tested because of
4497    # the special inspect results.
4498    tf_func_obj = def_function.function(a_instance)
4499    tf_func_obj(position_arg1=1, position_arg2=2)
4500    with self.assertRaisesRegex(
4501        TypeError, '.* missing 1 required argument: position_arg1'):
4502      tf_func_obj(position_arg2='bar')
4503
4504  def testMissingArgsTfFunctionedFunctions(self):
4505
4506    def func_pos(position_arg1, position_arg2):
4507      return position_arg1, position_arg2
4508
4509    def func_with_default(position_arg, named_arg=None):
4510      return position_arg, named_arg
4511
4512    def func_pos_3args(position_arg1, position_arg2, position_arg3):
4513      return position_arg1, position_arg2, position_arg3
4514
4515    tf_func_pos = def_function.function(func_pos)
4516    with self.assertRaisesRegex(
4517        TypeError, '.* missing 1 required argument: position_arg1'):
4518      tf_func_pos(position_arg2='foo')
4519
4520    tf_func_with_default = def_function.function(func_with_default)
4521    tf_func_with_default(position_arg='bar')
4522    with self.assertRaisesRegex(TypeError,
4523                                '.* missing 1 required argument: position_arg'):
4524      tf_func_with_default(named_arg='foo')
4525
4526    tf_func_pos_3args = def_function.function(func_pos_3args)
4527    with self.assertRaisesRegex(
4528        TypeError,
4529        '.* missing required arguments: position_arg1, position_arg3'):
4530      tf_func_pos_3args(position_arg2='foo')
4531
4532  def testShapeInferencePropagateConstNestedStack(self):
4533
4534    @def_function.function(input_signature=[
4535        tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
4536        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4537    ])
4538    def f(x, s):
4539      old_shape = array_ops.shape(x)
4540      new_shape = array_ops.stack([old_shape[0], s], axis=0)
4541      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4542      return y
4543
4544    @def_function.function(input_signature=[
4545        tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
4546    ])
4547    def g(x):
4548      y = f(x, s=5)
4549      assert y.shape.as_list() == [3, 5], y.shape.as_list()
4550      return y
4551
4552    self.assertAllEqual(
4553        g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
4554
4555  def testShapeInferencePropagateConstNestedUnstackStack(self):
4556
4557    @def_function.function(input_signature=[
4558        tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
4559        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4560    ])
4561    def f(x, s):
4562      s0, _ = array_ops.unstack(array_ops.shape(x), axis=0)
4563      new_shape = array_ops.stack([s0, s], axis=0)
4564      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4565      return y
4566
4567    @def_function.function(input_signature=[
4568        tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
4569    ])
4570    def g(x):
4571      y = f(x, s=5)
4572      assert y.shape.as_list() == [3, 5], y.shape.as_list()
4573      return y
4574
4575    self.assertAllEqual(
4576        g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
4577
4578  def testShapeInferencePropagateConstNestedConcat(self):
4579
4580    @def_function.function(input_signature=[
4581        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4582        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4583        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4584    ])
4585    def f(d1, d2, d3):
4586      new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
4587      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4588      return y
4589
4590    @def_function.function()
4591    def g():
4592      y = f(1, 2, 3)
4593      assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
4594      return y
4595
4596    self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
4597
4598  def testShapeInferencePropagateConstDoubleNested(self):
4599
4600    @def_function.function(input_signature=[
4601        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4602        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4603        tensor_spec.TensorSpec((), dtype=dtypes.int32),
4604    ])
4605    def f(d1, d2, d3):
4606      new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
4607      y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
4608      return y
4609
4610    @def_function.function()
4611    def g():
4612      y = def_function.function(f)(1, 2, 3)
4613      assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
4614      return y
4615
4616    self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
4617
4618  @test_util.run_v2_only
4619  def testControlDependencyAfterInline(self):
4620    v = variables.Variable(0.)
4621
4622    @def_function.function
4623    def assign():
4624      return v.assign(1.)
4625
4626    @def_function.function
4627    def assign_add():
4628      return v.assign_add(1.)
4629
4630    @def_function.function
4631    def f():
4632      check_ops.assert_equal_v2(assign(), 1.)
4633      check_ops.assert_equal_v2(assign_add(), 2.)
4634
4635    # We don't have a way to inspect the inlined graph in Python, so we run it
4636    # multiple times to have more confidence the dependency is correct.
4637    for _ in range(30):
4638      f()
4639
4640  @test_util.run_v2_only
4641  def testReadInFuncWriteOutside(self):
4642    # Run many times since we are testing for a potential race condition.
4643    for _ in range(30):
4644      # pylint: disable=cell-var-from-loop
4645      v = variables.Variable(1.)
4646
4647      @def_function.function
4648      def add_one():
4649        return v + 1.
4650
4651      @def_function.function
4652      def get_v_plus_one():
4653        v_plus_one = add_one()
4654        v.assign_add(2.0)
4655        return v_plus_one
4656
4657      self.assertAllEqual(get_v_plus_one(), 2.0)
4658
4659
4660class MultiDeviceTest(test.TestCase, parameterized.TestCase):
4661
4662  @test_util.run_gpu_only
4663  def testMultiDeviceOutput(self):
4664    """Tests that functions can produce outputs on multiple devices."""
4665    @function.defun
4666    def func(a, b, transpose_a):
4667      with ops.device('/device:CPU:0'):
4668        m1 = math_ops.matmul(a, b, transpose_a=transpose_a)
4669      with ops.device('/device:GPU:0'):
4670        m2 = math_ops.matmul(a, b, transpose_a=transpose_a)
4671      return m1, m2
4672
4673    t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
4674    m1, m2 = func(t, t, transpose_a=True)
4675    self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]])
4676    self.assertRegex(m1.backing_device, 'CPU')
4677    self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]])
4678    self.assertRegex(m2.backing_device, 'GPU')
4679
4680  @test_util.run_gpu_only
4681  def testEmptyBody(self):
4682    @function.defun
4683    def func(a, b):
4684      return b, a
4685
4686    with ops.device('/device:CPU:0'):
4687      a = array_ops.identity(3.0)
4688    with ops.device('/device:GPU:0'):
4689      b = array_ops.identity(5.0)
4690
4691    m1, m2 = func(a, b)
4692    self.assertAllEqual(m1.numpy(), 5.0)
4693    self.assertRegex(m1.backing_device, 'GPU')
4694    self.assertAllEqual(m2.numpy(), 3.0)
4695    self.assertRegex(m2.backing_device, 'CPU')
4696
4697  @test_util.run_gpu_only
4698  def testMultiDeviceInt32(self):
4699    """Tests that multi-device functions can take and output INT32s.
4700
4701    When an INT32 device tensor is fed into a function, it is copied to CPU
4702    by the eager runtime. The function sees all INT32 inputs on CPU.
4703
4704    We set allocator attribute 'on_host' for INT32 outputs. They can be
4705    partitioned into the GPU component function, but will be allocated on
4706    CPU nevertheless.
4707
4708    There is experimental support for `ints_on_device` in
4709    FunctionLibraryRuntime now. We can try that.
4710
4711    """
4712    with ops.device('/device:CPU:0'):
4713      int_cpu = constant_op.constant(3, dtype=dtypes.int32)
4714      resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32)
4715    with ops.device('/device:GPU:0'):
4716      int_gpu = constant_op.constant(7, dtype=dtypes.int32)
4717
4718    @function.defun
4719    def func(int_cpu, resource, int_gpu):
4720      with ops.device('/device:CPU:0'):
4721        m1 = int_cpu * resource + int_gpu
4722      with ops.device('/device:GPU:0'):
4723        # This computation will happen on GPU but m2 will be copied to CPU.
4724        m2 = int_gpu * resource + int_cpu + 1
4725      return m1, m2
4726
4727    m1, m2 = func(int_cpu, resource, int_gpu)
4728    self.assertAllEqual(m1.numpy(), 22)
4729    self.assertRegex(m1.backing_device, 'CPU')
4730    self.assertAllEqual(m2.numpy(), 39)
4731    self.assertRegex(m2.backing_device, 'CPU')
4732
4733    # flip arguments
4734    m1, m2 = func(int_gpu, resource, int_cpu)
4735    self.assertAllEqual(m1.numpy(), 38)
4736    self.assertRegex(m1.backing_device, 'CPU')
4737    self.assertAllEqual(m2.numpy(), 23)
4738    self.assertRegex(m2.backing_device, 'CPU')
4739
4740  @test_util.run_gpu_only
4741  def testMultiDeviceColocateWith(self):
4742    """Tests that function's outputs respect colocation constraints."""
4743    @function.defun
4744    def func(a, b):
4745      with ops.colocate_with(a):
4746        ra = 2 * a
4747      with ops.colocate_with(b):
4748        rb = 3 * b
4749      return ra, rb
4750
4751    devices = ['/device:CPU:0', '/device:GPU:0']
4752    for dev1, dev2 in itertools.product(devices, devices):
4753      with ops.device(dev1):
4754        a = array_ops.identity(1.0)
4755      with ops.device(dev2):
4756        b = array_ops.identity(10.0)
4757
4758      ra, rb = func(a, b)
4759      self.assertEqual(ra.numpy(), 2.0)
4760      self.assertRegex(ra.backing_device, dev1)
4761      self.assertEqual(rb.numpy(), 30.0)
4762      self.assertRegex(rb.backing_device, dev2)
4763
4764  @test_util.run_gpu_only
4765  def testMultiDeviceResources(self):
4766    with ops.device('/device:CPU:0'):
4767      c1 = resource_variable_ops.ResourceVariable(2.0)
4768      c2 = resource_variable_ops.ResourceVariable(7.0)
4769    with ops.device('/device:GPU:0'):
4770      g1 = resource_variable_ops.ResourceVariable(3.0)
4771      g2 = resource_variable_ops.ResourceVariable(5.0)
4772
4773    @function.defun
4774    def func(resource1, resource2):
4775      with ops.device('/device:CPU:0'):
4776        result1 = resource1 * g2
4777      with ops.device('/device:GPU:0'):
4778        result2 = resource2 * c2
4779      return result1, result2
4780
4781    r1, r2 = func(c1, g1)
4782    self.assertEqual(r1.numpy(), 10.0)
4783    self.assertRegex(r1.backing_device, 'CPU')
4784    self.assertEqual(r2.numpy(), 21.0)
4785    self.assertRegex(r2.backing_device, 'GPU')
4786
4787    # Call with flipped inputs. Check that we look at resource's
4788    # device and reinstantiates the function when inputs' devices change.
4789    r1, r2 = func(g1, c1)
4790    self.assertEqual(r1.numpy(), 15.0)
4791    self.assertRegex(r1.backing_device, 'CPU')
4792    self.assertEqual(r2.numpy(), 14.0)
4793    self.assertRegex(r2.backing_device, 'GPU')
4794
4795  @test_util.run_gpu_only
4796  def testOutputResources(self):
4797    with ops.device('/device:CPU:0'):
4798      c1 = resource_variable_ops.ResourceVariable(2.0)
4799    with ops.device('/device:GPU:0'):
4800      g1 = resource_variable_ops.ResourceVariable(3.0)
4801
4802    @function.defun
4803    def func(resource1, resource2):
4804      with ops.device('/device:CPU:0'):
4805        result1 = resource1 * 5
4806      with ops.device('/device:GPU:0'):
4807        result2 = resource2 * 7
4808      return result1, resource1.handle, result2, resource2.handle
4809
4810    r1, res1, r2, res2 = func(c1, g1)
4811    self.assertEqual(r1.numpy(), 10.0)
4812    self.assertRegex(r1.backing_device, 'CPU')
4813    self.assertEqual(r2.numpy(), 21.0)
4814    self.assertRegex(r2.backing_device, 'GPU')
4815
4816    def check_handle(handle, expected_value):
4817      self.assertRegex(handle.backing_device, 'CPU')
4818      tensor = gen_resource_variable_ops.read_variable_op(
4819          handle, dtypes.float32)
4820      self.assertEqual(tensor.numpy(), expected_value)
4821
4822    # Check that handles returned from functions are on CPU and an op using
4823    # the resource handle is correctly placed on the device backing the
4824    # resource.
4825    check_handle(res1, 2.0)
4826    check_handle(res2, 3.0)
4827
4828    # Call with flipped inputs to make sure the same the function is
4829    # reinstantiated and eager runtime does not mess up the device assignment
4830    # for ops consuming handles returned from defuns.
4831    r1, res1, r2, res2 = func(g1, c1)
4832    self.assertEqual(r1.numpy(), 15.0)
4833    self.assertRegex(r1.backing_device, 'CPU')
4834    self.assertEqual(r2.numpy(), 14.0)
4835    self.assertRegex(r2.backing_device, 'GPU')
4836    check_handle(res1, 3.0)
4837    check_handle(res2, 2.0)
4838
4839  @test_util.run_gpu_only
4840  def testPassResourceThroughNestedFunctionCall(self):
4841    """Test passing GPU resource to noinline function call placed on CPU.
4842
4843    PartitionedCallOp must not enforce any particular device assignment for the
4844    resource output. Inner function marked as `_nospecialize`, so Grappler would
4845    not prune unused function output.
4846    """
4847
4848    with ops.device('/device:GPU:0'):
4849      g1 = resource_variable_ops.ResourceVariable(3.0)
4850
4851    @function.defun_with_attributes(attributes={
4852        '_noinline': True,
4853        '_nospecialize': True
4854    })
4855    def inner(resource1):
4856      return resource1 * 2, resource1.handle
4857
4858    @function.defun
4859    def outer(resource1):
4860      with ops.device('/device:CPU:0'):
4861        r1, _ = inner(resource1)
4862      return r1
4863
4864    r1 = outer(g1)
4865
4866    self.assertEqual(r1.numpy(), 6.0)
4867    self.assertRegex(r1.backing_device, 'CPU')
4868
4869  @test_util.run_gpu_only
4870  def testReturnResourceFromNestedFunctionCall(self):
4871    """Test returning GPU resource from noinline function call placed on CPU.
4872
4873    When inferring output devices for the return value, do not set a device for
4874    returns of DT_RESOURCE data type based on the device assignment of the node
4875    that produced that resource. As an example function call placed on CPU can
4876    return resources on GPU.
4877    """
4878
4879    with ops.device('/device:GPU:0'):
4880      g1 = resource_variable_ops.ResourceVariable(3.0)
4881
4882    @function.defun_with_attributes(attributes={
4883        '_noinline': True
4884    })
4885    def inner(resource1):
4886      resource1.assign_add(2.0)
4887      return resource1 * 2, resource1.handle
4888
4889    @function.defun
4890    def outer(resource1):
4891      with ops.device('/device:CPU:0'):
4892        r1, res1 = inner(resource1)
4893      return r1, res1
4894
4895    r1, res1 = outer(g1)
4896
4897    self.assertEqual(r1.numpy(), 10.0)
4898    self.assertRegex(r1.backing_device, 'CPU')
4899
4900    def check_handle(handle, expected_value):
4901      self.assertRegex(handle.backing_device, 'CPU')
4902      tensor = gen_resource_variable_ops.read_variable_op(
4903          handle, dtypes.float32)
4904      self.assertEqual(tensor.numpy(), expected_value)
4905
4906    # Check that handles returned from functions are on CPU and an op using
4907    # the resource handle is correctly placed on the device backing the
4908    # resource.
4909    check_handle(res1, 5.0)
4910
4911  @test_util.run_gpu_only
4912  def testComplexInputOutputDevicePattern(self):
4913    """Tests input/output mapping logic in partitioning."""
4914    with ops.device('/device:CPU:0'):
4915      rc0 = resource_variable_ops.ResourceVariable(2.0)
4916      rc1 = resource_variable_ops.ResourceVariable(3.0)
4917      cc0 = array_ops.identity(5.0)
4918      cc1 = array_ops.identity(7.0)
4919    with ops.device('/device:GPU:0'):
4920      rg0 = resource_variable_ops.ResourceVariable(11.0)
4921      rg1 = resource_variable_ops.ResourceVariable(13.0)
4922      cg0 = array_ops.identity(17.0)
4923      cg1 = array_ops.identity(19.0)
4924
4925    # Make sure tensors are on expected devices.
4926    for tensor in [cc0, cc1]:
4927      self.assertRegex(tensor.backing_device, 'CPU:0')
4928    for tensor in [cg0, cg1]:
4929      self.assertRegex(tensor.backing_device, 'GPU:0')
4930
4931    @function.defun
4932    def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1):
4933      with ops.device('/device:CPU:0'):
4934        m1 = rc0 * cg0
4935      with ops.device('/device:GPU:0'):
4936        m2 = rg0 * cc0
4937
4938      with ops.device('/device:CPU:0'):
4939        r1 = 1000.0 * m2 + rc1 * cg1
4940      with ops.device('/device:GPU:0'):
4941        r2 = 1000.0 * m1 + rg1 * cc1
4942
4943      return r1, r2, m2, m1
4944
4945    r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1)
4946    self.assertRegex(m1.backing_device, 'CPU')
4947    self.assertRegex(r1.backing_device, 'CPU')
4948    self.assertRegex(m2.backing_device, 'GPU')
4949    self.assertRegex(r2.backing_device, 'GPU')
4950    self.assertEqual(m1.numpy(), 34.0)
4951    self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0)
4952    self.assertEqual(m2.numpy(), 55.0)
4953    self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0)
4954
4955  @test_util.run_gpu_only
4956  def testArgumentPruning(self):
4957    """Tests functions taking unnecessary arguments."""
4958    with ops.device('/device:CPU:0'):
4959      c1 = constant_op.constant(5.0)
4960      c2 = constant_op.constant(7.0)
4961
4962    with ops.device('/device:GPU:0'):
4963      g1 = constant_op.constant(11.0)
4964      g2 = constant_op.constant(13.0)
4965      g3 = constant_op.constant(17.0)
4966
4967    @function.defun
4968    def func(g1, g2, c1, g3, c2):  # pylint: disable=unused-argument
4969      # arguments g1 and g2 are unused and can be pruned by grappler.
4970      return c1 * g3 * c2
4971
4972    result = func(g1, g2, c1, g3, c2)
4973    self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0)
4974
4975  def testNestedCallWatchedVariables(self):
4976
4977    v = variables.Variable(4.)
4978
4979    @def_function.function
4980    def f():
4981      return v ** 2.
4982
4983    with backprop.GradientTape() as tape:
4984      f()
4985
4986    self.assertEqual((v,), tape.watched_variables())
4987
4988    @def_function.function
4989    def g():
4990      return f()
4991
4992    with backprop.GradientTape() as tape:
4993      g()
4994
4995    self.assertEqual((v,), tape.watched_variables())
4996
4997    # f() can rely on the variable being read during its trace. g() checks that
4998    # variables from a function which knows about them are recorded on the
4999    # tape. h() tests that functions forward knowledge of variables to callers.
5000
5001    @def_function.function
5002    def h():
5003      return g()
5004
5005    with backprop.GradientTape() as tape:
5006      h()
5007
5008    self.assertEqual((v,), tape.watched_variables())
5009
5010  def testDeferredCapture(self):
5011    value = 1.0
5012
5013    @def_function.function
5014    def lazy_capture(x):
5015      y = ops.get_default_graph().capture_call_time_value(
5016          lambda: value, tensor_spec.TensorSpec(None))
5017      return x + y
5018
5019    self.assertAllEqual(lazy_capture(2.0), 3.0)
5020    # After changing the value of `value` the function call should return a
5021    # different result.
5022    value = 2.0
5023    self.assertAllEqual(lazy_capture(2.0), 4.0)
5024
5025  def testDeferredCaptureWithKey(self):
5026    value0 = 1.0
5027    value1 = 2.0
5028
5029    @def_function.function
5030    def lazy_capture(x):
5031      w = ops.get_default_graph().capture_call_time_value(
5032          lambda: value0, tensor_spec.TensorSpec(None), key=0)
5033      y = ops.get_default_graph().capture_call_time_value(
5034          lambda: value1, tensor_spec.TensorSpec(None), key=1)
5035      def bad_closure():
5036        raise ValueError('Should not run')
5037      z = ops.get_default_graph().capture_call_time_value(
5038          bad_closure, tensor_spec.TensorSpec(None), key=1)
5039      return x + y + w + z
5040
5041    self.assertAllEqual(lazy_capture(2.0), 7.0)
5042    value0 = 2.0
5043    value1 = 3.0
5044    self.assertAllEqual(lazy_capture(2.0), 10.0)
5045
5046  def testDeferredCaptureTypeError(self):
5047    value = constant_op.constant(1.0)
5048
5049    @def_function.function
5050    def lazy_capture(x):
5051      y = ops.get_default_graph().capture_call_time_value(
5052          lambda: value, tensor_spec.TensorSpec(()))
5053      return x + y
5054
5055    self.assertAllEqual(lazy_capture(2.0), 3.0)
5056
5057    # dtype mismatch
5058    value = constant_op.constant(1)
5059    with self.assertRaisesRegex(ValueError, 'Value .* to a tensor with dtype'):
5060      lazy_capture(2.0)
5061
5062    # shape mismatch
5063    value = constant_op.constant([1.0])
5064    with self.assertRaisesRegex(ValueError, 'Value .* shape'):
5065      lazy_capture(2.0)
5066
5067  def testDeferredCaptureReturnNestWithCompositeTensor(self):
5068    i_s = indexed_slices.IndexedSlices(
5069        constant_op.constant([1, 2]),
5070        constant_op.constant([0, 1], dtype=dtypes.int64),
5071        constant_op.constant([2]))
5072    r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])
5073    s_t = sparse_tensor.SparseTensor(
5074        values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20])
5075
5076    @def_function.function
5077    def lazy_capture():
5078      y = ops.get_default_graph().capture_call_time_value(
5079          lambda: {'i': i_s, 't': (r_t, s_t)},
5080          {'i': indexed_slices.IndexedSlicesSpec(
5081              dtype=dtypes.int32, dense_shape_dtype=dtypes.int32),
5082           't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32),
5083                 sparse_tensor.SparseTensorSpec([None], dtypes.int32))})
5084      return y['i'], y['t']
5085
5086    i, (r, s) = lazy_capture()
5087    self.assertAllEqual(i_s.values, i.values)
5088    self.assertAllEqual(i_s.indices, i.indices)
5089    self.assertAllEqual(i_s.dense_shape, i.dense_shape)
5090    self.assertAllEqual(r_t, r)
5091    self.assertAllEqual(s_t.indices, s.indices)
5092    self.assertAllEqual(s_t.values, s.values)
5093    self.assertAllEqual(s_t.dense_shape, s.dense_shape)
5094
5095  def testDeferredCaptureCompositeTensorSpecTypeMismatch(self):
5096    value = indexed_slices.IndexedSlices(
5097        constant_op.constant([1, 2]),
5098        constant_op.constant([0, 1], dtype=dtypes.int64))
5099
5100    @def_function.function
5101    def lazy_capture():
5102      return ops.get_default_graph().capture_call_time_value(
5103          lambda: value,
5104          indexed_slices.IndexedSlicesSpec(dtype=dtypes.int32))
5105
5106    # Type matches spec.
5107    lazy_capture()
5108
5109    # Extra dense shape component.
5110    value = indexed_slices.IndexedSlices(
5111        constant_op.constant([1, 2]),
5112        constant_op.constant([0, 1], dtype=dtypes.int64),
5113        constant_op.constant([2]))
5114    with self.assertRaises(ValueError):
5115      lazy_capture()
5116
5117    # Index dtype mismatch int32 vs. int64.
5118    value = indexed_slices.IndexedSlices(
5119        constant_op.constant([1, 2]),
5120        constant_op.constant([0, 1]))
5121    with self.assertRaises(ValueError):
5122      lazy_capture()
5123
5124
5125if __name__ == '__main__':
5126  ops.enable_eager_execution()
5127  test.main()
5128