1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import copy 22import functools 23import itertools 24import multiprocessing.pool 25import os 26import sys 27import time 28import weakref 29 30from absl.testing import parameterized 31import numpy 32 33from tensorflow.core.protobuf import config_pb2 34from tensorflow.core.protobuf import rewriter_config_pb2 35from tensorflow.python.autograph.core import ag_ctx 36from tensorflow.python.data.ops import dataset_ops 37from tensorflow.python.data.ops import iterator_ops 38from tensorflow.python.eager import backprop 39from tensorflow.python.eager import cancellation 40from tensorflow.python.eager import context 41from tensorflow.python.eager import def_function 42from tensorflow.python.eager import function 43from tensorflow.python.framework import composite_tensor 44from tensorflow.python.framework import config 45from tensorflow.python.framework import constant_op 46from tensorflow.python.framework import dtypes 47from tensorflow.python.framework import errors 48from tensorflow.python.framework import func_graph 49from tensorflow.python.framework import function as tf_function 50from tensorflow.python.framework import indexed_slices 51from tensorflow.python.framework import ops 52from tensorflow.python.framework import random_seed 53from tensorflow.python.framework import sparse_tensor 54from tensorflow.python.framework import tensor_shape 55from tensorflow.python.framework import tensor_spec 56from tensorflow.python.framework import test_ops 57from tensorflow.python.framework import test_util 58from tensorflow.python.framework import type_spec 59from tensorflow.python.layers import convolutional 60from tensorflow.python.module import module 61from tensorflow.python.ops import array_ops 62from tensorflow.python.ops import check_ops 63from tensorflow.python.ops import clip_ops 64from tensorflow.python.ops import control_flow_ops 65from tensorflow.python.ops import data_flow_ops 66from tensorflow.python.ops import functional_ops 67from tensorflow.python.ops import gen_functional_ops 68from tensorflow.python.ops import gen_random_ops 69from tensorflow.python.ops import gen_resource_variable_ops 70from tensorflow.python.ops import gen_sendrecv_ops 71from tensorflow.python.ops import gradients_impl 72from tensorflow.python.ops import init_ops 73from tensorflow.python.ops import list_ops 74from tensorflow.python.ops import logging_ops 75from tensorflow.python.ops import math_ops 76from tensorflow.python.ops import random_ops 77from tensorflow.python.ops import resource_variable_ops 78from tensorflow.python.ops import string_ops 79from tensorflow.python.ops import variable_scope 80from tensorflow.python.ops import variables 81from tensorflow.python.ops.ragged import ragged_factory_ops 82from tensorflow.python.ops.ragged import ragged_tensor 83from tensorflow.python.ops.structured import structured_tensor 84from tensorflow.python.platform import test 85from tensorflow.python.saved_model.load import load 86from tensorflow.python.saved_model.save import save 87from tensorflow.python.training import training_ops 88from tensorflow.python.util import compat 89from tensorflow.python.util import nest 90from tensorflow.python.util import tf_decorator 91from tensorflow.python.util import tf_inspect 92 93try: 94 import attr # pylint:disable=g-import-not-at-top 95except ImportError: 96 attr = None 97 98 99def total_function_cache(defined): 100 # pylint: disable=protected-access 101 return (set(defined._function_cache.primary) 102 | set(defined._function_cache.arg_relaxed)) 103 # pylint: enable=protected-access 104 105 106def _example_indexed_slices_with_dense_shape(): 107 return indexed_slices.IndexedSlices( 108 constant_op.constant([1, 2]), constant_op.constant([0, 1]), 109 constant_op.constant([2])) 110 111 112def _example_indexed_slices_without_dense_shape(): 113 return indexed_slices.IndexedSlices( 114 constant_op.constant([1, 2]), constant_op.constant([0, 1])) 115 116 117def _spec_for_value(value): 118 """Returns the (nested) TypeSpec for a value.""" 119 if nest.is_sequence(value): 120 return nest.map_structure(_spec_for_value, value) 121 elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): 122 return type_spec.type_spec_from_value(value) 123 else: 124 return value 125 126 127# This dummy decorator imitates ordinary decorators utilizing tf_decorator. 128def dummy_tf_decorator(method): 129 130 def wrapper(*args, **kwargs): 131 return method(*args, **kwargs) 132 133 return tf_decorator.make_decorator(method, wrapper) 134 135 136# TODO(mdan): Organize these tests. 137class FunctionTest(test.TestCase, parameterized.TestCase): 138 139 def setUp(self): 140 super(FunctionTest, self).setUp() 141 cpus = config.list_physical_devices('CPU') 142 # Set 4 virtual CPUs 143 config.set_logical_device_configuration(cpus[0], [ 144 context.LogicalDeviceConfiguration(), 145 context.LogicalDeviceConfiguration(), 146 context.LogicalDeviceConfiguration(), 147 context.LogicalDeviceConfiguration() 148 ]) 149 150 def testBasic(self): 151 matmul = def_function.function(math_ops.matmul) 152 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 153 sq = matmul(t, t, transpose_a=True) 154 sq2 = matmul(sq, t, transpose_a=True) 155 self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) 156 self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108]) 157 158 def testPythonFunctionNotCallable(self): 159 with self.assertRaisesRegex(TypeError, 'is not a callable object'): 160 def_function.function(1) 161 162 def testOnExitCallback(self): 163 values = [] 164 def append_1(): 165 values.append(1) 166 167 def append_2(): 168 values.append(2) 169 170 def g(x): 171 old_values = list(values) 172 ops.add_exit_callback_to_default_func_graph(append_1) 173 self.assertEqual(old_values, values) 174 return x + 1 175 176 tf_g = def_function.function(g) 177 178 def f(x): 179 old_values = list(values) 180 ops.add_exit_callback_to_default_func_graph(append_2) 181 self.assertEqual(old_values, values) 182 return tf_g(x) 183 184 tf_f = def_function.function(f) 185 self.assertEmpty(values) 186 tf_f(constant_op.constant(1.0)) 187 self.assertEqual(values, [1, 2]) # Once for g, once for f. 188 tf_f(constant_op.constant([1.0])) # force a retrace 189 self.assertEqual(values, [1, 2, 1, 2]) # And again. 190 191 def testCannotAddExitCallbackWhenNotInFunctionScope(self): 192 with self.assertRaisesRegex(RuntimeError, 'when not building a function.'): 193 ops.add_exit_callback_to_default_func_graph(lambda: None) 194 195 def testVariable(self): 196 v1 = variables.Variable(1.0) 197 add = def_function.function(lambda x, v: x + v1 + v) 198 v2 = variables.Variable(1.0) 199 x = constant_op.constant(1.0) 200 r = add(x, v2) 201 self.assertEqual(3.0, self.evaluate(r)) 202 203 def testVariableOnly(self): 204 v = variables.Variable(1.0) 205 add = def_function.function(lambda x: x.assign_add(1.0)) 206 r1 = add(v) 207 self.assertEqual(2.0, self.evaluate(r1)) 208 c = constant_op.constant(1.0) 209 with self.assertRaisesRegex(AttributeError, 'no attribute'): 210 add(c) 211 212 @test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.') 213 def testPackedVariable(self): 214 with ops.device('/cpu:0'): 215 v0_0 = resource_variable_ops.ResourceVariable(1.0) 216 with ops.device('/cpu:1'): 217 v0_1 = resource_variable_ops.ResourceVariable(2.0) 218 v1_0 = resource_variable_ops.ResourceVariable(3.0) 219 with ops.device('/cpu:2'): 220 v1_1 = resource_variable_ops.ResourceVariable(4.0) 221 222 packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle]) 223 packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle]) 224 225 # TODO(b/145922293): use ResourceVariable.assign_add and 226 # ResourceVariable.read_value directly once we support packing multiple 227 # ResourceVariable into one ResourceVariable. 228 @def_function.function 229 def read_var(): 230 resource_variable_ops.assign_add_variable_op( 231 packed_var_0, constant_op.constant(5.0)) 232 resource_variable_ops.assign_add_variable_op( 233 packed_var_1, constant_op.constant(6.0)) 234 with ops.device('/cpu:0'): 235 read0 = resource_variable_ops.read_variable_op( 236 packed_var_0, dtype=dtypes.float32) 237 with ops.device('/cpu:1'): 238 read1 = resource_variable_ops.read_variable_op( 239 packed_var_0, dtype=dtypes.float32) 240 read2 = resource_variable_ops.read_variable_op( 241 packed_var_1, dtype=dtypes.float32) 242 with ops.device('/cpu:2'): 243 read3 = resource_variable_ops.read_variable_op( 244 packed_var_1, dtype=dtypes.float32) 245 246 return read0, read1, read2, read3 247 248 arg_attrs = read_var.get_concrete_function().function_def.arg_attr 249 self.assertLen(arg_attrs, 2) 250 self.assertEqual(arg_attrs[0].attr['_composite_device'].s, 251 compat.as_bytes(packed_var_0.device)) 252 self.assertEqual(arg_attrs[1].attr['_composite_device'].s, 253 compat.as_bytes(packed_var_1.device)) 254 255 self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6)) 256 257 def testImplementsAttributeBasic(self): 258 v = def_function.function( 259 experimental_implements='func')(lambda x, y: x + y) 260 with context.graph_mode(), self.cached_session(): 261 a = array_ops.placeholder(dtypes.float32, ()) 262 b = array_ops.placeholder(dtypes.float32, ()) 263 v(a, b) 264 gradients_impl.gradients(v(a, b), [a, b]) 265 fdefs = ops.get_default_graph().as_graph_def().library.function 266 self.assertLen(fdefs, 3) 267 not_present = 0 268 present = 0 269 for f in fdefs: 270 name = f.signature.name 271 if 'forward' in name or 'backward' in name: 272 not_present += 1 273 self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f) 274 else: 275 present += 1 276 self.assertEqual(f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME].s, 277 'func'.encode('ascii'), f) 278 self.assertEqual(not_present, 2, fdefs) 279 self.assertEqual(present, 1, fdefs) 280 281 def testImplementsAttributeAssertsOnSideInput(self): 282 with context.graph_mode(), self.cached_session(): 283 z = array_ops.zeros(0) 284 v = def_function.function( 285 experimental_implements='func')(lambda x, y: x + y + z) 286 a = array_ops.ones((1.0,)) 287 b = array_ops.ones((1.0,)) 288 with self.assertRaisesRegex(AssertionError, 289 'variables are always captured'): 290 v(a, b) 291 functions = ops.get_default_graph().as_graph_def().library.function 292 self.assertEmpty(functions) 293 294 def testImplementsAttributeWorksWithGradientTape(self): 295 add = lambda x, y: x + y ** 2 296 add = def_function.function(experimental_implements='MyFunc')(add) 297 x = variables.Variable(3.0) 298 y = variables.Variable(2.0) 299 300 with backprop.GradientTape() as tape: 301 g = add(x, y) 302 303 dg_dy, dg_dx = tape.gradient(g, [y, x]) 304 self.assertEqual(dg_dy.numpy(), 4.0) 305 self.assertEqual(dg_dx.numpy(), 1.0) 306 307 def testImplementsAttributeWorksOnVariables(self): 308 with context.graph_mode(), self.cached_session(): 309 v = def_function.function( 310 experimental_implements='func')(lambda x, y: x + y) 311 a = variables.Variable((1.0,)) 312 b = variables.Variable((1.0,)) 313 r1 = v(a, b) 314 _ = v(a, a) 315 functions = ops.get_default_graph().as_graph_def().library.function 316 # Verify that we created only one function 317 self.assertLen(functions, 1) 318 # Verify that eval() reads the current values. 319 a.initializer.run() 320 b.initializer.run() 321 self.assertEqual(r1.eval(), 2) 322 323 a.assign_add([1]).eval() 324 self.assertEqual(r1.eval(), 3) 325 326 def testImplementsAttributeWorksOnConstants(self): 327 with context.graph_mode(), self.cached_session(): 328 v = def_function.function( 329 experimental_implements='func')(lambda x, y: x + y) 330 a = variables.Variable(1.0) 331 r1 = v(a, 2.) 332 r2 = v(2., a) 333 functions = ops.get_default_graph().as_graph_def().library.function 334 self.assertLen(functions, 1) 335 self.assertLen(functions[0].signature.input_arg, 2) 336 # Verify that eval() reads the current values. 337 a.initializer.run() 338 self.assertEqual(r1.eval(), 3) 339 self.assertEqual(r2.eval(), 3) 340 341 def testImplementsAttributeSpecializes(self): 342 with context.graph_mode(), self.cached_session(): 343 v = def_function.function( 344 experimental_implements='func')(lambda x, y: x + y) 345 a = variables.Variable(1.0) 346 r1 = v(a, [2.]) 347 r2 = v([2., 2], a) 348 functions = ops.get_default_graph().as_graph_def().library.function 349 self.assertLen(functions, 2) 350 # Ensure that all parameters are still there and haven't been inlined! 351 352 self.assertLen(functions[0].signature.input_arg, 2) 353 self.assertLen(functions[1].signature.input_arg, 2) 354 # Verify that eval() reads the current values. 355 a.initializer.run() 356 numpy.testing.assert_equal(r1.eval(), [3.]) 357 numpy.testing.assert_equal(r2.eval(), [3., 3.]) 358 359 def testImplementsWorksWithTensorSpec(self): 360 v = def_function.function( 361 experimental_implements='func')(lambda x, y: x + y) 362 v = v.get_concrete_function( 363 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 364 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)) 365 x = v(1., 2.) 366 self.assertEqual(x.numpy(), 3.) 367 368 def testImplementsAttributeAsNameAttrList(self): 369 implements_attr = ( 370 'name: "embedding_matmul" attr { key: "key1" value { i: 2 } ' 371 '} attr { key: "key2" value { b: false } }') 372 v = def_function.function( 373 experimental_implements=implements_attr)(lambda x, y: x + y) 374 with context.graph_mode(), self.cached_session(): 375 a = array_ops.placeholder(dtypes.float32, ()) 376 b = array_ops.placeholder(dtypes.float32, ()) 377 v(a, b) 378 gradients_impl.gradients(v(a, b), [a, b]) 379 fdefs = ops.get_default_graph().as_graph_def().library.function 380 self.assertLen(fdefs, 3) 381 not_present = 0 382 present = 0 383 for f in fdefs: 384 name = f.signature.name 385 if 'forward' in name or 'backward' in name: 386 not_present += 1 387 self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f) 388 else: 389 present += 1 390 attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME] 391 self.assertIsNotNone(attr_value.func, f) 392 self.assertEqual(attr_value.func.name, 'embedding_matmul') 393 name_attrs = attr_value.func.attr 394 self.assertLen(name_attrs, 2) 395 self.assertEqual(not_present, 2, fdefs) 396 self.assertEqual(present, 1, fdefs) 397 398 def testExternalControlDependency(self): 399 with ops.Graph().as_default(), self.test_session(): 400 v = variables.Variable(1.0) 401 v.initializer.run() 402 403 op = v.assign_add(1.0) 404 405 @function.defun 406 def f(): 407 with ops.control_dependencies([op]): 408 return 1.0 409 410 self.evaluate(f()) 411 self.assertAllEqual(self.evaluate(v), 2.0) 412 413 def testInputShapeFunctionRelaxation(self): 414 unknown_dim = [False] 415 416 @function.defun(experimental_relax_shapes=True) 417 def func(a): 418 if a._shape_tuple()[0] is None: 419 unknown_dim[0] = True 420 return a + 1 421 422 func(constant_op.constant([])) 423 self.assertFalse(unknown_dim[0]) 424 self.assertLen(total_function_cache(func), 1) 425 426 func(constant_op.constant([1.0])) 427 self.assertFalse(unknown_dim[0]) 428 self.assertLen(total_function_cache(func), 2) 429 430 func(constant_op.constant([1.0, 2.0])) 431 self.assertTrue(unknown_dim[0]) 432 self.assertLen(total_function_cache(func), 2) 433 434 def testInputShapeRelaxationOnInstanceMethod(self): 435 # Test that experimental_relax_shapes is passed during 436 # instance method bounding. 437 unknown_dim = [False] 438 439 class Foo(object): 440 441 @def_function.function(experimental_relax_shapes=True) 442 def func(self, a): 443 if a._shape_tuple()[0] is None: 444 unknown_dim[0] = True 445 return a + 1 446 447 foo = Foo() 448 foo.func(constant_op.constant([])) 449 self.assertFalse(unknown_dim[0]) 450 451 foo.func(constant_op.constant([1.0])) 452 self.assertFalse(unknown_dim[0]) 453 454 foo.func(constant_op.constant([1.0, 2.0])) 455 self.assertTrue(unknown_dim[0]) 456 457 def testInputShapeFunctionRelaxationWithRaggedTensors(self): 458 traced_type_spec = [None] 459 460 @def_function.function(experimental_relax_shapes=True) 461 def func(x): 462 traced_type_spec[0] = x._type_spec 463 return x 464 465 def check_trace(x, expected_trace): 466 traced_type_spec[0] = None 467 func(x) 468 self.assertEqual(traced_type_spec[0], expected_trace) 469 470 check_trace( # Initial call gets traced. 471 ragged_factory_ops.constant([[1], [2, 3, 4]]), 472 ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)) 473 check_trace( # Input TypeSpec is the same -> no retrace. 474 ragged_factory_ops.constant([[1, 2], [3, 4]]), None) 475 check_trace( # Even if component tensor shapes change -> no retrace. 476 ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None) 477 check_trace( # Different TypeSpec shape (nrows): retrace 478 ragged_factory_ops.constant([[1], [2], [3]]), 479 ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32)) 480 check_trace( # Different nrows again: relax & retrace 481 ragged_factory_ops.constant([[1], [2], [3], [4]]), 482 ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)) 483 check_trace( # Different nrows yet again: not retrace 484 ragged_factory_ops.constant([[1]]), None) 485 check_trace( # Different ragged_rank: retrace 486 ragged_factory_ops.constant([[[1]]]), 487 ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32)) 488 check_trace( # Different ragged_rank again: retrace & relax 489 ragged_factory_ops.constant([[[1]], [[2]]]), 490 ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)) 491 492 def testInputShapeFunctionRelaxationWithStructuredTensors(self): 493 traced_type_spec = [None] 494 495 @def_function.function(experimental_relax_shapes=True) 496 def func(x): 497 traced_type_spec[0] = x._type_spec 498 return x 499 500 def check_trace(x, expected_trace): 501 traced_type_spec[0] = None 502 func(x) 503 self.assertEqual(traced_type_spec[0], expected_trace) 504 505 # If we have TypeSpecs that differ in ways other than just their shape, 506 # then retrace each time. 507 check_trace( 508 structured_tensor.StructuredTensor.from_pyval({'a': [1]}), 509 structured_tensor.StructuredTensorSpec( 510 [], {'a': tensor_spec.TensorSpec((1,), dtypes.int32)})) 511 check_trace( 512 structured_tensor.StructuredTensor.from_pyval({'b': [1]}), 513 structured_tensor.StructuredTensorSpec( 514 [], {'b': tensor_spec.TensorSpec((1,), dtypes.int32)})) 515 check_trace( 516 structured_tensor.StructuredTensor.from_pyval({'c': [1]}), 517 structured_tensor.StructuredTensorSpec( 518 [], {'c': tensor_spec.TensorSpec((1,), dtypes.int32)})) 519 520 # But if we call again with only shape different, then do relax: 521 check_trace( # retrace 522 structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}), 523 structured_tensor.StructuredTensorSpec( 524 [], {'a': tensor_spec.TensorSpec((2,), dtypes.int32)})) 525 check_trace( # relax & retrace 526 structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}), 527 structured_tensor.StructuredTensorSpec( 528 [], {'a': tensor_spec.TensorSpec((None,), dtypes.int32)})) 529 check_trace( # use relaxed graph 530 structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}), 531 None) 532 533 def testInputShapeFunctionRelaxationWithDatasetIterators(self): 534 # For dataset iterators, the TypeSpec includes type information that's 535 # not derivable from the component tensors. Make sure that the TypeSpec 536 # shapes get relaxed as appropriate. 537 538 traced_type_spec = [None] 539 540 @def_function.function(experimental_relax_shapes=True) 541 def func(x): 542 traced_type_spec[0] = x._type_spec 543 return x 544 545 def check_trace(x, expected_trace): 546 traced_type_spec[0] = None 547 func(x) 548 self.assertEqual(traced_type_spec[0], expected_trace) 549 550 ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2])) 551 ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2])) 552 ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2])) 553 ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2])) 554 ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1])) 555 check_trace( # shape=[1, 2]: retrace 556 dataset_ops.make_one_shot_iterator(ds_1_2), 557 iterator_ops.IteratorSpec( 558 tensor_spec.TensorSpec([1, 2], dtypes.float32))) 559 check_trace( # shape=[1, 2]: no retrace (use the [1, 2] graph) 560 dataset_ops.make_one_shot_iterator(ds_1_2), None) 561 check_trace( # shape=[2, 2]: retrace 562 dataset_ops.make_one_shot_iterator(ds_2_2), 563 iterator_ops.IteratorSpec( 564 tensor_spec.TensorSpec([2, 2], dtypes.float32))) 565 check_trace( # shape=[3, 2]: relax to [None, 2] and retrace 566 dataset_ops.make_one_shot_iterator(ds_3_2), 567 iterator_ops.IteratorSpec( 568 tensor_spec.TensorSpec([None, 2], dtypes.float32))) 569 check_trace( # shape=[4, 2]: no retrace (use the [None, 2] graph) 570 dataset_ops.make_one_shot_iterator(ds_4_2), None) 571 check_trace( # shape=[2, 1]: relax to [None, None] and retrace 572 dataset_ops.make_one_shot_iterator(ds_2_1), 573 iterator_ops.IteratorSpec( 574 tensor_spec.TensorSpec([None, None], dtypes.float32))) 575 576 def testCapturesVariables(self): 577 a = variables.Variable(1.0, trainable=False) 578 b = variables.Variable(1.0) 579 cc = [None] 580 581 @def_function.function 582 def f(): 583 c = cc[0] 584 if c is None: 585 c = cc[0] = variables.Variable(1.) 586 return a + b + c + 1 587 588 cf = f.get_concrete_function() 589 c = cc[0] 590 591 captured_variables = {v.ref() for v in (a, b, c)} 592 trainable_variables = {v.ref() for v in (b, c)} 593 self.assertEqual({v.ref() for v in cf.variables}, captured_variables) 594 self.assertEqual({v.ref() for v in cf.trainable_variables}, 595 trainable_variables) 596 self.assertEqual(cf.variables, cf.graph.variables) 597 self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables) 598 599 def testNestedInputShapeFunctionRelaxation(self): 600 unknown_dim = [False] 601 602 @function.defun(experimental_relax_shapes=True) 603 def func(a_, b_=None): 604 del a_ # Only used to check which cache is used. 605 self.assertEqual(b_[0]._shape_tuple(), ()) 606 if b_[1]._shape_tuple()[0] is None: 607 unknown_dim[0] = True 608 return b_[0] + 1 609 610 a = 'hi' 611 b0 = constant_op.constant(1.0) 612 func(a, b_=[b0, constant_op.constant([])]) 613 self.assertFalse(unknown_dim[0]) 614 self.assertLen(total_function_cache(func), 1) 615 616 func(a, b_=[b0, constant_op.constant([1.0])]) 617 self.assertFalse(unknown_dim[0]) 618 self.assertLen(total_function_cache(func), 2) 619 620 func(a, b_=[b0, constant_op.constant([1.0, 1.0])]) 621 self.assertTrue(unknown_dim[0]) 622 self.assertLen(total_function_cache(func), 2) 623 624 unknown_dim[0] = False 625 626 # Now do the same except with a new a which is not a tensor; this should 627 # change the cache key. 628 a = 'bye' 629 func(a, b_=[b0, constant_op.constant([])]) 630 self.assertFalse(unknown_dim[0]) 631 self.assertLen(total_function_cache(func), 3) 632 633 # Since we already marked a cache miss for a function with the same 634 # non-input signatures, here we will immediately start relaxing shapes. 635 func(a, b_=[b0, constant_op.constant([1.0])]) 636 self.assertTrue(unknown_dim[0]) 637 self.assertLen(total_function_cache(func), 3) 638 639 def testNestedShapeFunctionRelaxation(self): 640 641 got_shape = [None] 642 643 # The inner function will go through shape relaxation because the shapes it 644 # receives will be [1], [2], [3], ... 645 @def_function.function(experimental_relax_shapes=True) 646 def bar(x_shape): 647 got_shape[0] = x_shape._shape_tuple() 648 return x_shape 649 650 # The outer function will not go through shape relaxation because the shapes 651 # it receives will be [1], [[1]], [[[1]]], ... 652 @def_function.function(experimental_relax_shapes=True) 653 def foo(ones): 654 return bar(array_ops.shape(ones)) 655 656 for rank in range(1, 6): 657 x_shape = self.evaluate(foo(array_ops.ones([1] * rank))) 658 self.assertAllEqual(x_shape, [1] * rank) 659 if rank < 3: 660 self.assertEqual(got_shape[0], (rank,)) 661 else: 662 self.assertEqual(got_shape[0], (None,)) 663 664 def testNoHash(self): 665 666 @def_function.function() 667 def f(_): 668 return 1.0 669 670 with self.assertRaisesRegex(ValueError, r'got.*set'): 671 f(set([])) 672 673 def testFuncName(self): 674 675 @function.defun_with_attributes(attributes={'func_name': 'multiply'}) 676 def add(x, y): 677 _ = x * y 678 return x + y 679 680 @function.defun 681 def add_2(x, y): 682 _ = x * y 683 return x + y 684 685 self.assertEqual(add._name, 'multiply') 686 self.assertEqual(add_2._name, 'add_2') 687 688 def testBasicGraphMode(self): 689 matmul = def_function.function(math_ops.matmul) 690 691 @def_function.function 692 def sq(a): 693 return matmul(a, a) 694 695 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 696 out = sq(t) 697 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 698 699 def testNestedInputsGraphMode(self): 700 matmul = def_function.function(math_ops.matmul) 701 702 pair = collections.namedtuple('pair', ['a', 'b']) 703 704 @def_function.function 705 def a_times_b(inputs): 706 return matmul(inputs.a['a'], inputs.b['b']) 707 708 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 709 710 out = a_times_b(pair({'a': t}, {'b': t})) 711 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 712 713 def testNestedOutputsGraphMode(self): 714 matmul = def_function.function(math_ops.matmul) 715 716 pair = collections.namedtuple('pair', ['a', 'b']) 717 718 @def_function.function() 719 def pairs_mul(pair_a, pair_b): 720 return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b)) 721 722 a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]]) 723 b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]]) 724 725 out = pairs_mul(pair(a, b), pair(b, a)) 726 expected = pair(math_ops.matmul(a, b).numpy(), 727 math_ops.matmul(b, a).numpy()) 728 self.assertAllClose(out, expected) 729 730 @parameterized.named_parameters( 731 dict(testcase_name='Defun', 732 function_decorator=function.defun), 733 dict(testcase_name='DefFunction', 734 function_decorator=def_function.function)) 735 def testNestedFunctionGraphNotOutOfDate(self, function_decorator): 736 @function_decorator 737 def f(): 738 return constant_op.constant(1.) 739 740 class _Model(object): 741 742 @function_decorator 743 def g(self): 744 self.f = f.get_concrete_function() 745 746 model = _Model() 747 model.g() 748 concrete = model.f 749 weak_g_graph = weakref.ref(model.g.get_concrete_function().graph) 750 self.assertIs(weak_g_graph(), concrete.graph.outer_graph) 751 weak_g = weakref.ref(model.g) 752 del model 753 self.assertIsNone(weak_g()) 754 self.assertIsNone(weak_g_graph()) 755 self.assertIsNotNone(concrete.graph.outer_graph) 756 self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph) 757 758 def testGraphEagerIsolation(self): 759 760 @function.defun 761 def f(): 762 self.v = variables.Variable(1.0) 763 return self.v.read_value() 764 765 self.assertAllEqual(f(), 1.0) 766 767 with ops.Graph().as_default(): 768 self.assertEqual(f().shape, ()) 769 770 def testBasicGraphFunction(self): 771 matmul = def_function.function(math_ops.matmul) 772 773 @def_function.function 774 def sq(a): 775 return matmul(a, a) 776 777 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 778 779 sq_op = sq.get_concrete_function(t) 780 self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) 781 out = sq_op(t) 782 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 783 784 def testGetConcreteFunctionThreadSafety(self): 785 786 @def_function.function 787 def sq(): 788 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 789 return math_ops.matmul(t, t) 790 791 concrete_functions = [] 792 793 def thread_func(_): 794 cf = sq.get_concrete_function() 795 concrete_functions.append(cf) 796 797 num_threads = 100 798 pool = multiprocessing.pool.ThreadPool(num_threads) 799 _ = pool.map(thread_func, list(range(num_threads))) 800 801 self.assertLen(set(concrete_functions), 1) 802 803 def testGetConcreteFunctionThreadSafetyWithArgs(self): 804 @def_function.function 805 def add_100(*args): 806 return math_ops.add_n(args) 807 808 p = multiprocessing.pool.ThreadPool(2) 809 args = (constant_op.constant(1.),) * 100 810 f1, f2 = p.map(add_100.get_concrete_function, [args] * 2) 811 # I see about len(args) + max(0, len(args) - 3) arguments expected. 812 f1(*args) 813 del f2 814 815 def testInputSpecGraphFunction(self): 816 matmul = def_function.function(math_ops.matmul) 817 818 @def_function.function 819 def sq(a): 820 return matmul(a, a) 821 822 sq_op = sq.get_concrete_function( 823 tensor_spec.TensorSpec((None, None), dtypes.float32)) 824 self.assertEqual([None, None], sq_op.output_shapes.as_list()) 825 826 t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 827 out1 = sq_op(t1) 828 self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy()) 829 830 t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 831 out2 = sq_op(t2) 832 self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy()) 833 834 def testNestedInputSpecGraphFunction(self): 835 matmul = def_function.function(math_ops.matmul) 836 837 @def_function.function 838 def sq(mats): 839 ((a, b),) = mats 840 return matmul(a, b) 841 842 sq_op_autonamed = sq.get_concrete_function( 843 [(tensor_spec.TensorSpec((None, None), dtypes.float32), 844 tensor_spec.TensorSpec((None, None), dtypes.float32))]) 845 self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list()) 846 847 sq_op = sq.get_concrete_function( 848 [(tensor_spec.TensorSpec((None, None), dtypes.float32, 849 name='first_mat'), 850 tensor_spec.TensorSpec((None, None), dtypes.float32, 851 name='second_mat'))]) 852 self.assertEqual([None, None], sq_op.output_shapes.as_list()) 853 854 t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 855 t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]]) 856 out = sq_op(first_mat=t1, second_mat=t2) 857 self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy()) 858 self.assertAllEqual(sq_op_autonamed(t1, t2), 859 math_ops.matmul(t1, t2).numpy()) 860 861 def testExecutingStatelessDefunConcurrently(self): 862 863 @def_function.function 864 def stateless(x): 865 return math_ops.multiply(2.0, x) 866 867 pool = multiprocessing.pool.ThreadPool() 868 inputs = [constant_op.constant(1.0 * x) for x in range(100)] 869 outputs = [float(out) for out in pool.map(stateless, inputs)] 870 expected = [float(2.0 * x) for x in inputs] 871 self.assertSequenceEqual(outputs, expected) 872 873 def testExecutingManyStatelessDefunsConcurrently(self): 874 875 @def_function.function 876 def stateless(x): 877 del x 878 return math_ops.multiply(2.0, 2.0) 879 880 pool = multiprocessing.pool.ThreadPool() 881 # `pool.map` below instantiates 100 functions, one for each object. 882 objects = [object() for _ in range(100)] 883 outputs = [float(out) for out in pool.map(stateless, objects)] 884 expected = [4.0] * 100 885 self.assertSequenceEqual(outputs, expected) 886 887 @test_util.disable_tfrt('b/169431085: This test is flaky on tfrt') 888 def testExecutingStatefulDefunConcurrently(self): 889 890 v = resource_variable_ops.ResourceVariable(1.0) 891 892 @def_function.function 893 def stateful(x): 894 v.assign(x) 895 896 pool = multiprocessing.pool.ThreadPool() 897 inputs = [constant_op.constant(0.0)] * 100 898 pool.map(stateful, inputs) 899 self.assertEqual(float(v.read_value()), 0.0) 900 901 def testExecutingManyStatefulDefunsConcurrently(self): 902 903 v = resource_variable_ops.ResourceVariable(1.0) 904 905 @def_function.function 906 def stateful(x): 907 del x 908 return v.assign(0.0) 909 910 pool = multiprocessing.pool.ThreadPool() 911 # `pool.map` below instantiates 100 functions, one for each object. 912 pool.map(stateful, [object() for _ in range(100)]) 913 self.assertEqual(float(v.read_value()), 0.0) 914 915 def testShareRendezvous(self): 916 917 # Disable grappler from inlining the functions. Note we run the send & recv 918 # in graph mode since with eager mode the function should automatically be 919 # inlined. 920 context.context().set_optimizer_experimental_options( 921 {'disable_meta_optimizer': True}) 922 923 cpu = '/device:CPU:0' 924 925 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 926 927 @def_function.function 928 def send(): 929 x = constant_op.constant(1) 930 gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu) 931 return x 932 933 send._shared_rendezvous = True # pylint: disable=protected-access 934 935 @def_function.function(input_signature=signature) 936 def send_body(n): 937 send() 938 return n - 1 939 940 @def_function.function 941 def recv(): 942 return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu) 943 944 recv._shared_rendezvous = True # pylint: disable=protected-access 945 946 @def_function.function(input_signature=signature) 947 def recv_body(n): 948 recv() 949 return n - 1 950 951 @def_function.function(input_signature=signature) 952 def cond(n): 953 return n > 0 954 955 # Instead of calling the send & recv functions directly we want to call them 956 # through a functional while to ensure the rendezvous is shared across the 957 # while boundary. 958 @def_function.function 959 def fn(n): 960 functional_ops.While([n], cond.get_concrete_function(), 961 send_body.get_concrete_function()) 962 return functional_ops.While([n], cond.get_concrete_function(), 963 recv_body.get_concrete_function()) 964 965 # Use a graph context since functions will not be automatically inlined 966 with context.graph_mode(), self.cached_session(): 967 self.evaluate(fn(2)) 968 969 def disabled_testRandomSeed(self): 970 971 @def_function.function 972 def f(): 973 return random_ops.random_normal(()) 974 975 random_seed.set_random_seed(1) 976 x = f() 977 self.assertNotEqual(x, f()) 978 random_seed.set_random_seed(1) 979 self.assertAllEqual(f(), x) 980 981 def testNestedInputsGraphFunction(self): 982 matmul = def_function.function(math_ops.matmul) 983 984 pair = collections.namedtuple('pair', ['a', 'b']) 985 986 @def_function.function 987 def a_times_b(inputs): 988 return matmul(inputs.a['a'], inputs.b['b']) 989 990 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 991 sq_op = a_times_b.get_concrete_function( 992 pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')), 993 dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b')))) 994 self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) 995 out = sq_op(a=t, b=t) 996 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 997 998 def testNestedOutputGraphFunction(self): 999 matmul = def_function.function(math_ops.matmul) 1000 1001 @def_function.function 1002 def sq(a): 1003 return (matmul(a, a), {'b': constant_op.constant(1.0)}) 1004 1005 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 1006 1007 sq_op = sq.get_concrete_function(t) 1008 self.assertEqual(sq_op.output_shapes, 1009 (tensor_shape.TensorShape([2, 2]), 1010 {'b': tensor_shape.TensorShape([])})) 1011 self.assertEqual(sq_op.output_dtypes, 1012 (dtypes.float32, {'b': dtypes.float32})) 1013 (a, b) = sq_op(t) 1014 self.assertAllEqual(a, math_ops.matmul(t, t).numpy()) 1015 self.assertAllEqual(b['b'].numpy(), 1.0) 1016 1017 def testGraphFunctionNoneOutput(self): 1018 @def_function.function 1019 def fn(unused_a, unused_b): 1020 return None 1021 1022 x = constant_op.constant(1) 1023 fn_op = fn.get_concrete_function(x, x) 1024 self.assertEqual(fn_op.output_dtypes, None) 1025 self.assertEqual(fn_op.output_shapes, None) 1026 self.assertAllEqual(fn_op(x, x), None) 1027 1028 def testDefunNumpyArraysConvertedToTensors(self): 1029 1030 def f(x): 1031 self.assertIsInstance(x, ops.Tensor) 1032 return x 1033 1034 x = random_ops.random_uniform([2, 2]).numpy() 1035 defined = function.defun(f) 1036 defined(x) 1037 self.assertLen(total_function_cache(defined), 1) 1038 1039 x = random_ops.random_uniform([2, 2]).numpy() 1040 defined(x) 1041 # A NumPy array with different values but the same shape and dtype 1042 # shouldn't trigger another function definition. 1043 self.assertLen(total_function_cache(defined), 1) 1044 1045 np_ones = numpy.ones([], numpy.float32) 1046 np_zeros = numpy.zeros([], numpy.float32) 1047 tf_ones = array_ops.ones([]) 1048 tf_zeros = array_ops.zeros([]) 1049 1050 # Test that the numpy array is properly an argument to the graph function. 1051 self.assertEqual(1., defined(np_ones).numpy()) 1052 self.assertLen(total_function_cache(defined), 2) 1053 self.assertEqual(0., defined(np_zeros).numpy()) 1054 self.assertEqual(1., defined(tf_ones).numpy()) 1055 self.assertEqual(0., defined(tf_zeros).numpy()) 1056 self.assertLen(total_function_cache(defined), 2) 1057 1058 # Test that mutable inputs are supported. 1059 mutable = numpy.ones([], numpy.float32) 1060 self.assertEqual(1., defined(mutable).numpy()) 1061 mutable.fill(0) 1062 self.assertEqual(0., defined(mutable).numpy()) 1063 1064 class MyNdarray(numpy.ndarray): 1065 pass 1066 1067 # Test that the subclasses of ndarray are converted too. 1068 self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy()) 1069 self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy()) 1070 1071 # We should not have triggered any re-tracing of the python function. 1072 self.assertLen(total_function_cache(defined), 2) 1073 1074 def testNumpyDtypeInputSupported(self): 1075 @function.defun 1076 def f(x, dtype): 1077 return constant_op.constant(dtype(x)) 1078 1079 self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1)) 1080 self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2)) 1081 self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1)) 1082 self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2)) 1083 1084 def testDefunNumpyArraysConvertedToTensorsInKwargs(self): 1085 1086 def f(**kwargs): 1087 x = kwargs.pop('x') 1088 self.assertIsInstance(x, ops.Tensor) 1089 return x 1090 1091 x = random_ops.random_uniform([2, 2]).numpy() 1092 defined = function.defun(f) 1093 defined(x=x) 1094 self.assertLen(total_function_cache(defined), 1) 1095 1096 x = random_ops.random_uniform([2, 2]).numpy() 1097 defined(x=x) 1098 # A NumPy array with different values but the same shape and dtype 1099 # shouldn't trigger another function definition. 1100 self.assertLen(total_function_cache(defined), 1) 1101 1102 # Test that the numpy array is properly an argument to the graph function. 1103 self.assertEqual(1., defined(x=numpy.ones([])).numpy()) 1104 self.assertEqual(0., defined(x=numpy.zeros([])).numpy()) 1105 self.assertEqual(1., defined(x=array_ops.ones([])).numpy()) 1106 self.assertEqual(0., defined(x=array_ops.zeros([])).numpy()) 1107 1108 def testDefunCapturedInt32(self): 1109 x = constant_op.constant(1, dtype=dtypes.int32) 1110 1111 @def_function.function 1112 def add_int32s(): 1113 return x + x 1114 1115 self.assertEqual(2, int(add_int32s())) 1116 1117 def testDefunReadVariable(self): 1118 v = resource_variable_ops.ResourceVariable(1.0) 1119 1120 @def_function.function 1121 def f(): 1122 return v.read_value() 1123 1124 self.assertEqual(1.0, float(f())) 1125 1126 def testDefunAssignAddVariable(self): 1127 v = resource_variable_ops.ResourceVariable(1.0) 1128 x = constant_op.constant(2.0) 1129 1130 @def_function.function 1131 def test_assign_add(): 1132 v.assign_add(x) 1133 return v.read_value() 1134 1135 self.assertEqual(3.0, float(test_assign_add())) 1136 1137 @test_util.run_in_graph_and_eager_modes 1138 def testTensorInitializationInFunctionRaisesError(self): 1139 1140 @def_function.function 1141 def tensor_init(): 1142 with self.assertRaisesRegex(ValueError, 'could not be lifted out'): 1143 resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) 1144 1145 tensor_init() 1146 1147 @test_util.run_in_graph_and_eager_modes 1148 def testCallableTensorInitializationInFunction(self): 1149 1150 @def_function.function 1151 def tensor_init(): 1152 self.v = resource_variable_ops.ResourceVariable( 1153 lambda: constant_op.constant(2.0)) 1154 return self.v.read_value() 1155 1156 value = tensor_init() 1157 if not context.executing_eagerly(): 1158 self.evaluate(variables.global_variables_initializer()) 1159 self.assertEqual(self.evaluate(value), 2.0) 1160 1161 @test_util.also_run_as_tf_function 1162 def testInitScopeTensorInitializationInFunction(self): 1163 1164 @def_function.function 1165 def tensor_init(): 1166 with ops.init_scope(): 1167 const = constant_op.constant(2.0) 1168 # Note: this variable bypasses tf.function's variable creation 1169 # requirements by bypassing variable_creator_scope by using 1170 # ResourceVariable instead of Variable. 1171 self.v = resource_variable_ops.ResourceVariable(const) 1172 return self.v.read_value() 1173 1174 value = tensor_init() 1175 self.assertAllEqual(value, 2.0) 1176 1177 @test_util.run_in_graph_and_eager_modes 1178 def testGetConcreteFunctionCreatesVariables(self): 1179 1180 v_holder = [] 1181 1182 @def_function.function 1183 def tensor_init(): 1184 if not v_holder: 1185 v_holder.append(variables.Variable(5.)) 1186 return v_holder[0].read_value() 1187 1188 concrete = tensor_init.get_concrete_function() 1189 self.evaluate(variables.global_variables_initializer()) 1190 self.assertAllEqual(5., self.evaluate(concrete())) 1191 self.assertAllEqual(5., self.evaluate(tensor_init())) 1192 1193 def testFuncGraphCaptureByValue(self): 1194 v = variables.Variable(1.0) 1195 1196 def trivial_function(): 1197 return v.read_value() 1198 1199 graph_function = function.Function( 1200 trivial_function, 'test', capture_by_value=True) 1201 1202 self.assertAllEqual(graph_function(), 1.0) 1203 v.assign(2.0) 1204 self.assertAllEqual(graph_function(), 1.0) 1205 1206 def testFuncGraphCaptureByValueNested(self): 1207 v = variables.Variable(1.0) 1208 1209 def trivial_function(): 1210 return control_flow_ops.cond( 1211 array_ops.placeholder_with_default(True, ()), 1212 v.read_value, v.read_value) 1213 1214 graph_function = function.Function( 1215 trivial_function, 'test', capture_by_value=True) 1216 1217 self.assertAllEqual(graph_function(), 1.0) 1218 v.assign(2.0) 1219 self.assertAllEqual(graph_function(), 1.0) 1220 1221 def testDefunShapeInferenceWithCapturedResourceVariable(self): 1222 v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) 1223 1224 def f(): 1225 x = constant_op.constant([[1, 2], [3, 4]]) 1226 out = math_ops.matmul(v, x) 1227 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 1228 # We do not return v directly since the tensor conversion function of 1229 # ResourceVariable returns the read value and not the resource itself. 1230 return v._handle 1231 1232 compiled = def_function.function(f) 1233 var_handle = compiled() 1234 self.assertEqual(var_handle.dtype, dtypes.resource) 1235 self.assertEqual(var_handle.shape, tensor_shape.TensorShape([])) 1236 var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) 1237 self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) 1238 1239 def testShapeInferenceForMoreSpecificInput(self): 1240 1241 def f(a): 1242 return array_ops.reshape(a, [-1, 3]) 1243 1244 signature = [tensor_spec.TensorSpec(None, dtypes.float32)] 1245 compiled = def_function.function(f, input_signature=signature) 1246 1247 @def_function.function 1248 def use_f(): 1249 inputs = array_ops.zeros([10, 10, 3]) 1250 self.assertAllEqual(f(inputs).shape, compiled(inputs).shape) 1251 1252 use_f() 1253 1254 def testFuncListAttr(self): 1255 1256 @function.defun 1257 def test_function(val): 1258 1259 def fn1(): 1260 return array_ops.ones([10]) 1261 1262 fn2 = lambda: array_ops.ones([10]) * 2 1263 1264 def fn3(x=3): 1265 return array_ops.ones([10]) * x 1266 fn4 = functools.partial(fn3, x=4) 1267 fn5 = functools.partial(fn3, 5) 1268 1269 return gen_functional_ops.case(val, [], [dtypes.float32], 1270 [function.defun(f).get_concrete_function() 1271 for f in (fn1, fn2, fn3, fn4, fn5)]) 1272 1273 ones = array_ops.ones([10]) 1274 self.assertAllEqual([ones], test_function(0)) 1275 self.assertAllEqual([ones * 2], test_function(1)) 1276 self.assertAllEqual([ones * 3], test_function(2)) 1277 self.assertAllEqual([ones * 4], test_function(3)) 1278 self.assertAllEqual([ones * 5], test_function(4)) 1279 self.assertAllEqual([ones * 5], test_function(22)) # default branch 1280 1281 @test_util.enable_control_flow_v2 1282 def testVariableInLoopInFunction(self): 1283 1284 @function.defun 1285 def test_function(): 1286 1287 def loop_test(_): 1288 return False 1289 1290 def loop_body(_): 1291 return variable_scope.get_variable('a', shape=()) 1292 1293 return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) 1294 1295 self.assertEqual(test_function().shape, []) 1296 1297 def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): 1298 with context.graph_mode(): 1299 v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) 1300 1301 def f(): 1302 x = constant_op.constant([[1, 2], [3, 4]]) 1303 out = math_ops.matmul(v, x) 1304 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 1305 # We do not return v directly since the tensor conversion function of 1306 # ResourceVariable returns the read value and not the resource itself. 1307 return v._handle 1308 1309 compiled = def_function.function(f) 1310 var_handle = compiled() 1311 self.assertEqual(var_handle.dtype, dtypes.resource) 1312 self.assertEqual(var_handle.shape, tensor_shape.TensorShape([])) 1313 var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) 1314 self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) 1315 1316 def testDefunShapeInferenceWithCapturedVariableInGraphMode(self): 1317 with context.graph_mode(): 1318 v = variables.Variable([[1, 2], [3, 4]]) 1319 1320 def f(): 1321 x = constant_op.constant([[1, 2], [3, 4]]) 1322 out = math_ops.matmul(v, x) 1323 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 1324 1325 # Check that shape inference works while creating the defun 1326 compiled = def_function.function(f) 1327 compiled() 1328 1329 def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self): 1330 with context.graph_mode(): 1331 tensor_list = list_ops.empty_tensor_list( 1332 element_dtype=dtypes.float32, 1333 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 1334 tensor_list = list_ops.tensor_list_push_back(tensor_list, 1335 constant_op.constant(1.0)) 1336 tensor_list = list_ops.tensor_list_push_back(tensor_list, 1337 constant_op.constant(2.0)) 1338 1339 def f(): 1340 tl, value = list_ops.tensor_list_pop_back( 1341 tensor_list, element_dtype=dtypes.float32) 1342 self.assertEqual(value.shape, tensor_shape.TensorShape([])) 1343 return tl 1344 1345 compiled = def_function.function(f) 1346 output_tensor_list = compiled() 1347 _, value = list_ops.tensor_list_pop_back( 1348 output_tensor_list, element_dtype=dtypes.float32) 1349 self.assertEqual(value.shape, tensor_shape.TensorShape([])) 1350 1351 @test_util.run_in_graph_and_eager_modes 1352 def testDefunForcesResourceVariables(self): 1353 1354 def variable_creator(): 1355 self.v = variables.Variable(0.0) 1356 return self.v.read_value() 1357 1358 self.v = None 1359 defined = function.defun(variable_creator) 1360 defined() # Create the variable. 1361 self.assertIsInstance( 1362 self.v, resource_variable_ops.ResourceVariable) 1363 1364 def testRunMetadata(self): 1365 1366 @def_function.function 1367 def f(x): 1368 return x * x 1369 1370 with ops.device('cpu:0'): 1371 context.enable_run_metadata() 1372 f(constant_op.constant(1.0)) 1373 run_metadata = context.export_run_metadata() 1374 context.disable_run_metadata() 1375 self.assertLen(run_metadata.partition_graphs, 1) 1376 1377 def testGraphModeCaptureVariable(self): 1378 with context.graph_mode(), self.cached_session(): 1379 1380 class HasAVar(object): 1381 1382 def __init__(self): 1383 self.v = resource_variable_ops.ResourceVariable(1.0) 1384 1385 def call(self): 1386 return self.v * 2 1387 1388 o = HasAVar() 1389 self.evaluate(variables.global_variables_initializer()) 1390 call = def_function.function(o.call) 1391 op = call() 1392 self.assertAllEqual(self.evaluate(op), 2.0) 1393 1394 def testGraphModeManyFunctions(self): 1395 with ops.Graph().as_default(), self.cached_session(): 1396 1397 @def_function.function 1398 def f(x): 1399 return x * x 1400 1401 @def_function.function 1402 def g(x): 1403 return f(x) + 1 1404 1405 self.assertAllEqual(g(constant_op.constant(2.0)), 5.0) 1406 1407 def testDict(self): 1408 1409 @def_function.function 1410 def f(x): 1411 return {'name': x + 1} 1412 1413 self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0) 1414 1415 def testWeakrefInputsRejected(self): 1416 1417 @def_function.function 1418 def f(x): 1419 return x 1420 1421 class Dummy: 1422 pass 1423 o = Dummy() 1424 wr = weakref.ref(o) 1425 1426 with self.assertRaisesRegex(ValueError, 'weakref'): 1427 f(wr) 1428 1429 def testTensorConversionWithDefun(self): 1430 1431 @def_function.function 1432 def f(x): 1433 return math_ops.add(x, constant_op.constant(3)) 1434 1435 self.assertAllEqual(5, f(constant_op.constant(2))) 1436 1437 def testTensorConversionCall(self): 1438 1439 @def_function.function 1440 def f(x): 1441 return math_ops.add(x, constant_op.constant(3)) 1442 1443 @def_function.function 1444 def g(x): 1445 return f(f(x)) 1446 1447 self.assertAllEqual(8, g(constant_op.constant(2))) 1448 1449 def testCallShape(self): 1450 1451 @def_function.function 1452 def f(x): 1453 return x + 1 1454 1455 @def_function.function 1456 def g(x): 1457 x = f(x) 1458 self.assertEqual(x.shape.as_list(), []) 1459 return None 1460 1461 g(constant_op.constant(1.0)) 1462 1463 def testNestedDefunWithNoOutputAndTapedInput(self): 1464 three = resource_variable_ops.ResourceVariable(3.0, name='v') 1465 1466 @def_function.function 1467 def f(x): 1468 # This function intentionally takes a taped variable as input, 1469 # but does not return any values 1470 math_ops.add(x, three) 1471 1472 @def_function.function 1473 def g(x): 1474 y = math_ops.add(x, three) 1475 f(y) 1476 1477 g(three) 1478 1479 def testGatherResourceWithDefun(self): 1480 with ops.device('cpu:0'): 1481 v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 1482 1483 def sum_gather(): 1484 return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) 1485 1486 defined = def_function.function(sum_gather) 1487 self.assertAllEqual(sum_gather(), defined()) 1488 1489 @parameterized.named_parameters([ 1490 ('IndexedSlicesWithDenseShape', 1491 _example_indexed_slices_with_dense_shape,), 1492 ('IndexedSlicesWithoutDenseShape', 1493 _example_indexed_slices_without_dense_shape,), 1494 ('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths, 1495 {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}), 1496 ('RaggedTensorRaggedRank2', 1497 ragged_tensor.RaggedTensor.from_nested_row_lengths, 1498 {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), 1499 ('SparseTensor', sparse_tensor.SparseTensor, 1500 {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), 1501 ]) # pyformat: disable 1502 def testReturnCompositeTensorWithDefun(self, 1503 factory_fn, 1504 factory_kwargs={}, 1505 input_signature=None): 1506 input_ct = factory_fn(**factory_kwargs) 1507 1508 @def_function.function(input_signature=input_signature) 1509 def f(): 1510 return input_ct 1511 1512 output_ct = f() 1513 self.assertIsInstance(output_ct, type(input_ct)) 1514 nest.assert_same_structure(input_ct, output_ct, expand_composites=True) 1515 1516 input_flat = nest.flatten(input_ct, expand_composites=True) 1517 output_flat = nest.flatten(output_ct, expand_composites=True) 1518 for (input_component, output_component) in zip(input_flat, output_flat): 1519 self.assertAllEqual(input_component, output_component) 1520 1521 @parameterized.named_parameters([ 1522 ('IndexedSlicesWithDenseShape', 1523 _example_indexed_slices_with_dense_shape,), 1524 ('IndexedSlicesWithoutDenseShape', 1525 _example_indexed_slices_without_dense_shape,), 1526 ('RaggedTensorRaggedRank1', 1527 ragged_tensor.RaggedTensor.from_row_lengths, 1528 {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}), 1529 ('RaggedTensorRaggedRank2', 1530 ragged_tensor.RaggedTensor.from_nested_row_lengths, 1531 {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), 1532 ('SparseTensor', 1533 sparse_tensor.SparseTensor, 1534 {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), 1535 ('RaggedTensorRaggedRank1WithSignature', 1536 ragged_tensor.RaggedTensor.from_row_lengths, 1537 {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}, 1538 [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]), 1539 ('RaggedTensorRaggedRank2WithSignature', 1540 ragged_tensor.RaggedTensor.from_nested_row_lengths, 1541 {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}, 1542 [ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]), 1543 ('SparseTensorWithSignature', 1544 sparse_tensor.SparseTensor, 1545 {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}, 1546 [sparse_tensor.SparseTensorSpec([None], dtypes.int32)]), 1547 ]) # pyformat: disable 1548 def testCompositeAsArgumentTensorWithDefun(self, 1549 factory_fn, 1550 factory_kwargs={}, 1551 input_signature=None): 1552 input_ct = factory_fn(**factory_kwargs) 1553 1554 @def_function.function(input_signature=input_signature) 1555 def f(x): 1556 return x 1557 1558 output_ct = f(input_ct) 1559 self.assertIsInstance(output_ct, type(input_ct)) 1560 nest.assert_same_structure(input_ct, output_ct, expand_composites=True) 1561 1562 input_flat = nest.flatten(input_ct, expand_composites=True) 1563 output_flat = nest.flatten(output_ct, expand_composites=True) 1564 for (input_component, output_component) in zip(input_flat, output_flat): 1565 self.assertAllEqual(input_component, output_component) 1566 1567 def testTracedCompositeDiscardsShapeInfo(self): 1568 # SparseTensorSpec intentionally excludes info about the number of elements 1569 # that are in a sparse tensor (which is recorded as st.indices.shape[0] and 1570 # st.values.shape[0]). Similarly, RaggedTensorSpec intentionally excludes 1571 # info about the total number of values in a RaggedTensor (stored as 1572 # rt.values.shape[0]). This test checks that the placeholders created by 1573 # tf.function() properly mask this shape info. 1574 @def_function.function 1575 def f(rt, st): 1576 self.assertEqual(st.indices.shape.as_list()[:1], [None]) 1577 self.assertEqual(st.values.shape.as_list(), [None]) 1578 return (rt, st) 1579 1580 rt = ragged_factory_ops.constant([[1, 2], [3]]) 1581 st = sparse_tensor.SparseTensor([[0]], [0], [10]) 1582 f(rt, st) 1583 1584 @test_util.run_gpu_only 1585 def testFunctionOnDevice(self): 1586 x = constant_op.constant([1.]).gpu() 1587 f = def_function.function(math_ops.add) 1588 y = f(x, x).cpu() 1589 self.assertAllEqual(y, [2.]) 1590 1591 @test_util.run_gpu_only 1592 @test_util.run_in_graph_and_eager_modes 1593 def testFunctionWithResourcesOnDifferentDevices(self): 1594 with ops.device('/cpu:0'): 1595 v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 1596 1597 with ops.device('/gpu:0'): 1598 v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 1599 1600 def sum_gather(): 1601 cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2])) 1602 gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) 1603 return cpu_result, gpu_result 1604 1605 defined = function.defun(sum_gather) 1606 if not context.executing_eagerly(): 1607 self.evaluate(variables.global_variables_initializer()) 1608 expected = self.evaluate(sum_gather()) 1609 self.assertAllEqual(expected, self.evaluate(defined())) 1610 1611 @test_util.run_gpu_only 1612 @test_util.run_in_graph_and_eager_modes 1613 def testOpInFunctionWithConflictingResourceInputs(self): 1614 with ops.device('/cpu:0'): 1615 v_cpu = resource_variable_ops.ResourceVariable( 1616 [0.0, 1.0, 2.0], name='cpu') 1617 v_also_cpu = resource_variable_ops.ResourceVariable( 1618 [0.0, 1.0, 2.0], name='also_cpu') 1619 1620 with ops.device('/gpu:0'): 1621 v_gpu = resource_variable_ops.ResourceVariable( 1622 [0.0, 1.0, 2.0], name='gpu') 1623 1624 @def_function.function 1625 def resource_apply_adam(): 1626 training_ops.resource_apply_adam( 1627 v_cpu.handle, 1628 v_gpu.handle, 1629 v_also_cpu.handle, 1630 1.0, # beta1_power 1631 1.0, # beta2_power 1632 1.0, # learning_rate 1633 1.0, # beta1 1634 1.0, # beta2 1635 1.0, # epsilon, 1636 [1.0, 1.0, 1.0], # grad 1637 False) # use_locking 1638 return None 1639 1640 with self.assertRaisesRegex( 1641 errors.InvalidArgumentError, 1642 'Cannot place the graph because a reference or resource edge connects ' 1643 'colocation groups with incompatible assigned devices'): 1644 if not context.executing_eagerly(): 1645 self.evaluate(variables.global_variables_initializer()) 1646 self.evaluate(resource_apply_adam()) 1647 1648 @test_util.run_gpu_only 1649 def testFunctionHandlesInputsOnDifferentDevices(self): 1650 # The Reshape op requires the shape tensor to be placed in host memory. 1651 reshape = def_function.function(array_ops.reshape) 1652 value = constant_op.constant([1., 2.]).gpu() 1653 shape = constant_op.constant([2, 1]) 1654 reshaped = reshape(value, shape).cpu() 1655 self.assertAllEqual(reshaped, [[1], [2]]) 1656 1657 @test_util.run_gpu_only 1658 def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self): 1659 # The Reshape op requires the shape tensor to be placed in host memory. 1660 reshape = def_function.function(array_ops.reshape) 1661 value = constant_op.constant([1., 2.]) 1662 shape = constant_op.constant([2, 1]).gpu() 1663 reshape(value, shape) # No error is raised 1664 1665 def testNoneOutput(self): 1666 1667 @def_function.function 1668 def my_function(_): 1669 return None 1670 1671 self.assertAllEqual(my_function(1), None) 1672 1673 def testNestedFunctions(self): 1674 # TensorFlow function (which is what would be used in TensorFlow graph 1675 # construction). 1676 @tf_function.Defun(dtypes.int32, dtypes.int32) 1677 def add(a, b): 1678 return math_ops.add(a, b) 1679 1680 @def_function.function 1681 def add_one(x): 1682 return add(x, 1) 1683 1684 self.assertAllEqual(3, add_one(constant_op.constant(2))) 1685 1686 def testVariableCaptureInNestedFunctions(self): 1687 v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32) 1688 1689 @def_function.function 1690 def inner_read(): 1691 return v.read_value() 1692 1693 @def_function.function 1694 def outer(): 1695 return inner_read() 1696 1697 self.assertEqual(1, int(outer())) 1698 1699 def testReturnCapturedEagerTensor(self): 1700 t = constant_op.constant(1) 1701 1702 @def_function.function 1703 def read(): 1704 return t 1705 1706 self.assertEqual(1, int(read())) 1707 1708 def testReturnCapturedGraphTensor(self): 1709 with context.graph_mode(), self.cached_session(): 1710 t = constant_op.constant(1) 1711 1712 @def_function.function 1713 def read(): 1714 return t 1715 1716 self.assertEqual(1, int(self.evaluate(read()))) 1717 1718 def testSequenceInputs(self): 1719 clip_by_global_norm = def_function.function(clip_ops.clip_by_global_norm) 1720 t_list = [constant_op.constant(1.0), constant_op.constant(2.0)] 1721 clipped_list, global_norm = clip_by_global_norm(t_list, 1722 constant_op.constant(.2)) 1723 for t in clipped_list: 1724 self.assertIsInstance(t, ops.Tensor) 1725 self.assertIsInstance(global_norm, ops.Tensor) 1726 1727 def testNestedSequenceInputs(self): 1728 1729 def my_op(inputs): 1730 a, b, c = inputs 1731 e, f = b 1732 g, h = e 1733 return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c 1734 1735 my_eager_op = def_function.function(my_op) 1736 ret = my_eager_op([ 1737 constant_op.constant(1), [(constant_op.constant(2), 1738 constant_op.constant(3)), 1739 constant_op.constant(4)], 1740 constant_op.constant(5) 1741 ]) 1742 self.assertLen(ret, 2) 1743 self.assertAllEqual(ret[0][0], 2) 1744 self.assertAllEqual(ret[0][1][0][0], 8) 1745 self.assertAllEqual(ret[0][1][0][1], 4) 1746 self.assertIsInstance(ret[0][1][0], tuple) 1747 self.assertAllEqual(ret[0][1][1], 6) 1748 self.assertAllEqual(ret[0][2], 10) 1749 self.assertAllEqual(ret[1], 15) 1750 1751 def testVariableNamesRespectNameScopesWithDefun(self): 1752 @def_function.function 1753 def create_variable(): 1754 with ops.name_scope('foo', skip_on_eager=False): 1755 v = resource_variable_ops.ResourceVariable(0.0, name='bar') 1756 self.assertEqual(v.name, 'foo/bar:0') 1757 1758 create_variable() 1759 1760 def testVariableNamesRespectNameScopesWithDefunInGraph(self): 1761 with context.graph_mode(): 1762 @def_function.function 1763 def create_variable(): 1764 with ops.name_scope('foo', skip_on_eager=False): 1765 v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar') 1766 self.assertEqual(v.name, 'foo/bar:0') 1767 1768 with ops.get_default_graph().as_default(): 1769 create_variable() 1770 1771 @test_util.assert_no_new_pyobjects_executing_eagerly 1772 def testCallOptionsMemory(self): 1773 1774 @function.defun 1775 def model(x): 1776 return x + constant_op.constant(1.) 1777 1778 # This happens with a lot of option toggles, e.g. soft device placement 1779 context.context().function_call_options = None 1780 model(constant_op.constant(2.)) 1781 1782 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) 1783 def testLayerInDefun(self): 1784 conv = convolutional.Conv2D( 1785 filters=1, 1786 kernel_size=2, 1787 kernel_initializer=init_ops.ones_initializer(), 1788 bias_initializer=init_ops.zeros_initializer()) 1789 1790 @function.defun 1791 def model(x): 1792 return conv(x) 1793 1794 x = array_ops.ones([1, 2, 2, 1]) 1795 y = model(x) 1796 1797 if not context.executing_eagerly(): 1798 self.evaluate(variables.global_variables_initializer()) 1799 1800 self.assertAllClose([[[[4.0]]]], self.evaluate(y)) 1801 1802 # Variable lifting is somewhat different between defun/tf.function, so testing 1803 # device placement on both makes sense. 1804 @parameterized.named_parameters( 1805 dict(testcase_name='Defun', 1806 function_decorator=function.defun), 1807 dict(testcase_name='DefFunction', 1808 function_decorator=def_function.function)) 1809 @test_util.run_in_graph_and_eager_modes 1810 def testVariablesPlacedOnOutsideDevice(self, function_decorator): 1811 1812 class _Obj(object): 1813 1814 def __init__(self): 1815 self.v = None 1816 1817 @function_decorator 1818 def f(self): 1819 if self.v is None: 1820 self.v = variables.Variable(1.) 1821 return self.v + 1. 1822 1823 has_device = _Obj() 1824 with ops.device('cpu:0'): 1825 has_device.f() 1826 self.assertIn('CPU', has_device.v.device) 1827 1828 @test_util.run_in_graph_and_eager_modes 1829 def testMultipleDeviceCheck(self): 1830 1831 def f(): 1832 with ops.device('cpu'): 1833 return test_ops.device_placement_op() 1834 1835 func = function.defun(f) 1836 with ops.device('cpu:0'): 1837 output = self.evaluate(func()) 1838 self.assertIn(compat.as_bytes('CPU:0'), output) 1839 1840 @test_util.run_in_graph_and_eager_modes 1841 def testDeviceAnnotationsRespected(self): 1842 1843 def multi_device_fn(): 1844 with ops.device('/cpu:0'): 1845 s0 = test_ops.device_placement_op() 1846 with ops.device('/cpu:1'): 1847 s1 = test_ops.device_placement_op() 1848 with ops.device('/cpu:2'): 1849 s2 = test_ops.device_placement_op() 1850 s3 = test_ops.device_placement_op() 1851 return s0, s1, s2, s3 1852 1853 defined = function.defun(multi_device_fn) 1854 outputs = self.evaluate(defined()) 1855 self.assertLen(total_function_cache(defined), 1) 1856 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1857 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1858 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1859 1860 with ops.device('/cpu:3'): 1861 outputs = self.evaluate(defined()) 1862 # All function definitions are agnostic to call site devices. 1863 self.assertLen(total_function_cache(defined), 1) 1864 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1865 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1866 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1867 self.assertIn(compat.as_bytes('CPU:3'), outputs[3]) 1868 1869 with ops.device('/cpu:0'): 1870 outputs = self.evaluate(defined()) 1871 self.assertLen(total_function_cache(defined), 1) 1872 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1873 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1874 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1875 self.assertIn(compat.as_bytes('CPU:0'), outputs[3]) 1876 1877 @test_util.run_in_graph_and_eager_modes 1878 def testCallingGraphFunctionOnDifferentDevice(self): 1879 1880 def func(): 1881 return constant_op.constant(0) 1882 1883 defined = def_function.function(func) 1884 with ops.device('cpu:0'): 1885 cpu_graph_function = defined.get_concrete_function() 1886 1887 with ops.device('cpu:0'): 1888 self.assertEqual( 1889 self.evaluate(cpu_graph_function()), self.evaluate(func())) 1890 1891 with ops.device('cpu:1'): 1892 self.assertEqual(0., self.evaluate(cpu_graph_function())) 1893 1894 with ops.device(None): 1895 self.assertEqual(0., self.evaluate(cpu_graph_function())) 1896 1897 default_graph_function = defined.get_concrete_function() 1898 self.assertEqual( 1899 self.evaluate(default_graph_function()), self.evaluate(func())) 1900 1901 with ops.device('cpu:1'): 1902 self.assertEqual(0., self.evaluate(default_graph_function())) 1903 1904 @test_util.run_gpu_only 1905 @test_util.run_in_graph_and_eager_modes 1906 def testColocateWithRespected(self): 1907 # TODO(b/113291792): Use multiple CPUs instead of a GPU. 1908 with ops.device('cpu:0'): 1909 x = array_ops.identity(1.0) 1910 1911 with ops.device('gpu:0'): 1912 y = array_ops.identity(1.0) 1913 1914 @def_function.function 1915 def foo(): 1916 return test_ops.device_placement_op() 1917 1918 with ops.colocate_with(x): 1919 self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo())) 1920 1921 with ops.colocate_with(y): 1922 self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo())) 1923 1924 def testVariablesAreTracked(self): 1925 v = resource_variable_ops.ResourceVariable(1.0) 1926 1927 def foo(x): 1928 return v * x 1929 1930 defined = def_function.function(foo) 1931 1932 x = constant_op.constant([1.0]) 1933 self.assertEqual(1., self.evaluate(defined(x))) 1934 v.assign(2.) 1935 1936 x = constant_op.constant([1.0, 2.0]) 1937 self.assertAllEqual([2., 4.], self.evaluate(defined(x))) 1938 1939 def testCacheObjectHashCollisions(self): 1940 1941 class Foo(object): 1942 1943 def __hash__(self): 1944 return 42 1945 1946 def func(foo): 1947 del foo 1948 return 1949 1950 defined = function.defun(func) 1951 defined(Foo()) 1952 self.assertLen(total_function_cache(defined), 1) 1953 1954 defined(Foo()) 1955 self.assertLen(total_function_cache(defined), 2) 1956 1957 def testCacheTensorDtypeCollision(self): 1958 1959 def func(t): 1960 return t + t 1961 1962 defined = function.defun(func) 1963 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1964 defined(t) 1965 self.assertLen(total_function_cache(defined), 1) 1966 1967 t = constant_op.constant([[1.0]], dtype=dtypes.complex128) 1968 defined(t) 1969 self.assertLen(total_function_cache(defined), 2) 1970 1971 def testCacheTensorShapeCollision(self): 1972 1973 def func(t): 1974 return t + t 1975 1976 defined = function.defun(func) 1977 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1978 defined(t) 1979 self.assertLen(total_function_cache(defined), 1) 1980 1981 t = constant_op.constant([1.0], dtype=dtypes.complex64) 1982 defined(t) 1983 self.assertLen(total_function_cache(defined), 2) 1984 1985 def testCacheTensorShapeDtypeCollision(self): 1986 1987 def func(t): 1988 return t + t 1989 1990 defined = function.defun(func) 1991 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1992 defined(t) 1993 self.assertLen(total_function_cache(defined), 1) 1994 1995 t = constant_op.constant([1.0], dtype=dtypes.complex128) 1996 defined(t) 1997 self.assertLen(total_function_cache(defined), 2) 1998 1999 def testCacheTensorUnknownShapesCollisionRelaxedShapes(self): 2000 2001 def func(t): 2002 return t + t 2003 2004 with context.graph_mode(), self.cached_session(): 2005 defined = function.defun(func, experimental_relax_shapes=True) 2006 2007 p = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 2008 defined(p) 2009 self.assertLen(total_function_cache(defined), 1) 2010 2011 p = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) 2012 defined(p) 2013 self.assertLen(total_function_cache(defined), 2) 2014 2015 p = array_ops.placeholder(dtype=dtypes.float32, shape=[2]) 2016 defined(p) 2017 # Gradual shape relaxation is performed; and the common shape between 2018 # [1] and [2] is one containing unknown dimensions. 2019 self.assertLen(total_function_cache(defined), 2) 2020 2021 # pylint: disable=protected-access 2022 self.assertLen(defined._function_cache.arg_relaxed_specs, 1) 2023 relaxed_specs = ( 2024 list(defined._function_cache.arg_relaxed_specs.values())[0]) 2025 self.assertLen(relaxed_specs, 1) 2026 relaxed_shape = relaxed_specs[0].shape 2027 # pylint: enable=protected-access 2028 self.assertEqual(relaxed_shape.rank, 1) 2029 self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None) 2030 2031 t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32) 2032 defined(t) 2033 # Shape (3,) matches the relaxed shape TensorShape([None]) 2034 self.assertLen(total_function_cache(defined), 2) 2035 2036 def testPythonFunctionWithDefaultArgs(self): 2037 2038 def func(foo, bar=1, baz=2): 2039 del foo 2040 del bar 2041 del baz 2042 return 2043 2044 defined = function.defun(func) 2045 defined(0, baz=20) 2046 self.assertLen(total_function_cache(defined), 1) 2047 2048 defined(1) # bar=1, baz=2 2049 self.assertLen(total_function_cache(defined), 2) 2050 2051 # This matches the previous call. 2052 defined(foo=1) 2053 self.assertLen(total_function_cache(defined), 2) 2054 2055 defined(1, 2, 3) 2056 self.assertLen(total_function_cache(defined), 3) 2057 2058 # This matches the previous call. 2059 defined(1, bar=2, baz=3) 2060 self.assertLen(total_function_cache(defined), 3) 2061 2062 # This matches the previous call. 2063 defined(1, baz=3, bar=2) 2064 self.assertLen(total_function_cache(defined), 3) 2065 2066 def testDatasetIteratorCaching(self): 2067 def func(it1, it2): 2068 next(it1) 2069 next(it2) 2070 return 0 2071 2072 defined = function.defun(func) 2073 2074 d = dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]) 2075 it1 = iter(d) 2076 it2 = iter(d) 2077 _ = defined(it1, it2) # The two iterators are different 2078 self.assertLen(total_function_cache(defined), 1) 2079 2080 it3 = iter(d) 2081 it4 = iter(d) 2082 _ = defined(it3, it4) # The two iterators are different, should not retrace 2083 self.assertLen(total_function_cache(defined), 1) 2084 2085 it5 = iter(d) 2086 _ = defined(it5, it5) # The two iterators are the same, should retrace 2087 self.assertLen(total_function_cache(defined), 2) 2088 2089 def testFunctoolsPartialUnwrappedCorrectly(self): 2090 2091 def full_function(a, b, c=3): 2092 return a, b, c 2093 2094 partial = functools.partial(full_function, 1, c=4) 2095 a, b, c = partial(2) 2096 2097 defined = function.defun(partial) 2098 func_a, func_b, func_c = defined(2) 2099 self.assertEqual(func_a.numpy(), a) 2100 self.assertEqual(func_b.numpy(), b) 2101 self.assertEqual(func_c.numpy(), c) 2102 2103 def testInputSignatureWithMatchingInputs(self): 2104 2105 def foo(a): 2106 self.assertEqual(a.shape, (2,)) 2107 return a 2108 2109 signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] 2110 defined = function.defun(foo, input_signature=signature) 2111 a = array_ops.ones([2]) 2112 self.assertAllEqual(a, defined(a)) 2113 self.assertLen(total_function_cache(defined), 1) 2114 self.assertAllEqual(a, defined.get_concrete_function()(a)) 2115 self.assertAllEqual(a, defined.get_concrete_function(a)(a)) 2116 self.assertAllEqual(a, defined.get_concrete_function( 2117 tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a)) 2118 self.assertLen(total_function_cache(defined), 1) 2119 2120 def bar(a): 2121 self.assertEqual(a._shape_tuple(), (2, None)) 2122 return a 2123 2124 signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)] 2125 defined = function.defun(bar, input_signature=signature) 2126 a = array_ops.ones([2, 1]) 2127 out = defined(a) 2128 self.assertLen(total_function_cache(defined), 1) 2129 self.assertAllEqual(out, a) 2130 2131 # Changing the second dimension shouldn't create a new function. 2132 b = array_ops.ones([2, 3]) 2133 out = defined(b) 2134 self.assertLen(total_function_cache(defined), 1) 2135 self.assertAllEqual(out, b) 2136 2137 def testInputSignatureWithDictInPositionalArgs(self): 2138 2139 @function.defun 2140 def f(*_args, **_kwargs): 2141 return None 2142 2143 f(1, x=2) 2144 self.assertLen(total_function_cache(f), 1) 2145 f(1, x=2) 2146 self.assertLen(total_function_cache(f), 1) 2147 f(1, {'x': 2}) 2148 self.assertLen(total_function_cache(f), 2) 2149 2150 def testInputSignatureWithCompatibleInputs(self): 2151 2152 rank2_spec = tensor_spec.TensorSpec(shape=(None, None), 2153 dtype=dtypes.float32) 2154 2155 @function.defun(input_signature=[rank2_spec]) 2156 def func(a): 2157 self.assertEqual([None, None], a.shape.as_list()) 2158 return array_ops.shape(a) 2159 2160 self.assertAllEqual([3, 1], func([[0], [1.0], [1]])) 2161 self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]]))) 2162 2163 with self.assertRaisesRegex(ValueError, 'incompatible'): 2164 func([0.0, 1.0, 2.0]) # Wrong shape. 2165 2166 with self.assertRaisesRegex(ValueError, 'incompatible'): 2167 func([['wrong dtype']]) 2168 2169 def testNoKeywordOnlyArgumentsWithInputSignature(self): 2170 if sys.version_info[0] < 3: 2171 self.skipTest('keyword_only arguments only exist in Python 3.') 2172 2173 func = eval('lambda x, *, y: x') # pylint: disable=eval-used 2174 signature = [tensor_spec.TensorSpec(None, dtypes.int32)] 2175 with self.assertRaisesRegex( 2176 ValueError, 'Cannot define a TensorFlow function from a Python ' 2177 'function with keyword-only arguments when input_signature is ' 2178 'provided.'): 2179 def_function.function(func, signature) 2180 2181 def testNestedInputSignatures(self): 2182 2183 def expected_foo(a, b): 2184 return [a, b] 2185 2186 @function.defun(input_signature=[ 2187 [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2, 2188 tensor_spec.TensorSpec((1,), dtypes.float32), 2189 ]) 2190 def foo(a, b): 2191 self.assertEqual(a[0]._shape_tuple(), (2, None)) 2192 self.assertEqual(a[1]._shape_tuple(), (2, None)) 2193 self.assertEqual(b._shape_tuple(), (1,)) 2194 return [a, b] 2195 2196 a = array_ops.ones([2, 1]) 2197 b = array_ops.ones([1]) 2198 expected = expected_foo([a, a], b) 2199 out = foo([a, a], b) 2200 self.assertLen(total_function_cache(foo), 1) 2201 nest.assert_same_structure(out, expected) 2202 self.assertAllEqual(out[0][0], a) 2203 self.assertAllEqual(out[0][1], a) 2204 self.assertAllEqual(out[1], b) 2205 2206 # Changing the unspecified dimensions shouldn't create a new function. 2207 a = array_ops.ones([2, 3]) 2208 b = array_ops.ones([2, 5]) 2209 c = array_ops.ones([1]) 2210 expected = expected_foo([a, b], c) 2211 out = foo([a, b], c) 2212 self.assertLen(total_function_cache(foo), 1) 2213 nest.assert_same_structure(out, expected) 2214 self.assertAllEqual(out[0][0], a) 2215 self.assertAllEqual(out[0][1], b) 2216 self.assertAllEqual(out[1], c) 2217 2218 # Passing compatible inputs should work. 2219 a = a.numpy().tolist() 2220 b = b.numpy().tolist() 2221 c = c.numpy().tolist() 2222 out = foo([a, b], c) 2223 self.assertLen(total_function_cache(foo), 1) 2224 nest.assert_same_structure(out, expected) 2225 self.assertAllEqual(out[0][0], a) 2226 self.assertAllEqual(out[0][1], b) 2227 self.assertAllEqual(out[1], c) 2228 2229 def testNestedInputSignaturesWithDict(self): 2230 def expected_bar(a): 2231 return a 2232 2233 @function.defun(input_signature=[{ 2234 'a': tensor_spec.TensorSpec((2, None), dtypes.float32), 2235 'b': tensor_spec.TensorSpec((2, None), dtypes.float32), 2236 'c': tensor_spec.TensorSpec((1,), dtypes.float32)}]) 2237 def bar(a): 2238 self.assertEqual(a['a']._shape_tuple(), (2, None)) 2239 self.assertEqual(a['b']._shape_tuple(), (2, None)) 2240 self.assertEqual(a['c']._shape_tuple(), (1,)) 2241 return a 2242 2243 a = array_ops.ones([2, 3]) 2244 b = array_ops.ones([1]) 2245 inputs = {'a': a, 'b': a, 'c': b} 2246 expected = expected_bar(inputs) 2247 out = bar(inputs) 2248 nest.assert_same_structure(out, expected) 2249 self.assertAllEqual(out['a'], expected['a']) 2250 self.assertAllEqual(out['b'], expected['b']) 2251 self.assertAllEqual(out['c'], expected['c']) 2252 2253 # Passing compatible inputs should work. 2254 a = a.numpy().tolist() 2255 b = b.numpy().tolist() 2256 inputs = {'a': a, 'b': a, 'c': b} 2257 out = bar(inputs) 2258 nest.assert_same_structure(out, expected) 2259 self.assertAllEqual(out['a'], expected['a']) 2260 self.assertAllEqual(out['b'], expected['b']) 2261 self.assertAllEqual(out['c'], expected['c']) 2262 2263 def testInputSignatureMustBeSequenceOfTensorSpecs(self): 2264 2265 def foo(a, b): 2266 del a 2267 del b 2268 2269 # Signatures must consist exclusively of `TensorSpec` objects. 2270 signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] 2271 with self.assertRaisesRegex(TypeError, 'input_signature.*nested sequence'): 2272 def_function.function(foo, input_signature=signature) 2273 2274 # Signatures must be either lists or tuples on their outermost levels. 2275 signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} 2276 with self.assertRaisesRegex( 2277 TypeError, 'input_signature must be either a ' 2278 'tuple or a list.*'): 2279 function.defun(foo, input_signature=signature) 2280 2281 @test_util.run_in_graph_and_eager_modes 2282 def testInputsIncompatibleWithSignatureRaisesError(self): 2283 2284 def foo(a): 2285 return a 2286 2287 signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] 2288 defined = def_function.function(foo, input_signature=signature) 2289 2290 # Invalid shapes. 2291 with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'): 2292 defined(array_ops.ones([3])) 2293 2294 with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'): 2295 defined(array_ops.ones([2, 1])) 2296 2297 # Wrong number of arguments. 2298 with self.assertRaisesRegex(TypeError, 'specifies 1 .* got 2'): 2299 defined(array_ops.ones([2]), array_ops.ones([2])) 2300 with self.assertRaisesRegex(ValueError, 2301 'Structure of Python function inputs.*'): 2302 defined() 2303 2304 with self.assertRaisesRegex(ValueError, 2305 'inputs incompatible with input_signature'): 2306 defined.get_concrete_function( 2307 tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32)) 2308 2309 def testMismatchedConcreteSignatureRaisesError(self): 2310 2311 @def_function.function 2312 def run_test(): 2313 @def_function.function 2314 def f(x): 2315 return x 2316 2317 with self.assertRaisesRegex( 2318 TypeError, 'ConcreteFunction .* was constructed .* but was called'): 2319 f.get_concrete_function(1)(constant_op.constant(1)) 2320 2321 with self.assertRaisesRegex(TypeError, r'f\(x\) expected .* but got .*'): 2322 f.get_concrete_function(constant_op.constant(1))(1) 2323 2324 with self.assertRaisesRegex( 2325 TypeError, 'ConcreteFunction .* was constructed .* but was called'): 2326 f.get_concrete_function(1)(2) 2327 2328 run_test() 2329 2330 def testInputsIncompatibleWithNestedSignatureRaisesError(self): 2331 2332 def foo(a, b): 2333 return [a, b] 2334 2335 signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2, 2336 [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2] 2337 defined = function.defun(foo, input_signature=signature) 2338 a = array_ops.ones([1]) 2339 2340 with self.assertRaisesRegex(ValueError, 2341 'Structure of Python function inputs.*'): 2342 defined([a, a, a], [a]) 2343 2344 with self.assertRaisesRegex(ValueError, 2345 'Structure of Python function inputs.*'): 2346 defined([a], [a, a, a]) 2347 defined([a, a], [a, a]) 2348 2349 def testUnderspecifiedInputSignature(self): 2350 @function.defun(input_signature=[ 2351 tensor_spec.TensorSpec([], dtypes.float32), 2352 ]) 2353 def foo(a, training=True): 2354 if training: 2355 return a 2356 else: 2357 return -1.0 * a 2358 2359 x = constant_op.constant(1.0) 2360 with self.assertRaisesRegex( 2361 TypeError, 'got keyword argument `training` ' 2362 'that was not included in input_signature'): 2363 foo(x, training=True) 2364 2365 with self.assertRaisesRegex( 2366 TypeError, 'got keyword argument `training` ' 2367 'that was not included in input_signature'): 2368 foo(x, training=False) 2369 2370 self.assertAllEqual(x.numpy(), foo(x).numpy()) 2371 2372 def testInputSignatureWithPartialFunction(self): 2373 def full_function(a, b, c=3.0): 2374 return a, b, c 2375 2376 partial = functools.partial(full_function, 1, c=4) 2377 a, b, c = partial(2.0) 2378 signature = [tensor_spec.TensorSpec([], dtypes.float32)] 2379 defined = function.defun(partial, input_signature=signature) 2380 x = constant_op.constant(2.0) 2381 func_a, func_b, func_c = defined(x) 2382 self.assertEqual(func_a.numpy(), a) 2383 self.assertEqual(func_b.numpy(), b) 2384 self.assertEqual(func_c.numpy(), c) 2385 2386 def testInputSignatureConversionWithDefaultArg(self): 2387 2388 def foo(a, training=True): 2389 if training: 2390 return a 2391 else: 2392 return -1.0 * a 2393 2394 signature = [ 2395 tensor_spec.TensorSpec([], dtypes.float32), 2396 tensor_spec.TensorSpec([], dtypes.bool), 2397 ] 2398 defined = def_function.function(foo, input_signature=signature) 2399 a = constant_op.constant(1.0) 2400 self.assertAllEqual(a.numpy(), defined(a)) 2401 self.assertAllEqual(a.numpy(), defined(a, training=True)) 2402 self.assertAllEqual(-a.numpy(), defined(a, training=False)) 2403 2404 def testInputSignatureWithKeywordPositionalArgs(self): 2405 2406 @function.defun(input_signature=[ 2407 tensor_spec.TensorSpec([], dtypes.float32), 2408 tensor_spec.TensorSpec([], dtypes.int64) 2409 ]) 2410 def foo(flt, integer): 2411 return flt, integer 2412 2413 flt = constant_op.constant(1.0) 2414 integer = constant_op.constant(2, dtypes.int64) 2415 2416 out1, out2 = foo(flt, integer) 2417 self.assertLen(total_function_cache(foo), 1) 2418 self.assertEqual(out1.numpy(), 1.0) 2419 self.assertEqual(out2.numpy(), 2) 2420 2421 out1, out2 = foo(flt=flt, integer=integer) 2422 self.assertLen(total_function_cache(foo), 1) 2423 self.assertEqual(out1.numpy(), 1.0) 2424 self.assertEqual(out2.numpy(), 2) 2425 2426 out1, out2 = foo(integer=integer, flt=flt) 2427 self.assertLen(total_function_cache(foo), 1) 2428 self.assertEqual(out1.numpy(), 1.0) 2429 self.assertEqual(out2.numpy(), 2) 2430 2431 out1, out2 = foo(flt, integer=integer) 2432 self.assertLen(total_function_cache(foo), 1) 2433 self.assertEqual(out1.numpy(), 1.0) 2434 self.assertEqual(out2.numpy(), 2) 2435 2436 def testInputSignatureWithKeywordArgs(self): 2437 def foo(a, b, **kwargs): 2438 del kwargs 2439 return a, b 2440 2441 x = function.defun( 2442 foo, 2443 input_signature=[ 2444 tensor_spec.TensorSpec([], dtypes.float32), 2445 tensor_spec.TensorSpec([], dtypes.int32) 2446 ]).get_concrete_function() 2447 result = x(constant_op.constant(5.0), constant_op.constant(5)) 2448 self.assertAllEqual(result, [5.0, 5]) 2449 2450 def testInputSignatureWithCompositeTensors(self): 2451 def f(rt): 2452 self.assertEqual(rt.values.shape.as_list(), [None]) 2453 self.assertEqual(rt.row_splits.shape.as_list(), [4]) 2454 return rt 2455 2456 signature = [ragged_tensor.RaggedTensorSpec( 2457 shape=[3, None], dtype=dtypes.int32)] 2458 defined = function.defun(f, input_signature=signature) 2459 rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]]) 2460 out1 = defined(rt1) 2461 self.assertLen(total_function_cache(defined), 1) 2462 self.assertAllEqual(out1.values, rt1.values) 2463 self.assertAllEqual(out1.row_splits, rt1.row_splits) 2464 2465 # Changing the row lengths shouldn't create a new function. 2466 rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]]) 2467 out2 = defined(rt2) 2468 self.assertLen(total_function_cache(defined), 1) 2469 self.assertAllEqual(out2.values, rt2.values) 2470 self.assertAllEqual(out2.row_splits, rt2.row_splits) 2471 2472 # Different number of rows 2473 rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]]) 2474 with self.assertRaisesRegex(ValueError, 'incompatible'): 2475 defined(rt3) 2476 2477 # Different dtype 2478 rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]]) 2479 with self.assertRaisesRegex(ValueError, 'Structure .* does not match'): 2480 defined(rt4) 2481 2482 # Different rank 2483 rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]]) 2484 with self.assertRaisesRegex(ValueError, 'does not match'): 2485 defined(rt5) 2486 2487 def testInputSignatureWithVariableArgs(self): 2488 2489 def f(v): 2490 v.assign_add(1) 2491 2492 signature = [ 2493 resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) 2494 ] 2495 defined = function.defun(f, input_signature=signature) 2496 2497 v1 = variables.Variable(0) 2498 v2 = variables.Variable(0) 2499 2500 defined(v1) 2501 self.assertEqual(v1.numpy(), 1) 2502 self.assertEqual(v2.numpy(), 0) 2503 2504 defined(v=v2) 2505 self.assertEqual(v1.numpy(), 1) 2506 self.assertEqual(v2.numpy(), 1) 2507 2508 def testTensorKeywordArguments(self): 2509 2510 def foo(a, b): 2511 del a 2512 return b 2513 2514 defined = function.defun(foo) 2515 a = constant_op.constant(2.0) 2516 b = constant_op.constant([1.0, 2.0]) 2517 one = defined(a, b) 2518 self.assertLen(total_function_cache(defined), 1) 2519 2520 two = defined(a=a, b=b) 2521 self.assertLen(total_function_cache(defined), 1) 2522 2523 three = defined(b=b, a=a) 2524 self.assertLen(total_function_cache(defined), 1) 2525 2526 four = defined(a, b=b) 2527 self.assertLen(total_function_cache(defined), 1) 2528 2529 # The next call corresponds to a new input signature, hence 2530 # we expect another function to be defined. 2531 five = defined(b, a) 2532 self.assertLen(total_function_cache(defined), 2) 2533 2534 six = defined(a=b, b=a) 2535 self.assertLen(total_function_cache(defined), 2) 2536 2537 seven = defined(b=a, a=b) 2538 self.assertLen(total_function_cache(defined), 2) 2539 2540 self.assertAllEqual(one, [1.0, 2.0]) 2541 self.assertAllEqual(two, [1.0, 2.0]) 2542 self.assertAllEqual(three, [1.0, 2.0]) 2543 self.assertAllEqual(four, [1.0, 2.0]) 2544 self.assertAllEqual(five, 2.0) 2545 self.assertAllEqual(six, 2.0) 2546 self.assertAllEqual(seven, 2.0) 2547 2548 def testDefuningInstanceMethod(self): 2549 2550 integer = constant_op.constant(2, dtypes.int64) 2551 2552 class Foo(object): 2553 2554 def one(self, tensor): 2555 return tensor 2556 2557 @def_function.function 2558 def two(self, tensor, other=integer): 2559 return self.one(tensor), other 2560 2561 foo = Foo() 2562 t = constant_op.constant(1.0) 2563 one, two = foo.two(t) 2564 self.assertEqual(one.numpy(), 1.0) 2565 self.assertEqual(two.numpy(), 2) 2566 2567 def testDefuningInstanceMethodWithDefaultArgument(self): 2568 2569 integer = constant_op.constant(2, dtypes.int64) 2570 2571 class Foo(object): 2572 2573 @def_function.function 2574 def func(self, other=integer): 2575 return other 2576 2577 foo = Foo() 2578 self.assertEqual(foo.func().numpy(), int(integer)) 2579 2580 def testPythonCallWithSideEffects(self): 2581 state = [] 2582 2583 @def_function.function 2584 def side_effecting_function(): 2585 state.append(0) 2586 2587 side_effecting_function() 2588 self.assertAllEqual(state, [0]) 2589 2590 # The second invocation should call the graph function, which shouldn't 2591 # trigger the list append. 2592 side_effecting_function() 2593 self.assertAllEqual(state, [0]) 2594 2595 # Whereas calling the python function directly should create a side-effect. 2596 side_effecting_function.python_function() 2597 self.assertAllEqual(state, [0, 0]) 2598 2599 def testFunctionWithNestedFunctionCallAndSideEffects(self): 2600 v1 = variables.Variable(1.0) 2601 v2 = variables.Variable(1.0) 2602 2603 @def_function.function 2604 def add_one(a): 2605 a.assign_add(1.0) 2606 2607 # Grappler will inline calls to `add_one` into the function body, we check 2608 # that all side-effects were executed. 2609 @def_function.function 2610 def side_effecting_function(a, b): 2611 add_one(a) 2612 add_one(b) 2613 return a + b 2614 2615 result = side_effecting_function(v1, v2) 2616 self.assertEqual(result.numpy(), 4.0) 2617 2618 def testFunctionWithExtraAttributes(self): 2619 @function.defun_with_attributes(attributes={'experimental_1': 'value1', 2620 'experimental_2': 2}) 2621 def matmul(x, y): 2622 return math_ops.matmul(x, y) 2623 2624 def add(x, y): 2625 return math_ops.add(x, y) 2626 defun_add = function.defun_with_attributes( 2627 add, attributes={'experimental_3': True, 'experimental_4': 1.0}) 2628 2629 with context.graph_mode(), self.cached_session(): 2630 with ops.get_default_graph().as_default(): 2631 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2632 sq = matmul(t, t) 2633 double = defun_add(t, t) 2634 self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) 2635 self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) 2636 2637 graph = ops.get_default_graph() 2638 # pylint: disable=protected-access 2639 self.assertLen(graph._functions, 2) 2640 functions = list(graph._functions.values()) 2641 self.assertRegex(functions[0].definition.signature.name, '.*matmul.*') 2642 attrs = functions[0].definition.attr 2643 self.assertLen(attrs, 2) 2644 self.assertEqual(attrs['experimental_1'].s, b'value1') 2645 self.assertEqual(attrs['experimental_2'].i, 2) 2646 2647 self.assertRegex(functions[1].definition.signature.name, '.*add.*') 2648 attrs = functions[1].definition.attr 2649 self.assertLen(attrs, 2) 2650 self.assertEqual(attrs['experimental_3'].b, True) 2651 self.assertEqual(attrs['experimental_4'].f, 1.0) 2652 # pylint: enable=protected-access 2653 2654 def testFunctionWithInvalidAttribute(self): 2655 @function.defun_with_attributes(attributes={'experimental_1': ['value1']}) 2656 def add(x, y): 2657 return math_ops.add(x, y) 2658 2659 with self.assertRaisesRegex(ValueError, 2660 'Attribute experimental_1 must be .* Got .*'): 2661 with context.graph_mode(), self.cached_session(): 2662 with ops.get_default_graph().as_default(): 2663 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2664 add(t, t) 2665 2666 def testRegisterFunction(self): 2667 2668 @function.defun 2669 def add(x, y): 2670 return math_ops.add(x, y) 2671 2672 def matmul(x, y): 2673 return math_ops.matmul(x, y) 2674 defun_matmul = function.defun(matmul) 2675 2676 with context.graph_mode(), self.cached_session(): 2677 with ops.get_default_graph().as_default(): 2678 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2679 function.register(defun_matmul, t, t) 2680 function.register(add, t, t) 2681 2682 graph = ops.get_default_graph() 2683 # pylint: disable=protected-access 2684 self.assertLen(graph._functions, 6) 2685 # two sets of functions, each of them are (inference, forward, backward) 2686 functions = list(graph._functions.values()) 2687 captured_function_names = [ 2688 f.definition.signature.name for f in functions 2689 ] 2690 expected_func_name_regex = [ 2691 '.*inference.*matmul.*', 2692 '.*forward.*matmul.*', 2693 '.*inference.*backward.*matmul.*', 2694 '.*inference.*add.*', 2695 '.*forward.*add.*', 2696 '.*inference.*backward.*add.*', 2697 ] 2698 for i in range(len(functions)): 2699 self.assertRegex(captured_function_names[i], 2700 expected_func_name_regex[i]) 2701 2702 # Check the forward and backward function has the correct attributes. 2703 self.assertEqual( 2704 functions[1].definition.attr['backward_function_name'].s, 2705 functions[2].name) 2706 self.assertEqual( 2707 functions[2].definition.attr['forward_function_name'].s, 2708 functions[1].name) 2709 2710 self.assertEqual( 2711 functions[4].definition.attr['backward_function_name'].s, 2712 functions[5].name) 2713 self.assertEqual( 2714 functions[5].definition.attr['forward_function_name'].s, 2715 functions[4].name) 2716 2717 sq = defun_matmul(t, t) 2718 double = add(t, t) 2719 self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) 2720 self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) 2721 # Make sure the pre registered function is used, and no other function 2722 # is added. 2723 self.assertLen(graph._functions, 6) 2724 functions = list(graph._functions.values()) 2725 for i in range(len(functions)): 2726 self.assertEqual(captured_function_names[i], 2727 functions[i].definition.signature.name) 2728 2729 @parameterized.named_parameters( 2730 dict(testcase_name='Defun', 2731 function_decorator=function.defun), 2732 dict(testcase_name='DefFunction', 2733 function_decorator=def_function.function)) 2734 def testRegisterConcreteFunction(self, function_decorator): 2735 @function_decorator 2736 def py_add(x, y): 2737 return math_ops.add(x, y) 2738 2739 py_add(array_ops.ones([]), array_ops.ones([])) 2740 add = py_add.get_concrete_function( 2741 tensor_spec.TensorSpec(None, dtypes.float32), 2742 tensor_spec.TensorSpec(None, dtypes.float32)) 2743 2744 @function_decorator 2745 def py_composite(x, y): 2746 return x, add(x, y) 2747 2748 py_composite(array_ops.ones([]), array_ops.ones([])) 2749 composite = py_composite.get_concrete_function( 2750 tensor_spec.TensorSpec(None, dtypes.float32), 2751 tensor_spec.TensorSpec(None, dtypes.float32)) 2752 2753 with context.graph_mode(), self.cached_session(): 2754 with ops.get_default_graph().as_default(): 2755 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2756 composite.add_to_graph() 2757 composite.add_gradient_functions_to_graph() 2758 2759 graph = ops.get_default_graph() 2760 # pylint: disable=protected-access 2761 self.assertLen(graph._functions, 6) 2762 # two sets of functions, each of them are (inference, forward, backward) 2763 functions = list(graph._functions.values()) 2764 captured_function_names = [ 2765 f.definition.signature.name for f in functions 2766 ] 2767 expected_func_name_regex = [ 2768 '.*inference.*py_composite.*', 2769 '.*inference.*py_add.*', 2770 '.*forward.*py_composite.*', 2771 '.*forward.*py_add.*', 2772 '.*inference.*backward.*py_composite.*', 2773 '.*inference.*backward.*py_add.*', 2774 ] 2775 for expected, found in zip( 2776 expected_func_name_regex, 2777 captured_function_names): 2778 self.assertRegex(found, expected) 2779 2780 composite_t, composite_double = composite(t, t) 2781 double = add(t, t) 2782 self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double)) 2783 self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double)) 2784 self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t)) 2785 # Make sure the pre registered function is used, and no other function 2786 # is added. 2787 self.assertLen(graph._functions, 6) 2788 2789 @parameterized.named_parameters( 2790 dict(testcase_name='Defun', 2791 function_decorator=function.defun), 2792 dict(testcase_name='DefFunction', 2793 function_decorator=def_function.function)) 2794 def testEagerCaptures(self, function_decorator): 2795 with context.eager_mode(): 2796 large_tensor = array_ops.ones(shape=(256,)) 2797 self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD) 2798 2799 small_tensor = array_ops.ones(shape=(4,)) 2800 self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD) 2801 2802 v = resource_variable_ops.ResourceVariable(0.0) 2803 2804 for captured, op_type in [(large_tensor, 'Placeholder'), 2805 (small_tensor, 'Const'), (v, 'Placeholder')]: 2806 @function_decorator 2807 def test_fn(): 2808 return captured + 1 # pylint: disable=cell-var-from-loop 2809 2810 g = test_fn.get_concrete_function().graph 2811 internal_captures = g.internal_captures 2812 self.assertLen(internal_captures, 1) 2813 self.assertEqual(internal_captures[0].op.type, op_type) 2814 2815 def testRegisterFunctionWithInputSignature(self): 2816 def matmul(x, y): 2817 return math_ops.matmul(x, y) 2818 defun_matmul = function.defun( 2819 matmul, 2820 input_signature=[ 2821 tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), 2822 tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) 2823 ]) 2824 with context.graph_mode(), self.cached_session(): 2825 with ops.get_default_graph().as_default(): 2826 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2827 function.register(defun_matmul, t, t) 2828 2829 graph = ops.get_default_graph() 2830 # pylint: disable=protected-access 2831 self.assertLen(graph._functions, 3) 2832 2833 # Test register function with cache, note inputs are ignored. 2834 function.register(defun_matmul) 2835 graph = ops.get_default_graph() 2836 self.assertLen(graph._functions, 3) 2837 2838 def testRegisterFunctionWithCache(self): 2839 def matmul(x, y): 2840 return math_ops.matmul(x, y) 2841 defun_matmul = function.defun(matmul) 2842 2843 with context.graph_mode(), self.cached_session(): 2844 with ops.get_default_graph().as_default(): 2845 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2846 t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]]) 2847 function.register(defun_matmul, t, t) 2848 function.register(defun_matmul, t2, t2) 2849 2850 graph = ops.get_default_graph() 2851 # Only one function is registered since the input param are in same type 2852 # pylint: disable=protected-access 2853 self.assertLen(graph._functions, 3) 2854 2855 def testCallingFunctionWithDifferentVariables(self): 2856 2857 @function.defun 2858 def foo(v): 2859 v.assign_add(1.0) 2860 return v.read_value() 2861 2862 v = resource_variable_ops.ResourceVariable(0.0) 2863 graph_function = foo.get_concrete_function(v) 2864 self.assertLen(graph_function.inputs, 1) 2865 self.assertEmpty(graph_function.captured_inputs) 2866 2867 self.assertEqual(float(graph_function(v)), 1.0) 2868 self.assertEqual(float(graph_function(v)), 2.0) 2869 2870 w = resource_variable_ops.ResourceVariable(0.0) 2871 2872 @function.defun 2873 def bar(v): 2874 del v 2875 return constant_op.constant(1.0) 2876 2877 graph_function = bar.get_concrete_function(v) 2878 self.assertEqual(float(graph_function(v)), 1.0) 2879 self.assertEqual(float(graph_function(w)), 1.0) 2880 2881 def testCallingFunctionWithNonTensorsFails(self): 2882 2883 @function.defun 2884 def foo(x): 2885 return x 2886 2887 graph_function = foo.get_concrete_function(constant_op.constant(1.0)) 2888 with self.assertRaises((TypeError, ValueError)): 2889 graph_function('Not a Tensor.') 2890 2891 def testSwapImplementationWithGrapplerPlugin(self): 2892 # Set the min_graph_nodes to -1 since the graph in this test is too small, 2893 # and will be ignored by grappler if don't set this. 2894 rewrites = rewriter_config_pb2.RewriterConfig() 2895 rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON 2896 rewrites.min_graph_nodes = -1 2897 graph_options = config_pb2.GraphOptions( 2898 rewrite_options=rewrites, build_cost_model=1) 2899 config_proto = config_pb2.ConfigProto(graph_options=graph_options) 2900 2901 with context.graph_mode(), self.cached_session( 2902 config=config_proto, graph=ops.Graph(), use_gpu=True): 2903 2904 @function.defun_with_attributes( 2905 attributes={ 2906 'api_implements': 'random_boost', 2907 'api_preferred_device': 'CPU' 2908 }) 2909 def cpu_boost(x): 2910 return math_ops.add(x, 2.0) 2911 2912 @function.defun_with_attributes( 2913 attributes={ 2914 'api_implements': 'random_boost', 2915 'api_preferred_device': 'GPU' 2916 }) 2917 def gpu_boost(x): 2918 return math_ops.add(x, 4.0) 2919 2920 x = constant_op.constant(1.0) 2921 2922 function.register(cpu_boost, x) 2923 y = gpu_boost(x) 2924 y_value = self.evaluate(y) 2925 2926 if test.is_gpu_available(): 2927 self.assertEqual(y_value, 5.0) 2928 else: 2929 # Grappler fallback to use the CPU impl even called with GPU function. 2930 self.assertEqual(y_value, 3.0) 2931 2932 @test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior ' 2933 'equivalent to implementation_selector for function') 2934 def testSwapImplementationInEager(self): 2935 if not context.executing_eagerly(): 2936 self.skipTest('eager only') 2937 2938 # testSharedRendezvous sets the disable_meta_optimizer flag to True 2939 # if that subtest runs before this one, then having that set to True 2940 # will cause this subtest to fail. To avoid that scenario, explicitly 2941 # set the disable_meta_optimizer flag to false here 2942 context.context().set_optimizer_experimental_options({ 2943 'min_graph_nodes': -1, 2944 'implementation_selector': True, 2945 'disable_meta_optimizer': False 2946 }) 2947 2948 @function.defun_with_attributes( 2949 attributes={'api_implements': 'foo', 2950 'api_preferred_device': 'CPU'}) 2951 def on_cpu(x): 2952 return x + 2 2953 2954 @function.defun_with_attributes( 2955 attributes={'api_implements': 'foo', 2956 'api_preferred_device': 'GPU'}) 2957 def on_gpu(x): 2958 return x + 4 2959 2960 @function.defun 2961 def run_on_cpu(t): 2962 function.register(on_cpu, t) 2963 with ops.device('CPU:0'): 2964 return on_gpu(t) 2965 2966 # Expect to run the on_cpu branch, regardless whether gpu is available. 2967 self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3) 2968 2969 def testDefunFunctionSeparateGraphs(self): 2970 with context.graph_mode(): 2971 2972 @function.defun 2973 def add(x): 2974 return x + 5 2975 2976 @function.defun 2977 def maybe_add(x, should_add): 2978 if should_add: 2979 return add(x) 2980 else: 2981 return x 2982 2983 with ops.Graph().as_default(): 2984 x = constant_op.constant(11) 2985 maybe_add(x, True) 2986 self.assertLen(total_function_cache(maybe_add), 1) 2987 self.assertLen(total_function_cache(add), 1) 2988 2989 maybe_add(x, False) 2990 self.assertLen(total_function_cache(maybe_add), 2) 2991 self.assertLen(total_function_cache(add), 1) 2992 2993 with ops.Graph().as_default(): 2994 x = constant_op.constant(11) 2995 maybe_add(x, True) 2996 self.assertLen(total_function_cache(maybe_add), 3) 2997 self.assertLen(total_function_cache(add), 2) 2998 2999 def testCacheKeyOverlappingShapes(self): 3000 @function.defun 3001 def defined(t): 3002 return t 3003 3004 defined(array_ops.zeros([12, 1])) 3005 self.assertLen(total_function_cache(defined), 1) 3006 3007 defined(array_ops.zeros([1, 21])) 3008 self.assertLen(total_function_cache(defined), 2) 3009 3010 def testCacheKeyNestedLists(self): 3011 @function.defun 3012 def defined(l): 3013 return l 3014 3015 a = constant_op.constant(1.) 3016 b = constant_op.constant(2.) 3017 c = constant_op.constant(3.) 3018 defined([[a], b, c]) 3019 self.assertLen(total_function_cache(defined), 1) 3020 3021 defined([[a, b], c]) 3022 self.assertLen(total_function_cache(defined), 2) 3023 3024 def testCacheKeyAttrsClass(self): 3025 if attr is None: 3026 self.skipTest('attr module is unavailable.') 3027 3028 @attr.s 3029 class TestClass(object): 3030 a = attr.ib() 3031 b = attr.ib() 3032 3033 @function.defun 3034 def defined(l): 3035 return l 3036 3037 defined( 3038 TestClass( 3039 constant_op.constant(1.), 3040 [constant_op.constant(2.), 3041 constant_op.constant(3.)])) 3042 self.assertLen(total_function_cache(defined), 1) 3043 defined( 3044 TestClass( 3045 constant_op.constant(1.), 3046 [constant_op.constant(2.), 3047 constant_op.constant(3.)])) 3048 self.assertLen(total_function_cache(defined), 1) 3049 3050 defined( 3051 TestClass([constant_op.constant(1.), 3052 constant_op.constant(2.)], constant_op.constant(3.))) 3053 self.assertLen(total_function_cache(defined), 2) 3054 3055 def testCacheKeyVariables(self): 3056 @function.defun 3057 def defined(a, b, c): 3058 return a + b + c 3059 3060 x = resource_variable_ops.ResourceVariable(0.0) 3061 y = resource_variable_ops.ResourceVariable(0.0) 3062 z = resource_variable_ops.ResourceVariable(0.0) 3063 3064 # If tensor equality is not enabled, we always get a cache miss if the 3065 # function is called with different variables. With equality enabled we 3066 # should only get a miss if the aliasing changed. 3067 defined(x, y, z) 3068 self.assertLen(total_function_cache(defined), 1) 3069 defined(x, y, z) 3070 self.assertLen(total_function_cache(defined), 1) 3071 3072 # Re-arranging arguments causes cache miss 3073 defined(z, y, x) 3074 self.assertLen(total_function_cache(defined), 2) 3075 defined(z, y, x) 3076 self.assertLen(total_function_cache(defined), 2) 3077 3078 # Aliasing causes cache miss 3079 defined(x, x, z) 3080 self.assertLen(total_function_cache(defined), 3) 3081 defined(x, x, z) 3082 self.assertLen(total_function_cache(defined), 3) 3083 3084 # Re-arranging arguments causes cache miss 3085 defined(y, y, z) 3086 self.assertLen(total_function_cache(defined), 4) 3087 defined(y, y, z) 3088 self.assertLen(total_function_cache(defined), 4) 3089 3090 # Different alias positions causes cache miss 3091 defined(z, y, y) 3092 self.assertLen(total_function_cache(defined), 5) 3093 defined(z, y, y) 3094 self.assertLen(total_function_cache(defined), 5) 3095 3096 x_copy = copy.deepcopy(x) 3097 3098 # Deep copy causes cache miss 3099 defined(x_copy, y, z) 3100 self.assertLen(total_function_cache(defined), 6) 3101 defined(x_copy, y, z) 3102 self.assertLen(total_function_cache(defined), 6) 3103 3104 def testVariableRetracing(self): 3105 v1 = variables.Variable(1.) 3106 v2 = variables.Variable(1.) 3107 v3 = copy.deepcopy(variables.Variable(1.)) 3108 3109 var_dict = {id(v1): constant_op.constant(1), 3110 id(v2): constant_op.constant(2), 3111 id(v3): constant_op.constant(3)} 3112 3113 @function.defun 3114 def lookup_tensor(v): 3115 return var_dict[id(v)] 3116 3117 self.assertEqual(1, lookup_tensor(v1).numpy()) 3118 self.assertEqual(2, lookup_tensor(v2).numpy()) 3119 self.assertEqual(3, lookup_tensor(v3).numpy()) 3120 3121 def testDecoratedMethodInspect(self): 3122 3123 class DefunnedMiniModel(object): 3124 3125 @function.defun 3126 def call(self, inputs, training=True): 3127 pass 3128 3129 m = DefunnedMiniModel() 3130 fullargspec = tf_inspect.getfullargspec(m.call) 3131 self.assertIn('training', fullargspec.args) 3132 3133 def testFunctionModifiesInputList(self): 3134 # Tests on `list` methods that do in place modification, except `list.sort` 3135 # since it cannot even be "defunned" in the first place 3136 3137 def get_list(): 3138 return [constant_op.constant(0.), constant_op.constant(1.)] 3139 3140 expected_msg = '.*() should not modify' 3141 3142 with self.assertRaisesRegex(ValueError, expected_msg): 3143 3144 @def_function.function 3145 def append(l): 3146 l.append(constant_op.constant(0.)) 3147 3148 append(get_list()) 3149 3150 with self.assertRaisesRegex(ValueError, expected_msg): 3151 3152 @def_function.function 3153 def extend(l): 3154 l.extend([constant_op.constant(0.)]) 3155 3156 extend(get_list()) 3157 3158 with self.assertRaisesRegex(ValueError, expected_msg): 3159 3160 @def_function.function 3161 def insert(l): 3162 l.insert(0, constant_op.constant(0.)) 3163 3164 insert(get_list()) 3165 3166 with self.assertRaisesRegex(ValueError, expected_msg): 3167 3168 @def_function.function 3169 def pop(l): 3170 l.pop() 3171 3172 pop(get_list()) 3173 3174 with self.assertRaisesRegex(ValueError, expected_msg): 3175 3176 @def_function.function 3177 def reverse(l): 3178 l.reverse() 3179 3180 reverse(get_list()) 3181 3182 with self.assertRaisesRegex(ValueError, expected_msg): 3183 3184 @def_function.function 3185 def remove(l): 3186 l.remove(l[0]) 3187 3188 remove(get_list()) 3189 3190 # `list.clear` is a method that is in Py3 but not Py2 3191 if sys.version.startswith('3'): 3192 3193 with self.assertRaisesRegex(ValueError, expected_msg): 3194 3195 @def_function.function 3196 def clear(l): 3197 l.clear() 3198 3199 clear(get_list()) 3200 3201 # One last test for keyword arguments 3202 with self.assertRaisesRegex(ValueError, expected_msg): 3203 3204 @def_function.function 3205 def kwdappend(**kwargs): 3206 l = kwargs['l'] 3207 l.append(constant_op.constant(0.)) 3208 3209 kwdappend(l=get_list()) 3210 3211 def testFunctionModifiesInputDict(self): 3212 3213 def get_dict(): 3214 return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)} 3215 3216 expected_msg = '.* should not modify' 3217 3218 with self.assertRaisesRegex(ValueError, expected_msg): 3219 3220 @def_function.function 3221 def clear(m): 3222 m.clear() 3223 3224 clear(get_dict()) 3225 3226 with self.assertRaisesRegex(ValueError, expected_msg): 3227 3228 @def_function.function 3229 def pop(m): 3230 m.pop('t1') 3231 3232 pop(get_dict()) 3233 3234 with self.assertRaisesRegex(ValueError, expected_msg): 3235 3236 @def_function.function 3237 def popitem(m): 3238 m.popitem() 3239 3240 popitem(get_dict()) 3241 3242 with self.assertRaisesRegex(ValueError, expected_msg): 3243 3244 @def_function.function 3245 def update(m): 3246 m.update({'t1': constant_op.constant(3.)}) 3247 3248 update(get_dict()) 3249 3250 with self.assertRaisesRegex(ValueError, expected_msg): 3251 3252 @def_function.function 3253 def setdefault(m): 3254 m.setdefault('t3', constant_op.constant(3.)) 3255 3256 setdefault(get_dict()) 3257 3258 def testFunctionModifiesInputNest(self): 3259 with self.assertRaisesRegex(ValueError, 'modify.* should not modify'): 3260 3261 @def_function.function 3262 def modify(n): 3263 n[0]['t1'].append(constant_op.constant(1.)) 3264 3265 nested_input = [{ 3266 't1': [constant_op.constant(0.), 3267 constant_op.constant(1.)], 3268 }, 3269 constant_op.constant(2.)] 3270 3271 modify(nested_input) 3272 3273 with self.assertRaisesRegex(ValueError, 3274 'modify_same_flat.* should not modify'): 3275 3276 # The flat list doesn't change whereas the true structure changes 3277 @def_function.function 3278 def modify_same_flat(n): 3279 n[0].append(n[1].pop(0)) 3280 3281 nested_input = [[constant_op.constant(0.)], 3282 [constant_op.constant(1.), 3283 constant_op.constant(2.)]] 3284 3285 modify_same_flat(nested_input) 3286 3287 @test_util.disable_tfrt('b/173429686') 3288 def testExecutorType(self): 3289 @function.defun 3290 def add_five(x): 3291 return x + 5 3292 3293 self.assertEqual( 3294 5, 3295 add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) 3296 3297 with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'): 3298 with context.function_executor_type('NON_EXISTENT_EXECUTOR'): 3299 add_five(constant_op.constant(0, dtype=dtypes.int32)) 3300 3301 for executor_type in ('', 'DEFAULT', None): 3302 with context.function_executor_type(executor_type): 3303 self.assertAllEqual( 3304 5, 3305 add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) 3306 3307 @test_util.assert_no_garbage_created 3308 def testReferenceCycles(self): 3309 3310 fn = function.defun(lambda x: 2. * x) 3311 3312 fn(constant_op.constant(4.0)) 3313 weak_fn = weakref.ref(fn) 3314 del fn 3315 # Tests that the weak reference we made to the function is now dead, which 3316 # means the object has been deleted. This should be true as long as the 3317 # function itself is not involved in a reference cycle. 3318 self.assertIs(None, weak_fn()) 3319 3320 def testFunctionStackInErrorMessage(self): 3321 if context.executing_eagerly(): 3322 # TODO(b/122736651): Remove this skipTest once fixed. 3323 self.skipTest('Error interpolation is not working when function is ' 3324 'invoked without PartitionedCallOp.') 3325 3326 @def_function.function() 3327 def fn3(x): 3328 return x + 2 3329 3330 @def_function.function() 3331 def fn2(x): 3332 check_ops.assert_equal(fn3(x), 3) 3333 return 2 3334 3335 @def_function.function() 3336 def fn(x): 3337 return fn2(x) 3338 3339 with self.assertRaises(errors.InvalidArgumentError) as cm: 3340 fn(2) 3341 e = cm.exception 3342 self.assertIn('fn -> fn2', e.message) 3343 self.assertIn('node assert_equal/Assert/Assert (defined at', e.message) 3344 self.assertNotIn('fn3', e.message) 3345 3346 @test_util.run_gpu_only 3347 def testFunctionIsNotPinned(self): 3348 """Tests that functions aren't pinned to the CPU by the eager runtime.""" 3349 seed1, seed2 = 79, 25 3350 shape = constant_op.constant([4, 7]) 3351 dtype = dtypes.float32 3352 3353 @def_function.function 3354 def func(): 3355 with ops.device('GPU:0'): 3356 return gen_random_ops.random_standard_normal( 3357 shape, dtype=dtype, seed=seed1, seed2=seed2) 3358 3359 with ops.device('GPU:0'): 3360 x = func() 3361 self.assertRegex(x.device, 'GPU') 3362 3363 @test_util.run_in_graph_and_eager_modes 3364 def testShapeCaching(self): 3365 3366 @function.defun 3367 def func(x): 3368 return array_ops.shape(x) 3369 3370 @function.defun( 3371 input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)]) 3372 def calls_func(x): 3373 return func(x) 3374 3375 self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1])))) 3376 self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2])))) 3377 self.assertAllEqual( 3378 [3, 3], 3379 self.evaluate(calls_func(array_ops.zeros([3, 3])))) 3380 3381 def testLimitedRetracing(self): 3382 trace_count = [0] 3383 @function.defun 3384 def func(x): 3385 trace_count[0] += 1 3386 return x 3387 3388 for _ in range(50): 3389 func(constant_op.constant(3.)) 3390 func(constant_op.constant(4.)) 3391 func(constant_op.constant([[1., 2.]])) 3392 func(constant_op.constant([[]])) 3393 func(constant_op.constant([[3., 4.], [5., 6.]])) 3394 func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]])) 3395 # Tracing more than twice per input doesn't make sense. 3396 self.assertLess(trace_count[0], 13) 3397 3398 def testLimitedRetracingWithCompositeTensors(self): 3399 trace_count = [0] 3400 3401 @def_function.function 3402 def f(x): 3403 trace_count[0] += 1 3404 return x 3405 3406 for i in range(10): 3407 f(ragged_factory_ops.constant([[1, 2], [i]])) 3408 f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]])) 3409 f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])) 3410 self.assertEqual(trace_count[0], 3) 3411 3412 def test_concrete_function_shape_mismatch(self): 3413 3414 @def_function.function 3415 def f(argument_name): 3416 return argument_name + 1. 3417 3418 f_concrete = f.get_concrete_function(constant_op.constant([1.])) 3419 3420 # Calling a function from eager doesn't do any shape checking above what 3421 # kernels do while executing. 3422 self.assertAllEqual( 3423 [2., 3.], 3424 f_concrete(constant_op.constant([1., 2.])).numpy()) 3425 3426 @def_function.function 3427 def g(): 3428 f_concrete(constant_op.constant([1., 2.])) 3429 3430 with self.assertRaisesRegex(ValueError, 'argument_name'): 3431 g() 3432 3433 @test_util.run_in_graph_and_eager_modes 3434 def test_shape_inference_with_symbolic_shapes(self): 3435 3436 @def_function.function 3437 def _uses_symbolic_shapes(w, x, y): 3438 x = array_ops.identity(x, name='name_collision') 3439 x = array_ops.transpose(x, [1, 0, 2]) 3440 x_batch = array_ops.shape(x)[0] 3441 y_batch = array_ops.shape(y)[0] 3442 y *= w 3443 n = y_batch // x_batch 3444 return array_ops.reshape(y, [n, x_batch, -1]) 3445 3446 conc = _uses_symbolic_shapes.get_concrete_function( 3447 tensor_spec.TensorSpec(None, dtypes.float32), 3448 tensor_spec.TensorSpec(None, dtypes.float32), 3449 tensor_spec.TensorSpec(None, dtypes.float32)) 3450 3451 @def_function.function 3452 def _call_concrete(): 3453 c = constant_op.constant(1.) 3454 array_ops.identity(c, name='name_collision') 3455 output1 = conc(array_ops.ones([2]), 3456 array_ops.ones([5, 4, 2]), 3457 array_ops.ones([20, 2])) 3458 self.assertEqual([5, 4, 2], output1.shape) 3459 output2 = conc(array_ops.ones([3]), 3460 array_ops.ones([5, 4, 3]), 3461 array_ops.ones([40, 3])) 3462 self.assertEqual([10, 4, 3], output2.shape) 3463 return output1, output2 3464 3465 output1, output2 = _call_concrete() 3466 self.assertEqual((5, 4, 2), self.evaluate(output1).shape) 3467 self.assertEqual((10, 4, 3), self.evaluate(output2).shape) 3468 3469 def testAutoGraphContext(self): 3470 3471 @def_function.function 3472 def test_fn(): 3473 self.assertEqual( 3474 ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED) 3475 3476 prev_status = ag_ctx.control_status_ctx().status 3477 test_fn() 3478 self.assertEqual(ag_ctx.control_status_ctx().status, prev_status) 3479 3480 @test_util.disable_tfrt('b/170435618') 3481 def testCancelBeforeFunctionExecution(self): 3482 if not context.executing_eagerly(): 3483 self.skipTest('eager only') 3484 3485 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 3486 3487 @def_function.function 3488 def f(): 3489 return q.dequeue() 3490 3491 c_mgr = cancellation.CancellationManager() 3492 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 3493 3494 c_mgr.start_cancel() 3495 with self.assertRaises(errors.CancelledError): 3496 cancelable_func() 3497 3498 @test_util.disable_tfrt('b/170435618') 3499 def testCancelBlockedFunctionExecution(self): 3500 if not context.executing_eagerly(): 3501 self.skipTest('eager only') 3502 3503 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 3504 3505 @def_function.function 3506 def f(): 3507 return q.dequeue() 3508 3509 c_mgr = cancellation.CancellationManager() 3510 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 3511 3512 def cancel_thread(): 3513 time.sleep(0.5) 3514 c_mgr.start_cancel() 3515 3516 t = self.checkedThread(cancel_thread) 3517 t.start() 3518 with self.assertRaises(errors.CancelledError): 3519 cancelable_func() 3520 t.join() 3521 3522 @test_util.disable_tfrt('b/170435618') 3523 def testCancelAfterFunctionExecution(self): 3524 if not context.executing_eagerly(): 3525 self.skipTest('eager only') 3526 3527 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 3528 q.enqueue(37) 3529 3530 @def_function.function 3531 def f(): 3532 return q.dequeue() 3533 3534 c_mgr = cancellation.CancellationManager() 3535 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 3536 3537 self.assertAllEqual(37, cancelable_func().numpy()) 3538 3539 # Cancellation after the function executes is a no-op. 3540 c_mgr.start_cancel() 3541 3542 def testAddFunctionCallback(self): 3543 functions = [] 3544 def function_callback(f, name, graph, inputs, outputs): 3545 del name, graph, inputs, outputs 3546 functions.append(f) 3547 3548 @def_function.function 3549 def plus_one(x): 3550 return x + 1 3551 3552 try: 3553 function.add_function_callback(function_callback) 3554 x_float32 = numpy.array(3.0, dtype=numpy.float32) 3555 self.assertAllClose(plus_one(x_float32), 4.0) 3556 self.assertLen(functions, 1) 3557 # Function is already created. Executing it again should not invoke the 3558 # function callback. 3559 self.assertAllClose(plus_one(x_float32), 4.0) 3560 self.assertLen(functions, 1) 3561 # Signature change leads to a new Function being built. 3562 x_float64 = numpy.array(3.0, dtype=numpy.float64) 3563 self.assertAllClose(plus_one(x_float64), 4.0) 3564 self.assertLen(functions, 2) 3565 finally: 3566 function.clear_function_callbacks() 3567 3568 def testFunctionCallbackAddOps(self): 3569 file_name = os.path.join(self.get_temp_dir(), 'test') 3570 3571 def function_callback(f, name, graph, inputs, outputs): 3572 del f, name, inputs 3573 3574 with graph.as_default(): 3575 printer = logging_ops.print_v2( 3576 'hello', 3577 output_stream='file://' + file_name 3578 ) 3579 outputs[0].op._add_control_input(printer) 3580 3581 @def_function.function 3582 def plus_one(x): 3583 return x + 1 3584 3585 self.addCleanup(function.clear_function_callbacks) 3586 function.add_function_callback(function_callback) 3587 x_float32 = numpy.array(3.0, dtype=numpy.float32) 3588 3589 self.assertAllClose(plus_one(x_float32), 4.0) 3590 3591 with open(file_name, 'r') as f: 3592 self.assertEqual(f.read().strip(), 'hello') 3593 3594 def testRemoveFunctionCallback(self): 3595 functions_1 = [] 3596 def function_callback_1(f, name, graph, inputs, outputs): 3597 del name, graph, inputs, outputs 3598 functions_1.append(f) 3599 3600 functions_2 = [] 3601 def function_callback_2(f, name, graph, inputs, outputs): 3602 del name, graph, inputs, outputs 3603 functions_2.append(f) 3604 3605 @def_function.function 3606 def plus_one(x): 3607 return x + 1 3608 3609 try: 3610 function.add_function_callback(function_callback_1) 3611 function.add_function_callback(function_callback_2) 3612 self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float32)), 4.0) 3613 self.assertLen(functions_1, 1) 3614 self.assertLen(functions_2, 1) 3615 function.remove_function_callback(function_callback_1) 3616 # The 1st callback should not be invokved after remove_function_callback() 3617 # is called. 3618 self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float64)), 4.0) 3619 self.assertLen(functions_1, 1) 3620 self.assertLen(functions_2, 2) 3621 finally: 3622 function.clear_function_callbacks() 3623 3624 def testClearFunctionCallbacks(self): 3625 function.add_function_callback(lambda f: None) 3626 function.add_function_callback(lambda f: None) 3627 self.assertLen(function._function_callbacks, 2) 3628 function.clear_function_callbacks() 3629 self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access 3630 3631 @test_util.run_in_graph_and_eager_modes 3632 def testConcreteFunctionWithNestedTensorInputs(self): 3633 3634 @def_function.function 3635 def f(x, y): 3636 return (x['a'] + x['b'], y[0] + y[1]) 3637 3638 a = constant_op.constant(1000) 3639 b = constant_op.constant(200) 3640 c = constant_op.constant(30) 3641 d = {'a': a, 'b': b} 3642 e = (c, 4) 3643 3644 # Test different argument signatures when constructing the concrete func. 3645 for cf in [ 3646 f.get_concrete_function(d, e), 3647 f.get_concrete_function(d, y=e), 3648 f.get_concrete_function(y=e, x=d), 3649 f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)), 3650 f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)), 3651 f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d)) 3652 ]: 3653 # Test different calling conventions when calling the concrete func. 3654 for output in [ 3655 cf(d, e), # structured signature 3656 cf(d, y=e), # structured signature w/ kwarg 3657 cf(y=e, x=d), # structured signature w/ 2 kwargs 3658 cf(a, b, c), # flat signature 3659 cf(x=a, x_1=b, y=c) # flat signature w/ kwargs 3660 ]: 3661 self.assertIsInstance(output, tuple) 3662 self.assertLen(output, 2) 3663 self.assertAllEqual(output[0], 1200) 3664 self.assertAllEqual(output[1], 34) 3665 3666 @test_util.run_in_graph_and_eager_modes 3667 def testConcreteFunctionWithNestedNonTensorInputs(self): 3668 3669 @def_function.function 3670 def f(x, y): 3671 return (x['a'] + x['b'], y[0] + y[1]) 3672 3673 a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)} 3674 b = (50, 3) 3675 3676 for cf in [ # argument y is bound to non-Tensor value (50, 3). 3677 f.get_concrete_function(a, b), 3678 f.get_concrete_function(a, y=b), 3679 f.get_concrete_function(x=a, y=b) 3680 ]: 3681 for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]: 3682 self.assertAllEqual(output[0] + output[1], 1253) 3683 3684 @test_util.run_in_graph_and_eager_modes 3685 def testConcreteFunctionWithNonTensorStringInputs(self): 3686 3687 @def_function.function 3688 def f(x, y): 3689 return string_ops.string_join([x, y]) 3690 3691 a = constant_op.constant('a') 3692 b = 'b' 3693 3694 cf = f.get_concrete_function(a, b) 3695 for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]: 3696 self.assertAllEqual(output, b'ab') 3697 3698 @test_util.run_in_graph_and_eager_modes 3699 def testConcreteFunctionWithBoundNestedNonTensorInputs(self): 3700 3701 @def_function.function 3702 def f(x, y): 3703 return (x['a'] + x['b'], y[0] + y[1]) 3704 3705 a = {'a': 3000, 'b': 200, 'c': 9000} 3706 b = (constant_op.constant(30), 4) 3707 3708 for cf in [ # argument x is bound to non-tensor value `a` 3709 f.get_concrete_function(a, b), 3710 f.get_concrete_function(a, y=b), 3711 f.get_concrete_function(x=a, y=b) 3712 ]: 3713 for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]: 3714 self.assertAllEqual(output[0] + output[1], 3234) 3715 3716 @test_util.run_in_graph_and_eager_modes 3717 def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self): 3718 3719 @def_function.function 3720 def f(x, y): 3721 return (x['a'] + x['b'], y[0] + y[1]) 3722 3723 a = {'a': 5000, 'b': 500} 3724 b = (50, 5) 3725 3726 cf = f.get_concrete_function(a, b) 3727 for output in [cf(), cf(a), cf(y=b)]: 3728 self.assertAllEqual(output[0] + output[1], 5555) 3729 3730 @test_util.run_in_graph_and_eager_modes 3731 def testConcreteFunctionMethodWithVarargs(self): 3732 float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) 3733 3734 class MyModel(module.Module): 3735 3736 @def_function.function(input_signature=[float32_scalar, float32_scalar]) 3737 def add(self, *arg): 3738 return math_ops.add(*arg) 3739 3740 m = MyModel() 3741 cf = m.add.get_concrete_function() 3742 cf(-12.0, 3.0) 3743 3744 @test_util.run_in_graph_and_eager_modes 3745 def testConcreteFunctionStructuredSignatureKeywordOrder(self): 3746 # Check that keyword-only arguments are sorted appropriately, so that they 3747 # feed the right tensor into each input. 3748 @def_function.function 3749 def g(**kwargs): 3750 return string_ops.reduce_join( 3751 string_ops.reduce_join( 3752 ops.convert_to_tensor(sorted(kwargs.items())), 3753 axis=1, 3754 separator='='), 3755 axis=0, 3756 separator=', ') 3757 3758 s = constant_op.constant('s') 3759 g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s) 3760 self.assertAllEqual( 3761 g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'), 3762 b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') 3763 self.assertAllEqual( 3764 g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'), 3765 b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') 3766 self.assertAllEqual( 3767 g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'), 3768 b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') 3769 3770 # pylint: disable=g-long-lambda 3771 @parameterized.named_parameters([ 3772 dict( 3773 testcase_name='MissingArg', 3774 conc_args=lambda: (1, constant_op.constant(2)), 3775 call_args=lambda: (1,), 3776 error=r'func\(x, y\) missing required arguments: y'), 3777 dict( 3778 testcase_name='MissingVararg', 3779 conc_args=lambda: (1, 2, constant_op.constant(1.0)), 3780 call_args=lambda: (1, 2), 3781 error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'), 3782 dict( 3783 testcase_name='ExtraPositionalArg', 3784 conc_args=lambda: (1, 2), 3785 call_args=lambda: (1, 2, 3), 3786 error=r'func\(x, y\) takes 2 .* got 3'), 3787 dict( 3788 testcase_name='MissingKeywordOnlyArg', 3789 conc_args=lambda: (1, 2), 3790 conc_kwargs=lambda: {'c': constant_op.constant(1.0)}, 3791 call_args=lambda: (1, 2), 3792 error=r'func\(x, y, \*, c\) missing required arguments: c'), 3793 dict( 3794 testcase_name='ExtraKeywordArg', 3795 conc_args=lambda: (1, 2), 3796 call_args=lambda: (1, 2), 3797 call_kwargs=lambda: {'c': constant_op.constant(1.0)}, 3798 error=r'func\(x, y\) got unexpected keyword arguments: c'), 3799 dict( 3800 testcase_name='ExpectedRaggedGotNest', 3801 conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),), 3802 call_args=lambda: ({ 3803 'a': constant_op.constant([1, 2, 3]) 3804 },), 3805 error=r'func\(x, y\): argument x had incorrect type\n' 3806 r' expected: RaggedTensor\n' 3807 r" got: {'a': (Eager)?Tensor}"), 3808 dict( 3809 testcase_name='WrongRaggedRank', 3810 conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),), 3811 call_args=lambda: (ragged_factory_ops.constant([[[1]]]),), 3812 error=r'func\(x, y\): argument x had incorrect type\n'), 3813 dict( 3814 testcase_name='WrongRaggedDType', 3815 conc_args=lambda: (ragged_factory_ops.constant([[1]]),), 3816 call_args=lambda: (ragged_factory_ops.constant([[1.0]]),), 3817 error=r'func\(x, y\): argument x had incorrect type\n'), 3818 dict( 3819 testcase_name='ExpectedDictGotTensor', 3820 conc_args=lambda: ({ 3821 'a': constant_op.constant(1), 3822 'b': constant_op.constant(1) 3823 },), 3824 call_args=lambda: (constant_op.constant(1),), 3825 error=r'func\(x, y\): argument x had incorrect type\n'), 3826 dict( 3827 testcase_name='ExpectedTupleGotTensor', 3828 conc_args=lambda: 3829 ((constant_op.constant(1), constant_op.constant(2)),), 3830 call_args=lambda: (constant_op.constant(1),), 3831 error=r'func\(x, y\): argument x had incorrect type\n'), 3832 dict( 3833 testcase_name='WrongDType', 3834 conc_args=lambda: (constant_op.constant(1),), 3835 call_args=lambda: (constant_op.constant(1.0),), 3836 exception=(ValueError, errors.InvalidArgumentError, 3837 # on xla_gpu, we get InternalError instead. 3838 errors.InternalError)), 3839 dict( 3840 testcase_name='ExpectedTensorGotInt', 3841 conc_args=lambda: (constant_op.constant(1),), 3842 call_args=lambda: (5,), 3843 error=r'func\(x, y\) expected a Tensor in x, but got int value 5'), 3844 dict( 3845 testcase_name='ExpectedIntGotDifferentInt', 3846 conc_args=lambda: (5,), 3847 call_args=lambda: (8,), 3848 error=r'ConcreteFunction func\(x, y\) was constructed with int ' 3849 r'value 5 in x, but was called with int value 8'), 3850 dict( 3851 testcase_name='ExpectedIntGotTensor', 3852 conc_args=lambda: (5,), 3853 call_args=lambda: (constant_op.constant(6),), 3854 error=r'ConcreteFunction func\(x, y\) was constructed with int ' 3855 'value 5 in x, but was called with (Eager)?Tensor value .*'), 3856 dict( 3857 testcase_name='TwoValuesForArgument', 3858 conc_args=lambda: (1, 2), 3859 call_args=lambda: (1, 2), 3860 call_kwargs=lambda: {'x': 3}, 3861 error=r"func\(x, y\) got two values for 'x'"), 3862 ]) 3863 # pylint: enable=g-long-lambda 3864 @test_util.run_in_graph_and_eager_modes 3865 def testConcreteFunctionStructuredSignatureError(self, 3866 conc_args=(), 3867 conc_kwargs=None, 3868 call_args=(), 3869 call_kwargs=None, 3870 error='.*', 3871 exception=TypeError): 3872 """Tests for errors in the structrued signature. 3873 3874 Args: 3875 conc_args: Positional arguments used for get_concrete_function. 3876 conc_kwargs: Keyword arguments used for get_concrete_function. 3877 call_args: Positional arguments used to call the function. 3878 call_kwargs: Keyword arguments used to call the function. 3879 error: Expected exception message. 3880 exception: Expected exception type. 3881 """ 3882 conc_args = conc_args() if callable(conc_args) else conc_args 3883 conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {} 3884 call_args = call_args() if callable(call_args) else call_args 3885 call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {} 3886 self.assertIsInstance(conc_args, tuple) 3887 self.assertIsInstance(call_args, tuple) 3888 self.assertIsInstance(conc_kwargs, dict) 3889 self.assertIsInstance(call_kwargs, dict) 3890 3891 @def_function.function 3892 def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg 3893 del y, varargs, kwargs 3894 return x 3895 3896 conc = func.get_concrete_function(*conc_args, **conc_kwargs) 3897 with self.assertRaisesRegex(exception, error): 3898 self.evaluate(conc(*call_args, **call_kwargs)) 3899 3900 # pylint: disable=g-long-lambda 3901 @parameterized.named_parameters([ 3902 dict( 3903 testcase_name='MissingArg', 3904 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3905 call_args=lambda: (constant_op.constant(1),), 3906 error=r'func\(x, y\) missing required arguments: y'), 3907 dict( 3908 testcase_name='TwoValuesForArg', 3909 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3910 call_args=lambda: (constant_op.constant(1),), 3911 call_kwargs=lambda: { 3912 'x': constant_op.constant(1), 3913 'y': constant_op.constant(1) 3914 }, 3915 error=r"func\(x, y\) got two values for 'x'"), 3916 dict( 3917 testcase_name='ExtraPositionalArg', 3918 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3919 call_args=lambda: (constant_op.constant(1), constant_op.constant(2), 3920 constant_op.constant(3)), 3921 error=r'func\(x, y\) takes 2 .* got 3'), 3922 dict( 3923 testcase_name='UnexpectedKeywordArg', 3924 conc_args=lambda: (constant_op.constant(1),), 3925 call_args=lambda: (constant_op.constant(1),), 3926 call_kwargs=lambda: {'c': constant_op.constant(1)}, 3927 error=r'func\(x\) got unexpected keyword arguments: c'), 3928 dict( 3929 testcase_name='MissingVararg', 3930 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2), 3931 constant_op.constant(3)), 3932 call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3933 error=r'func\(x, y, varargs_0\) missing required ' 3934 r'arguments: varargs_0'), 3935 dict( 3936 testcase_name='MissingKeywordArg', 3937 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3938 conc_kwargs=lambda: {'c': constant_op.constant(1)}, 3939 call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3940 error=r'func\(x, y, c\) missing required arguments: c'), 3941 dict( 3942 testcase_name='ExpectedTensorGotInt', 3943 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3944 call_args=lambda: (5, constant_op.constant(2)), 3945 error=r'func\(x, y\): expected argument #0\(zero-based\) to be ' 3946 r'a Tensor; got int \(5\)'), 3947 dict( 3948 testcase_name='WrongDType', 3949 conc_args=lambda: (constant_op.constant(1),), 3950 call_args=lambda: (constant_op.constant(1.0),), 3951 exception=(ValueError, errors.InvalidArgumentError, 3952 # on xla_gpu, we get InternalError instead. 3953 errors.InternalError)), 3954 dict( 3955 testcase_name='MissingKeywordArgNestPiece', 3956 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3957 conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])}, 3958 call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3959 call_kwargs=lambda: {'c': constant_op.constant(1)}, 3960 error=r'func\(x, y, c, c_1\) missing required arguments: c_1'), 3961 ]) 3962 # pylint: enable=g-long-lambda 3963 @test_util.run_in_graph_and_eager_modes 3964 def testConcreteFunctionFlatSignatureError(self, 3965 conc_args=(), 3966 conc_kwargs=None, 3967 call_args=(), 3968 call_kwargs=None, 3969 error='.*', 3970 exception=TypeError): 3971 """Tests for errors in the flat signature. 3972 3973 Args: 3974 conc_args: Positional arguments used for get_concrete_function. 3975 conc_kwargs: Keyword arguments used for get_concrete_function. 3976 call_args: Positional arguments used to call the function. 3977 call_kwargs: Keyword arguments used to call the function. 3978 error: Expected exception message. 3979 exception: Expected exception type. 3980 """ 3981 conc_args = conc_args() if callable(conc_args) else conc_args 3982 conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {} 3983 call_args = call_args() if callable(call_args) else call_args 3984 call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {} 3985 self.assertIsInstance(conc_args, tuple) 3986 self.assertIsInstance(call_args, tuple) 3987 self.assertIsInstance(conc_kwargs, dict) 3988 self.assertIsInstance(call_kwargs, dict) 3989 3990 @def_function.function 3991 def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg 3992 del y, varargs, kwargs 3993 return x 3994 3995 conc = func.get_concrete_function(*conc_args, **conc_kwargs) 3996 3997 # Remove _function_spec, to disable the structured signature. 3998 conc._set_function_spec(None) # pylint: disable=protected-access 3999 4000 with self.assertRaisesRegex(exception, error): 4001 self.evaluate(conc(*call_args, **call_kwargs)) 4002 4003 @test_util.run_in_graph_and_eager_modes 4004 def testConcreteFunctionAmbiguousSignature(self): 4005 # When both the flat & structured signatures are applicable, but they 4006 # give different results, we use the structured signature. Note: we expect 4007 # this to be extremely rare. 4008 @def_function.function 4009 def f(x, y): 4010 return x * 10 + y 4011 4012 conc = f.get_concrete_function( 4013 x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'), 4014 y=tensor_spec.TensorSpec(None, dtypes.int32, name='x')) 4015 4016 result = conc(x=constant_op.constant(5), y=constant_op.constant(6)) 4017 self.assertAllEqual(result, 56) 4018 4019 def testPrettyPrintedSignature(self): 4020 4021 @def_function.function 4022 def func(x, kangaroo=None, octopus=7): 4023 del octopus, kangaroo 4024 return x 4025 4026 scalar = constant_op.constant(5) 4027 vector = constant_op.constant([10, 10, 20]) 4028 ragged = ragged_factory_ops.constant([[10, 20], [40]]) 4029 4030 c1 = func.get_concrete_function(scalar, vector) 4031 c1_summary = r'func\(x, kangaroo, octopus=7\)' 4032 c1_details = (r' Args:\n' 4033 r' x: int32 Tensor, shape=\(\)\n' 4034 r' kangaroo: int32 Tensor, shape=\(3,\)\n' 4035 r' Returns:\n' 4036 r' int32 Tensor, shape=\(\)') 4037 self.assertRegex(c1.pretty_printed_signature(verbose=False), c1_summary) 4038 self.assertRegex( 4039 c1.pretty_printed_signature(verbose=True), 4040 c1_summary + '\n' + c1_details) 4041 self.assertRegex( 4042 repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>') 4043 self.assertRegex( 4044 str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details)) 4045 4046 c2 = func.get_concrete_function(scalar, ragged, 3) 4047 c2_summary = r'func\(x, kangaroo, octopus=3\)' 4048 c2_details = (r' Args:\n' 4049 r' x: int32 Tensor, shape=\(\)\n' 4050 r' kangaroo: RaggedTensorSpec\(.*\)\n' 4051 r' Returns:\n' 4052 r' int32 Tensor, shape=\(\)') 4053 self.assertRegex(c2.pretty_printed_signature(), 4054 c2_summary + '\n' + c2_details) 4055 4056 c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]}) 4057 c3_summary = r'func\(x, kangaroo=None, octopus=7\)' 4058 c3_details = (r' Args:\n' 4059 r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n" 4060 r' <1>: int32 Tensor, shape=\(\)\n' 4061 r' <2>: RaggedTensorSpec\(.*\)\n' 4062 r' <3>: RaggedTensorSpec\(.*\)\n' 4063 r' Returns:\n' 4064 r" {'a': <1>, 'b': \[<2>, <3>\]}\n" 4065 r' <1>: int32 Tensor, shape=\(\)\n' 4066 r' <2>: RaggedTensorSpec\(.*\)\n' 4067 r' <3>: RaggedTensorSpec\(.*\)') 4068 4069 # python 3.5 does not gurantee deterministic iteration of dict contents 4070 # which can lead mismatch on pretty_printed_signature output for "Args" 4071 if sys.version_info >= (3, 6): 4072 self.assertRegex(c3.pretty_printed_signature(), 4073 c3_summary + '\n' + c3_details) 4074 4075 # pylint: disable=keyword-arg-before-vararg 4076 @def_function.function 4077 def func2(x, y=3, *args, **kwargs): 4078 return (x, y, args, kwargs) 4079 4080 c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar) 4081 c4_summary = 'func2(x, y=4, <arg3>=5, *, a)' 4082 self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary) 4083 4084 c5 = func2.get_concrete_function(8, vector) 4085 c5_summary = 'func2(x=8, y)' 4086 self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary) 4087 4088 def testPrettyPrintedExplicitSignatureWithKeywordArg(self): # b/159639913 4089 4090 @def_function.function(input_signature=[tensor_spec.TensorSpec(None)]) 4091 def fn(a, b=1): 4092 return a + b 4093 4094 concrete_fn = fn.get_concrete_function() 4095 self.assertEqual(concrete_fn.pretty_printed_signature(False), 'fn(a)') 4096 self.assertEqual( 4097 concrete_fn.pretty_printed_signature(True), 'fn(a)\n' 4098 ' Args:\n' 4099 ' a: float32 Tensor, shape=<unknown>\n' 4100 ' Returns:\n' 4101 ' float32 Tensor, shape=<unknown>') 4102 4103 def testPrettyPrintedSignatureLoadedNamedTuple(self): 4104 Point = collections.namedtuple('Point', ['x', 'y']) 4105 4106 @def_function.function 4107 def fn(b, a): # pylint: disable=unused-argument 4108 return 1. 4109 4110 b = Point( 4111 x=constant_op.constant(1., dtype=dtypes.float32), 4112 y=constant_op.constant(1., dtype=dtypes.float32)) 4113 a = Point( 4114 x=constant_op.constant(1, dtype=dtypes.int32), 4115 y=constant_op.constant(1, dtype=dtypes.int32)) 4116 4117 mod = module.Module() 4118 f = fn.get_concrete_function(b, a) 4119 save(mod, '/tmp/f', signatures=f) 4120 loaded = load('/tmp/f') 4121 4122 printed = loaded.signatures['serving_default'].pretty_printed_signature() 4123 self.assertIn('a: int32 Tensor, shape=()', printed) 4124 self.assertIn('a_1: int32 Tensor, shape=()', printed) 4125 self.assertIn('b: float32 Tensor, shape=()', printed) 4126 self.assertIn('b_1: float32 Tensor, shape=()', printed) 4127 4128 @test_util.run_in_graph_and_eager_modes 4129 def testIndexedSlicesAsGradientsForConcreteFunctions(self): 4130 4131 @def_function.function 4132 def summing_rnn(inputs): 4133 return math_ops.reduce_sum(inputs, axis=1) 4134 4135 @def_function.function 4136 def gradients(inputs): 4137 with backprop.GradientTape() as tape: 4138 tape.watch(inputs) 4139 hidden = summing_rnn(inputs) 4140 hidden = array_ops.gather(hidden, constant_op.constant([0])) 4141 loss = math_ops.reduce_mean(hidden) 4142 return tape.gradient(loss, inputs) 4143 4144 gradients(constant_op.constant([[[1.0], [2.0]]])) # No error is raised 4145 4146 def testFollowTypeHintsTraceBasic(self): 4147 trace_count = [0] 4148 4149 def func(x: ops.Tensor): 4150 trace_count[0] += 1 4151 return x 4152 4153 enabled = def_function.function(func, experimental_follow_type_hints=True) 4154 disabled = def_function.function(func, experimental_follow_type_hints=False) 4155 4156 enabled(1) # Initial call gets traced 4157 enabled(2) 4158 enabled(3) 4159 self.assertEqual(trace_count[0], 1) 4160 4161 trace_count = [0] 4162 disabled(1) 4163 disabled(2) # Retrace 4164 disabled(3) # Retrace 4165 self.assertEqual(trace_count[0], 3) 4166 4167 def testFollowTypeHintsTraceWithArgs(self): 4168 trace_count = [0] 4169 4170 def func(*args: ops.Tensor): 4171 trace_count[0] += 1 4172 return args 4173 4174 enabled = def_function.function(func, experimental_follow_type_hints=True) 4175 disabled = def_function.function(func, experimental_follow_type_hints=False) 4176 4177 args = ( 4178 'abc', 4179 'def', 4180 ) * 20 4181 args2 = ( 4182 'def', 4183 'abc', 4184 ) * 20 4185 4186 enabled(args) 4187 enabled(args2) 4188 self.assertEqual(trace_count[0], 1) 4189 4190 trace_count = [0] 4191 disabled(args) 4192 disabled(args2) # Retrace 4193 self.assertEqual(trace_count[0], 2) 4194 4195 def testFollowTypeHintsTraceWithKwargs(self): 4196 trace_count = [0] 4197 4198 def func(t: ops.Tensor, **kwargs: ops.Tensor): 4199 del kwargs 4200 trace_count[0] += 1 4201 return t 4202 4203 enabled = def_function.function(func, experimental_follow_type_hints=True) 4204 disabled = def_function.function(func, experimental_follow_type_hints=False) 4205 4206 enabled(1, x=1, y=1.0, z='one') 4207 enabled(2, x=2, y=2.0, z='two') 4208 self.assertEqual(trace_count[0], 1) 4209 4210 trace_count = [0] 4211 disabled(1, x=1, y=1.0, z='one') 4212 disabled(2, x=2, y=2.0, z='two') # Retrace 4213 self.assertEqual(trace_count[0], 2) 4214 4215 def testFollowTypeHintsTraceWithMultipleInputTypes(self): 4216 trace_count = [0] 4217 4218 def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor): 4219 del args, kwargs 4220 trace_count[0] += 1 4221 return t 4222 4223 enabled = def_function.function(func, experimental_follow_type_hints=True) 4224 disabled = def_function.function(func, experimental_follow_type_hints=False) 4225 4226 enabled(1, constant_op.constant(1), 'str', x=4.0) 4227 enabled(2, constant_op.constant(2), 'str2', x=5.0) 4228 self.assertEqual(trace_count[0], 1) 4229 4230 trace_count = [0] 4231 disabled(1, constant_op.constant(1), 'str', x=4.0) 4232 disabled(2, constant_op.constant(2), 'str2', x=5.0) # Retrace 4233 self.assertEqual(trace_count[0], 2) 4234 4235 def testFollowTypeHintsTraceWithOnlyArgNamed(self): 4236 trace_count = [0] 4237 4238 def func(t: ops.Tensor, i: int = 1, **kwargs): # pylint: disable=bad-whitespace 4239 del i, kwargs 4240 trace_count[0] += 1 4241 return t 4242 4243 enabled = def_function.function(func, experimental_follow_type_hints=True) 4244 4245 enabled(1, 3, x=4.0, y='str') 4246 enabled(2, 4, x=4.0, y='str') # Retrace 4247 self.assertEqual(trace_count[0], 2) 4248 4249 def testFollowTypeHintsTraceWithNotAllNamed(self): 4250 trace_count = [0] 4251 4252 def func(x, y: ops.Tensor, z: int): 4253 del y, z 4254 trace_count[0] += 1 4255 return x 4256 4257 enabled = def_function.function(func, experimental_follow_type_hints=True) 4258 4259 enabled(1, 2, 3) 4260 enabled(1, 20, 3) # No retrace - change in ops.Tensor typed arg 4261 enabled(2, 2, 3) # Retrace - change in untyped arg 4262 enabled(2, 2, 4) # Retrace - change in typed arg 4263 self.assertEqual(trace_count[0], 3) 4264 4265 def testFollowTypeHintsTraceWithOnlyArgsNamed(self): 4266 trace_count = [0] 4267 4268 def func(x, y, *args: ops.Tensor): 4269 del y, args 4270 trace_count[0] += 1 4271 return x 4272 4273 enabled = def_function.function(func, experimental_follow_type_hints=True) 4274 4275 enabled(1, 20, 3, 4, 5, 6) 4276 enabled(1, 20, 3, 4, 5, 60) # No retrace - change in *args 4277 enabled(1, 30, 7, 8, 9, 10) # Retrace - change in args 4278 self.assertEqual(trace_count[0], 2) 4279 4280 def testFollowTypeHintsTraceWithOnlyKwargsNamed(self): 4281 trace_count = [0] 4282 4283 def func(x, y, *args, **kwargs: ops.Tensor): 4284 del y, args, kwargs 4285 trace_count[0] += 1 4286 return x 4287 4288 enabled = def_function.function(func, experimental_follow_type_hints=True) 4289 4290 enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) 4291 enabled( 4292 1, 2, 3, 4, 5, 6, a=1.5, b=2.5, 4293 c=3.5) # No retrace - change in **kwargs 4294 enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) # Retrace - change in args 4295 enabled( 4296 1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0) # Retrace - change in *args 4297 self.assertEqual(trace_count[0], 3) 4298 4299 def testFollowTypeHintsTraceWithArgsEquals(self): 4300 trace_count = [0] 4301 4302 def func( 4303 x: ops.Tensor = 0, # pylint:disable=bad-whitespace 4304 y: int = 1, # pylint:disable=bad-whitespace 4305 **kwargs: ops.Tensor): 4306 del y, kwargs 4307 trace_count[0] += 1 4308 return x 4309 4310 enabled = def_function.function(func, experimental_follow_type_hints=True) 4311 4312 enabled(x=1, y=2, z=3) 4313 enabled(x=1, y=3, z=3) # Retrace - change in args 4314 enabled(x=2, y=2, z=4) # No retrace - change in args and **kwargs 4315 enabled(x=2, y=2, z=4, u=5) # Retrace - change in **kwargs 4316 self.assertEqual(trace_count[0], 3) 4317 4318 def testFollowTypeHintsWithTensorSpec(self): 4319 def func(x: ops.Tensor, y): 4320 return x + y 4321 v = def_function.function(experimental_follow_type_hints=True)(func) 4322 v = v.get_concrete_function( 4323 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 3) 4324 x = v(constant_op.constant(1.), 3) 4325 self.assertEqual(x.numpy(), 4.) 4326 4327 def testFollowTypeHintsTraceWithKwArgsAndNoVarKws(self): 4328 trace_count = [0] 4329 4330 def func(a: int, b: ops.Tensor, 4331 x: ops.Tensor = 0, y: int = 1): 4332 del a, b, y 4333 trace_count[0] += 1 4334 return x 4335 4336 enabled = def_function.function(func, experimental_follow_type_hints=True) 4337 4338 enabled(0, 0, x=1, y=2) 4339 enabled(0, 0, x=2, y=2,) # No retrace, since only tensor changed 4340 self.assertEqual(trace_count[0], 1) 4341 4342 # Pass args as keyword args. 4343 enabled(a=0, b=0, x=2, y=2,) # No retrace, args are the same 4344 self.assertEqual(trace_count[0], 1) 4345 4346 enabled(a=1, b=0, x=2, y=2,) # Retrace, since non-tensor arg changed 4347 self.assertEqual(trace_count[0], 2) 4348 4349 enabled(a=1, b=2, x=2, y=2) # No retrace, since only tensor changed 4350 self.assertEqual(trace_count[0], 2) 4351 4352 trace_count[0] = 0 4353 disabled = def_function.function(func, experimental_follow_type_hints=False) 4354 disabled(0, 0, x=1, y=2) 4355 disabled(0, 0, x=2, y=2,) # Retrace 4356 self.assertEqual(trace_count[0], 2) 4357 4358 def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self): 4359 trace_count = [0] 4360 4361 def func(x, y, **kwargs: ops.Tensor): 4362 del y, kwargs 4363 trace_count[0] += 1 4364 return x 4365 4366 enabled = def_function.function(func, experimental_follow_type_hints=True) 4367 4368 enabled(x=1, y=2, z=3) 4369 enabled(x=1, y=3, z=3) # Retrace 4370 enabled(x=1, y=2, z=4) # No retrace 4371 enabled(x=2, y=2, z=4) # Retrace 4372 enabled(x=2, y=2, z=4, u=5) # Retrace 4373 self.assertEqual(trace_count[0], 4) 4374 4375 def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self): 4376 trace_count = [0] 4377 4378 def func(x: ops.Tensor, y: int, **kwargs): 4379 del y, kwargs 4380 trace_count[0] += 1 4381 return x 4382 4383 enabled = def_function.function(func, experimental_follow_type_hints=True) 4384 4385 enabled(x=1, y=2, z=3) 4386 enabled(x=1, y=3, z=3) # Retrace 4387 enabled(x=1, y=2, z=4) # Retrace 4388 enabled(x=2, y=2, z=3) # No retrace 4389 enabled(x=2, y=2, z=4, u=5) # Retrace 4390 self.assertEqual(trace_count[0], 4) 4391 4392 def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self): 4393 trace_count = [0] 4394 4395 def func(*, a: ops.Tensor = None, b=1): # pylint: disable=bad-whitespace 4396 del b 4397 trace_count[0] += 1 4398 return a 4399 4400 enabled = def_function.function(func, experimental_follow_type_hints=True) 4401 4402 enabled(a=1, b=2) 4403 enabled(a=2, b=2) # No retrace 4404 enabled(a=1, b=1) # Retrace 4405 self.assertEqual(trace_count[0], 2) 4406 4407 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self): 4408 trace_count = [0] 4409 4410 def func(arg: ops.Tensor, *args, kwonly, **kwargs): 4411 del args, kwonly, kwargs 4412 trace_count[0] += 1 4413 return arg 4414 4415 enabled = def_function.function(func, experimental_follow_type_hints=True) 4416 4417 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4418 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4419 enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4420 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4421 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace 4422 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace 4423 self.assertEqual(trace_count[0], 4) 4424 4425 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self): 4426 trace_count = [0] 4427 4428 def func(arg, *args: ops.Tensor, kwonly, **kwargs): 4429 del args, kwonly, kwargs 4430 trace_count[0] += 1 4431 return arg 4432 4433 enabled = def_function.function(func, experimental_follow_type_hints=True) 4434 4435 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4436 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4437 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4438 enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4439 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace 4440 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace 4441 self.assertEqual(trace_count[0], 4) 4442 4443 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self): 4444 trace_count = [0] 4445 4446 def func(arg, *args, kwonly: ops.Tensor, **kwargs): 4447 del args, kwonly, kwargs 4448 trace_count[0] += 1 4449 return arg 4450 4451 enabled = def_function.function(func, experimental_follow_type_hints=True) 4452 4453 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4454 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4455 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4456 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # No retrace 4457 enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7) # No retrace 4458 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace 4459 self.assertEqual(trace_count[0], 4) 4460 4461 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self): 4462 trace_count = [0] 4463 4464 def func(arg, *args, kwonly, **kwargs: ops.Tensor): 4465 del args, kwonly, kwargs 4466 trace_count[0] += 1 4467 return arg 4468 4469 enabled = def_function.function(func, experimental_follow_type_hints=True) 4470 4471 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4472 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4473 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4474 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace 4475 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # No retrace 4476 enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace 4477 self.assertEqual(trace_count[0], 4) 4478 4479 def testWithExtraWrapper(self): 4480 4481 class Foo(module.Module): 4482 4483 def __init__(self): 4484 super().__init__() 4485 self.var = None 4486 4487 @def_function.function 4488 @dummy_tf_decorator 4489 def add(self, x, y, z=1): 4490 if self.var is None: 4491 return x + y + z 4492 4493 foo = Foo() 4494 self.assertEqual(foo.add(2, 3).numpy(), 6) 4495 4496 @parameterized.parameters([(def_function.function, dummy_tf_decorator), 4497 (dummy_tf_decorator, def_function.function), 4498 (def_function.function, def_function.function)]) 4499 def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2): 4500 4501 class Foo(module.Module): 4502 4503 def __init__(self): 4504 super().__init__() 4505 self.var = None 4506 4507 @decorator1 4508 @decorator2 4509 def add1(self, x, y): 4510 if self.var is None: 4511 return x + y 4512 4513 foo = Foo() 4514 with self.assertRaisesRegex(TypeError, 'got two values'): 4515 foo.add1(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter 4516 4517 def testWithExtraWrapperMissingArgs(self): 4518 4519 class Foo(module.Module): 4520 4521 def __init__(self): 4522 super().__init__() 4523 self.var = None 4524 4525 @def_function.function 4526 @dummy_tf_decorator 4527 def add1(self, x, y): 4528 if self.var is None: 4529 return x + y 4530 4531 @def_function.function 4532 @dummy_tf_decorator 4533 def add2(self, x, y): 4534 if self.var is None: 4535 return x + y 4536 4537 @def_function.function 4538 @def_function.function 4539 def add3(self, x, y): 4540 if self.var is None: 4541 return x + y 4542 4543 foo = Foo() 4544 with self.assertRaisesRegex( 4545 TypeError, 'missing 1 required positional argument: \'y\''): 4546 foo.add1(2) # pylint: disable=no-value-for-parameter 4547 4548 with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'): 4549 foo.add1(y=2) # pylint: disable=no-value-for-parameter 4550 4551 with self.assertRaisesRegex( 4552 TypeError, 'missing 1 required positional argument: \'y\''): 4553 foo.add2(2) # pylint: disable=no-value-for-parameter 4554 4555 with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'): 4556 foo.add2(y=2) # pylint: disable=no-value-for-parameter 4557 4558 with self.assertRaisesRegex( 4559 TypeError, 'missing 1 required positional argument: \'y\''): 4560 foo.add3(2) # pylint: disable=no-value-for-parameter 4561 4562 with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'): 4563 foo.add3(y=2) # pylint: disable=no-value-for-parameter 4564 4565 def testMissingArgsTfFunctionedMethod(self): 4566 4567 class A(object): 4568 4569 def func(self, position_arg1, position_arg2): 4570 return position_arg1, position_arg2 4571 4572 @def_function.function 4573 def decorated_method(self, position_arg1, position_arg2): 4574 return position_arg1, position_arg2 4575 4576 a_instance = A() 4577 tf_method_pos = def_function.function(a_instance.func) 4578 with self.assertRaisesRegex( 4579 TypeError, '.* missing 1 required argument: position_arg1'): 4580 tf_method_pos(position_arg2='foo') 4581 4582 # tf.function-decorated instance methods need to be tested because of 4583 # the __get__ method implementation. 4584 tf_func_decorated_method = def_function.function( 4585 a_instance.decorated_method) 4586 tf_func_decorated_method(position_arg1='foo', position_arg2='bar') 4587 with self.assertRaisesRegex( 4588 TypeError, '.* missing 1 required argument: position_arg1'): 4589 tf_func_decorated_method(position_arg2='bar') 4590 4591 def testMissingArgsTfFunctionedObject(self): 4592 4593 class A(object): 4594 4595 def __call__(self, position_arg1, position_arg2): 4596 return position_arg1, position_arg2 4597 4598 a_instance = A() 4599 4600 # A tf.function-decorated callable object needs to be tested because of 4601 # the special inspect results. 4602 tf_func_obj = def_function.function(a_instance) 4603 tf_func_obj(position_arg1=1, position_arg2=2) 4604 with self.assertRaisesRegex( 4605 TypeError, '.* missing 1 required argument: position_arg1'): 4606 tf_func_obj(position_arg2='bar') 4607 4608 def testMissingArgsTfFunctionedFunctions(self): 4609 4610 def func_pos(position_arg1, position_arg2): 4611 return position_arg1, position_arg2 4612 4613 def func_with_default(position_arg, named_arg=None): 4614 return position_arg, named_arg 4615 4616 def func_pos_3args(position_arg1, position_arg2, position_arg3): 4617 return position_arg1, position_arg2, position_arg3 4618 4619 tf_func_pos = def_function.function(func_pos) 4620 with self.assertRaisesRegex( 4621 TypeError, '.* missing 1 required argument: position_arg1'): 4622 tf_func_pos(position_arg2='foo') 4623 4624 tf_func_with_default = def_function.function(func_with_default) 4625 tf_func_with_default(position_arg='bar') 4626 with self.assertRaisesRegex(TypeError, 4627 '.* missing 1 required argument: position_arg'): 4628 tf_func_with_default(named_arg='foo') 4629 4630 tf_func_pos_3args = def_function.function(func_pos_3args) 4631 with self.assertRaisesRegex( 4632 TypeError, 4633 '.* missing required arguments: position_arg1, position_arg3'): 4634 tf_func_pos_3args(position_arg2='foo') 4635 4636 def testShapeInferencePropagateConstNestedStack(self): 4637 4638 @def_function.function(input_signature=[ 4639 tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), 4640 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4641 ]) 4642 def f(x, s): 4643 old_shape = array_ops.shape(x) 4644 new_shape = array_ops.stack([old_shape[0], s], axis=0) 4645 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4646 return y 4647 4648 @def_function.function(input_signature=[ 4649 tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) 4650 ]) 4651 def g(x): 4652 y = f(x, s=5) 4653 assert y.shape.as_list() == [3, 5], y.shape.as_list() 4654 return y 4655 4656 self.assertAllEqual( 4657 g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) 4658 4659 def testShapeInferencePropagateConstNestedUnstackStack(self): 4660 4661 @def_function.function(input_signature=[ 4662 tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), 4663 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4664 ]) 4665 def f(x, s): 4666 s0, _ = array_ops.unstack(array_ops.shape(x), axis=0) 4667 new_shape = array_ops.stack([s0, s], axis=0) 4668 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4669 return y 4670 4671 @def_function.function(input_signature=[ 4672 tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) 4673 ]) 4674 def g(x): 4675 y = f(x, s=5) 4676 assert y.shape.as_list() == [3, 5], y.shape.as_list() 4677 return y 4678 4679 self.assertAllEqual( 4680 g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) 4681 4682 def testShapeInferencePropagateConstNestedConcat(self): 4683 4684 @def_function.function(input_signature=[ 4685 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4686 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4687 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4688 ]) 4689 def f(d1, d2, d3): 4690 new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) 4691 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4692 return y 4693 4694 @def_function.function() 4695 def g(): 4696 y = f(1, 2, 3) 4697 assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() 4698 return y 4699 4700 self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) 4701 4702 def testShapeInferencePropagateConstDoubleNested(self): 4703 4704 @def_function.function(input_signature=[ 4705 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4706 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4707 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4708 ]) 4709 def f(d1, d2, d3): 4710 new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) 4711 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4712 return y 4713 4714 @def_function.function() 4715 def g(): 4716 y = def_function.function(f)(1, 2, 3) 4717 assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() 4718 return y 4719 4720 self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) 4721 4722 @test_util.run_v2_only 4723 def testControlDependencyAfterInline(self): 4724 v = variables.Variable(0.) 4725 4726 @def_function.function 4727 def assign(): 4728 return v.assign(1.) 4729 4730 @def_function.function 4731 def assign_add(): 4732 return v.assign_add(1.) 4733 4734 @def_function.function 4735 def f(): 4736 check_ops.assert_equal_v2(assign(), 1.) 4737 check_ops.assert_equal_v2(assign_add(), 2.) 4738 4739 # We don't have a way to inspect the inlined graph in Python, so we run it 4740 # multiple times to have more confidence the dependency is correct. 4741 for _ in range(30): 4742 f() 4743 4744 @test_util.run_v2_only 4745 def testReadInFuncWriteOutside(self): 4746 # Run many times since we are testing for a potential race condition. 4747 for _ in range(30): 4748 # pylint: disable=cell-var-from-loop 4749 v = variables.Variable(1.) 4750 4751 @def_function.function 4752 def add_one(): 4753 return v + 1. 4754 4755 @def_function.function 4756 def get_v_plus_one(): 4757 v_plus_one = add_one() 4758 v.assign_add(2.0) 4759 return v_plus_one 4760 4761 self.assertAllEqual(get_v_plus_one(), 2.0) 4762 4763 4764class MultiDeviceTest(test.TestCase, parameterized.TestCase): 4765 4766 @test_util.run_gpu_only 4767 def testMultiDeviceOutput(self): 4768 """Tests that functions can produce outputs on multiple devices.""" 4769 @function.defun 4770 def func(a, b, transpose_a): 4771 with ops.device('/device:CPU:0'): 4772 m1 = math_ops.matmul(a, b, transpose_a=transpose_a) 4773 with ops.device('/device:GPU:0'): 4774 m2 = math_ops.matmul(a, b, transpose_a=transpose_a) 4775 return m1, m2 4776 4777 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 4778 m1, m2 = func(t, t, transpose_a=True) 4779 self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]]) 4780 self.assertRegex(m1.backing_device, 'CPU') 4781 self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]]) 4782 self.assertRegex(m2.backing_device, 'GPU') 4783 4784 @test_util.run_gpu_only 4785 def testEmptyBody(self): 4786 @function.defun 4787 def func(a, b): 4788 return b, a 4789 4790 with ops.device('/device:CPU:0'): 4791 a = array_ops.identity(3.0) 4792 with ops.device('/device:GPU:0'): 4793 b = array_ops.identity(5.0) 4794 4795 m1, m2 = func(a, b) 4796 self.assertAllEqual(m1.numpy(), 5.0) 4797 self.assertRegex(m1.backing_device, 'GPU') 4798 self.assertAllEqual(m2.numpy(), 3.0) 4799 self.assertRegex(m2.backing_device, 'CPU') 4800 4801 @test_util.run_gpu_only 4802 def testMultiDeviceInt32(self): 4803 """Tests that multi-device functions can take and output INT32s. 4804 4805 When an INT32 device tensor is fed into a function, it is copied to CPU 4806 by the eager runtime. The function sees all INT32 inputs on CPU. 4807 4808 We set allocator attribute 'on_host' for INT32 outputs. They can be 4809 partitioned into the GPU component function, but will be allocated on 4810 CPU nevertheless. 4811 4812 There is experimental support for `ints_on_device` in 4813 FunctionLibraryRuntime now. We can try that. 4814 4815 """ 4816 with ops.device('/device:CPU:0'): 4817 int_cpu = constant_op.constant(3, dtype=dtypes.int32) 4818 resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32) 4819 with ops.device('/device:GPU:0'): 4820 int_gpu = constant_op.constant(7, dtype=dtypes.int32) 4821 4822 @function.defun 4823 def func(int_cpu, resource, int_gpu): 4824 with ops.device('/device:CPU:0'): 4825 m1 = int_cpu * resource + int_gpu 4826 with ops.device('/device:GPU:0'): 4827 # This computation will happen on GPU but m2 will be copied to CPU. 4828 m2 = int_gpu * resource + int_cpu + 1 4829 return m1, m2 4830 4831 m1, m2 = func(int_cpu, resource, int_gpu) 4832 self.assertAllEqual(m1.numpy(), 22) 4833 self.assertRegex(m1.backing_device, 'CPU') 4834 self.assertAllEqual(m2.numpy(), 39) 4835 self.assertRegex(m2.backing_device, 'CPU') 4836 4837 # flip arguments 4838 m1, m2 = func(int_gpu, resource, int_cpu) 4839 self.assertAllEqual(m1.numpy(), 38) 4840 self.assertRegex(m1.backing_device, 'CPU') 4841 self.assertAllEqual(m2.numpy(), 23) 4842 self.assertRegex(m2.backing_device, 'CPU') 4843 4844 @test_util.run_gpu_only 4845 def testMultiDeviceColocateWith(self): 4846 """Tests that function's outputs respect colocation constraints.""" 4847 @function.defun 4848 def func(a, b): 4849 with ops.colocate_with(a): 4850 ra = 2 * a 4851 with ops.colocate_with(b): 4852 rb = 3 * b 4853 return ra, rb 4854 4855 devices = ['/device:CPU:0', '/device:GPU:0'] 4856 for dev1, dev2 in itertools.product(devices, devices): 4857 with ops.device(dev1): 4858 a = array_ops.identity(1.0) 4859 with ops.device(dev2): 4860 b = array_ops.identity(10.0) 4861 4862 ra, rb = func(a, b) 4863 self.assertEqual(ra.numpy(), 2.0) 4864 self.assertRegex(ra.backing_device, dev1) 4865 self.assertEqual(rb.numpy(), 30.0) 4866 self.assertRegex(rb.backing_device, dev2) 4867 4868 @test_util.run_gpu_only 4869 def testMultiDeviceResources(self): 4870 with ops.device('/device:CPU:0'): 4871 c1 = resource_variable_ops.ResourceVariable(2.0) 4872 c2 = resource_variable_ops.ResourceVariable(7.0) 4873 with ops.device('/device:GPU:0'): 4874 g1 = resource_variable_ops.ResourceVariable(3.0) 4875 g2 = resource_variable_ops.ResourceVariable(5.0) 4876 4877 @function.defun 4878 def func(resource1, resource2): 4879 with ops.device('/device:CPU:0'): 4880 result1 = resource1 * g2 4881 with ops.device('/device:GPU:0'): 4882 result2 = resource2 * c2 4883 return result1, result2 4884 4885 r1, r2 = func(c1, g1) 4886 self.assertEqual(r1.numpy(), 10.0) 4887 self.assertRegex(r1.backing_device, 'CPU') 4888 self.assertEqual(r2.numpy(), 21.0) 4889 self.assertRegex(r2.backing_device, 'GPU') 4890 4891 # Call with flipped inputs. Check that we look at resource's 4892 # device and reinstantiates the function when inputs' devices change. 4893 r1, r2 = func(g1, c1) 4894 self.assertEqual(r1.numpy(), 15.0) 4895 self.assertRegex(r1.backing_device, 'CPU') 4896 self.assertEqual(r2.numpy(), 14.0) 4897 self.assertRegex(r2.backing_device, 'GPU') 4898 4899 @test_util.run_gpu_only 4900 def testOutputResources(self): 4901 with ops.device('/device:CPU:0'): 4902 c1 = resource_variable_ops.ResourceVariable(2.0) 4903 with ops.device('/device:GPU:0'): 4904 g1 = resource_variable_ops.ResourceVariable(3.0) 4905 4906 @function.defun 4907 def func(resource1, resource2): 4908 with ops.device('/device:CPU:0'): 4909 result1 = resource1 * 5 4910 with ops.device('/device:GPU:0'): 4911 result2 = resource2 * 7 4912 return result1, resource1.handle, result2, resource2.handle 4913 4914 r1, res1, r2, res2 = func(c1, g1) 4915 self.assertEqual(r1.numpy(), 10.0) 4916 self.assertRegex(r1.backing_device, 'CPU') 4917 self.assertEqual(r2.numpy(), 21.0) 4918 self.assertRegex(r2.backing_device, 'GPU') 4919 4920 def check_handle(handle, expected_value): 4921 self.assertRegex(handle.backing_device, 'CPU') 4922 tensor = gen_resource_variable_ops.read_variable_op( 4923 handle, dtypes.float32) 4924 self.assertEqual(tensor.numpy(), expected_value) 4925 4926 # Check that handles returned from functions are on CPU and an op using 4927 # the resource handle is correctly placed on the device backing the 4928 # resource. 4929 check_handle(res1, 2.0) 4930 check_handle(res2, 3.0) 4931 4932 # Call with flipped inputs to make sure the same the function is 4933 # reinstantiated and eager runtime does not mess up the device assignment 4934 # for ops consuming handles returned from defuns. 4935 r1, res1, r2, res2 = func(g1, c1) 4936 self.assertEqual(r1.numpy(), 15.0) 4937 self.assertRegex(r1.backing_device, 'CPU') 4938 self.assertEqual(r2.numpy(), 14.0) 4939 self.assertRegex(r2.backing_device, 'GPU') 4940 check_handle(res1, 3.0) 4941 check_handle(res2, 2.0) 4942 4943 @test_util.run_gpu_only 4944 def testPassResourceThroughNestedFunctionCall(self): 4945 """Test passing GPU resource to noinline function call placed on CPU. 4946 4947 PartitionedCallOp must not enforce any particular device assignment for the 4948 resource output. Inner function marked as `_nospecialize`, so Grappler would 4949 not prune unused function output. 4950 """ 4951 4952 with ops.device('/device:GPU:0'): 4953 g1 = resource_variable_ops.ResourceVariable(3.0) 4954 4955 @function.defun_with_attributes(attributes={ 4956 '_noinline': True, 4957 '_nospecialize': True 4958 }) 4959 def inner(resource1): 4960 return resource1 * 2, resource1.handle 4961 4962 @function.defun 4963 def outer(resource1): 4964 with ops.device('/device:CPU:0'): 4965 r1, _ = inner(resource1) 4966 return r1 4967 4968 r1 = outer(g1) 4969 4970 self.assertEqual(r1.numpy(), 6.0) 4971 self.assertRegex(r1.backing_device, 'CPU') 4972 4973 @test_util.run_gpu_only 4974 def testReturnResourceFromNestedFunctionCall(self): 4975 """Test returning GPU resource from noinline function call placed on CPU. 4976 4977 When inferring output devices for the return value, do not set a device for 4978 returns of DT_RESOURCE data type based on the device assignment of the node 4979 that produced that resource. As an example function call placed on CPU can 4980 return resources on GPU. 4981 """ 4982 4983 with ops.device('/device:GPU:0'): 4984 g1 = resource_variable_ops.ResourceVariable(3.0) 4985 4986 @function.defun_with_attributes(attributes={ 4987 '_noinline': True 4988 }) 4989 def inner(resource1): 4990 resource1.assign_add(2.0) 4991 return resource1 * 2, resource1.handle 4992 4993 @function.defun 4994 def outer(resource1): 4995 with ops.device('/device:CPU:0'): 4996 r1, res1 = inner(resource1) 4997 return r1, res1 4998 4999 r1, res1 = outer(g1) 5000 5001 self.assertEqual(r1.numpy(), 10.0) 5002 self.assertRegex(r1.backing_device, 'CPU') 5003 5004 def check_handle(handle, expected_value): 5005 self.assertRegex(handle.backing_device, 'CPU') 5006 tensor = gen_resource_variable_ops.read_variable_op( 5007 handle, dtypes.float32) 5008 self.assertEqual(tensor.numpy(), expected_value) 5009 5010 # Check that handles returned from functions are on CPU and an op using 5011 # the resource handle is correctly placed on the device backing the 5012 # resource. 5013 check_handle(res1, 5.0) 5014 5015 @test_util.run_gpu_only 5016 def testComplexInputOutputDevicePattern(self): 5017 """Tests input/output mapping logic in partitioning.""" 5018 with ops.device('/device:CPU:0'): 5019 rc0 = resource_variable_ops.ResourceVariable(2.0) 5020 rc1 = resource_variable_ops.ResourceVariable(3.0) 5021 cc0 = array_ops.identity(5.0) 5022 cc1 = array_ops.identity(7.0) 5023 with ops.device('/device:GPU:0'): 5024 rg0 = resource_variable_ops.ResourceVariable(11.0) 5025 rg1 = resource_variable_ops.ResourceVariable(13.0) 5026 cg0 = array_ops.identity(17.0) 5027 cg1 = array_ops.identity(19.0) 5028 5029 # Make sure tensors are on expected devices. 5030 for tensor in [cc0, cc1]: 5031 self.assertRegex(tensor.backing_device, 'CPU:0') 5032 for tensor in [cg0, cg1]: 5033 self.assertRegex(tensor.backing_device, 'GPU:0') 5034 5035 @function.defun 5036 def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1): 5037 with ops.device('/device:CPU:0'): 5038 m1 = rc0 * cg0 5039 with ops.device('/device:GPU:0'): 5040 m2 = rg0 * cc0 5041 5042 with ops.device('/device:CPU:0'): 5043 r1 = 1000.0 * m2 + rc1 * cg1 5044 with ops.device('/device:GPU:0'): 5045 r2 = 1000.0 * m1 + rg1 * cc1 5046 5047 return r1, r2, m2, m1 5048 5049 r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1) 5050 self.assertRegex(m1.backing_device, 'CPU') 5051 self.assertRegex(r1.backing_device, 'CPU') 5052 self.assertRegex(m2.backing_device, 'GPU') 5053 self.assertRegex(r2.backing_device, 'GPU') 5054 self.assertEqual(m1.numpy(), 34.0) 5055 self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0) 5056 self.assertEqual(m2.numpy(), 55.0) 5057 self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0) 5058 5059 @test_util.run_gpu_only 5060 def testArgumentPruning(self): 5061 """Tests functions taking unnecessary arguments.""" 5062 with ops.device('/device:CPU:0'): 5063 c1 = constant_op.constant(5.0) 5064 c2 = constant_op.constant(7.0) 5065 5066 with ops.device('/device:GPU:0'): 5067 g1 = constant_op.constant(11.0) 5068 g2 = constant_op.constant(13.0) 5069 g3 = constant_op.constant(17.0) 5070 5071 @function.defun 5072 def func(g1, g2, c1, g3, c2): # pylint: disable=unused-argument 5073 # arguments g1 and g2 are unused and can be pruned by grappler. 5074 return c1 * g3 * c2 5075 5076 result = func(g1, g2, c1, g3, c2) 5077 self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0) 5078 5079 def testNestedCallWatchedVariables(self): 5080 5081 v = variables.Variable(4.) 5082 5083 @def_function.function 5084 def f(): 5085 return v ** 2. 5086 5087 with backprop.GradientTape() as tape: 5088 f() 5089 5090 self.assertEqual((v,), tape.watched_variables()) 5091 5092 @def_function.function 5093 def g(): 5094 return f() 5095 5096 with backprop.GradientTape() as tape: 5097 g() 5098 5099 self.assertEqual((v,), tape.watched_variables()) 5100 5101 # f() can rely on the variable being read during its trace. g() checks that 5102 # variables from a function which knows about them are recorded on the 5103 # tape. h() tests that functions forward knowledge of variables to callers. 5104 5105 @def_function.function 5106 def h(): 5107 return g() 5108 5109 with backprop.GradientTape() as tape: 5110 h() 5111 5112 self.assertEqual((v,), tape.watched_variables()) 5113 5114 def testDeferredCapture(self): 5115 value = 1.0 5116 5117 @def_function.function 5118 def lazy_capture(x): 5119 y = ops.get_default_graph().capture_call_time_value( 5120 lambda: value, tensor_spec.TensorSpec(None)) 5121 return x + y 5122 5123 self.assertAllEqual(lazy_capture(2.0), 3.0) 5124 # After changing the value of `value` the function call should return a 5125 # different result. 5126 value = 2.0 5127 self.assertAllEqual(lazy_capture(2.0), 4.0) 5128 5129 def testDeferredCaptureWithKey(self): 5130 value0 = 1.0 5131 value1 = 2.0 5132 5133 @def_function.function 5134 def lazy_capture(x): 5135 w = ops.get_default_graph().capture_call_time_value( 5136 lambda: value0, tensor_spec.TensorSpec(None), key=0) 5137 y = ops.get_default_graph().capture_call_time_value( 5138 lambda: value1, tensor_spec.TensorSpec(None), key=1) 5139 def bad_closure(): 5140 raise ValueError('Should not run') 5141 z = ops.get_default_graph().capture_call_time_value( 5142 bad_closure, tensor_spec.TensorSpec(None), key=1) 5143 return x + y + w + z 5144 5145 self.assertAllEqual(lazy_capture(2.0), 7.0) 5146 value0 = 2.0 5147 value1 = 3.0 5148 self.assertAllEqual(lazy_capture(2.0), 10.0) 5149 5150 def testDeferredCaptureTypeError(self): 5151 value = constant_op.constant(1.0) 5152 5153 @def_function.function 5154 def lazy_capture(x): 5155 y = ops.get_default_graph().capture_call_time_value( 5156 lambda: value, tensor_spec.TensorSpec(())) 5157 return x + y 5158 5159 self.assertAllEqual(lazy_capture(2.0), 3.0) 5160 5161 # dtype mismatch 5162 value = constant_op.constant(1) 5163 with self.assertRaisesRegex(ValueError, 'Value .* to a tensor with dtype'): 5164 lazy_capture(2.0) 5165 5166 # shape mismatch 5167 value = constant_op.constant([1.0]) 5168 with self.assertRaisesRegex(ValueError, 'Value .* shape'): 5169 lazy_capture(2.0) 5170 5171 def testDeferredCaptureReturnNestWithCompositeTensor(self): 5172 i_s = indexed_slices.IndexedSlices( 5173 constant_op.constant([1, 2]), 5174 constant_op.constant([0, 1], dtype=dtypes.int64), 5175 constant_op.constant([2])) 5176 r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]) 5177 s_t = sparse_tensor.SparseTensor( 5178 values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20]) 5179 5180 @def_function.function 5181 def lazy_capture(): 5182 y = ops.get_default_graph().capture_call_time_value( 5183 lambda: {'i': i_s, 't': (r_t, s_t)}, 5184 {'i': indexed_slices.IndexedSlicesSpec( 5185 dtype=dtypes.int32, dense_shape_dtype=dtypes.int32), 5186 't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32), 5187 sparse_tensor.SparseTensorSpec([None], dtypes.int32))}) 5188 return y['i'], y['t'] 5189 5190 i, (r, s) = lazy_capture() 5191 self.assertAllEqual(i_s.values, i.values) 5192 self.assertAllEqual(i_s.indices, i.indices) 5193 self.assertAllEqual(i_s.dense_shape, i.dense_shape) 5194 self.assertAllEqual(r_t, r) 5195 self.assertAllEqual(s_t.indices, s.indices) 5196 self.assertAllEqual(s_t.values, s.values) 5197 self.assertAllEqual(s_t.dense_shape, s.dense_shape) 5198 5199 def testDeferredCaptureCompositeTensorSpecTypeMismatch(self): 5200 value = indexed_slices.IndexedSlices( 5201 constant_op.constant([1, 2]), 5202 constant_op.constant([0, 1], dtype=dtypes.int64)) 5203 5204 @def_function.function 5205 def lazy_capture(): 5206 return ops.get_default_graph().capture_call_time_value( 5207 lambda: value, 5208 indexed_slices.IndexedSlicesSpec(dtype=dtypes.int32)) 5209 5210 # Type matches spec. 5211 lazy_capture() 5212 5213 # Extra dense shape component. 5214 value = indexed_slices.IndexedSlices( 5215 constant_op.constant([1, 2]), 5216 constant_op.constant([0, 1], dtype=dtypes.int64), 5217 constant_op.constant([2])) 5218 with self.assertRaises(ValueError): 5219 lazy_capture() 5220 5221 # Index dtype mismatch int32 vs. int64. 5222 value = indexed_slices.IndexedSlices( 5223 constant_op.constant([1, 2]), 5224 constant_op.constant([0, 1])) 5225 with self.assertRaises(ValueError): 5226 lazy_capture() 5227 5228 def testFunctoolsLruCache(self): 5229 self.skipTest( 5230 "b/194845243: inspect.getfullargspec doesn't unwrap Python decorators.") 5231 5232 @def_function.function 5233 @functools.lru_cache(maxsize=2) 5234 def f(a): 5235 return 2 * a 5236 5237 self.assertAllEqual(f(1), array_ops.constant(2)) 5238 5239if __name__ == '__main__': 5240 ops.enable_eager_execution() 5241 test.main() 5242