1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import copy 22import functools 23import itertools 24import multiprocessing.pool 25import os 26import sys 27import time 28import weakref 29 30from absl.testing import parameterized 31import numpy 32 33from tensorflow.core.protobuf import config_pb2 34from tensorflow.core.protobuf import rewriter_config_pb2 35from tensorflow.python.autograph.core import ag_ctx 36from tensorflow.python.data.ops import dataset_ops 37from tensorflow.python.data.ops import iterator_ops 38from tensorflow.python.eager import backprop 39from tensorflow.python.eager import cancellation 40from tensorflow.python.eager import context 41from tensorflow.python.eager import def_function 42from tensorflow.python.eager import function 43from tensorflow.python.framework import composite_tensor 44from tensorflow.python.framework import config 45from tensorflow.python.framework import constant_op 46from tensorflow.python.framework import dtypes 47from tensorflow.python.framework import errors 48from tensorflow.python.framework import func_graph 49from tensorflow.python.framework import function as tf_function 50from tensorflow.python.framework import indexed_slices 51from tensorflow.python.framework import ops 52from tensorflow.python.framework import random_seed 53from tensorflow.python.framework import sparse_tensor 54from tensorflow.python.framework import tensor_shape 55from tensorflow.python.framework import tensor_spec 56from tensorflow.python.framework import test_ops 57from tensorflow.python.framework import test_util 58from tensorflow.python.framework import type_spec 59from tensorflow.python.layers import convolutional 60from tensorflow.python.module import module 61from tensorflow.python.ops import array_ops 62from tensorflow.python.ops import check_ops 63from tensorflow.python.ops import clip_ops 64from tensorflow.python.ops import control_flow_ops 65from tensorflow.python.ops import data_flow_ops 66from tensorflow.python.ops import functional_ops 67from tensorflow.python.ops import gen_functional_ops 68from tensorflow.python.ops import gen_random_ops 69from tensorflow.python.ops import gen_resource_variable_ops 70from tensorflow.python.ops import gen_sendrecv_ops 71from tensorflow.python.ops import gradients_impl 72from tensorflow.python.ops import init_ops 73from tensorflow.python.ops import list_ops 74from tensorflow.python.ops import logging_ops 75from tensorflow.python.ops import math_ops 76from tensorflow.python.ops import random_ops 77from tensorflow.python.ops import resource_variable_ops 78from tensorflow.python.ops import string_ops 79from tensorflow.python.ops import variable_scope 80from tensorflow.python.ops import variables 81from tensorflow.python.ops.ragged import ragged_factory_ops 82from tensorflow.python.ops.ragged import ragged_tensor 83from tensorflow.python.ops.structured import structured_tensor 84from tensorflow.python.platform import test 85from tensorflow.python.training import training_ops 86from tensorflow.python.util import compat 87from tensorflow.python.util import nest 88from tensorflow.python.util import tf_decorator 89from tensorflow.python.util import tf_inspect 90 91try: 92 import attr # pylint:disable=g-import-not-at-top 93except ImportError: 94 attr = None 95 96 97def total_function_cache(defined): 98 # pylint: disable=protected-access 99 return (set(defined._function_cache.primary) 100 | set(defined._function_cache.arg_relaxed)) 101 # pylint: enable=protected-access 102 103 104def _example_indexed_slices_with_dense_shape(): 105 return indexed_slices.IndexedSlices( 106 constant_op.constant([1, 2]), constant_op.constant([0, 1]), 107 constant_op.constant([2])) 108 109 110def _example_indexed_slices_without_dense_shape(): 111 return indexed_slices.IndexedSlices( 112 constant_op.constant([1, 2]), constant_op.constant([0, 1])) 113 114 115def _spec_for_value(value): 116 """Returns the (nested) TypeSpec for a value.""" 117 if nest.is_sequence(value): 118 return nest.map_structure(_spec_for_value, value) 119 elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): 120 return type_spec.type_spec_from_value(value) 121 else: 122 return value 123 124 125# This dummy decorator imitates ordinary decorators utilizing tf_decorator. 126def dummy_tf_decorator(method): 127 128 def wrapper(*args, **kwargs): 129 return method(*args, **kwargs) 130 131 return tf_decorator.make_decorator(method, wrapper) 132 133 134class FunctionTest(test.TestCase, parameterized.TestCase): 135 136 def setUp(self): 137 super(FunctionTest, self).setUp() 138 cpus = config.list_physical_devices('CPU') 139 # Set 4 virtual CPUs 140 config.set_logical_device_configuration(cpus[0], [ 141 context.LogicalDeviceConfiguration(), 142 context.LogicalDeviceConfiguration(), 143 context.LogicalDeviceConfiguration(), 144 context.LogicalDeviceConfiguration() 145 ]) 146 147 def testBasic(self): 148 matmul = def_function.function(math_ops.matmul) 149 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 150 sq = matmul(t, t, transpose_a=True) 151 sq2 = matmul(sq, t, transpose_a=True) 152 self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) 153 self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108]) 154 155 def testOnExitCallback(self): 156 values = [] 157 def append_1(): 158 values.append(1) 159 160 def append_2(): 161 values.append(2) 162 163 def g(x): 164 old_values = list(values) 165 ops.add_exit_callback_to_default_func_graph(append_1) 166 self.assertEqual(old_values, values) 167 return x + 1 168 169 tf_g = def_function.function(g) 170 171 def f(x): 172 old_values = list(values) 173 ops.add_exit_callback_to_default_func_graph(append_2) 174 self.assertEqual(old_values, values) 175 return tf_g(x) 176 177 tf_f = def_function.function(f) 178 self.assertEmpty(values) 179 tf_f(constant_op.constant(1.0)) 180 self.assertEqual(values, [1, 2]) # Once for g, once for f. 181 tf_f(constant_op.constant([1.0])) # force a retrace 182 self.assertEqual(values, [1, 2, 1, 2]) # And again. 183 184 def testCannotAddExitCallbackWhenNotInFunctionScope(self): 185 with self.assertRaisesRegex(RuntimeError, 'when not building a function.'): 186 ops.add_exit_callback_to_default_func_graph(lambda: None) 187 188 def testVariable(self): 189 v1 = variables.Variable(1.0) 190 add = def_function.function(lambda x, v: x + v1 + v) 191 v2 = variables.Variable(1.0) 192 x = constant_op.constant(1.0) 193 r = add(x, v2) 194 self.assertEqual(3.0, self.evaluate(r)) 195 196 def testVariableOnly(self): 197 v = variables.Variable(1.0) 198 add = def_function.function(lambda x: x.assign_add(1.0)) 199 r1 = add(v) 200 self.assertEqual(2.0, self.evaluate(r1)) 201 c = constant_op.constant(1.0) 202 with self.assertRaisesRegex(AttributeError, 'no attribute'): 203 add(c) 204 205 @test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.') 206 def testPackedVariable(self): 207 with ops.device('/cpu:0'): 208 v0_0 = resource_variable_ops.ResourceVariable(1.0) 209 with ops.device('/cpu:1'): 210 v0_1 = resource_variable_ops.ResourceVariable(2.0) 211 v1_0 = resource_variable_ops.ResourceVariable(3.0) 212 with ops.device('/cpu:2'): 213 v1_1 = resource_variable_ops.ResourceVariable(4.0) 214 215 packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle]) 216 packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle]) 217 218 # TODO(b/145922293): use ResourceVariable.assign_add and 219 # ResourceVariable.read_value directly once we support packing multiple 220 # ResourceVariable into one ResourceVariable. 221 @def_function.function 222 def read_var(): 223 resource_variable_ops.assign_add_variable_op( 224 packed_var_0, constant_op.constant(5.0)) 225 resource_variable_ops.assign_add_variable_op( 226 packed_var_1, constant_op.constant(6.0)) 227 with ops.device('/cpu:0'): 228 read0 = resource_variable_ops.read_variable_op( 229 packed_var_0, dtype=dtypes.float32) 230 with ops.device('/cpu:1'): 231 read1 = resource_variable_ops.read_variable_op( 232 packed_var_0, dtype=dtypes.float32) 233 read2 = resource_variable_ops.read_variable_op( 234 packed_var_1, dtype=dtypes.float32) 235 with ops.device('/cpu:2'): 236 read3 = resource_variable_ops.read_variable_op( 237 packed_var_1, dtype=dtypes.float32) 238 239 return read0, read1, read2, read3 240 241 arg_attrs = read_var.get_concrete_function().function_def.arg_attr 242 self.assertLen(arg_attrs, 2) 243 self.assertEqual(arg_attrs[0].attr['_composite_device'].s, 244 compat.as_bytes(packed_var_0.device)) 245 self.assertEqual(arg_attrs[1].attr['_composite_device'].s, 246 compat.as_bytes(packed_var_1.device)) 247 248 self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6)) 249 250 def testImplementsAttributeBasic(self): 251 v = def_function.function( 252 experimental_implements='func')(lambda x, y: x + y) 253 with context.graph_mode(), self.cached_session(): 254 a = array_ops.placeholder(dtypes.float32, ()) 255 b = array_ops.placeholder(dtypes.float32, ()) 256 v(a, b) 257 gradients_impl.gradients(v(a, b), [a, b]) 258 fdefs = ops.get_default_graph().as_graph_def().library.function 259 self.assertLen(fdefs, 3) 260 not_present = 0 261 present = 0 262 for f in fdefs: 263 name = f.signature.name 264 if 'forward' in name or 'backward' in name: 265 not_present += 1 266 self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f) 267 else: 268 present += 1 269 self.assertEqual(f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME].s, 270 'func'.encode('ascii'), f) 271 self.assertEqual(not_present, 2, fdefs) 272 self.assertEqual(present, 1, fdefs) 273 274 def testImplementsAttributeAssertsOnSideInput(self): 275 with context.graph_mode(), self.cached_session(): 276 z = array_ops.zeros(0) 277 v = def_function.function( 278 experimental_implements='func')(lambda x, y: x + y + z) 279 a = array_ops.ones((1.0,)) 280 b = array_ops.ones((1.0,)) 281 with self.assertRaisesRegex(AssertionError, 282 'variables are always captured'): 283 v(a, b) 284 functions = ops.get_default_graph().as_graph_def().library.function 285 self.assertEmpty(functions) 286 287 def testImplementsAttributeWorksWithGradientTape(self): 288 add = lambda x, y: x + y ** 2 289 add = def_function.function(experimental_implements='MyFunc')(add) 290 x = variables.Variable(3.0) 291 y = variables.Variable(2.0) 292 293 with backprop.GradientTape() as tape: 294 g = add(x, y) 295 296 dg_dy, dg_dx = tape.gradient(g, [y, x]) 297 self.assertEqual(dg_dy.numpy(), 4.0) 298 self.assertEqual(dg_dx.numpy(), 1.0) 299 300 def testImplementsAttributeWorksOnVariables(self): 301 with context.graph_mode(), self.cached_session(): 302 v = def_function.function( 303 experimental_implements='func')(lambda x, y: x + y) 304 a = variables.Variable((1.0,)) 305 b = variables.Variable((1.0,)) 306 r1 = v(a, b) 307 _ = v(a, a) 308 functions = ops.get_default_graph().as_graph_def().library.function 309 # Verify that we created only one function 310 self.assertLen(functions, 1) 311 # Verify that eval() reads the current values. 312 a.initializer.run() 313 b.initializer.run() 314 self.assertEqual(r1.eval(), 2) 315 316 a.assign_add([1]).eval() 317 self.assertEqual(r1.eval(), 3) 318 319 def testImplementsAttributeWorksOnConstants(self): 320 with context.graph_mode(), self.cached_session(): 321 v = def_function.function( 322 experimental_implements='func')(lambda x, y: x + y) 323 a = variables.Variable(1.0) 324 r1 = v(a, 2.) 325 r2 = v(2., a) 326 functions = ops.get_default_graph().as_graph_def().library.function 327 self.assertLen(functions, 1) 328 self.assertLen(functions[0].signature.input_arg, 2) 329 # Verify that eval() reads the current values. 330 a.initializer.run() 331 self.assertEqual(r1.eval(), 3) 332 self.assertEqual(r2.eval(), 3) 333 334 def testImplementsAttributeSpecializes(self): 335 with context.graph_mode(), self.cached_session(): 336 v = def_function.function( 337 experimental_implements='func')(lambda x, y: x + y) 338 a = variables.Variable(1.0) 339 r1 = v(a, [2.]) 340 r2 = v([2., 2], a) 341 functions = ops.get_default_graph().as_graph_def().library.function 342 self.assertLen(functions, 2) 343 # Ensure that all parameters are still there and haven't been inlined! 344 345 self.assertLen(functions[0].signature.input_arg, 2) 346 self.assertLen(functions[1].signature.input_arg, 2) 347 # Verify that eval() reads the current values. 348 a.initializer.run() 349 numpy.testing.assert_equal(r1.eval(), [3.]) 350 numpy.testing.assert_equal(r2.eval(), [3., 3.]) 351 352 def testImplementsWorksWithTensorSpec(self): 353 v = def_function.function( 354 experimental_implements='func')(lambda x, y: x + y) 355 v = v.get_concrete_function( 356 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 357 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)) 358 x = v(1., 2.) 359 self.assertEqual(x.numpy(), 3.) 360 361 def testImplementsAttributeAsNameAttrList(self): 362 implements_attr = ( 363 'name: "embedding_matmul" attr { key: "key1" value { i: 2 } ' 364 '} attr { key: "key2" value { b: false } }') 365 v = def_function.function( 366 experimental_implements=implements_attr)(lambda x, y: x + y) 367 with context.graph_mode(), self.cached_session(): 368 a = array_ops.placeholder(dtypes.float32, ()) 369 b = array_ops.placeholder(dtypes.float32, ()) 370 v(a, b) 371 gradients_impl.gradients(v(a, b), [a, b]) 372 fdefs = ops.get_default_graph().as_graph_def().library.function 373 self.assertLen(fdefs, 3) 374 not_present = 0 375 present = 0 376 for f in fdefs: 377 name = f.signature.name 378 if 'forward' in name or 'backward' in name: 379 not_present += 1 380 self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f) 381 else: 382 present += 1 383 attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME] 384 self.assertIsNotNone(attr_value.func, f) 385 self.assertEqual(attr_value.func.name, 'embedding_matmul') 386 name_attrs = attr_value.func.attr 387 self.assertLen(name_attrs, 2) 388 self.assertEqual(not_present, 2, fdefs) 389 self.assertEqual(present, 1, fdefs) 390 391 def testExternalControlDependency(self): 392 with ops.Graph().as_default(), self.test_session(): 393 v = variables.Variable(1.0) 394 v.initializer.run() 395 396 op = v.assign_add(1.0) 397 398 @function.defun 399 def f(): 400 with ops.control_dependencies([op]): 401 return 1.0 402 403 self.evaluate(f()) 404 self.assertAllEqual(self.evaluate(v), 2.0) 405 406 def testInputShapeFunctionRelaxation(self): 407 unknown_dim = [False] 408 409 @function.defun(experimental_relax_shapes=True) 410 def func(a): 411 if a._shape_tuple()[0] is None: 412 unknown_dim[0] = True 413 return a + 1 414 415 func(constant_op.constant([])) 416 self.assertFalse(unknown_dim[0]) 417 self.assertLen(total_function_cache(func), 1) 418 419 func(constant_op.constant([1.0])) 420 self.assertFalse(unknown_dim[0]) 421 self.assertLen(total_function_cache(func), 2) 422 423 func(constant_op.constant([1.0, 2.0])) 424 self.assertTrue(unknown_dim[0]) 425 self.assertLen(total_function_cache(func), 2) 426 427 def testInputShapeRelaxationOnInstanceMethod(self): 428 # Test that experimental_relax_shapes is passed during 429 # instance method bounding. 430 unknown_dim = [False] 431 432 class Foo(object): 433 434 @def_function.function(experimental_relax_shapes=True) 435 def func(self, a): 436 if a._shape_tuple()[0] is None: 437 unknown_dim[0] = True 438 return a + 1 439 440 foo = Foo() 441 foo.func(constant_op.constant([])) 442 self.assertFalse(unknown_dim[0]) 443 444 foo.func(constant_op.constant([1.0])) 445 self.assertFalse(unknown_dim[0]) 446 447 foo.func(constant_op.constant([1.0, 2.0])) 448 self.assertTrue(unknown_dim[0]) 449 450 def testInputShapeFunctionRelaxationWithRaggedTensors(self): 451 traced_type_spec = [None] 452 453 @def_function.function(experimental_relax_shapes=True) 454 def func(x): 455 traced_type_spec[0] = x._type_spec 456 return x 457 458 def check_trace(x, expected_trace): 459 traced_type_spec[0] = None 460 func(x) 461 self.assertEqual(traced_type_spec[0], expected_trace) 462 463 check_trace( # Initial call gets traced. 464 ragged_factory_ops.constant([[1], [2, 3, 4]]), 465 ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)) 466 check_trace( # Input TypeSpec is the same -> no retrace. 467 ragged_factory_ops.constant([[1, 2], [3, 4]]), None) 468 check_trace( # Even if component tensor shapes change -> no retrace. 469 ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None) 470 check_trace( # Different TypeSpec shape (nrows): retrace 471 ragged_factory_ops.constant([[1], [2], [3]]), 472 ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32)) 473 check_trace( # Different nrows again: relax & retrace 474 ragged_factory_ops.constant([[1], [2], [3], [4]]), 475 ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)) 476 check_trace( # Different nrows yet again: not retrace 477 ragged_factory_ops.constant([[1]]), None) 478 check_trace( # Different ragged_rank: retrace 479 ragged_factory_ops.constant([[[1]]]), 480 ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32)) 481 check_trace( # Different ragged_rank again: retrace & relax 482 ragged_factory_ops.constant([[[1]], [[2]]]), 483 ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)) 484 485 def testInputShapeFunctionRelaxationWithStructuredTensors(self): 486 traced_type_spec = [None] 487 488 @def_function.function(experimental_relax_shapes=True) 489 def func(x): 490 traced_type_spec[0] = x._type_spec 491 return x 492 493 def check_trace(x, expected_trace): 494 traced_type_spec[0] = None 495 func(x) 496 self.assertEqual(traced_type_spec[0], expected_trace) 497 498 # If we have TypeSpecs that differ in ways other than just their shape, 499 # then retrace each time. 500 check_trace( 501 structured_tensor.StructuredTensor.from_pyval({'a': [1]}), 502 structured_tensor.StructuredTensorSpec( 503 [], {'a': tensor_spec.TensorSpec((1,), dtypes.int32)})) 504 check_trace( 505 structured_tensor.StructuredTensor.from_pyval({'b': [1]}), 506 structured_tensor.StructuredTensorSpec( 507 [], {'b': tensor_spec.TensorSpec((1,), dtypes.int32)})) 508 check_trace( 509 structured_tensor.StructuredTensor.from_pyval({'c': [1]}), 510 structured_tensor.StructuredTensorSpec( 511 [], {'c': tensor_spec.TensorSpec((1,), dtypes.int32)})) 512 513 # But if we call again with only shape different, then do relax: 514 check_trace( # retrace 515 structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}), 516 structured_tensor.StructuredTensorSpec( 517 [], {'a': tensor_spec.TensorSpec((2,), dtypes.int32)})) 518 check_trace( # relax & retrace 519 structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}), 520 structured_tensor.StructuredTensorSpec( 521 [], {'a': tensor_spec.TensorSpec((None,), dtypes.int32)})) 522 check_trace( # use relaxed graph 523 structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}), 524 None) 525 526 def testInputShapeFunctionRelaxationWithDatasetIterators(self): 527 # For dataset iterators, the TypeSpec includes type information that's 528 # not derivable from the component tensors. Make sure that the TypeSpec 529 # shapes get relaxed as appropriate. 530 531 traced_type_spec = [None] 532 533 @def_function.function(experimental_relax_shapes=True) 534 def func(x): 535 traced_type_spec[0] = x._type_spec 536 return x 537 538 def check_trace(x, expected_trace): 539 traced_type_spec[0] = None 540 func(x) 541 self.assertEqual(traced_type_spec[0], expected_trace) 542 543 ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2])) 544 ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2])) 545 ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2])) 546 ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2])) 547 ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1])) 548 check_trace( # shape=[1, 2]: retrace 549 dataset_ops.make_one_shot_iterator(ds_1_2), 550 iterator_ops.IteratorSpec( 551 tensor_spec.TensorSpec([1, 2], dtypes.float32))) 552 check_trace( # shape=[1, 2]: no retrace (use the [1, 2] graph) 553 dataset_ops.make_one_shot_iterator(ds_1_2), None) 554 check_trace( # shape=[2, 2]: retrace 555 dataset_ops.make_one_shot_iterator(ds_2_2), 556 iterator_ops.IteratorSpec( 557 tensor_spec.TensorSpec([2, 2], dtypes.float32))) 558 check_trace( # shape=[3, 2]: relax to [None, 2] and retrace 559 dataset_ops.make_one_shot_iterator(ds_3_2), 560 iterator_ops.IteratorSpec( 561 tensor_spec.TensorSpec([None, 2], dtypes.float32))) 562 check_trace( # shape=[4, 2]: no retrace (use the [None, 2] graph) 563 dataset_ops.make_one_shot_iterator(ds_4_2), None) 564 check_trace( # shape=[2, 1]: relax to [None, None] and retrace 565 dataset_ops.make_one_shot_iterator(ds_2_1), 566 iterator_ops.IteratorSpec( 567 tensor_spec.TensorSpec([None, None], dtypes.float32))) 568 569 def testCapturesVariables(self): 570 a = variables.Variable(1.0, trainable=False) 571 b = variables.Variable(1.0) 572 cc = [None] 573 574 @def_function.function 575 def f(): 576 c = cc[0] 577 if c is None: 578 c = cc[0] = variables.Variable(1.) 579 return a + b + c + 1 580 581 cf = f.get_concrete_function() 582 c = cc[0] 583 584 captured_variables = {v.ref() for v in (a, b, c)} 585 trainable_variables = {v.ref() for v in (b, c)} 586 self.assertEqual({v.ref() for v in cf.variables}, captured_variables) 587 self.assertEqual({v.ref() for v in cf.trainable_variables}, 588 trainable_variables) 589 self.assertEqual(cf.variables, cf.graph.variables) 590 self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables) 591 592 def testNestedInputShapeFunctionRelaxation(self): 593 unknown_dim = [False] 594 595 @function.defun(experimental_relax_shapes=True) 596 def func(a_, b_=None): 597 del a_ # Only used to check which cache is used. 598 self.assertEqual(b_[0]._shape_tuple(), ()) 599 if b_[1]._shape_tuple()[0] is None: 600 unknown_dim[0] = True 601 return b_[0] + 1 602 603 a = 'hi' 604 b0 = constant_op.constant(1.0) 605 func(a, b_=[b0, constant_op.constant([])]) 606 self.assertFalse(unknown_dim[0]) 607 self.assertLen(total_function_cache(func), 1) 608 609 func(a, b_=[b0, constant_op.constant([1.0])]) 610 self.assertFalse(unknown_dim[0]) 611 self.assertLen(total_function_cache(func), 2) 612 613 func(a, b_=[b0, constant_op.constant([1.0, 1.0])]) 614 self.assertTrue(unknown_dim[0]) 615 self.assertLen(total_function_cache(func), 2) 616 617 unknown_dim[0] = False 618 619 # Now do the same except with a new a which is not a tensor; this should 620 # change the cache key. 621 a = 'bye' 622 func(a, b_=[b0, constant_op.constant([])]) 623 self.assertFalse(unknown_dim[0]) 624 self.assertLen(total_function_cache(func), 3) 625 626 # Since we already marked a cache miss for a function with the same 627 # non-input signatures, here we will immediately start relaxing shapes. 628 func(a, b_=[b0, constant_op.constant([1.0])]) 629 self.assertTrue(unknown_dim[0]) 630 self.assertLen(total_function_cache(func), 3) 631 632 def testNestedShapeFunctionRelaxation(self): 633 634 got_shape = [None] 635 636 # The inner function will go through shape relaxation because the shapes it 637 # receives will be [1], [2], [3], ... 638 @def_function.function(experimental_relax_shapes=True) 639 def bar(x_shape): 640 got_shape[0] = x_shape._shape_tuple() 641 return x_shape 642 643 # The outer function will not go through shape relaxation because the shapes 644 # it receives will be [1], [[1]], [[[1]]], ... 645 @def_function.function(experimental_relax_shapes=True) 646 def foo(ones): 647 return bar(array_ops.shape(ones)) 648 649 for rank in range(1, 6): 650 x_shape = self.evaluate(foo(array_ops.ones([1] * rank))) 651 self.assertAllEqual(x_shape, [1] * rank) 652 if rank < 3: 653 self.assertEqual(got_shape[0], (rank,)) 654 else: 655 self.assertEqual(got_shape[0], (None,)) 656 657 def testNoHash(self): 658 659 @def_function.function() 660 def f(_): 661 return 1.0 662 663 with self.assertRaisesRegex(ValueError, r'Got type: set'): 664 f(set([])) 665 666 def testFuncName(self): 667 668 @function.defun_with_attributes(attributes={'func_name': 'multiply'}) 669 def add(x, y): 670 _ = x * y 671 return x + y 672 673 @function.defun 674 def add_2(x, y): 675 _ = x * y 676 return x + y 677 678 self.assertEqual(add._name, 'multiply') 679 self.assertEqual(add_2._name, 'add_2') 680 681 def testBasicGraphMode(self): 682 matmul = def_function.function(math_ops.matmul) 683 684 @def_function.function 685 def sq(a): 686 return matmul(a, a) 687 688 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 689 out = sq(t) 690 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 691 692 def testNestedInputsGraphMode(self): 693 matmul = def_function.function(math_ops.matmul) 694 695 pair = collections.namedtuple('pair', ['a', 'b']) 696 697 @def_function.function 698 def a_times_b(inputs): 699 return matmul(inputs.a['a'], inputs.b['b']) 700 701 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 702 703 out = a_times_b(pair({'a': t}, {'b': t})) 704 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 705 706 def testNestedOutputsGraphMode(self): 707 matmul = def_function.function(math_ops.matmul) 708 709 pair = collections.namedtuple('pair', ['a', 'b']) 710 711 @def_function.function() 712 def pairs_mul(pair_a, pair_b): 713 return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b)) 714 715 a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]]) 716 b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]]) 717 718 out = pairs_mul(pair(a, b), pair(b, a)) 719 expected = pair(math_ops.matmul(a, b).numpy(), 720 math_ops.matmul(b, a).numpy()) 721 self.assertAllClose(out, expected) 722 723 @parameterized.named_parameters( 724 dict(testcase_name='Defun', 725 function_decorator=function.defun), 726 dict(testcase_name='DefFunction', 727 function_decorator=def_function.function)) 728 def testNestedFunctionGraphNotOutOfDate(self, function_decorator): 729 @function_decorator 730 def f(): 731 return constant_op.constant(1.) 732 733 class _Model(object): 734 735 @function_decorator 736 def g(self): 737 self.f = f.get_concrete_function() 738 739 model = _Model() 740 model.g() 741 concrete = model.f 742 weak_g_graph = weakref.ref(model.g.get_concrete_function().graph) 743 self.assertIs(weak_g_graph(), concrete.graph.outer_graph) 744 weak_g = weakref.ref(model.g) 745 del model 746 self.assertIsNone(weak_g()) 747 self.assertIsNone(weak_g_graph()) 748 self.assertIsNotNone(concrete.graph.outer_graph) 749 self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph) 750 751 def testGraphEagerIsolation(self): 752 753 @function.defun 754 def f(): 755 self.v = variables.Variable(1.0) 756 return self.v.read_value() 757 758 self.assertAllEqual(f(), 1.0) 759 760 with ops.Graph().as_default(): 761 self.assertEqual(f().shape, ()) 762 763 def testBasicGraphFunction(self): 764 matmul = def_function.function(math_ops.matmul) 765 766 @def_function.function 767 def sq(a): 768 return matmul(a, a) 769 770 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 771 772 sq_op = sq.get_concrete_function(t) 773 self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) 774 out = sq_op(t) 775 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 776 777 def testGetConcreteFunctionThreadSafety(self): 778 779 @def_function.function 780 def sq(): 781 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 782 return math_ops.matmul(t, t) 783 784 concrete_functions = [] 785 786 def thread_func(_): 787 cf = sq.get_concrete_function() 788 concrete_functions.append(cf) 789 790 num_threads = 100 791 pool = multiprocessing.pool.ThreadPool(num_threads) 792 _ = pool.map(thread_func, list(range(num_threads))) 793 794 self.assertLen(set(concrete_functions), 1) 795 796 def testGetConcreteFunctionThreadSafetyWithArgs(self): 797 @def_function.function 798 def add_100(*args): 799 return math_ops.add_n(args) 800 801 p = multiprocessing.pool.ThreadPool(2) 802 args = (constant_op.constant(1.),) * 100 803 f1, f2 = p.map(add_100.get_concrete_function, [args] * 2) 804 # I see about len(args) + max(0, len(args) - 3) arguments expected. 805 f1(*args) 806 del f2 807 808 def testInputSpecGraphFunction(self): 809 matmul = def_function.function(math_ops.matmul) 810 811 @def_function.function 812 def sq(a): 813 return matmul(a, a) 814 815 sq_op = sq.get_concrete_function( 816 tensor_spec.TensorSpec((None, None), dtypes.float32)) 817 self.assertEqual([None, None], sq_op.output_shapes.as_list()) 818 819 t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 820 out1 = sq_op(t1) 821 self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy()) 822 823 t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 824 out2 = sq_op(t2) 825 self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy()) 826 827 def testNestedInputSpecGraphFunction(self): 828 matmul = def_function.function(math_ops.matmul) 829 830 @def_function.function 831 def sq(mats): 832 ((a, b),) = mats 833 return matmul(a, b) 834 835 sq_op_autonamed = sq.get_concrete_function( 836 [(tensor_spec.TensorSpec((None, None), dtypes.float32), 837 tensor_spec.TensorSpec((None, None), dtypes.float32))]) 838 self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list()) 839 840 sq_op = sq.get_concrete_function( 841 [(tensor_spec.TensorSpec((None, None), dtypes.float32, 842 name='first_mat'), 843 tensor_spec.TensorSpec((None, None), dtypes.float32, 844 name='second_mat'))]) 845 self.assertEqual([None, None], sq_op.output_shapes.as_list()) 846 847 t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 848 t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]]) 849 out = sq_op(first_mat=t1, second_mat=t2) 850 self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy()) 851 self.assertAllEqual(sq_op_autonamed(t1, t2), 852 math_ops.matmul(t1, t2).numpy()) 853 854 def testExecutingStatelessDefunConcurrently(self): 855 856 @def_function.function 857 def stateless(x): 858 return math_ops.multiply(2.0, x) 859 860 pool = multiprocessing.pool.ThreadPool() 861 inputs = [constant_op.constant(1.0 * x) for x in range(100)] 862 outputs = [float(out) for out in pool.map(stateless, inputs)] 863 expected = [float(2.0 * x) for x in inputs] 864 self.assertSequenceEqual(outputs, expected) 865 866 def testExecutingManyStatelessDefunsConcurrently(self): 867 868 @def_function.function 869 def stateless(x): 870 del x 871 return math_ops.multiply(2.0, 2.0) 872 873 pool = multiprocessing.pool.ThreadPool() 874 # `pool.map` below instantiates 100 functions, one for each object. 875 objects = [object() for _ in range(100)] 876 outputs = [float(out) for out in pool.map(stateless, objects)] 877 expected = [4.0] * 100 878 self.assertSequenceEqual(outputs, expected) 879 880 @test_util.disable_tfrt('b/169431085: This test is flaky on tfrt') 881 def testExecutingStatefulDefunConcurrently(self): 882 883 v = resource_variable_ops.ResourceVariable(1.0) 884 885 @def_function.function 886 def stateful(x): 887 v.assign(x) 888 889 pool = multiprocessing.pool.ThreadPool() 890 inputs = [constant_op.constant(0.0)] * 100 891 pool.map(stateful, inputs) 892 self.assertEqual(float(v.read_value()), 0.0) 893 894 def testExecutingManyStatefulDefunsConcurrently(self): 895 896 v = resource_variable_ops.ResourceVariable(1.0) 897 898 @def_function.function 899 def stateful(x): 900 del x 901 return v.assign(0.0) 902 903 pool = multiprocessing.pool.ThreadPool() 904 # `pool.map` below instantiates 100 functions, one for each object. 905 pool.map(stateful, [object() for _ in range(100)]) 906 self.assertEqual(float(v.read_value()), 0.0) 907 908 def testShareRendezvous(self): 909 910 # Disable grappler from inlining the functions. Note we run the send & recv 911 # in graph mode since with eager mode the function should automatically be 912 # inlined. 913 context.context().set_optimizer_experimental_options( 914 {'disable_meta_optimizer': True}) 915 916 cpu = '/device:CPU:0' 917 918 signature = [tensor_spec.TensorSpec([], dtypes.int32)] 919 920 @def_function.function 921 def send(): 922 x = constant_op.constant(1) 923 gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu) 924 return x 925 926 send._shared_rendezvous = True # pylint: disable=protected-access 927 928 @def_function.function(input_signature=signature) 929 def send_body(n): 930 send() 931 return n - 1 932 933 @def_function.function 934 def recv(): 935 return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu) 936 937 recv._shared_rendezvous = True # pylint: disable=protected-access 938 939 @def_function.function(input_signature=signature) 940 def recv_body(n): 941 recv() 942 return n - 1 943 944 @def_function.function(input_signature=signature) 945 def cond(n): 946 return n > 0 947 948 # Instead of calling the send & recv functions directly we want to call them 949 # through a functional while to ensure the rendezvous is shared across the 950 # while boundary. 951 @def_function.function 952 def fn(n): 953 functional_ops.While([n], cond.get_concrete_function(), 954 send_body.get_concrete_function()) 955 return functional_ops.While([n], cond.get_concrete_function(), 956 recv_body.get_concrete_function()) 957 958 # Use a graph context since functions will not be automatically inlined 959 with context.graph_mode(), self.cached_session(): 960 self.evaluate(fn(2)) 961 962 def disabled_testRandomSeed(self): 963 964 @def_function.function 965 def f(): 966 return random_ops.random_normal(()) 967 968 random_seed.set_random_seed(1) 969 x = f() 970 self.assertNotEqual(x, f()) 971 random_seed.set_random_seed(1) 972 self.assertAllEqual(f(), x) 973 974 def testNestedInputsGraphFunction(self): 975 matmul = def_function.function(math_ops.matmul) 976 977 pair = collections.namedtuple('pair', ['a', 'b']) 978 979 @def_function.function 980 def a_times_b(inputs): 981 return matmul(inputs.a['a'], inputs.b['b']) 982 983 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 984 sq_op = a_times_b.get_concrete_function( 985 pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')), 986 dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b')))) 987 self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) 988 out = sq_op(a=t, b=t) 989 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 990 991 def testNestedOutputGraphFunction(self): 992 matmul = def_function.function(math_ops.matmul) 993 994 @def_function.function 995 def sq(a): 996 return (matmul(a, a), {'b': constant_op.constant(1.0)}) 997 998 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 999 1000 sq_op = sq.get_concrete_function(t) 1001 self.assertEqual(sq_op.output_shapes, 1002 (tensor_shape.TensorShape([2, 2]), 1003 {'b': tensor_shape.TensorShape([])})) 1004 self.assertEqual(sq_op.output_dtypes, 1005 (dtypes.float32, {'b': dtypes.float32})) 1006 (a, b) = sq_op(t) 1007 self.assertAllEqual(a, math_ops.matmul(t, t).numpy()) 1008 self.assertAllEqual(b['b'].numpy(), 1.0) 1009 1010 def testGraphFunctionNoneOutput(self): 1011 @def_function.function 1012 def fn(unused_a, unused_b): 1013 return None 1014 1015 x = constant_op.constant(1) 1016 fn_op = fn.get_concrete_function(x, x) 1017 self.assertEqual(fn_op.output_dtypes, None) 1018 self.assertEqual(fn_op.output_shapes, None) 1019 self.assertAllEqual(fn_op(x, x), None) 1020 1021 def testDefunNumpyArraysConvertedToTensors(self): 1022 1023 def f(x): 1024 self.assertIsInstance(x, ops.Tensor) 1025 return x 1026 1027 x = random_ops.random_uniform([2, 2]).numpy() 1028 defined = function.defun(f) 1029 defined(x) 1030 self.assertLen(total_function_cache(defined), 1) 1031 1032 x = random_ops.random_uniform([2, 2]).numpy() 1033 defined(x) 1034 # A NumPy array with different values but the same shape and dtype 1035 # shouldn't trigger another function definition. 1036 self.assertLen(total_function_cache(defined), 1) 1037 1038 np_ones = numpy.ones([], numpy.float32) 1039 np_zeros = numpy.zeros([], numpy.float32) 1040 tf_ones = array_ops.ones([]) 1041 tf_zeros = array_ops.zeros([]) 1042 1043 # Test that the numpy array is properly an argument to the graph function. 1044 self.assertEqual(1., defined(np_ones).numpy()) 1045 self.assertLen(total_function_cache(defined), 2) 1046 self.assertEqual(0., defined(np_zeros).numpy()) 1047 self.assertEqual(1., defined(tf_ones).numpy()) 1048 self.assertEqual(0., defined(tf_zeros).numpy()) 1049 self.assertLen(total_function_cache(defined), 2) 1050 1051 # Test that mutable inputs are supported. 1052 mutable = numpy.ones([], numpy.float32) 1053 self.assertEqual(1., defined(mutable).numpy()) 1054 mutable.fill(0) 1055 self.assertEqual(0., defined(mutable).numpy()) 1056 1057 class MyNdarray(numpy.ndarray): 1058 pass 1059 1060 # Test that the subclasses of ndarray are converted too. 1061 self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy()) 1062 self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy()) 1063 1064 # We should not have triggered any re-tracing of the python function. 1065 self.assertLen(total_function_cache(defined), 2) 1066 1067 def testNumpyDtypeInputSupported(self): 1068 @function.defun 1069 def f(x, dtype): 1070 return constant_op.constant(dtype(x)) 1071 1072 self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1)) 1073 self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2)) 1074 self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1)) 1075 self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2)) 1076 1077 def testDefunNumpyArraysConvertedToTensorsInKwargs(self): 1078 1079 def f(**kwargs): 1080 x = kwargs.pop('x') 1081 self.assertIsInstance(x, ops.Tensor) 1082 return x 1083 1084 x = random_ops.random_uniform([2, 2]).numpy() 1085 defined = function.defun(f) 1086 defined(x=x) 1087 self.assertLen(total_function_cache(defined), 1) 1088 1089 x = random_ops.random_uniform([2, 2]).numpy() 1090 defined(x=x) 1091 # A NumPy array with different values but the same shape and dtype 1092 # shouldn't trigger another function definition. 1093 self.assertLen(total_function_cache(defined), 1) 1094 1095 # Test that the numpy array is properly an argument to the graph function. 1096 self.assertEqual(1., defined(x=numpy.ones([])).numpy()) 1097 self.assertEqual(0., defined(x=numpy.zeros([])).numpy()) 1098 self.assertEqual(1., defined(x=array_ops.ones([])).numpy()) 1099 self.assertEqual(0., defined(x=array_ops.zeros([])).numpy()) 1100 1101 def testDefunCapturedInt32(self): 1102 x = constant_op.constant(1, dtype=dtypes.int32) 1103 1104 @def_function.function 1105 def add_int32s(): 1106 return x + x 1107 1108 self.assertEqual(2, int(add_int32s())) 1109 1110 def testDefunReadVariable(self): 1111 v = resource_variable_ops.ResourceVariable(1.0) 1112 1113 @def_function.function 1114 def f(): 1115 return v.read_value() 1116 1117 self.assertEqual(1.0, float(f())) 1118 1119 def testDefunAssignAddVariable(self): 1120 v = resource_variable_ops.ResourceVariable(1.0) 1121 x = constant_op.constant(2.0) 1122 1123 @def_function.function 1124 def test_assign_add(): 1125 v.assign_add(x) 1126 return v.read_value() 1127 1128 self.assertEqual(3.0, float(test_assign_add())) 1129 1130 @test_util.run_in_graph_and_eager_modes 1131 def testTensorInitializationInFunctionRaisesError(self): 1132 error_msg = ('Tensor-typed variable initializers must either be ' 1133 'wrapped in an init_scope or callable.*') 1134 1135 @def_function.function 1136 def tensor_init(): 1137 with self.assertRaisesRegex(ValueError, error_msg): 1138 resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) 1139 1140 tensor_init() 1141 1142 @test_util.run_in_graph_and_eager_modes 1143 def testCallableTensorInitializationInFunction(self): 1144 1145 @def_function.function 1146 def tensor_init(): 1147 self.v = resource_variable_ops.ResourceVariable( 1148 lambda: constant_op.constant(2.0)) 1149 return self.v.read_value() 1150 1151 value = tensor_init() 1152 if not context.executing_eagerly(): 1153 self.evaluate(variables.global_variables_initializer()) 1154 self.assertEqual(self.evaluate(value), 2.0) 1155 1156 @test_util.also_run_as_tf_function 1157 def testInitScopeTensorInitializationInFunction(self): 1158 1159 @def_function.function 1160 def tensor_init(): 1161 with ops.init_scope(): 1162 const = constant_op.constant(2.0) 1163 # Note: this variable bypasses tf.function's variable creation 1164 # requirements by bypassing variable_creator_scope by using 1165 # ResourceVariable instead of Variable. 1166 self.v = resource_variable_ops.ResourceVariable(const) 1167 return self.v.read_value() 1168 1169 value = tensor_init() 1170 self.assertAllEqual(value, 2.0) 1171 1172 @test_util.run_in_graph_and_eager_modes 1173 def testGetConcreteFunctionCreatesVariables(self): 1174 1175 v_holder = [] 1176 1177 @def_function.function 1178 def tensor_init(): 1179 if not v_holder: 1180 v_holder.append(variables.Variable(5.)) 1181 return v_holder[0].read_value() 1182 1183 concrete = tensor_init.get_concrete_function() 1184 self.evaluate(variables.global_variables_initializer()) 1185 self.assertAllEqual(5., self.evaluate(concrete())) 1186 self.assertAllEqual(5., self.evaluate(tensor_init())) 1187 1188 def testFuncGraphCaptureByValue(self): 1189 v = variables.Variable(1.0) 1190 1191 def trivial_function(): 1192 return v.read_value() 1193 1194 graph_function = function.Function( 1195 trivial_function, 'test', capture_by_value=True) 1196 1197 self.assertAllEqual(graph_function(), 1.0) 1198 v.assign(2.0) 1199 self.assertAllEqual(graph_function(), 1.0) 1200 1201 def testFuncGraphCaptureByValueNested(self): 1202 v = variables.Variable(1.0) 1203 1204 def trivial_function(): 1205 return control_flow_ops.cond( 1206 array_ops.placeholder_with_default(True, ()), 1207 v.read_value, v.read_value) 1208 1209 graph_function = function.Function( 1210 trivial_function, 'test', capture_by_value=True) 1211 1212 self.assertAllEqual(graph_function(), 1.0) 1213 v.assign(2.0) 1214 self.assertAllEqual(graph_function(), 1.0) 1215 1216 def testDefunShapeInferenceWithCapturedResourceVariable(self): 1217 v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) 1218 1219 def f(): 1220 x = constant_op.constant([[1, 2], [3, 4]]) 1221 out = math_ops.matmul(v, x) 1222 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 1223 # We do not return v directly since the tensor conversion function of 1224 # ResourceVariable returns the read value and not the resource itself. 1225 return v._handle 1226 1227 compiled = def_function.function(f) 1228 var_handle = compiled() 1229 self.assertEqual(var_handle.dtype, dtypes.resource) 1230 self.assertEqual(var_handle.shape, tensor_shape.TensorShape([])) 1231 var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) 1232 self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) 1233 1234 def testShapeInferenceForMoreSpecificInput(self): 1235 1236 def f(a): 1237 return array_ops.reshape(a, [-1, 3]) 1238 1239 signature = [tensor_spec.TensorSpec(None, dtypes.float32)] 1240 compiled = def_function.function(f, input_signature=signature) 1241 1242 @def_function.function 1243 def use_f(): 1244 inputs = array_ops.zeros([10, 10, 3]) 1245 self.assertAllEqual(f(inputs).shape, compiled(inputs).shape) 1246 1247 use_f() 1248 1249 def testFuncListAttr(self): 1250 1251 @function.defun 1252 def test_function(val): 1253 1254 def fn1(): 1255 return array_ops.ones([10]) 1256 1257 fn2 = lambda: array_ops.ones([10]) * 2 1258 1259 def fn3(x=3): 1260 return array_ops.ones([10]) * x 1261 fn4 = functools.partial(fn3, x=4) 1262 fn5 = functools.partial(fn3, 5) 1263 1264 return gen_functional_ops.case(val, [], [dtypes.float32], 1265 [function.defun(f).get_concrete_function() 1266 for f in (fn1, fn2, fn3, fn4, fn5)]) 1267 1268 ones = array_ops.ones([10]) 1269 self.assertAllEqual([ones], test_function(0)) 1270 self.assertAllEqual([ones * 2], test_function(1)) 1271 self.assertAllEqual([ones * 3], test_function(2)) 1272 self.assertAllEqual([ones * 4], test_function(3)) 1273 self.assertAllEqual([ones * 5], test_function(4)) 1274 self.assertAllEqual([ones * 5], test_function(22)) # default branch 1275 1276 @test_util.enable_control_flow_v2 1277 def testVariableInLoopInFunction(self): 1278 1279 @function.defun 1280 def test_function(): 1281 1282 def loop_test(_): 1283 return False 1284 1285 def loop_body(_): 1286 return variable_scope.get_variable('a', shape=()) 1287 1288 return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) 1289 1290 self.assertEqual(test_function().shape, []) 1291 1292 def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): 1293 with context.graph_mode(): 1294 v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) 1295 1296 def f(): 1297 x = constant_op.constant([[1, 2], [3, 4]]) 1298 out = math_ops.matmul(v, x) 1299 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 1300 # We do not return v directly since the tensor conversion function of 1301 # ResourceVariable returns the read value and not the resource itself. 1302 return v._handle 1303 1304 compiled = def_function.function(f) 1305 var_handle = compiled() 1306 self.assertEqual(var_handle.dtype, dtypes.resource) 1307 self.assertEqual(var_handle.shape, tensor_shape.TensorShape([])) 1308 var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) 1309 self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) 1310 1311 def testDefunShapeInferenceWithCapturedVariableInGraphMode(self): 1312 with context.graph_mode(): 1313 v = variables.Variable([[1, 2], [3, 4]]) 1314 1315 def f(): 1316 x = constant_op.constant([[1, 2], [3, 4]]) 1317 out = math_ops.matmul(v, x) 1318 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 1319 1320 # Check that shape inference works while creating the defun 1321 compiled = def_function.function(f) 1322 compiled() 1323 1324 def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self): 1325 with context.graph_mode(): 1326 tensor_list = list_ops.empty_tensor_list( 1327 element_dtype=dtypes.float32, 1328 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 1329 tensor_list = list_ops.tensor_list_push_back(tensor_list, 1330 constant_op.constant(1.0)) 1331 tensor_list = list_ops.tensor_list_push_back(tensor_list, 1332 constant_op.constant(2.0)) 1333 1334 def f(): 1335 tl, value = list_ops.tensor_list_pop_back( 1336 tensor_list, element_dtype=dtypes.float32) 1337 self.assertEqual(value.shape, tensor_shape.TensorShape([])) 1338 return tl 1339 1340 compiled = def_function.function(f) 1341 output_tensor_list = compiled() 1342 _, value = list_ops.tensor_list_pop_back( 1343 output_tensor_list, element_dtype=dtypes.float32) 1344 self.assertEqual(value.shape, tensor_shape.TensorShape([])) 1345 1346 @test_util.run_in_graph_and_eager_modes 1347 def testDefunForcesResourceVariables(self): 1348 1349 def variable_creator(): 1350 self.v = variables.Variable(0.0) 1351 return self.v.read_value() 1352 1353 self.v = None 1354 defined = function.defun(variable_creator) 1355 defined() # Create the variable. 1356 self.assertIsInstance( 1357 self.v, resource_variable_ops.ResourceVariable) 1358 1359 def testRunMetadata(self): 1360 1361 @def_function.function 1362 def f(x): 1363 return x * x 1364 1365 with ops.device('cpu:0'): 1366 context.enable_run_metadata() 1367 f(constant_op.constant(1.0)) 1368 run_metadata = context.export_run_metadata() 1369 context.disable_run_metadata() 1370 self.assertLen(run_metadata.partition_graphs, 1) 1371 1372 def testGraphModeCaptureVariable(self): 1373 with context.graph_mode(), self.cached_session(): 1374 1375 class HasAVar(object): 1376 1377 def __init__(self): 1378 self.v = resource_variable_ops.ResourceVariable(1.0) 1379 1380 def call(self): 1381 return self.v * 2 1382 1383 o = HasAVar() 1384 self.evaluate(variables.global_variables_initializer()) 1385 call = def_function.function(o.call) 1386 op = call() 1387 self.assertAllEqual(self.evaluate(op), 2.0) 1388 1389 def testGraphModeManyFunctions(self): 1390 with ops.Graph().as_default(), self.cached_session(): 1391 1392 @def_function.function 1393 def f(x): 1394 return x * x 1395 1396 @def_function.function 1397 def g(x): 1398 return f(x) + 1 1399 1400 self.assertAllEqual(g(constant_op.constant(2.0)), 5.0) 1401 1402 def testDict(self): 1403 1404 @def_function.function 1405 def f(x): 1406 return {'name': x + 1} 1407 1408 self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0) 1409 1410 def testTensorConversionWithDefun(self): 1411 1412 @def_function.function 1413 def f(x): 1414 return math_ops.add(x, constant_op.constant(3)) 1415 1416 self.assertAllEqual(5, f(constant_op.constant(2))) 1417 1418 def testTensorConversionCall(self): 1419 1420 @def_function.function 1421 def f(x): 1422 return math_ops.add(x, constant_op.constant(3)) 1423 1424 @def_function.function 1425 def g(x): 1426 return f(f(x)) 1427 1428 self.assertAllEqual(8, g(constant_op.constant(2))) 1429 1430 def testCallShape(self): 1431 1432 @def_function.function 1433 def f(x): 1434 return x + 1 1435 1436 @def_function.function 1437 def g(x): 1438 x = f(x) 1439 self.assertEqual(x.shape.as_list(), []) 1440 return None 1441 1442 g(constant_op.constant(1.0)) 1443 1444 def testNestedDefunWithNoOutputAndTapedInput(self): 1445 three = resource_variable_ops.ResourceVariable(3.0, name='v') 1446 1447 @def_function.function 1448 def f(x): 1449 # This function intentionally takes a taped variable as input, 1450 # but does not return any values 1451 math_ops.add(x, three) 1452 1453 @def_function.function 1454 def g(x): 1455 y = math_ops.add(x, three) 1456 f(y) 1457 1458 g(three) 1459 1460 def testGatherResourceWithDefun(self): 1461 with ops.device('cpu:0'): 1462 v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 1463 1464 def sum_gather(): 1465 return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) 1466 1467 defined = def_function.function(sum_gather) 1468 self.assertAllEqual(sum_gather(), defined()) 1469 1470 @parameterized.named_parameters([ 1471 ('IndexedSlicesWithDenseShape', 1472 _example_indexed_slices_with_dense_shape,), 1473 ('IndexedSlicesWithoutDenseShape', 1474 _example_indexed_slices_without_dense_shape,), 1475 ('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths, 1476 {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}), 1477 ('RaggedTensorRaggedRank2', 1478 ragged_tensor.RaggedTensor.from_nested_row_lengths, 1479 {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), 1480 ('SparseTensor', sparse_tensor.SparseTensor, 1481 {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), 1482 ]) # pyformat: disable 1483 def testReturnCompositeTensorWithDefun(self, 1484 factory_fn, 1485 factory_kwargs={}, 1486 input_signature=None): 1487 input_ct = factory_fn(**factory_kwargs) 1488 1489 @def_function.function(input_signature=input_signature) 1490 def f(): 1491 return input_ct 1492 1493 output_ct = f() 1494 self.assertIsInstance(output_ct, type(input_ct)) 1495 nest.assert_same_structure(input_ct, output_ct, expand_composites=True) 1496 1497 input_flat = nest.flatten(input_ct, expand_composites=True) 1498 output_flat = nest.flatten(output_ct, expand_composites=True) 1499 for (input_component, output_component) in zip(input_flat, output_flat): 1500 self.assertAllEqual(input_component, output_component) 1501 1502 @parameterized.named_parameters([ 1503 ('IndexedSlicesWithDenseShape', 1504 _example_indexed_slices_with_dense_shape,), 1505 ('IndexedSlicesWithoutDenseShape', 1506 _example_indexed_slices_without_dense_shape,), 1507 ('RaggedTensorRaggedRank1', 1508 ragged_tensor.RaggedTensor.from_row_lengths, 1509 {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}), 1510 ('RaggedTensorRaggedRank2', 1511 ragged_tensor.RaggedTensor.from_nested_row_lengths, 1512 {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), 1513 ('SparseTensor', 1514 sparse_tensor.SparseTensor, 1515 {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), 1516 ('RaggedTensorRaggedRank1WithSignature', 1517 ragged_tensor.RaggedTensor.from_row_lengths, 1518 {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}, 1519 [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]), 1520 ('RaggedTensorRaggedRank2WithSignature', 1521 ragged_tensor.RaggedTensor.from_nested_row_lengths, 1522 {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}, 1523 [ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]), 1524 ('SparseTensorWithSignature', 1525 sparse_tensor.SparseTensor, 1526 {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}, 1527 [sparse_tensor.SparseTensorSpec([None], dtypes.int32)]), 1528 ]) # pyformat: disable 1529 def testCompositeAsArgumentTensorWithDefun(self, 1530 factory_fn, 1531 factory_kwargs={}, 1532 input_signature=None): 1533 input_ct = factory_fn(**factory_kwargs) 1534 1535 @def_function.function(input_signature=input_signature) 1536 def f(x): 1537 return x 1538 1539 output_ct = f(input_ct) 1540 self.assertIsInstance(output_ct, type(input_ct)) 1541 nest.assert_same_structure(input_ct, output_ct, expand_composites=True) 1542 1543 input_flat = nest.flatten(input_ct, expand_composites=True) 1544 output_flat = nest.flatten(output_ct, expand_composites=True) 1545 for (input_component, output_component) in zip(input_flat, output_flat): 1546 self.assertAllEqual(input_component, output_component) 1547 1548 def testTracedCompositeDiscardsShapeInfo(self): 1549 # SparseTensorSpec intentionally excludes info about the number of elements 1550 # that are in a sparse tensor (which is recorded as st.indices.shape[0] and 1551 # st.values.shape[0]). Similarly, RaggedTensorSpec intentionally excludes 1552 # info about the total number of values in a RaggedTensor (stored as 1553 # rt.values.shape[0]). This test checks that the placeholders created by 1554 # tf.function() properly mask this shape info. 1555 @def_function.function 1556 def f(rt, st): 1557 self.assertEqual(st.indices.shape.as_list()[:1], [None]) 1558 self.assertEqual(st.values.shape.as_list(), [None]) 1559 return (rt, st) 1560 1561 rt = ragged_factory_ops.constant([[1, 2], [3]]) 1562 st = sparse_tensor.SparseTensor([[0]], [0], [10]) 1563 f(rt, st) 1564 1565 @test_util.run_gpu_only 1566 def testFunctionOnDevice(self): 1567 x = constant_op.constant([1.]).gpu() 1568 f = def_function.function(math_ops.add) 1569 y = f(x, x).cpu() 1570 self.assertAllEqual(y, [2.]) 1571 1572 @test_util.run_gpu_only 1573 @test_util.run_in_graph_and_eager_modes 1574 def testFunctionWithResourcesOnDifferentDevices(self): 1575 with ops.device('/cpu:0'): 1576 v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 1577 1578 with ops.device('/gpu:0'): 1579 v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 1580 1581 def sum_gather(): 1582 cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2])) 1583 gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) 1584 return cpu_result, gpu_result 1585 1586 defined = function.defun(sum_gather) 1587 if not context.executing_eagerly(): 1588 self.evaluate(variables.global_variables_initializer()) 1589 expected = self.evaluate(sum_gather()) 1590 self.assertAllEqual(expected, self.evaluate(defined())) 1591 1592 @test_util.run_gpu_only 1593 @test_util.run_in_graph_and_eager_modes 1594 def testOpInFunctionWithConflictingResourceInputs(self): 1595 with ops.device('/cpu:0'): 1596 v_cpu = resource_variable_ops.ResourceVariable( 1597 [0.0, 1.0, 2.0], name='cpu') 1598 v_also_cpu = resource_variable_ops.ResourceVariable( 1599 [0.0, 1.0, 2.0], name='also_cpu') 1600 1601 with ops.device('/gpu:0'): 1602 v_gpu = resource_variable_ops.ResourceVariable( 1603 [0.0, 1.0, 2.0], name='gpu') 1604 1605 @def_function.function 1606 def resource_apply_adam(): 1607 training_ops.resource_apply_adam( 1608 v_cpu.handle, 1609 v_gpu.handle, 1610 v_also_cpu.handle, 1611 1.0, # beta1_power 1612 1.0, # beta2_power 1613 1.0, # learning_rate 1614 1.0, # beta1 1615 1.0, # beta2 1616 1.0, # epsilon, 1617 [1.0, 1.0, 1.0], # grad 1618 False) # use_locking 1619 return None 1620 1621 with self.assertRaisesRegex( 1622 errors.InvalidArgumentError, 1623 'Cannot place the graph because a reference or resource edge connects ' 1624 'colocation groups with incompatible assigned devices'): 1625 if not context.executing_eagerly(): 1626 self.evaluate(variables.global_variables_initializer()) 1627 self.evaluate(resource_apply_adam()) 1628 1629 @test_util.run_gpu_only 1630 def testFunctionHandlesInputsOnDifferentDevices(self): 1631 # The Reshape op requires the shape tensor to be placed in host memory. 1632 reshape = def_function.function(array_ops.reshape) 1633 value = constant_op.constant([1., 2.]).gpu() 1634 shape = constant_op.constant([2, 1]) 1635 reshaped = reshape(value, shape).cpu() 1636 self.assertAllEqual(reshaped, [[1], [2]]) 1637 1638 @test_util.run_gpu_only 1639 def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self): 1640 # The Reshape op requires the shape tensor to be placed in host memory. 1641 reshape = def_function.function(array_ops.reshape) 1642 value = constant_op.constant([1., 2.]) 1643 shape = constant_op.constant([2, 1]).gpu() 1644 reshape(value, shape) # No error is raised 1645 1646 def testNoneOutput(self): 1647 1648 @def_function.function 1649 def my_function(_): 1650 return None 1651 1652 self.assertAllEqual(my_function(1), None) 1653 1654 def testNestedFunctions(self): 1655 # TensorFlow function (which is what would be used in TensorFlow graph 1656 # construction). 1657 @tf_function.Defun(dtypes.int32, dtypes.int32) 1658 def add(a, b): 1659 return math_ops.add(a, b) 1660 1661 @def_function.function 1662 def add_one(x): 1663 return add(x, 1) 1664 1665 self.assertAllEqual(3, add_one(constant_op.constant(2))) 1666 1667 def testVariableCaptureInNestedFunctions(self): 1668 v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32) 1669 1670 @def_function.function 1671 def inner_read(): 1672 return v.read_value() 1673 1674 @def_function.function 1675 def outer(): 1676 return inner_read() 1677 1678 self.assertEqual(1, int(outer())) 1679 1680 def testReturnCapturedEagerTensor(self): 1681 t = constant_op.constant(1) 1682 1683 @def_function.function 1684 def read(): 1685 return t 1686 1687 self.assertEqual(1, int(read())) 1688 1689 def testReturnCapturedGraphTensor(self): 1690 with context.graph_mode(), self.cached_session(): 1691 t = constant_op.constant(1) 1692 1693 @def_function.function 1694 def read(): 1695 return t 1696 1697 self.assertEqual(1, int(self.evaluate(read()))) 1698 1699 def testSequenceInputs(self): 1700 clip_by_global_norm = def_function.function(clip_ops.clip_by_global_norm) 1701 t_list = [constant_op.constant(1.0), constant_op.constant(2.0)] 1702 clipped_list, global_norm = clip_by_global_norm(t_list, 1703 constant_op.constant(.2)) 1704 for t in clipped_list: 1705 self.assertIsInstance(t, ops.Tensor) 1706 self.assertIsInstance(global_norm, ops.Tensor) 1707 1708 def testNestedSequenceInputs(self): 1709 1710 def my_op(inputs): 1711 a, b, c = inputs 1712 e, f = b 1713 g, h = e 1714 return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c 1715 1716 my_eager_op = def_function.function(my_op) 1717 ret = my_eager_op([ 1718 constant_op.constant(1), [(constant_op.constant(2), 1719 constant_op.constant(3)), 1720 constant_op.constant(4)], 1721 constant_op.constant(5) 1722 ]) 1723 self.assertLen(ret, 2) 1724 self.assertAllEqual(ret[0][0], 2) 1725 self.assertAllEqual(ret[0][1][0][0], 8) 1726 self.assertAllEqual(ret[0][1][0][1], 4) 1727 self.assertIsInstance(ret[0][1][0], tuple) 1728 self.assertAllEqual(ret[0][1][1], 6) 1729 self.assertAllEqual(ret[0][2], 10) 1730 self.assertAllEqual(ret[1], 15) 1731 1732 def testVariableNamesRespectNameScopesWithDefun(self): 1733 @def_function.function 1734 def create_variable(): 1735 with ops.name_scope('foo', skip_on_eager=False): 1736 v = resource_variable_ops.ResourceVariable(0.0, name='bar') 1737 self.assertEqual(v.name, 'foo/bar:0') 1738 1739 create_variable() 1740 1741 def testVariableNamesRespectNameScopesWithDefunInGraph(self): 1742 with context.graph_mode(): 1743 @def_function.function 1744 def create_variable(): 1745 with ops.name_scope('foo', skip_on_eager=False): 1746 v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar') 1747 self.assertEqual(v.name, 'foo/bar:0') 1748 1749 with ops.get_default_graph().as_default(): 1750 create_variable() 1751 1752 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) 1753 def testLayerInDefun(self): 1754 conv = convolutional.Conv2D( 1755 filters=1, 1756 kernel_size=2, 1757 kernel_initializer=init_ops.ones_initializer(), 1758 bias_initializer=init_ops.zeros_initializer()) 1759 1760 @function.defun 1761 def model(x): 1762 return conv(x) 1763 1764 x = array_ops.ones([1, 2, 2, 1]) 1765 y = model(x) 1766 1767 if not context.executing_eagerly(): 1768 self.evaluate(variables.global_variables_initializer()) 1769 1770 self.assertAllClose([[[[4.0]]]], self.evaluate(y)) 1771 1772 # Variable lifting is somewhat different between defun/tf.function, so testing 1773 # device placement on both makes sense. 1774 @parameterized.named_parameters( 1775 dict(testcase_name='Defun', 1776 function_decorator=function.defun), 1777 dict(testcase_name='DefFunction', 1778 function_decorator=def_function.function)) 1779 @test_util.run_in_graph_and_eager_modes 1780 def testVariablesPlacedOnOutsideDevice(self, function_decorator): 1781 1782 class _Obj(object): 1783 1784 def __init__(self): 1785 self.v = None 1786 1787 @function_decorator 1788 def f(self): 1789 if self.v is None: 1790 self.v = variables.Variable(1.) 1791 return self.v + 1. 1792 1793 has_device = _Obj() 1794 with ops.device('cpu:0'): 1795 has_device.f() 1796 self.assertIn('CPU', has_device.v.device) 1797 1798 @test_util.run_in_graph_and_eager_modes 1799 def testMultipleDeviceCheck(self): 1800 1801 def f(): 1802 with ops.device('cpu'): 1803 return test_ops.device_placement_op() 1804 1805 func = function.defun(f) 1806 with ops.device('cpu:0'): 1807 output = self.evaluate(func()) 1808 self.assertIn(compat.as_bytes('CPU:0'), output) 1809 1810 @test_util.run_in_graph_and_eager_modes 1811 def testDeviceAnnotationsRespected(self): 1812 1813 def multi_device_fn(): 1814 with ops.device('/cpu:0'): 1815 s0 = test_ops.device_placement_op() 1816 with ops.device('/cpu:1'): 1817 s1 = test_ops.device_placement_op() 1818 with ops.device('/cpu:2'): 1819 s2 = test_ops.device_placement_op() 1820 s3 = test_ops.device_placement_op() 1821 return s0, s1, s2, s3 1822 1823 defined = function.defun(multi_device_fn) 1824 outputs = self.evaluate(defined()) 1825 self.assertLen(total_function_cache(defined), 1) 1826 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1827 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1828 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1829 1830 with ops.device('/cpu:3'): 1831 outputs = self.evaluate(defined()) 1832 # All function definitions are agnostic to call site devices. 1833 self.assertLen(total_function_cache(defined), 1) 1834 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1835 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1836 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1837 self.assertIn(compat.as_bytes('CPU:3'), outputs[3]) 1838 1839 with ops.device('/cpu:0'): 1840 outputs = self.evaluate(defined()) 1841 self.assertLen(total_function_cache(defined), 1) 1842 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1843 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1844 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1845 self.assertIn(compat.as_bytes('CPU:0'), outputs[3]) 1846 1847 @test_util.run_in_graph_and_eager_modes 1848 def testCallingGraphFunctionOnDifferentDevice(self): 1849 1850 def func(): 1851 return constant_op.constant(0) 1852 1853 defined = def_function.function(func) 1854 with ops.device('cpu:0'): 1855 cpu_graph_function = defined.get_concrete_function() 1856 1857 with ops.device('cpu:0'): 1858 self.assertEqual( 1859 self.evaluate(cpu_graph_function()), self.evaluate(func())) 1860 1861 with ops.device('cpu:1'): 1862 self.assertEqual(0., self.evaluate(cpu_graph_function())) 1863 1864 with ops.device(None): 1865 self.assertEqual(0., self.evaluate(cpu_graph_function())) 1866 1867 default_graph_function = defined.get_concrete_function() 1868 self.assertEqual( 1869 self.evaluate(default_graph_function()), self.evaluate(func())) 1870 1871 with ops.device('cpu:1'): 1872 self.assertEqual(0., self.evaluate(default_graph_function())) 1873 1874 @test_util.run_gpu_only 1875 @test_util.run_in_graph_and_eager_modes 1876 def testColocateWithRespected(self): 1877 # TODO(b/113291792): Use multiple CPUs instead of a GPU. 1878 with ops.device('cpu:0'): 1879 x = array_ops.identity(1.0) 1880 1881 with ops.device('gpu:0'): 1882 y = array_ops.identity(1.0) 1883 1884 @def_function.function 1885 def foo(): 1886 return test_ops.device_placement_op() 1887 1888 with ops.colocate_with(x): 1889 self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo())) 1890 1891 with ops.colocate_with(y): 1892 self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo())) 1893 1894 def testVariablesAreTracked(self): 1895 v = resource_variable_ops.ResourceVariable(1.0) 1896 1897 def foo(x): 1898 return v * x 1899 1900 defined = def_function.function(foo) 1901 1902 x = constant_op.constant([1.0]) 1903 self.assertEqual(1., self.evaluate(defined(x))) 1904 v.assign(2.) 1905 1906 x = constant_op.constant([1.0, 2.0]) 1907 self.assertAllEqual([2., 4.], self.evaluate(defined(x))) 1908 1909 def testCacheObjectHashCollisions(self): 1910 1911 class Foo(object): 1912 1913 def __hash__(self): 1914 return 42 1915 1916 def func(foo): 1917 del foo 1918 return 1919 1920 defined = function.defun(func) 1921 defined(Foo()) 1922 self.assertLen(total_function_cache(defined), 1) 1923 1924 defined(Foo()) 1925 self.assertLen(total_function_cache(defined), 2) 1926 1927 def testCacheTensorDtypeCollision(self): 1928 1929 def func(t): 1930 return t + t 1931 1932 defined = function.defun(func) 1933 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1934 defined(t) 1935 self.assertLen(total_function_cache(defined), 1) 1936 1937 t = constant_op.constant([[1.0]], dtype=dtypes.complex128) 1938 defined(t) 1939 self.assertLen(total_function_cache(defined), 2) 1940 1941 def testCacheTensorShapeCollision(self): 1942 1943 def func(t): 1944 return t + t 1945 1946 defined = function.defun(func) 1947 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1948 defined(t) 1949 self.assertLen(total_function_cache(defined), 1) 1950 1951 t = constant_op.constant([1.0], dtype=dtypes.complex64) 1952 defined(t) 1953 self.assertLen(total_function_cache(defined), 2) 1954 1955 def testCacheTensorShapeDtypeCollision(self): 1956 1957 def func(t): 1958 return t + t 1959 1960 defined = function.defun(func) 1961 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1962 defined(t) 1963 self.assertLen(total_function_cache(defined), 1) 1964 1965 t = constant_op.constant([1.0], dtype=dtypes.complex128) 1966 defined(t) 1967 self.assertLen(total_function_cache(defined), 2) 1968 1969 def testCacheTensorUnknownShapesCollisionRelaxedShapes(self): 1970 1971 def func(t): 1972 return t + t 1973 1974 with context.graph_mode(), self.cached_session(): 1975 defined = function.defun(func, experimental_relax_shapes=True) 1976 1977 p = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 1978 defined(p) 1979 self.assertLen(total_function_cache(defined), 1) 1980 1981 p = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) 1982 defined(p) 1983 self.assertLen(total_function_cache(defined), 2) 1984 1985 p = array_ops.placeholder(dtype=dtypes.float32, shape=[2]) 1986 defined(p) 1987 # Gradual shape relaxation is performed; and the common shape between 1988 # [1] and [2] is one containing unknown dimensions. 1989 self.assertLen(total_function_cache(defined), 2) 1990 1991 # pylint: disable=protected-access 1992 self.assertLen(defined._function_cache.arg_relaxed_specs, 1) 1993 relaxed_specs = ( 1994 list(defined._function_cache.arg_relaxed_specs.values())[0]) 1995 self.assertLen(relaxed_specs, 1) 1996 relaxed_shape = relaxed_specs[0].shape 1997 # pylint: enable=protected-access 1998 self.assertEqual(relaxed_shape.rank, 1) 1999 self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None) 2000 2001 t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32) 2002 defined(t) 2003 # Shape (3,) matches the relaxed shape TensorShape([None]) 2004 self.assertLen(total_function_cache(defined), 2) 2005 2006 def testPythonFunctionWithDefaultArgs(self): 2007 2008 def func(foo, bar=1, baz=2): 2009 del foo 2010 del bar 2011 del baz 2012 return 2013 2014 defined = function.defun(func) 2015 defined(0, baz=20) 2016 2017 def cache_keys(): 2018 """Sanitizes cache keys of non-input metadata.""" 2019 return tuple(key[0] for key in total_function_cache(defined)) 2020 2021 # `True` corresponds to the fact that we're executing eagerly 2022 self.assertIn(('URRRu', (0, 1, 20)), cache_keys()) 2023 2024 defined(1) # bar=1, baz=2 2025 self.assertIn(('URRRu', (1, 1, 2)), cache_keys()) 2026 2027 # This matches the previous call. 2028 defined(foo=1) 2029 self.assertLen(total_function_cache(defined), 2) 2030 2031 defined(1, 2, 3) 2032 self.assertLen(total_function_cache(defined), 3) 2033 self.assertIn(('URRRu', (1, 2, 3)), cache_keys()) 2034 2035 # This matches the previous call. 2036 defined(1, bar=2, baz=3) 2037 self.assertLen(total_function_cache(defined), 3) 2038 2039 # This matches the previous call. 2040 defined(1, baz=3, bar=2) 2041 self.assertLen(total_function_cache(defined), 3) 2042 2043 def testFunctoolsPartialUnwrappedCorrectly(self): 2044 2045 def full_function(a, b, c=3): 2046 return a, b, c 2047 2048 partial = functools.partial(full_function, 1, c=4) 2049 a, b, c = partial(2) 2050 2051 defined = function.defun(partial) 2052 func_a, func_b, func_c = defined(2) 2053 self.assertEqual(func_a.numpy(), a) 2054 self.assertEqual(func_b.numpy(), b) 2055 self.assertEqual(func_c.numpy(), c) 2056 2057 def testInputSignatureWithMatchingInputs(self): 2058 2059 def foo(a): 2060 self.assertEqual(a.shape, (2,)) 2061 return a 2062 2063 signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] 2064 defined = function.defun(foo, input_signature=signature) 2065 a = array_ops.ones([2]) 2066 self.assertAllEqual(a, defined(a)) 2067 self.assertLen(total_function_cache(defined), 1) 2068 self.assertAllEqual(a, defined.get_concrete_function()(a)) 2069 self.assertAllEqual(a, defined.get_concrete_function(a)(a)) 2070 self.assertAllEqual(a, defined.get_concrete_function( 2071 tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a)) 2072 self.assertLen(total_function_cache(defined), 1) 2073 2074 def bar(a): 2075 self.assertEqual(a._shape_tuple(), (2, None)) 2076 return a 2077 2078 signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)] 2079 defined = function.defun(bar, input_signature=signature) 2080 a = array_ops.ones([2, 1]) 2081 out = defined(a) 2082 self.assertLen(total_function_cache(defined), 1) 2083 self.assertAllEqual(out, a) 2084 2085 # Changing the second dimension shouldn't create a new function. 2086 b = array_ops.ones([2, 3]) 2087 out = defined(b) 2088 self.assertLen(total_function_cache(defined), 1) 2089 self.assertAllEqual(out, b) 2090 2091 def testInputSignatureWithCompatibleInputs(self): 2092 2093 rank2_spec = tensor_spec.TensorSpec(shape=(None, None), 2094 dtype=dtypes.float32) 2095 2096 @function.defun(input_signature=[rank2_spec]) 2097 def func(a): 2098 self.assertEqual([None, None], a.shape.as_list()) 2099 return array_ops.shape(a) 2100 2101 self.assertAllEqual([3, 1], func([[0], [1.0], [1]])) 2102 self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]]))) 2103 2104 with self.assertRaisesRegex(ValueError, 'incompatible'): 2105 func([0.0, 1.0, 2.0]) # Wrong shape. 2106 2107 with self.assertRaisesRegex(ValueError, 'incompatible'): 2108 func([['wrong dtype']]) 2109 2110 def testNoKeywordOnlyArgumentsWithInputSignature(self): 2111 if sys.version_info[0] < 3: 2112 self.skipTest('keyword_only arguments only exist in Python 3.') 2113 2114 func = eval('lambda x, *, y: x') # pylint: disable=eval-used 2115 signature = [tensor_spec.TensorSpec(None, dtypes.int32)] 2116 with self.assertRaisesRegex( 2117 ValueError, 'Cannot define a TensorFlow function from a Python ' 2118 'function with keyword-only arguments when input_signature is ' 2119 'provided.'): 2120 def_function.function(func, signature) 2121 2122 def testNestedInputSignatures(self): 2123 2124 def expected_foo(a, b): 2125 return [a, b] 2126 2127 @function.defun(input_signature=[ 2128 [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2, 2129 tensor_spec.TensorSpec((1,), dtypes.float32), 2130 ]) 2131 def foo(a, b): 2132 self.assertEqual(a[0]._shape_tuple(), (2, None)) 2133 self.assertEqual(a[1]._shape_tuple(), (2, None)) 2134 self.assertEqual(b._shape_tuple(), (1,)) 2135 return [a, b] 2136 2137 a = array_ops.ones([2, 1]) 2138 b = array_ops.ones([1]) 2139 expected = expected_foo([a, a], b) 2140 out = foo([a, a], b) 2141 self.assertLen(total_function_cache(foo), 1) 2142 nest.assert_same_structure(out, expected) 2143 self.assertAllEqual(out[0][0], a) 2144 self.assertAllEqual(out[0][1], a) 2145 self.assertAllEqual(out[1], b) 2146 2147 # Changing the unspecified dimensions shouldn't create a new function. 2148 a = array_ops.ones([2, 3]) 2149 b = array_ops.ones([2, 5]) 2150 c = array_ops.ones([1]) 2151 expected = expected_foo([a, b], c) 2152 out = foo([a, b], c) 2153 self.assertLen(total_function_cache(foo), 1) 2154 nest.assert_same_structure(out, expected) 2155 self.assertAllEqual(out[0][0], a) 2156 self.assertAllEqual(out[0][1], b) 2157 self.assertAllEqual(out[1], c) 2158 2159 # Passing compatible inputs should work. 2160 a = a.numpy().tolist() 2161 b = b.numpy().tolist() 2162 c = c.numpy().tolist() 2163 out = foo([a, b], c) 2164 self.assertLen(total_function_cache(foo), 1) 2165 nest.assert_same_structure(out, expected) 2166 self.assertAllEqual(out[0][0], a) 2167 self.assertAllEqual(out[0][1], b) 2168 self.assertAllEqual(out[1], c) 2169 2170 def testNestedInputSignaturesWithDict(self): 2171 def expected_bar(a): 2172 return a 2173 2174 @function.defun(input_signature=[{ 2175 'a': tensor_spec.TensorSpec((2, None), dtypes.float32), 2176 'b': tensor_spec.TensorSpec((2, None), dtypes.float32), 2177 'c': tensor_spec.TensorSpec((1,), dtypes.float32)}]) 2178 def bar(a): 2179 self.assertEqual(a['a']._shape_tuple(), (2, None)) 2180 self.assertEqual(a['b']._shape_tuple(), (2, None)) 2181 self.assertEqual(a['c']._shape_tuple(), (1,)) 2182 return a 2183 2184 a = array_ops.ones([2, 3]) 2185 b = array_ops.ones([1]) 2186 inputs = {'a': a, 'b': a, 'c': b} 2187 expected = expected_bar(inputs) 2188 out = bar(inputs) 2189 nest.assert_same_structure(out, expected) 2190 self.assertAllEqual(out['a'], expected['a']) 2191 self.assertAllEqual(out['b'], expected['b']) 2192 self.assertAllEqual(out['c'], expected['c']) 2193 2194 # Passing compatible inputs should work. 2195 a = a.numpy().tolist() 2196 b = b.numpy().tolist() 2197 inputs = {'a': a, 'b': a, 'c': b} 2198 out = bar(inputs) 2199 nest.assert_same_structure(out, expected) 2200 self.assertAllEqual(out['a'], expected['a']) 2201 self.assertAllEqual(out['b'], expected['b']) 2202 self.assertAllEqual(out['c'], expected['c']) 2203 2204 def testInputSignatureMustBeSequenceOfTensorSpecs(self): 2205 2206 def foo(a, b): 2207 del a 2208 del b 2209 2210 # Signatures must consist exclusively of `TensorSpec` objects. 2211 signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] 2212 with self.assertRaisesRegex(TypeError, 'Invalid input_signature.*'): 2213 def_function.function(foo, input_signature=signature) 2214 2215 # Signatures must be either lists or tuples on their outermost levels. 2216 signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} 2217 with self.assertRaisesRegex( 2218 TypeError, 'input_signature must be either a ' 2219 'tuple or a list.*'): 2220 function.defun(foo, input_signature=signature) 2221 2222 @test_util.run_in_graph_and_eager_modes 2223 def testInputsIncompatibleWithSignatureRaisesError(self): 2224 2225 def foo(a): 2226 return a 2227 2228 signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] 2229 defined = def_function.function(foo, input_signature=signature) 2230 2231 # Invalid shapes. 2232 with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'): 2233 defined(array_ops.ones([3])) 2234 2235 with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'): 2236 defined(array_ops.ones([2, 1])) 2237 2238 # Wrong number of arguments. 2239 with self.assertRaisesRegex( 2240 TypeError, r'takes 1 positional arguments \(as specified by the ' 2241 r'input_signature\) but 2 were given'): 2242 defined(array_ops.ones([2]), array_ops.ones([2])) 2243 with self.assertRaisesRegex(ValueError, 2244 'Structure of Python function inputs.*'): 2245 defined() 2246 2247 with self.assertRaisesRegex(ValueError, 2248 'inputs incompatible with input_signature'): 2249 defined.get_concrete_function( 2250 tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32)) 2251 2252 def testInputsIncompatibleWithNestedSignatureRaisesError(self): 2253 2254 def foo(a, b): 2255 return [a, b] 2256 2257 signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2, 2258 [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2] 2259 defined = function.defun(foo, input_signature=signature) 2260 a = array_ops.ones([1]) 2261 2262 with self.assertRaisesRegex(ValueError, 2263 'Structure of Python function inputs.*'): 2264 defined([a, a, a], [a]) 2265 2266 with self.assertRaisesRegex(ValueError, 2267 'Structure of Python function inputs.*'): 2268 defined([a], [a, a, a]) 2269 defined([a, a], [a, a]) 2270 2271 def testUnderspecifiedInputSignature(self): 2272 @function.defun(input_signature=[ 2273 tensor_spec.TensorSpec([], dtypes.float32), 2274 ]) 2275 def foo(a, training=True): 2276 if training: 2277 return a 2278 else: 2279 return -1.0 * a 2280 2281 x = constant_op.constant(1.0) 2282 with self.assertRaisesRegex( 2283 TypeError, 'got keyword argument `training` ' 2284 'that was not included in input_signature'): 2285 foo(x, training=True) 2286 2287 with self.assertRaisesRegex( 2288 TypeError, 'got keyword argument `training` ' 2289 'that was not included in input_signature'): 2290 foo(x, training=False) 2291 2292 self.assertAllEqual(x.numpy(), foo(x).numpy()) 2293 2294 def testInputSignatureWithPartialFunction(self): 2295 def full_function(a, b, c=3.0): 2296 return a, b, c 2297 2298 partial = functools.partial(full_function, 1, c=4) 2299 a, b, c = partial(2.0) 2300 signature = [tensor_spec.TensorSpec([], dtypes.float32)] 2301 defined = function.defun(partial, input_signature=signature) 2302 x = constant_op.constant(2.0) 2303 func_a, func_b, func_c = defined(x) 2304 self.assertEqual(func_a.numpy(), a) 2305 self.assertEqual(func_b.numpy(), b) 2306 self.assertEqual(func_c.numpy(), c) 2307 2308 def testInputSignatureConversionWithDefaultArg(self): 2309 2310 def foo(a, training=True): 2311 if training: 2312 return a 2313 else: 2314 return -1.0 * a 2315 2316 signature = [ 2317 tensor_spec.TensorSpec([], dtypes.float32), 2318 tensor_spec.TensorSpec([], dtypes.bool), 2319 ] 2320 defined = def_function.function(foo, input_signature=signature) 2321 a = constant_op.constant(1.0) 2322 self.assertAllEqual(a.numpy(), defined(a)) 2323 self.assertAllEqual(a.numpy(), defined(a, training=True)) 2324 self.assertAllEqual(-a.numpy(), defined(a, training=False)) 2325 2326 def testInputSignatureWithKeywordPositionalArgs(self): 2327 2328 @function.defun(input_signature=[ 2329 tensor_spec.TensorSpec([], dtypes.float32), 2330 tensor_spec.TensorSpec([], dtypes.int64) 2331 ]) 2332 def foo(flt, integer): 2333 return flt, integer 2334 2335 flt = constant_op.constant(1.0) 2336 integer = constant_op.constant(2, dtypes.int64) 2337 2338 out1, out2 = foo(flt, integer) 2339 self.assertLen(total_function_cache(foo), 1) 2340 self.assertEqual(out1.numpy(), 1.0) 2341 self.assertEqual(out2.numpy(), 2) 2342 2343 out1, out2 = foo(flt=flt, integer=integer) 2344 self.assertLen(total_function_cache(foo), 1) 2345 self.assertEqual(out1.numpy(), 1.0) 2346 self.assertEqual(out2.numpy(), 2) 2347 2348 out1, out2 = foo(integer=integer, flt=flt) 2349 self.assertLen(total_function_cache(foo), 1) 2350 self.assertEqual(out1.numpy(), 1.0) 2351 self.assertEqual(out2.numpy(), 2) 2352 2353 out1, out2 = foo(flt, integer=integer) 2354 self.assertLen(total_function_cache(foo), 1) 2355 self.assertEqual(out1.numpy(), 1.0) 2356 self.assertEqual(out2.numpy(), 2) 2357 2358 def testInputSignatureWithKeywordArgs(self): 2359 def foo(a, b, **kwargs): 2360 del kwargs 2361 return a, b 2362 2363 x = function.defun( 2364 foo, 2365 input_signature=[ 2366 tensor_spec.TensorSpec([], dtypes.float32), 2367 tensor_spec.TensorSpec([], dtypes.int32) 2368 ]).get_concrete_function() 2369 result = x(constant_op.constant(5.0), constant_op.constant(5)) 2370 self.assertAllEqual(result, [5.0, 5]) 2371 2372 def testInputSignatureWithCompositeTensors(self): 2373 def f(rt): 2374 self.assertEqual(rt.values.shape.as_list(), [None]) 2375 self.assertEqual(rt.row_splits.shape.as_list(), [4]) 2376 return rt 2377 2378 signature = [ragged_tensor.RaggedTensorSpec( 2379 shape=[3, None], dtype=dtypes.int32)] 2380 defined = function.defun(f, input_signature=signature) 2381 rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]]) 2382 out1 = defined(rt1) 2383 self.assertLen(total_function_cache(defined), 1) 2384 self.assertAllEqual(out1.values, rt1.values) 2385 self.assertAllEqual(out1.row_splits, rt1.row_splits) 2386 2387 # Changing the row lengths shouldn't create a new function. 2388 rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]]) 2389 out2 = defined(rt2) 2390 self.assertLen(total_function_cache(defined), 1) 2391 self.assertAllEqual(out2.values, rt2.values) 2392 self.assertAllEqual(out2.row_splits, rt2.row_splits) 2393 2394 # Different number of rows 2395 rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]]) 2396 with self.assertRaisesRegex(ValueError, 'incompatible'): 2397 defined(rt3) 2398 2399 # Different dtype 2400 rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]]) 2401 with self.assertRaisesRegex(ValueError, 'Structure .* does not match'): 2402 defined(rt4) 2403 2404 # Different rank 2405 rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]]) 2406 with self.assertRaisesRegex(ValueError, 'does not match'): 2407 defined(rt5) 2408 2409 def testInputSignatureWithVariableArgs(self): 2410 2411 def f(v): 2412 v.assign_add(1) 2413 2414 signature = [ 2415 resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) 2416 ] 2417 defined = function.defun(f, input_signature=signature) 2418 2419 v1 = variables.Variable(0) 2420 v2 = variables.Variable(0) 2421 2422 defined(v1) 2423 self.assertEqual(v1.numpy(), 1) 2424 self.assertEqual(v2.numpy(), 0) 2425 2426 defined(v=v2) 2427 self.assertEqual(v1.numpy(), 1) 2428 self.assertEqual(v2.numpy(), 1) 2429 2430 def testTensorKeywordArguments(self): 2431 2432 def foo(a, b): 2433 del a 2434 return b 2435 2436 defined = function.defun(foo) 2437 a = constant_op.constant(2.0) 2438 b = constant_op.constant([1.0, 2.0]) 2439 one = defined(a, b) 2440 self.assertLen(total_function_cache(defined), 1) 2441 2442 two = defined(a=a, b=b) 2443 self.assertLen(total_function_cache(defined), 1) 2444 2445 three = defined(b=b, a=a) 2446 self.assertLen(total_function_cache(defined), 1) 2447 2448 four = defined(a, b=b) 2449 self.assertLen(total_function_cache(defined), 1) 2450 2451 # The next call corresponds to a new input signature, hence 2452 # we expect another function to be defined. 2453 five = defined(b, a) 2454 self.assertLen(total_function_cache(defined), 2) 2455 2456 six = defined(a=b, b=a) 2457 self.assertLen(total_function_cache(defined), 2) 2458 2459 seven = defined(b=a, a=b) 2460 self.assertLen(total_function_cache(defined), 2) 2461 2462 self.assertAllEqual(one, [1.0, 2.0]) 2463 self.assertAllEqual(two, [1.0, 2.0]) 2464 self.assertAllEqual(three, [1.0, 2.0]) 2465 self.assertAllEqual(four, [1.0, 2.0]) 2466 self.assertAllEqual(five, 2.0) 2467 self.assertAllEqual(six, 2.0) 2468 self.assertAllEqual(seven, 2.0) 2469 2470 def testDefuningInstanceMethod(self): 2471 2472 integer = constant_op.constant(2, dtypes.int64) 2473 2474 class Foo(object): 2475 2476 def one(self, tensor): 2477 return tensor 2478 2479 @def_function.function 2480 def two(self, tensor, other=integer): 2481 return self.one(tensor), other 2482 2483 foo = Foo() 2484 t = constant_op.constant(1.0) 2485 one, two = foo.two(t) 2486 self.assertEqual(one.numpy(), 1.0) 2487 self.assertEqual(two.numpy(), 2) 2488 2489 def testDefuningInstanceMethodWithDefaultArgument(self): 2490 2491 integer = constant_op.constant(2, dtypes.int64) 2492 2493 class Foo(object): 2494 2495 @def_function.function 2496 def func(self, other=integer): 2497 return other 2498 2499 foo = Foo() 2500 self.assertEqual(foo.func().numpy(), int(integer)) 2501 2502 def testPythonCallWithSideEffects(self): 2503 state = [] 2504 2505 @def_function.function 2506 def side_effecting_function(): 2507 state.append(0) 2508 2509 side_effecting_function() 2510 self.assertAllEqual(state, [0]) 2511 2512 # The second invocation should call the graph function, which shouldn't 2513 # trigger the list append. 2514 side_effecting_function() 2515 self.assertAllEqual(state, [0]) 2516 2517 # Whereas calling the python function directly should create a side-effect. 2518 side_effecting_function.python_function() 2519 self.assertAllEqual(state, [0, 0]) 2520 2521 def testFunctionWithNestedFunctionCallAndSideEffects(self): 2522 v1 = variables.Variable(1.0) 2523 v2 = variables.Variable(1.0) 2524 2525 @def_function.function 2526 def add_one(a): 2527 a.assign_add(1.0) 2528 2529 # Grappler will inline calls to `add_one` into the function body, we check 2530 # that all side-effects were executed. 2531 @def_function.function 2532 def side_effecting_function(a, b): 2533 add_one(a) 2534 add_one(b) 2535 return a + b 2536 2537 result = side_effecting_function(v1, v2) 2538 self.assertEqual(result.numpy(), 4.0) 2539 2540 def testFunctionWithExtraAttributes(self): 2541 @function.defun_with_attributes(attributes={'experimental_1': 'value1', 2542 'experimental_2': 2}) 2543 def matmul(x, y): 2544 return math_ops.matmul(x, y) 2545 2546 def add(x, y): 2547 return math_ops.add(x, y) 2548 defun_add = function.defun_with_attributes( 2549 add, attributes={'experimental_3': True, 'experimental_4': 1.0}) 2550 2551 with context.graph_mode(), self.cached_session(): 2552 with ops.get_default_graph().as_default(): 2553 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2554 sq = matmul(t, t) 2555 double = defun_add(t, t) 2556 self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) 2557 self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) 2558 2559 graph = ops.get_default_graph() 2560 # pylint: disable=protected-access 2561 self.assertLen(graph._functions, 2) 2562 functions = list(graph._functions.values()) 2563 self.assertRegex(functions[0].definition.signature.name, '.*matmul.*') 2564 attrs = functions[0].definition.attr 2565 self.assertLen(attrs, 2) 2566 self.assertEqual(attrs['experimental_1'].s, b'value1') 2567 self.assertEqual(attrs['experimental_2'].i, 2) 2568 2569 self.assertRegex(functions[1].definition.signature.name, '.*add.*') 2570 attrs = functions[1].definition.attr 2571 self.assertLen(attrs, 2) 2572 self.assertEqual(attrs['experimental_3'].b, True) 2573 self.assertEqual(attrs['experimental_4'].f, 1.0) 2574 # pylint: enable=protected-access 2575 2576 def testFunctionWithInvalidAttribute(self): 2577 @function.defun_with_attributes(attributes={'experimental_1': ['value1']}) 2578 def add(x, y): 2579 return math_ops.add(x, y) 2580 2581 with self.assertRaisesRegex(ValueError, '.*Unsupported attribute type.*'): 2582 with context.graph_mode(), self.cached_session(): 2583 with ops.get_default_graph().as_default(): 2584 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2585 add(t, t) 2586 2587 def testRegisterFunction(self): 2588 2589 @function.defun 2590 def add(x, y): 2591 return math_ops.add(x, y) 2592 2593 def matmul(x, y): 2594 return math_ops.matmul(x, y) 2595 defun_matmul = function.defun(matmul) 2596 2597 with context.graph_mode(), self.cached_session(): 2598 with ops.get_default_graph().as_default(): 2599 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2600 function.register(defun_matmul, t, t) 2601 function.register(add, t, t) 2602 2603 graph = ops.get_default_graph() 2604 # pylint: disable=protected-access 2605 self.assertLen(graph._functions, 6) 2606 # two sets of functions, each of them are (inference, forward, backward) 2607 functions = list(graph._functions.values()) 2608 captured_function_names = [ 2609 f.definition.signature.name for f in functions 2610 ] 2611 expected_func_name_regex = [ 2612 '.*inference.*matmul.*', 2613 '.*forward.*matmul.*', 2614 '.*inference.*backward.*matmul.*', 2615 '.*inference.*add.*', 2616 '.*forward.*add.*', 2617 '.*inference.*backward.*add.*', 2618 ] 2619 for i in range(len(functions)): 2620 self.assertRegex(captured_function_names[i], 2621 expected_func_name_regex[i]) 2622 2623 # Check the forward and backward function has the correct attributes. 2624 self.assertEqual( 2625 functions[1].definition.attr['backward_function_name'].s, 2626 functions[2].name) 2627 self.assertEqual( 2628 functions[2].definition.attr['forward_function_name'].s, 2629 functions[1].name) 2630 2631 self.assertEqual( 2632 functions[4].definition.attr['backward_function_name'].s, 2633 functions[5].name) 2634 self.assertEqual( 2635 functions[5].definition.attr['forward_function_name'].s, 2636 functions[4].name) 2637 2638 sq = defun_matmul(t, t) 2639 double = add(t, t) 2640 self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) 2641 self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) 2642 # Make sure the pre registered function is used, and no other function 2643 # is added. 2644 self.assertLen(graph._functions, 6) 2645 functions = list(graph._functions.values()) 2646 for i in range(len(functions)): 2647 self.assertEqual(captured_function_names[i], 2648 functions[i].definition.signature.name) 2649 2650 @parameterized.named_parameters( 2651 dict(testcase_name='Defun', 2652 function_decorator=function.defun), 2653 dict(testcase_name='DefFunction', 2654 function_decorator=def_function.function)) 2655 def testRegisterConcreteFunction(self, function_decorator): 2656 @function_decorator 2657 def py_add(x, y): 2658 return math_ops.add(x, y) 2659 2660 py_add(array_ops.ones([]), array_ops.ones([])) 2661 add = py_add.get_concrete_function( 2662 tensor_spec.TensorSpec(None, dtypes.float32), 2663 tensor_spec.TensorSpec(None, dtypes.float32)) 2664 2665 @function_decorator 2666 def py_composite(x, y): 2667 return x, add(x, y) 2668 2669 py_composite(array_ops.ones([]), array_ops.ones([])) 2670 composite = py_composite.get_concrete_function( 2671 tensor_spec.TensorSpec(None, dtypes.float32), 2672 tensor_spec.TensorSpec(None, dtypes.float32)) 2673 2674 with context.graph_mode(), self.cached_session(): 2675 with ops.get_default_graph().as_default(): 2676 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2677 composite.add_to_graph() 2678 composite.add_gradient_functions_to_graph() 2679 2680 graph = ops.get_default_graph() 2681 # pylint: disable=protected-access 2682 self.assertLen(graph._functions, 6) 2683 # two sets of functions, each of them are (inference, forward, backward) 2684 functions = list(graph._functions.values()) 2685 captured_function_names = [ 2686 f.definition.signature.name for f in functions 2687 ] 2688 expected_func_name_regex = [ 2689 '.*inference.*py_composite.*', 2690 '.*inference.*py_add.*', 2691 '.*forward.*py_composite.*', 2692 '.*forward.*py_add.*', 2693 '.*inference.*backward.*py_composite.*', 2694 '.*inference.*backward.*py_add.*', 2695 ] 2696 for expected, found in zip( 2697 expected_func_name_regex, 2698 captured_function_names): 2699 self.assertRegex(found, expected) 2700 2701 composite_t, composite_double = composite(t, t) 2702 double = add(t, t) 2703 self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double)) 2704 self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double)) 2705 self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t)) 2706 # Make sure the pre registered function is used, and no other function 2707 # is added. 2708 self.assertLen(graph._functions, 6) 2709 2710 @parameterized.named_parameters( 2711 dict(testcase_name='Defun', 2712 function_decorator=function.defun), 2713 dict(testcase_name='DefFunction', 2714 function_decorator=def_function.function)) 2715 def testEagerCaptures(self, function_decorator): 2716 with context.eager_mode(): 2717 large_tensor = array_ops.ones(shape=(256,)) 2718 self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD) 2719 2720 small_tensor = array_ops.ones(shape=(4,)) 2721 self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD) 2722 2723 v = resource_variable_ops.ResourceVariable(0.0) 2724 2725 for captured, op_type in [(large_tensor, 'Placeholder'), 2726 (small_tensor, 'Const'), (v, 'Placeholder')]: 2727 @function_decorator 2728 def test_fn(): 2729 return captured + 1 # pylint: disable=cell-var-from-loop 2730 2731 g = test_fn.get_concrete_function().graph 2732 internal_captures = g.internal_captures 2733 self.assertLen(internal_captures, 1) 2734 self.assertEqual(internal_captures[0].op.type, op_type) 2735 2736 def testRegisterFunctionWithInputSignature(self): 2737 def matmul(x, y): 2738 return math_ops.matmul(x, y) 2739 defun_matmul = function.defun( 2740 matmul, 2741 input_signature=[ 2742 tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), 2743 tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) 2744 ]) 2745 with context.graph_mode(), self.cached_session(): 2746 with ops.get_default_graph().as_default(): 2747 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2748 function.register(defun_matmul, t, t) 2749 2750 graph = ops.get_default_graph() 2751 # pylint: disable=protected-access 2752 self.assertLen(graph._functions, 3) 2753 2754 # Test register function with cache, note inputs are ignored. 2755 function.register(defun_matmul) 2756 graph = ops.get_default_graph() 2757 self.assertLen(graph._functions, 3) 2758 2759 def testRegisterFunctionWithCache(self): 2760 def matmul(x, y): 2761 return math_ops.matmul(x, y) 2762 defun_matmul = function.defun(matmul) 2763 2764 with context.graph_mode(), self.cached_session(): 2765 with ops.get_default_graph().as_default(): 2766 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2767 t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]]) 2768 function.register(defun_matmul, t, t) 2769 function.register(defun_matmul, t2, t2) 2770 2771 graph = ops.get_default_graph() 2772 # Only one function is registered since the input param are in same type 2773 # pylint: disable=protected-access 2774 self.assertLen(graph._functions, 3) 2775 2776 def testCallingFunctionWithDifferentVariables(self): 2777 2778 @function.defun 2779 def foo(v): 2780 v.assign_add(1.0) 2781 return v.read_value() 2782 2783 v = resource_variable_ops.ResourceVariable(0.0) 2784 graph_function = foo.get_concrete_function(v) 2785 self.assertLen(graph_function.inputs, 1) 2786 self.assertEmpty(graph_function.captured_inputs) 2787 2788 self.assertEqual(float(graph_function(v)), 1.0) 2789 self.assertEqual(float(graph_function(v)), 2.0) 2790 2791 w = resource_variable_ops.ResourceVariable(0.0) 2792 2793 @function.defun 2794 def bar(v): 2795 del v 2796 return constant_op.constant(1.0) 2797 2798 graph_function = bar.get_concrete_function(v) 2799 self.assertEqual(float(graph_function(v)), 1.0) 2800 self.assertEqual(float(graph_function(w)), 1.0) 2801 2802 def testCallingFunctionWithNonTensorsFails(self): 2803 2804 @function.defun 2805 def foo(x): 2806 return x 2807 2808 graph_function = foo.get_concrete_function(constant_op.constant(1.0)) 2809 with self.assertRaises((TypeError, ValueError)): 2810 graph_function('Not a Tensor.') 2811 2812 def testSwapImplementationWithGrapplerPlugin(self): 2813 # Set the min_graph_nodes to -1 since the graph in this test is too small, 2814 # and will be ignored by grappler if don't set this. 2815 rewrites = rewriter_config_pb2.RewriterConfig() 2816 rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON 2817 rewrites.min_graph_nodes = -1 2818 graph_options = config_pb2.GraphOptions( 2819 rewrite_options=rewrites, build_cost_model=1) 2820 config_proto = config_pb2.ConfigProto(graph_options=graph_options) 2821 2822 with context.graph_mode(), self.cached_session( 2823 config=config_proto, graph=ops.Graph(), use_gpu=True): 2824 2825 @function.defun_with_attributes( 2826 attributes={ 2827 'api_implements': 'random_boost', 2828 'api_preferred_device': 'CPU' 2829 }) 2830 def cpu_boost(x): 2831 return math_ops.add(x, 2.0) 2832 2833 @function.defun_with_attributes( 2834 attributes={ 2835 'api_implements': 'random_boost', 2836 'api_preferred_device': 'GPU' 2837 }) 2838 def gpu_boost(x): 2839 return math_ops.add(x, 4.0) 2840 2841 x = constant_op.constant(1.0) 2842 2843 function.register(cpu_boost, x) 2844 y = gpu_boost(x) 2845 y_value = self.evaluate(y) 2846 2847 if test.is_gpu_available(): 2848 self.assertEqual(y_value, 5.0) 2849 else: 2850 # Grappler fallback to use the CPU impl even called with GPU function. 2851 self.assertEqual(y_value, 3.0) 2852 2853 @test_util.disable_tfrt('b/174712583: TFRT doesn\'t support behavior ' 2854 'equivalent to implementation_selector for function') 2855 def testSwapImplementationInEager(self): 2856 if not context.executing_eagerly(): 2857 self.skipTest('eager only') 2858 2859 # testSharedRendezvous sets the disable_meta_optimizer flag to True 2860 # if that subtest runs before this one, then having that set to True 2861 # will cause this subtest to fail. To avoid that scenario, explicitly 2862 # set the disable_meta_optimizer flag to false here 2863 context.context().set_optimizer_experimental_options({ 2864 'min_graph_nodes': -1, 2865 'implementation_selector': True, 2866 'disable_meta_optimizer': False 2867 }) 2868 2869 @function.defun_with_attributes( 2870 attributes={'api_implements': 'foo', 2871 'api_preferred_device': 'CPU'}) 2872 def on_cpu(x): 2873 return x + 2 2874 2875 @function.defun_with_attributes( 2876 attributes={'api_implements': 'foo', 2877 'api_preferred_device': 'GPU'}) 2878 def on_gpu(x): 2879 return x + 4 2880 2881 @function.defun 2882 def run_on_cpu(t): 2883 function.register(on_cpu, t) 2884 with ops.device('CPU:0'): 2885 return on_gpu(t) 2886 2887 # Expect to run the on_cpu branch, regardless whether gpu is available. 2888 self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3) 2889 2890 def testDefunFunctionSeparateGraphs(self): 2891 with context.graph_mode(): 2892 2893 @function.defun 2894 def add(x): 2895 return x + 5 2896 2897 @function.defun 2898 def maybe_add(x, should_add): 2899 if should_add: 2900 return add(x) 2901 else: 2902 return x 2903 2904 with ops.Graph().as_default(): 2905 x = constant_op.constant(11) 2906 maybe_add(x, True) 2907 self.assertLen(total_function_cache(maybe_add), 1) 2908 self.assertLen(total_function_cache(add), 1) 2909 2910 maybe_add(x, False) 2911 self.assertLen(total_function_cache(maybe_add), 2) 2912 self.assertLen(total_function_cache(add), 1) 2913 2914 with ops.Graph().as_default(): 2915 x = constant_op.constant(11) 2916 maybe_add(x, True) 2917 self.assertLen(total_function_cache(maybe_add), 3) 2918 self.assertLen(total_function_cache(add), 2) 2919 2920 def testCacheKeyOverlappingShapes(self): 2921 @function.defun 2922 def defined(t): 2923 return t 2924 2925 defined(array_ops.zeros([12, 1])) 2926 self.assertLen(total_function_cache(defined), 1) 2927 2928 defined(array_ops.zeros([1, 21])) 2929 self.assertLen(total_function_cache(defined), 2) 2930 2931 def testCacheKeyNestedLists(self): 2932 @function.defun 2933 def defined(l): 2934 return l 2935 2936 a = constant_op.constant(1.) 2937 b = constant_op.constant(2.) 2938 c = constant_op.constant(3.) 2939 defined([[a], b, c]) 2940 self.assertLen(total_function_cache(defined), 1) 2941 2942 defined([[a, b], c]) 2943 self.assertLen(total_function_cache(defined), 2) 2944 2945 def testCacheKeyAttrsClass(self): 2946 if attr is None: 2947 self.skipTest('attr module is unavailable.') 2948 2949 @attr.s 2950 class TestClass(object): 2951 a = attr.ib() 2952 b = attr.ib() 2953 2954 @function.defun 2955 def defined(l): 2956 return l 2957 2958 defined( 2959 TestClass( 2960 constant_op.constant(1.), 2961 [constant_op.constant(2.), 2962 constant_op.constant(3.)])) 2963 self.assertLen(total_function_cache(defined), 1) 2964 defined( 2965 TestClass( 2966 constant_op.constant(1.), 2967 [constant_op.constant(2.), 2968 constant_op.constant(3.)])) 2969 self.assertLen(total_function_cache(defined), 1) 2970 2971 defined( 2972 TestClass([constant_op.constant(1.), 2973 constant_op.constant(2.)], constant_op.constant(3.))) 2974 self.assertLen(total_function_cache(defined), 2) 2975 2976 def testCacheKeyVariables(self): 2977 @function.defun 2978 def defined(a, b, c): 2979 return a + b + c 2980 2981 x = resource_variable_ops.ResourceVariable(0.0) 2982 y = resource_variable_ops.ResourceVariable(0.0) 2983 z = resource_variable_ops.ResourceVariable(0.0) 2984 2985 # If tensor equality is not enabled, we always get a cache miss if the 2986 # function is called with different variables. With equality enabled we 2987 # should only get a miss if the aliasing changed. 2988 defined(x, y, z) 2989 self.assertLen(total_function_cache(defined), 1) 2990 defined(x, y, z) 2991 self.assertLen(total_function_cache(defined), 1) 2992 2993 # Re-arranging arguments causes cache miss 2994 defined(z, y, x) 2995 self.assertLen(total_function_cache(defined), 2) 2996 defined(z, y, x) 2997 self.assertLen(total_function_cache(defined), 2) 2998 2999 # Aliasing causes cache miss 3000 defined(x, x, z) 3001 self.assertLen(total_function_cache(defined), 3) 3002 defined(x, x, z) 3003 self.assertLen(total_function_cache(defined), 3) 3004 3005 # Re-arranging arguments causes cache miss 3006 defined(y, y, z) 3007 self.assertLen(total_function_cache(defined), 4) 3008 defined(y, y, z) 3009 self.assertLen(total_function_cache(defined), 4) 3010 3011 # Different alias positions causes cache miss 3012 defined(z, y, y) 3013 self.assertLen(total_function_cache(defined), 5) 3014 defined(z, y, y) 3015 self.assertLen(total_function_cache(defined), 5) 3016 3017 x_copy = copy.deepcopy(x) 3018 3019 # Deep copy causes cache miss 3020 defined(x_copy, y, z) 3021 self.assertLen(total_function_cache(defined), 6) 3022 defined(x_copy, y, z) 3023 self.assertLen(total_function_cache(defined), 6) 3024 3025 def testVariableRetracing(self): 3026 v1 = variables.Variable(1.) 3027 v2 = variables.Variable(1.) 3028 v3 = copy.deepcopy(variables.Variable(1.)) 3029 3030 var_dict = {id(v1): constant_op.constant(1), 3031 id(v2): constant_op.constant(2), 3032 id(v3): constant_op.constant(3)} 3033 3034 @function.defun 3035 def lookup_tensor(v): 3036 return var_dict[id(v)] 3037 3038 self.assertEqual(1, lookup_tensor(v1).numpy()) 3039 self.assertEqual(2, lookup_tensor(v2).numpy()) 3040 self.assertEqual(3, lookup_tensor(v3).numpy()) 3041 3042 def testDecoratedMethodInspect(self): 3043 3044 class DefunnedMiniModel(object): 3045 3046 @function.defun 3047 def call(self, inputs, training=True): 3048 pass 3049 3050 m = DefunnedMiniModel() 3051 fullargspec = tf_inspect.getfullargspec(m.call) 3052 self.assertIn('training', fullargspec.args) 3053 3054 def testFunctionModifiesInputList(self): 3055 # Tests on `list` methods that do in place modification, except `list.sort` 3056 # since it cannot even be "defunned" in the first place 3057 3058 def get_list(): 3059 return [constant_op.constant(0.), constant_op.constant(1.)] 3060 3061 expected_msg = '.*() should not modify' 3062 3063 with self.assertRaisesRegex(ValueError, expected_msg): 3064 3065 @def_function.function 3066 def append(l): 3067 l.append(constant_op.constant(0.)) 3068 3069 append(get_list()) 3070 3071 with self.assertRaisesRegex(ValueError, expected_msg): 3072 3073 @def_function.function 3074 def extend(l): 3075 l.extend([constant_op.constant(0.)]) 3076 3077 extend(get_list()) 3078 3079 with self.assertRaisesRegex(ValueError, expected_msg): 3080 3081 @def_function.function 3082 def insert(l): 3083 l.insert(0, constant_op.constant(0.)) 3084 3085 insert(get_list()) 3086 3087 with self.assertRaisesRegex(ValueError, expected_msg): 3088 3089 @def_function.function 3090 def pop(l): 3091 l.pop() 3092 3093 pop(get_list()) 3094 3095 with self.assertRaisesRegex(ValueError, expected_msg): 3096 3097 @def_function.function 3098 def reverse(l): 3099 l.reverse() 3100 3101 reverse(get_list()) 3102 3103 with self.assertRaisesRegex(ValueError, expected_msg): 3104 3105 @def_function.function 3106 def remove(l): 3107 l.remove(l[0]) 3108 3109 remove(get_list()) 3110 3111 # `list.clear` is a method that is in Py3 but not Py2 3112 if sys.version.startswith('3'): 3113 3114 with self.assertRaisesRegex(ValueError, expected_msg): 3115 3116 @def_function.function 3117 def clear(l): 3118 l.clear() 3119 3120 clear(get_list()) 3121 3122 # One last test for keyword arguments 3123 with self.assertRaisesRegex(ValueError, expected_msg): 3124 3125 @def_function.function 3126 def kwdappend(**kwargs): 3127 l = kwargs['l'] 3128 l.append(constant_op.constant(0.)) 3129 3130 kwdappend(l=get_list()) 3131 3132 def testFunctionModifiesInputDict(self): 3133 3134 def get_dict(): 3135 return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)} 3136 3137 expected_msg = '.* should not modify' 3138 3139 with self.assertRaisesRegex(ValueError, expected_msg): 3140 3141 @def_function.function 3142 def clear(m): 3143 m.clear() 3144 3145 clear(get_dict()) 3146 3147 with self.assertRaisesRegex(ValueError, expected_msg): 3148 3149 @def_function.function 3150 def pop(m): 3151 m.pop('t1') 3152 3153 pop(get_dict()) 3154 3155 with self.assertRaisesRegex(ValueError, expected_msg): 3156 3157 @def_function.function 3158 def popitem(m): 3159 m.popitem() 3160 3161 popitem(get_dict()) 3162 3163 with self.assertRaisesRegex(ValueError, expected_msg): 3164 3165 @def_function.function 3166 def update(m): 3167 m.update({'t1': constant_op.constant(3.)}) 3168 3169 update(get_dict()) 3170 3171 with self.assertRaisesRegex(ValueError, expected_msg): 3172 3173 @def_function.function 3174 def setdefault(m): 3175 m.setdefault('t3', constant_op.constant(3.)) 3176 3177 setdefault(get_dict()) 3178 3179 def testFunctionModifiesInputNest(self): 3180 with self.assertRaisesRegex(ValueError, 'modify.* should not modify'): 3181 3182 @def_function.function 3183 def modify(n): 3184 n[0]['t1'].append(constant_op.constant(1.)) 3185 3186 nested_input = [{ 3187 't1': [constant_op.constant(0.), 3188 constant_op.constant(1.)], 3189 }, 3190 constant_op.constant(2.)] 3191 3192 modify(nested_input) 3193 3194 with self.assertRaisesRegex(ValueError, 3195 'modify_same_flat.* should not modify'): 3196 3197 # The flat list doesn't change whereas the true structure changes 3198 @def_function.function 3199 def modify_same_flat(n): 3200 n[0].append(n[1].pop(0)) 3201 3202 nested_input = [[constant_op.constant(0.)], 3203 [constant_op.constant(1.), 3204 constant_op.constant(2.)]] 3205 3206 modify_same_flat(nested_input) 3207 3208 @test_util.disable_tfrt('b/173429686') 3209 def testExecutorType(self): 3210 @function.defun 3211 def add_five(x): 3212 return x + 5 3213 3214 self.assertEqual( 3215 5, 3216 add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) 3217 3218 with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'): 3219 with context.function_executor_type('NON_EXISTENT_EXECUTOR'): 3220 add_five(constant_op.constant(0, dtype=dtypes.int32)) 3221 3222 for executor_type in ('', 'DEFAULT', None): 3223 with context.function_executor_type(executor_type): 3224 self.assertAllEqual( 3225 5, 3226 add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) 3227 3228 @test_util.assert_no_garbage_created 3229 def testReferenceCycles(self): 3230 3231 fn = function.defun(lambda x: 2. * x) 3232 3233 fn(constant_op.constant(4.0)) 3234 weak_fn = weakref.ref(fn) 3235 del fn 3236 # Tests that the weak reference we made to the function is now dead, which 3237 # means the object has been deleted. This should be true as long as the 3238 # function itself is not involved in a reference cycle. 3239 self.assertIs(None, weak_fn()) 3240 3241 def testFunctionStackInErrorMessage(self): 3242 if context.executing_eagerly(): 3243 # TODO(b/122736651): Remove this skipTest once fixed. 3244 self.skipTest('Error interpolation is not working when function is ' 3245 'invoked without PartitionedCallOp.') 3246 3247 @def_function.function() 3248 def fn3(x): 3249 return x + 2 3250 3251 @def_function.function() 3252 def fn2(x): 3253 check_ops.assert_equal(fn3(x), 3) 3254 return 2 3255 3256 @def_function.function() 3257 def fn(x): 3258 return fn2(x) 3259 3260 with self.assertRaises(errors.InvalidArgumentError) as cm: 3261 fn(2) 3262 e = cm.exception 3263 self.assertIn('fn -> fn2', e.message) 3264 self.assertIn('node assert_equal/Assert/Assert (defined at', e.message) 3265 self.assertNotIn('fn3', e.message) 3266 3267 @test_util.run_gpu_only 3268 def testFunctionIsNotPinned(self): 3269 """Tests that functions aren't pinned to the CPU by the eager runtime.""" 3270 seed1, seed2 = 79, 25 3271 shape = constant_op.constant([4, 7]) 3272 dtype = dtypes.float32 3273 3274 @def_function.function 3275 def func(): 3276 with ops.device('GPU:0'): 3277 return gen_random_ops.random_standard_normal( 3278 shape, dtype=dtype, seed=seed1, seed2=seed2) 3279 3280 with ops.device('GPU:0'): 3281 x = func() 3282 self.assertRegex(x.device, 'GPU') 3283 3284 @test_util.run_in_graph_and_eager_modes 3285 def testShapeCaching(self): 3286 3287 @function.defun 3288 def func(x): 3289 return array_ops.shape(x) 3290 3291 @function.defun( 3292 input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)]) 3293 def calls_func(x): 3294 return func(x) 3295 3296 self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1])))) 3297 self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2])))) 3298 self.assertAllEqual( 3299 [3, 3], 3300 self.evaluate(calls_func(array_ops.zeros([3, 3])))) 3301 3302 def testLimitedRetracing(self): 3303 trace_count = [0] 3304 @function.defun 3305 def func(x): 3306 trace_count[0] += 1 3307 return x 3308 3309 for _ in range(50): 3310 func(constant_op.constant(3.)) 3311 func(constant_op.constant(4.)) 3312 func(constant_op.constant([[1., 2.]])) 3313 func(constant_op.constant([[]])) 3314 func(constant_op.constant([[3., 4.], [5., 6.]])) 3315 func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]])) 3316 # Tracing more than twice per input doesn't make sense. 3317 self.assertLess(trace_count[0], 13) 3318 3319 def testLimitedRetracingWithCompositeTensors(self): 3320 trace_count = [0] 3321 3322 @def_function.function 3323 def f(x): 3324 trace_count[0] += 1 3325 return x 3326 3327 for i in range(10): 3328 f(ragged_factory_ops.constant([[1, 2], [i]])) 3329 f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]])) 3330 f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])) 3331 self.assertEqual(trace_count[0], 3) 3332 3333 def test_concrete_function_shape_mismatch(self): 3334 3335 @def_function.function 3336 def f(argument_name): 3337 return argument_name + 1. 3338 3339 f_concrete = f.get_concrete_function(constant_op.constant([1.])) 3340 3341 # Calling a function from eager doesn't do any shape checking above what 3342 # kernels do while executing. 3343 self.assertAllEqual( 3344 [2., 3.], 3345 f_concrete(constant_op.constant([1., 2.])).numpy()) 3346 3347 @def_function.function 3348 def g(): 3349 f_concrete(constant_op.constant([1., 2.])) 3350 3351 with self.assertRaisesRegex(ValueError, 'argument_name'): 3352 g() 3353 3354 @test_util.run_in_graph_and_eager_modes 3355 def test_shape_inference_with_symbolic_shapes(self): 3356 3357 @def_function.function 3358 def _uses_symbolic_shapes(w, x, y): 3359 x = array_ops.identity(x, name='name_collision') 3360 x = array_ops.transpose(x, [1, 0, 2]) 3361 x_batch = array_ops.shape(x)[0] 3362 y_batch = array_ops.shape(y)[0] 3363 y *= w 3364 n = y_batch // x_batch 3365 return array_ops.reshape(y, [n, x_batch, -1]) 3366 3367 conc = _uses_symbolic_shapes.get_concrete_function( 3368 tensor_spec.TensorSpec(None, dtypes.float32), 3369 tensor_spec.TensorSpec(None, dtypes.float32), 3370 tensor_spec.TensorSpec(None, dtypes.float32)) 3371 3372 @def_function.function 3373 def _call_concrete(): 3374 c = constant_op.constant(1.) 3375 array_ops.identity(c, name='name_collision') 3376 output1 = conc(array_ops.ones([2]), 3377 array_ops.ones([5, 4, 2]), 3378 array_ops.ones([20, 2])) 3379 self.assertEqual([5, 4, 2], output1.shape) 3380 output2 = conc(array_ops.ones([3]), 3381 array_ops.ones([5, 4, 3]), 3382 array_ops.ones([40, 3])) 3383 self.assertEqual([10, 4, 3], output2.shape) 3384 return output1, output2 3385 3386 output1, output2 = _call_concrete() 3387 self.assertEqual((5, 4, 2), self.evaluate(output1).shape) 3388 self.assertEqual((10, 4, 3), self.evaluate(output2).shape) 3389 3390 def testAutoGraphContext(self): 3391 3392 @def_function.function 3393 def test_fn(): 3394 self.assertEqual( 3395 ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED) 3396 3397 prev_status = ag_ctx.control_status_ctx().status 3398 test_fn() 3399 self.assertEqual(ag_ctx.control_status_ctx().status, prev_status) 3400 3401 @test_util.disable_tfrt('b/170435618') 3402 def testCancelBeforeFunctionExecution(self): 3403 if not context.executing_eagerly(): 3404 self.skipTest('eager only') 3405 3406 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 3407 3408 @def_function.function 3409 def f(): 3410 return q.dequeue() 3411 3412 c_mgr = cancellation.CancellationManager() 3413 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 3414 3415 c_mgr.start_cancel() 3416 with self.assertRaises(errors.CancelledError): 3417 cancelable_func() 3418 3419 @test_util.disable_tfrt('b/170435618') 3420 def testCancelBlockedFunctionExecution(self): 3421 if not context.executing_eagerly(): 3422 self.skipTest('eager only') 3423 3424 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 3425 3426 @def_function.function 3427 def f(): 3428 return q.dequeue() 3429 3430 c_mgr = cancellation.CancellationManager() 3431 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 3432 3433 def cancel_thread(): 3434 time.sleep(0.5) 3435 c_mgr.start_cancel() 3436 3437 t = self.checkedThread(cancel_thread) 3438 t.start() 3439 with self.assertRaises(errors.CancelledError): 3440 cancelable_func() 3441 t.join() 3442 3443 @test_util.disable_tfrt('b/170435618') 3444 def testCancelAfterFunctionExecution(self): 3445 if not context.executing_eagerly(): 3446 self.skipTest('eager only') 3447 3448 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 3449 q.enqueue(37) 3450 3451 @def_function.function 3452 def f(): 3453 return q.dequeue() 3454 3455 c_mgr = cancellation.CancellationManager() 3456 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 3457 3458 self.assertAllEqual(37, cancelable_func().numpy()) 3459 3460 # Cancellation after the function executes is a no-op. 3461 c_mgr.start_cancel() 3462 3463 def testAddFunctionCallback(self): 3464 functions = [] 3465 def function_callback(f, name, graph, inputs, outputs): 3466 del name, graph, inputs, outputs 3467 functions.append(f) 3468 3469 @def_function.function 3470 def plus_one(x): 3471 return x + 1 3472 3473 try: 3474 function.add_function_callback(function_callback) 3475 x_float32 = numpy.array(3.0, dtype=numpy.float32) 3476 self.assertAllClose(plus_one(x_float32), 4.0) 3477 self.assertLen(functions, 1) 3478 # Function is already created. Executing it again should not invoke the 3479 # function callback. 3480 self.assertAllClose(plus_one(x_float32), 4.0) 3481 self.assertLen(functions, 1) 3482 # Signature change leads to a new Function being built. 3483 x_float64 = numpy.array(3.0, dtype=numpy.float64) 3484 self.assertAllClose(plus_one(x_float64), 4.0) 3485 self.assertLen(functions, 2) 3486 finally: 3487 function.clear_function_callbacks() 3488 3489 def testFunctionCallbackAddOps(self): 3490 file_name = os.path.join(self.get_temp_dir(), 'test') 3491 3492 def function_callback(f, name, graph, inputs, outputs): 3493 del f, name, inputs 3494 3495 with graph.as_default(): 3496 printer = logging_ops.print_v2( 3497 'hello', 3498 output_stream='file://' + file_name 3499 ) 3500 outputs[0].op._add_control_input(printer) 3501 3502 @def_function.function 3503 def plus_one(x): 3504 return x + 1 3505 3506 self.addCleanup(function.clear_function_callbacks) 3507 function.add_function_callback(function_callback) 3508 x_float32 = numpy.array(3.0, dtype=numpy.float32) 3509 3510 self.assertAllClose(plus_one(x_float32), 4.0) 3511 3512 with open(file_name, 'r') as f: 3513 self.assertEqual(f.read().strip(), 'hello') 3514 3515 def testRemoveFunctionCallback(self): 3516 functions_1 = [] 3517 def function_callback_1(f, name, graph, inputs, outputs): 3518 del name, graph, inputs, outputs 3519 functions_1.append(f) 3520 3521 functions_2 = [] 3522 def function_callback_2(f, name, graph, inputs, outputs): 3523 del name, graph, inputs, outputs 3524 functions_2.append(f) 3525 3526 @def_function.function 3527 def plus_one(x): 3528 return x + 1 3529 3530 try: 3531 function.add_function_callback(function_callback_1) 3532 function.add_function_callback(function_callback_2) 3533 self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float32)), 4.0) 3534 self.assertLen(functions_1, 1) 3535 self.assertLen(functions_2, 1) 3536 function.remove_function_callback(function_callback_1) 3537 # The 1st callback should not be invokved after remove_function_callback() 3538 # is called. 3539 self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float64)), 4.0) 3540 self.assertLen(functions_1, 1) 3541 self.assertLen(functions_2, 2) 3542 finally: 3543 function.clear_function_callbacks() 3544 3545 def testClearFunctionCallbacks(self): 3546 function.add_function_callback(lambda f: None) 3547 function.add_function_callback(lambda f: None) 3548 self.assertLen(function._function_callbacks, 2) 3549 function.clear_function_callbacks() 3550 self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access 3551 3552 @test_util.run_in_graph_and_eager_modes 3553 def testConcreteFunctionWithNestedTensorInputs(self): 3554 3555 @def_function.function 3556 def f(x, y): 3557 return (x['a'] + x['b'], y[0] + y[1]) 3558 3559 a = constant_op.constant(1000) 3560 b = constant_op.constant(200) 3561 c = constant_op.constant(30) 3562 d = {'a': a, 'b': b} 3563 e = (c, 4) 3564 3565 # Test different argument signatures when constructing the concrete func. 3566 for cf in [ 3567 f.get_concrete_function(d, e), 3568 f.get_concrete_function(d, y=e), 3569 f.get_concrete_function(y=e, x=d), 3570 f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)), 3571 f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)), 3572 f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d)) 3573 ]: 3574 # Test different calling conventions when calling the concrete func. 3575 for output in [ 3576 cf(d, e), # structured signature 3577 cf(d, y=e), # structured signature w/ kwarg 3578 cf(y=e, x=d), # structured signature w/ 2 kwargs 3579 cf(a, b, c), # flat signature 3580 cf(x=a, x_1=b, y=c) # flat signature w/ kwargs 3581 ]: 3582 self.assertIsInstance(output, tuple) 3583 self.assertLen(output, 2) 3584 self.assertAllEqual(output[0], 1200) 3585 self.assertAllEqual(output[1], 34) 3586 3587 @test_util.run_in_graph_and_eager_modes 3588 def testConcreteFunctionWithNestedNonTensorInputs(self): 3589 3590 @def_function.function 3591 def f(x, y): 3592 return (x['a'] + x['b'], y[0] + y[1]) 3593 3594 a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)} 3595 b = (50, 3) 3596 3597 for cf in [ # argument y is bound to non-Tensor value (50, 3). 3598 f.get_concrete_function(a, b), 3599 f.get_concrete_function(a, y=b), 3600 f.get_concrete_function(x=a, y=b) 3601 ]: 3602 for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]: 3603 self.assertAllEqual(output[0] + output[1], 1253) 3604 3605 @test_util.run_in_graph_and_eager_modes 3606 def testConcreteFunctionWithNonTensorStringInputs(self): 3607 3608 @def_function.function 3609 def f(x, y): 3610 return string_ops.string_join([x, y]) 3611 3612 a = constant_op.constant('a') 3613 b = 'b' 3614 3615 cf = f.get_concrete_function(a, b) 3616 for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]: 3617 self.assertAllEqual(output, b'ab') 3618 3619 @test_util.run_in_graph_and_eager_modes 3620 def testConcreteFunctionWithBoundNestedNonTensorInputs(self): 3621 3622 @def_function.function 3623 def f(x, y): 3624 return (x['a'] + x['b'], y[0] + y[1]) 3625 3626 a = {'a': 3000, 'b': 200, 'c': 9000} 3627 b = (constant_op.constant(30), 4) 3628 3629 for cf in [ # argument x is bound to non-tensor value `a` 3630 f.get_concrete_function(a, b), 3631 f.get_concrete_function(a, y=b), 3632 f.get_concrete_function(x=a, y=b) 3633 ]: 3634 for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]: 3635 self.assertAllEqual(output[0] + output[1], 3234) 3636 3637 @test_util.run_in_graph_and_eager_modes 3638 def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self): 3639 3640 @def_function.function 3641 def f(x, y): 3642 return (x['a'] + x['b'], y[0] + y[1]) 3643 3644 a = {'a': 5000, 'b': 500} 3645 b = (50, 5) 3646 3647 cf = f.get_concrete_function(a, b) 3648 for output in [cf(), cf(a), cf(y=b)]: 3649 self.assertAllEqual(output[0] + output[1], 5555) 3650 3651 @test_util.run_in_graph_and_eager_modes 3652 def testConcreteFunctionMethodWithVarargs(self): 3653 float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) 3654 3655 class MyModel(module.Module): 3656 3657 @def_function.function(input_signature=[float32_scalar, float32_scalar]) 3658 def add(self, *arg): 3659 return math_ops.add(*arg) 3660 3661 m = MyModel() 3662 cf = m.add.get_concrete_function() 3663 cf(-12.0, 3.0) 3664 3665 @test_util.run_in_graph_and_eager_modes 3666 def testConcreteFunctionStructuredSignatureKeywordOrder(self): 3667 # Check that keyword-only arguments are sorted appropriately, so that they 3668 # feed the right tensor into each input. 3669 @def_function.function 3670 def g(**kwargs): 3671 return string_ops.reduce_join( 3672 string_ops.reduce_join( 3673 ops.convert_to_tensor(sorted(kwargs.items())), 3674 axis=1, 3675 separator='='), 3676 axis=0, 3677 separator=', ') 3678 3679 s = constant_op.constant('s') 3680 g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s) 3681 self.assertAllEqual( 3682 g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'), 3683 b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') 3684 self.assertAllEqual( 3685 g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'), 3686 b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') 3687 self.assertAllEqual( 3688 g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'), 3689 b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') 3690 3691 # pylint: disable=g-long-lambda 3692 @parameterized.named_parameters([ 3693 dict( 3694 testcase_name='MissingArg', 3695 conc_args=lambda: (1, constant_op.constant(2)), 3696 call_args=lambda: (1,), 3697 error=r'func\(x, y\) missing required arguments: y'), 3698 dict( 3699 testcase_name='MissingVararg', 3700 conc_args=lambda: (1, 2, constant_op.constant(1.0)), 3701 call_args=lambda: (1, 2), 3702 error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'), 3703 dict( 3704 testcase_name='ExtraPositionalArg', 3705 conc_args=lambda: (1, 2), 3706 call_args=lambda: (1, 2, 3), 3707 error=r'func\(x, y\) takes 2 positional arguments but 3 were given'), 3708 dict( 3709 testcase_name='MissingKeywordOnlyArg', 3710 conc_args=lambda: (1, 2), 3711 conc_kwargs=lambda: {'c': constant_op.constant(1.0)}, 3712 call_args=lambda: (1, 2), 3713 error=r'func\(x, y, \*, c\) missing required arguments: c'), 3714 dict( 3715 testcase_name='ExtraKeywordArg', 3716 conc_args=lambda: (1, 2), 3717 call_args=lambda: (1, 2), 3718 call_kwargs=lambda: {'c': constant_op.constant(1.0)}, 3719 error=r'func\(x, y\) got unexpected keyword arguments: c'), 3720 dict( 3721 testcase_name='ExpectedRaggedGotNest', 3722 conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),), 3723 call_args=lambda: ({ 3724 'a': constant_op.constant([1, 2, 3]) 3725 },), 3726 error=r'func\(x, y\): argument x had incorrect type\n' 3727 r' expected: RaggedTensor\n' 3728 r" got: {'a': (Eager)?Tensor}"), 3729 dict( 3730 testcase_name='WrongRaggedRank', 3731 conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),), 3732 call_args=lambda: (ragged_factory_ops.constant([[[1]]]),), 3733 error=r'func\(x, y\): argument x had incorrect type\n'), 3734 dict( 3735 testcase_name='WrongRaggedDType', 3736 conc_args=lambda: (ragged_factory_ops.constant([[1]]),), 3737 call_args=lambda: (ragged_factory_ops.constant([[1.0]]),), 3738 error=r'func\(x, y\): argument x had incorrect type\n'), 3739 dict( 3740 testcase_name='ExpectedDictGotTensor', 3741 conc_args=lambda: ({ 3742 'a': constant_op.constant(1), 3743 'b': constant_op.constant(1) 3744 },), 3745 call_args=lambda: (constant_op.constant(1),), 3746 error=r'func\(x, y\): argument x had incorrect type\n'), 3747 dict( 3748 testcase_name='ExpectedTupleGotTensor', 3749 conc_args=lambda: 3750 ((constant_op.constant(1), constant_op.constant(2)),), 3751 call_args=lambda: (constant_op.constant(1),), 3752 error=r'func\(x, y\): argument x had incorrect type\n'), 3753 dict( 3754 testcase_name='WrongDType', 3755 conc_args=lambda: (constant_op.constant(1),), 3756 call_args=lambda: (constant_op.constant(1.0),), 3757 exception=(ValueError, errors.InvalidArgumentError, 3758 # on xla_gpu, we get InternalError instead. 3759 errors.InternalError)), 3760 dict( 3761 testcase_name='ExpectedTensorGotInt', 3762 conc_args=lambda: (constant_op.constant(1),), 3763 call_args=lambda: (5,), 3764 error=r'func\(x, y\) expected a Tensor in x, but got int value 5'), 3765 dict( 3766 testcase_name='ExpectedIntGotDifferentInt', 3767 conc_args=lambda: (5,), 3768 call_args=lambda: (8,), 3769 error=r'ConcreteFunction func\(x, y\) was constructed with int ' 3770 r'value 5 in x, but was called with int value 8'), 3771 dict( 3772 testcase_name='ExpectedIntGotTensor', 3773 conc_args=lambda: (5,), 3774 call_args=lambda: (constant_op.constant(6),), 3775 error=r'ConcreteFunction func\(x, y\) was constructed with int ' 3776 'value 5 in x, but was called with (Eager)?Tensor value .*'), 3777 dict( 3778 testcase_name='TwoValuesForArgument', 3779 conc_args=lambda: (1, 2), 3780 call_args=lambda: (1, 2), 3781 call_kwargs=lambda: {'x': 3}, 3782 error=r"func\(x, y\) got two values for argument 'x'"), 3783 ]) 3784 # pylint: enable=g-long-lambda 3785 @test_util.run_in_graph_and_eager_modes 3786 def testConcreteFunctionStructuredSignatureError(self, 3787 conc_args=(), 3788 conc_kwargs=None, 3789 call_args=(), 3790 call_kwargs=None, 3791 error='.*', 3792 exception=TypeError): 3793 """Tests for errors in the structrued signature. 3794 3795 Args: 3796 conc_args: Positional arguments used for get_concrete_function. 3797 conc_kwargs: Keyword arguments used for get_concrete_function. 3798 call_args: Positional arguments used to call the function. 3799 call_kwargs: Keyword arguments used to call the function. 3800 error: Expected exception message. 3801 exception: Expected exception type. 3802 """ 3803 conc_args = conc_args() if callable(conc_args) else conc_args 3804 conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {} 3805 call_args = call_args() if callable(call_args) else call_args 3806 call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {} 3807 self.assertIsInstance(conc_args, tuple) 3808 self.assertIsInstance(call_args, tuple) 3809 self.assertIsInstance(conc_kwargs, dict) 3810 self.assertIsInstance(call_kwargs, dict) 3811 3812 @def_function.function 3813 def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg 3814 del y, varargs, kwargs 3815 return x 3816 3817 conc = func.get_concrete_function(*conc_args, **conc_kwargs) 3818 with self.assertRaisesRegex(exception, error): 3819 self.evaluate(conc(*call_args, **call_kwargs)) 3820 3821 # pylint: disable=g-long-lambda 3822 @parameterized.named_parameters([ 3823 dict( 3824 testcase_name='MissingArg', 3825 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3826 call_args=lambda: (constant_op.constant(1),), 3827 error=r'func\(x, y\) missing required arguments: y'), 3828 dict( 3829 testcase_name='TwoValuesForArg', 3830 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3831 call_args=lambda: (constant_op.constant(1),), 3832 call_kwargs=lambda: { 3833 'x': constant_op.constant(1), 3834 'y': constant_op.constant(1) 3835 }, 3836 error=r"func\(x, y\) got two values for argument 'x'"), 3837 dict( 3838 testcase_name='ExtraPositionalArg', 3839 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3840 call_args=lambda: (constant_op.constant(1), constant_op.constant(2), 3841 constant_op.constant(3)), 3842 error=r'func\(x, y\) takes 2 positional arguments but 3 were given'), 3843 dict( 3844 testcase_name='UnexpectedKeywordArg', 3845 conc_args=lambda: (constant_op.constant(1),), 3846 call_args=lambda: (constant_op.constant(1),), 3847 call_kwargs=lambda: {'c': constant_op.constant(1)}, 3848 error=r'func\(x\) got unexpected keyword arguments: c'), 3849 dict( 3850 testcase_name='MissingVararg', 3851 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2), 3852 constant_op.constant(3)), 3853 call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3854 error=r'func\(x, y, varargs_0\) missing required ' 3855 r'arguments: varargs_0'), 3856 dict( 3857 testcase_name='MissingKeywordArg', 3858 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3859 conc_kwargs=lambda: {'c': constant_op.constant(1)}, 3860 call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3861 error=r'func\(x, y, c\) missing required arguments: c'), 3862 dict( 3863 testcase_name='ExpectedTensorGotInt', 3864 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3865 call_args=lambda: (5, constant_op.constant(2)), 3866 error=r'func\(x, y\): expected argument #0\(zero-based\) to be ' 3867 r'a Tensor; got int \(5\)'), 3868 dict( 3869 testcase_name='WrongDType', 3870 conc_args=lambda: (constant_op.constant(1),), 3871 call_args=lambda: (constant_op.constant(1.0),), 3872 exception=(ValueError, errors.InvalidArgumentError, 3873 # on xla_gpu, we get InternalError instead. 3874 errors.InternalError)), 3875 dict( 3876 testcase_name='MissingKeywordArgNestPiece', 3877 conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3878 conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])}, 3879 call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), 3880 call_kwargs=lambda: {'c': constant_op.constant(1)}, 3881 error=r'func\(x, y, c, c_1\) missing required arguments: c_1'), 3882 ]) 3883 # pylint: enable=g-long-lambda 3884 @test_util.run_in_graph_and_eager_modes 3885 def testConcreteFunctionFlatSignatureError(self, 3886 conc_args=(), 3887 conc_kwargs=None, 3888 call_args=(), 3889 call_kwargs=None, 3890 error='.*', 3891 exception=TypeError): 3892 """Tests for errors in the flat signature. 3893 3894 Args: 3895 conc_args: Positional arguments used for get_concrete_function. 3896 conc_kwargs: Keyword arguments used for get_concrete_function. 3897 call_args: Positional arguments used to call the function. 3898 call_kwargs: Keyword arguments used to call the function. 3899 error: Expected exception message. 3900 exception: Expected exception type. 3901 """ 3902 conc_args = conc_args() if callable(conc_args) else conc_args 3903 conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {} 3904 call_args = call_args() if callable(call_args) else call_args 3905 call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {} 3906 self.assertIsInstance(conc_args, tuple) 3907 self.assertIsInstance(call_args, tuple) 3908 self.assertIsInstance(conc_kwargs, dict) 3909 self.assertIsInstance(call_kwargs, dict) 3910 3911 @def_function.function 3912 def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg 3913 del y, varargs, kwargs 3914 return x 3915 3916 conc = func.get_concrete_function(*conc_args, **conc_kwargs) 3917 3918 # Remove _function_spec, to disable the structured signature. 3919 conc._set_function_spec(None) # pylint: disable=protected-access 3920 3921 with self.assertRaisesRegex(exception, error): 3922 self.evaluate(conc(*call_args, **call_kwargs)) 3923 3924 @test_util.run_in_graph_and_eager_modes 3925 def testConcreteFunctionAmbiguousSignature(self): 3926 # When both the flat & structured signatures are applicable, but they 3927 # give different results, we use the structured signature. Note: we expect 3928 # this to be extremely rare. 3929 @def_function.function 3930 def f(x, y): 3931 return x * 10 + y 3932 3933 conc = f.get_concrete_function( 3934 x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'), 3935 y=tensor_spec.TensorSpec(None, dtypes.int32, name='x')) 3936 3937 result = conc(x=constant_op.constant(5), y=constant_op.constant(6)) 3938 self.assertAllEqual(result, 56) 3939 3940 def testPrettyPrintedSignature(self): 3941 3942 @def_function.function 3943 def func(x, kangaroo=None, octopus=7): 3944 del octopus, kangaroo 3945 return x 3946 3947 scalar = constant_op.constant(5) 3948 vector = constant_op.constant([10, 10, 20]) 3949 ragged = ragged_factory_ops.constant([[10, 20], [40]]) 3950 3951 c1 = func.get_concrete_function(scalar, vector) 3952 c1_summary = r'func\(x, kangaroo, octopus=7\)' 3953 c1_details = (r' Args:\n' 3954 r' x: int32 Tensor, shape=\(\)\n' 3955 r' kangaroo: int32 Tensor, shape=\(3,\)\n' 3956 r' Returns:\n' 3957 r' int32 Tensor, shape=\(\)') 3958 self.assertRegex(c1.pretty_printed_signature(verbose=False), c1_summary) 3959 self.assertRegex( 3960 c1.pretty_printed_signature(verbose=True), 3961 c1_summary + '\n' + c1_details) 3962 self.assertRegex( 3963 repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>') 3964 self.assertRegex( 3965 str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details)) 3966 3967 c2 = func.get_concrete_function(scalar, ragged, 3) 3968 c2_summary = r'func\(x, kangaroo, octopus=3\)' 3969 c2_details = (r' Args:\n' 3970 r' x: int32 Tensor, shape=\(\)\n' 3971 r' kangaroo: RaggedTensorSpec\(.*\)\n' 3972 r' Returns:\n' 3973 r' int32 Tensor, shape=\(\)') 3974 self.assertRegex(c2.pretty_printed_signature(), 3975 c2_summary + '\n' + c2_details) 3976 3977 c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]}) 3978 c3_summary = r'func\(x, kangaroo=None, octopus=7\)' 3979 c3_details = (r' Args:\n' 3980 r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n" 3981 r' <1>: int32 Tensor, shape=\(\)\n' 3982 r' <2>: RaggedTensorSpec\(.*\)\n' 3983 r' <3>: RaggedTensorSpec\(.*\)\n' 3984 r' Returns:\n' 3985 r" {'a': <1>, 'b': \[<2>, <3>\]}\n" 3986 r' <1>: int32 Tensor, shape=\(\)\n' 3987 r' <2>: RaggedTensorSpec\(.*\)\n' 3988 r' <3>: RaggedTensorSpec\(.*\)') 3989 3990 # python 3.5 does not gurantee deterministic iteration of dict contents 3991 # which can lead mismatch on pretty_printed_signature output for "Args" 3992 if sys.version_info >= (3, 6): 3993 self.assertRegex(c3.pretty_printed_signature(), 3994 c3_summary + '\n' + c3_details) 3995 3996 # pylint: disable=keyword-arg-before-vararg 3997 @def_function.function 3998 def func2(x, y=3, *args, **kwargs): 3999 return (x, y, args, kwargs) 4000 4001 c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar) 4002 c4_summary = 'func2(x, y=4, <arg3>=5, *, a)' 4003 self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary) 4004 4005 c5 = func2.get_concrete_function(8, vector) 4006 c5_summary = 'func2(x=8, y)' 4007 self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary) 4008 4009 def testPrettyPrintedExplicitSignatureWithKeywordArg(self): # b/159639913 4010 4011 @def_function.function(input_signature=[tensor_spec.TensorSpec(None)]) 4012 def fn(a, b=1): 4013 return a + b 4014 4015 concrete_fn = fn.get_concrete_function() 4016 self.assertEqual(concrete_fn.pretty_printed_signature(False), 'fn(a)') 4017 self.assertEqual( 4018 concrete_fn.pretty_printed_signature(True), 'fn(a)\n' 4019 ' Args:\n' 4020 ' a: float32 Tensor, shape=<unknown>\n' 4021 ' Returns:\n' 4022 ' float32 Tensor, shape=<unknown>') 4023 4024 @test_util.run_in_graph_and_eager_modes 4025 def testIndexedSlicesAsGradientsForConcreteFunctions(self): 4026 4027 @def_function.function 4028 def summing_rnn(inputs): 4029 return math_ops.reduce_sum(inputs, axis=1) 4030 4031 @def_function.function 4032 def gradients(inputs): 4033 with backprop.GradientTape() as tape: 4034 tape.watch(inputs) 4035 hidden = summing_rnn(inputs) 4036 hidden = array_ops.gather(hidden, constant_op.constant([0])) 4037 loss = math_ops.reduce_mean(hidden) 4038 return tape.gradient(loss, inputs) 4039 4040 gradients(constant_op.constant([[[1.0], [2.0]]])) # No error is raised 4041 4042 def testFollowTypeHintsTraceBasic(self): 4043 trace_count = [0] 4044 4045 def func(x: ops.Tensor): 4046 trace_count[0] += 1 4047 return x 4048 4049 enabled = def_function.function(func, experimental_follow_type_hints=True) 4050 disabled = def_function.function(func, experimental_follow_type_hints=False) 4051 4052 enabled(1) # Initial call gets traced 4053 enabled(2) 4054 enabled(3) 4055 self.assertEqual(trace_count[0], 1) 4056 4057 trace_count = [0] 4058 disabled(1) 4059 disabled(2) # Retrace 4060 disabled(3) # Retrace 4061 self.assertEqual(trace_count[0], 3) 4062 4063 def testFollowTypeHintsTraceWithArgs(self): 4064 trace_count = [0] 4065 4066 def func(*args: ops.Tensor): 4067 trace_count[0] += 1 4068 return args 4069 4070 enabled = def_function.function(func, experimental_follow_type_hints=True) 4071 disabled = def_function.function(func, experimental_follow_type_hints=False) 4072 4073 args = ( 4074 'abc', 4075 'def', 4076 ) * 20 4077 args2 = ( 4078 'def', 4079 'abc', 4080 ) * 20 4081 4082 enabled(args) 4083 enabled(args2) 4084 self.assertEqual(trace_count[0], 1) 4085 4086 trace_count = [0] 4087 disabled(args) 4088 disabled(args2) # Retrace 4089 self.assertEqual(trace_count[0], 2) 4090 4091 def testFollowTypeHintsTraceWithKwargs(self): 4092 trace_count = [0] 4093 4094 def func(t: ops.Tensor, **kwargs: ops.Tensor): 4095 del kwargs 4096 trace_count[0] += 1 4097 return t 4098 4099 enabled = def_function.function(func, experimental_follow_type_hints=True) 4100 disabled = def_function.function(func, experimental_follow_type_hints=False) 4101 4102 enabled(1, x=1, y=1.0, z='one') 4103 enabled(2, x=2, y=2.0, z='two') 4104 self.assertEqual(trace_count[0], 1) 4105 4106 trace_count = [0] 4107 disabled(1, x=1, y=1.0, z='one') 4108 disabled(2, x=2, y=2.0, z='two') # Retrace 4109 self.assertEqual(trace_count[0], 2) 4110 4111 def testFollowTypeHintsTraceWithMultipleInputTypes(self): 4112 trace_count = [0] 4113 4114 def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor): 4115 del args, kwargs 4116 trace_count[0] += 1 4117 return t 4118 4119 enabled = def_function.function(func, experimental_follow_type_hints=True) 4120 disabled = def_function.function(func, experimental_follow_type_hints=False) 4121 4122 enabled(1, constant_op.constant(1), 'str', x=4.0) 4123 enabled(2, constant_op.constant(2), 'str2', x=5.0) 4124 self.assertEqual(trace_count[0], 1) 4125 4126 trace_count = [0] 4127 disabled(1, constant_op.constant(1), 'str', x=4.0) 4128 disabled(2, constant_op.constant(2), 'str2', x=5.0) # Retrace 4129 self.assertEqual(trace_count[0], 2) 4130 4131 def testFollowTypeHintsTraceWithOnlyArgNamed(self): 4132 trace_count = [0] 4133 4134 def func(t: ops.Tensor, i: int = 1, **kwargs): # pylint: disable=bad-whitespace 4135 del i, kwargs 4136 trace_count[0] += 1 4137 return t 4138 4139 enabled = def_function.function(func, experimental_follow_type_hints=True) 4140 4141 enabled(1, 3, x=4.0, y='str') 4142 enabled(2, 4, x=4.0, y='str') # Retrace 4143 self.assertEqual(trace_count[0], 2) 4144 4145 def testFollowTypeHintsTraceWithNotAllNamed(self): 4146 trace_count = [0] 4147 4148 def func(x, y: ops.Tensor, z: int): 4149 del y, z 4150 trace_count[0] += 1 4151 return x 4152 4153 enabled = def_function.function(func, experimental_follow_type_hints=True) 4154 4155 enabled(1, 2, 3) 4156 enabled(1, 20, 3) # No retrace - change in ops.Tensor typed arg 4157 enabled(2, 2, 3) # Retrace - change in untyped arg 4158 enabled(2, 2, 4) # Retrace - change in typed arg 4159 self.assertEqual(trace_count[0], 3) 4160 4161 def testFollowTypeHintsTraceWithOnlyArgsNamed(self): 4162 trace_count = [0] 4163 4164 def func(x, y, *args: ops.Tensor): 4165 del y, args 4166 trace_count[0] += 1 4167 return x 4168 4169 enabled = def_function.function(func, experimental_follow_type_hints=True) 4170 4171 enabled(1, 20, 3, 4, 5, 6) 4172 enabled(1, 20, 3, 4, 5, 60) # No retrace - change in *args 4173 enabled(1, 30, 7, 8, 9, 10) # Retrace - change in args 4174 self.assertEqual(trace_count[0], 2) 4175 4176 def testFollowTypeHintsTraceWithOnlyKwargsNamed(self): 4177 trace_count = [0] 4178 4179 def func(x, y, *args, **kwargs: ops.Tensor): 4180 del y, args, kwargs 4181 trace_count[0] += 1 4182 return x 4183 4184 enabled = def_function.function(func, experimental_follow_type_hints=True) 4185 4186 enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) 4187 enabled( 4188 1, 2, 3, 4, 5, 6, a=1.5, b=2.5, 4189 c=3.5) # No retrace - change in **kwargs 4190 enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) # Retrace - change in args 4191 enabled( 4192 1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0) # Retrace - change in *args 4193 self.assertEqual(trace_count[0], 3) 4194 4195 def testFollowTypeHintsTraceWithArgsEquals(self): 4196 trace_count = [0] 4197 4198 def func( 4199 x: ops.Tensor = 0, # pylint:disable=bad-whitespace 4200 y: int = 1, # pylint:disable=bad-whitespace 4201 **kwargs: ops.Tensor): 4202 del y, kwargs 4203 trace_count[0] += 1 4204 return x 4205 4206 enabled = def_function.function(func, experimental_follow_type_hints=True) 4207 4208 enabled(x=1, y=2, z=3) 4209 enabled(x=1, y=3, z=3) # Retrace - change in args 4210 enabled(x=2, y=2, z=4) # No retrace - change in args and **kwargs 4211 enabled(x=2, y=2, z=4, u=5) # Retrace - change in **kwargs 4212 self.assertEqual(trace_count[0], 3) 4213 4214 def testFollowTypeHintsWithTensorSpec(self): 4215 def func(x: ops.Tensor, y): 4216 return x + y 4217 v = def_function.function(experimental_follow_type_hints=True)(func) 4218 v = v.get_concrete_function( 4219 tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32), 3) 4220 x = v(constant_op.constant(1.), 3) 4221 self.assertEqual(x.numpy(), 4.) 4222 4223 def testFollowTypeHintsTraceWithKwArgsAndNoVarKws(self): 4224 trace_count = [0] 4225 4226 def func(a: int, b: ops.Tensor, 4227 x: ops.Tensor = 0, y: int = 1): 4228 del a, b, y 4229 trace_count[0] += 1 4230 return x 4231 4232 enabled = def_function.function(func, experimental_follow_type_hints=True) 4233 4234 enabled(0, 0, x=1, y=2) 4235 enabled(0, 0, x=2, y=2,) # No retrace, since only tensor changed 4236 self.assertEqual(trace_count[0], 1) 4237 4238 # Pass args as keyword args. 4239 enabled(a=0, b=0, x=2, y=2,) # No retrace, args are the same 4240 self.assertEqual(trace_count[0], 1) 4241 4242 enabled(a=1, b=0, x=2, y=2,) # Retrace, since non-tensor arg changed 4243 self.assertEqual(trace_count[0], 2) 4244 4245 enabled(a=1, b=2, x=2, y=2) # No retrace, since only tensor changed 4246 self.assertEqual(trace_count[0], 2) 4247 4248 trace_count[0] = 0 4249 disabled = def_function.function(func, experimental_follow_type_hints=False) 4250 disabled(0, 0, x=1, y=2) 4251 disabled(0, 0, x=2, y=2,) # Retrace 4252 self.assertEqual(trace_count[0], 2) 4253 4254 def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self): 4255 trace_count = [0] 4256 4257 def func(x, y, **kwargs: ops.Tensor): 4258 del y, kwargs 4259 trace_count[0] += 1 4260 return x 4261 4262 enabled = def_function.function(func, experimental_follow_type_hints=True) 4263 4264 enabled(x=1, y=2, z=3) 4265 enabled(x=1, y=3, z=3) # Retrace 4266 enabled(x=1, y=2, z=4) # No retrace 4267 enabled(x=2, y=2, z=4) # Retrace 4268 enabled(x=2, y=2, z=4, u=5) # Retrace 4269 self.assertEqual(trace_count[0], 4) 4270 4271 def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self): 4272 trace_count = [0] 4273 4274 def func(x: ops.Tensor, y: int, **kwargs): 4275 del y, kwargs 4276 trace_count[0] += 1 4277 return x 4278 4279 enabled = def_function.function(func, experimental_follow_type_hints=True) 4280 4281 enabled(x=1, y=2, z=3) 4282 enabled(x=1, y=3, z=3) # Retrace 4283 enabled(x=1, y=2, z=4) # Retrace 4284 enabled(x=2, y=2, z=3) # No retrace 4285 enabled(x=2, y=2, z=4, u=5) # Retrace 4286 self.assertEqual(trace_count[0], 4) 4287 4288 def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self): 4289 trace_count = [0] 4290 4291 def func(*, a: ops.Tensor = None, b=1): # pylint: disable=bad-whitespace 4292 del b 4293 trace_count[0] += 1 4294 return a 4295 4296 enabled = def_function.function(func, experimental_follow_type_hints=True) 4297 4298 enabled(a=1, b=2) 4299 enabled(a=2, b=2) # No retrace 4300 enabled(a=1, b=1) # Retrace 4301 self.assertEqual(trace_count[0], 2) 4302 4303 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self): 4304 trace_count = [0] 4305 4306 def func(arg: ops.Tensor, *args, kwonly, **kwargs): 4307 del args, kwonly, kwargs 4308 trace_count[0] += 1 4309 return arg 4310 4311 enabled = def_function.function(func, experimental_follow_type_hints=True) 4312 4313 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4314 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4315 enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4316 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4317 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace 4318 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace 4319 self.assertEqual(trace_count[0], 4) 4320 4321 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self): 4322 trace_count = [0] 4323 4324 def func(arg, *args: ops.Tensor, kwonly, **kwargs): 4325 del args, kwonly, kwargs 4326 trace_count[0] += 1 4327 return arg 4328 4329 enabled = def_function.function(func, experimental_follow_type_hints=True) 4330 4331 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4332 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4333 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4334 enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7) # No retrace 4335 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace 4336 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace 4337 self.assertEqual(trace_count[0], 4) 4338 4339 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self): 4340 trace_count = [0] 4341 4342 def func(arg, *args, kwonly: ops.Tensor, **kwargs): 4343 del args, kwonly, kwargs 4344 trace_count[0] += 1 4345 return arg 4346 4347 enabled = def_function.function(func, experimental_follow_type_hints=True) 4348 4349 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4350 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4351 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4352 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # No retrace 4353 enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7) # No retrace 4354 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace 4355 self.assertEqual(trace_count[0], 4) 4356 4357 def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self): 4358 trace_count = [0] 4359 4360 def func(arg, *args, kwonly, **kwargs: ops.Tensor): 4361 del args, kwonly, kwargs 4362 trace_count[0] += 1 4363 return arg 4364 4365 enabled = def_function.function(func, experimental_follow_type_hints=True) 4366 4367 enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) 4368 enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4369 enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace 4370 enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace 4371 enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # No retrace 4372 enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace 4373 self.assertEqual(trace_count[0], 4) 4374 4375 def testWithExtraWrapper(self): 4376 4377 class Foo(module.Module): 4378 4379 def __init__(self): 4380 super().__init__() 4381 self.var = None 4382 4383 @def_function.function 4384 @dummy_tf_decorator 4385 def add(self, x, y, z=1): 4386 if self.var is None: 4387 return x + y + z 4388 4389 foo = Foo() 4390 self.assertEqual(foo.add(2, 3).numpy(), 6) 4391 4392 @parameterized.parameters([(def_function.function, dummy_tf_decorator), 4393 (dummy_tf_decorator, def_function.function), 4394 (def_function.function, def_function.function)]) 4395 def testWithExtraWrapperRedundantArgs(self, decorator1, decorator2): 4396 4397 class Foo(module.Module): 4398 4399 def __init__(self): 4400 super().__init__() 4401 self.var = None 4402 4403 @decorator1 4404 @decorator2 4405 def add1(self, x, y): 4406 if self.var is None: 4407 return x + y 4408 4409 foo = Foo() 4410 with self.assertRaisesRegex(TypeError, 'got two values for argument'): 4411 foo.add1(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter 4412 4413 def testWithExtraWrapperMissingArgs(self): 4414 4415 class Foo(module.Module): 4416 4417 def __init__(self): 4418 super().__init__() 4419 self.var = None 4420 4421 @def_function.function 4422 @dummy_tf_decorator 4423 def add1(self, x, y): 4424 if self.var is None: 4425 return x + y 4426 4427 @def_function.function 4428 @dummy_tf_decorator 4429 def add2(self, x, y): 4430 if self.var is None: 4431 return x + y 4432 4433 @def_function.function 4434 @def_function.function 4435 def add3(self, x, y): 4436 if self.var is None: 4437 return x + y 4438 4439 foo = Foo() 4440 with self.assertRaisesRegex( 4441 TypeError, 'missing 1 required positional argument: \'y\''): 4442 foo.add1(2) # pylint: disable=no-value-for-parameter 4443 4444 with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'): 4445 foo.add1(y=2) # pylint: disable=no-value-for-parameter 4446 4447 with self.assertRaisesRegex( 4448 TypeError, 'missing 1 required positional argument: \'y\''): 4449 foo.add2(2) # pylint: disable=no-value-for-parameter 4450 4451 with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'): 4452 foo.add2(y=2) # pylint: disable=no-value-for-parameter 4453 4454 with self.assertRaisesRegex( 4455 TypeError, 'missing 1 required positional argument: \'y\''): 4456 foo.add3(2) # pylint: disable=no-value-for-parameter 4457 4458 with self.assertRaisesRegex(TypeError, 'missing 1 required argument: x'): 4459 foo.add3(y=2) # pylint: disable=no-value-for-parameter 4460 4461 def testMissingArgsTfFunctionedMethod(self): 4462 4463 class A(object): 4464 4465 def func(self, position_arg1, position_arg2): 4466 return position_arg1, position_arg2 4467 4468 @def_function.function 4469 def decorated_method(self, position_arg1, position_arg2): 4470 return position_arg1, position_arg2 4471 4472 a_instance = A() 4473 tf_method_pos = def_function.function(a_instance.func) 4474 with self.assertRaisesRegex( 4475 TypeError, '.* missing 1 required argument: position_arg1'): 4476 tf_method_pos(position_arg2='foo') 4477 4478 # tf.function-decorated instance methods need to be tested because of 4479 # the __get__ method implementation. 4480 tf_func_decorated_method = def_function.function( 4481 a_instance.decorated_method) 4482 tf_func_decorated_method(position_arg1='foo', position_arg2='bar') 4483 with self.assertRaisesRegex( 4484 TypeError, '.* missing 1 required argument: position_arg1'): 4485 tf_func_decorated_method(position_arg2='bar') 4486 4487 def testMissingArgsTfFunctionedObject(self): 4488 4489 class A(object): 4490 4491 def __call__(self, position_arg1, position_arg2): 4492 return position_arg1, position_arg2 4493 4494 a_instance = A() 4495 4496 # A tf.function-decorated callable object needs to be tested because of 4497 # the special inspect results. 4498 tf_func_obj = def_function.function(a_instance) 4499 tf_func_obj(position_arg1=1, position_arg2=2) 4500 with self.assertRaisesRegex( 4501 TypeError, '.* missing 1 required argument: position_arg1'): 4502 tf_func_obj(position_arg2='bar') 4503 4504 def testMissingArgsTfFunctionedFunctions(self): 4505 4506 def func_pos(position_arg1, position_arg2): 4507 return position_arg1, position_arg2 4508 4509 def func_with_default(position_arg, named_arg=None): 4510 return position_arg, named_arg 4511 4512 def func_pos_3args(position_arg1, position_arg2, position_arg3): 4513 return position_arg1, position_arg2, position_arg3 4514 4515 tf_func_pos = def_function.function(func_pos) 4516 with self.assertRaisesRegex( 4517 TypeError, '.* missing 1 required argument: position_arg1'): 4518 tf_func_pos(position_arg2='foo') 4519 4520 tf_func_with_default = def_function.function(func_with_default) 4521 tf_func_with_default(position_arg='bar') 4522 with self.assertRaisesRegex(TypeError, 4523 '.* missing 1 required argument: position_arg'): 4524 tf_func_with_default(named_arg='foo') 4525 4526 tf_func_pos_3args = def_function.function(func_pos_3args) 4527 with self.assertRaisesRegex( 4528 TypeError, 4529 '.* missing required arguments: position_arg1, position_arg3'): 4530 tf_func_pos_3args(position_arg2='foo') 4531 4532 def testShapeInferencePropagateConstNestedStack(self): 4533 4534 @def_function.function(input_signature=[ 4535 tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), 4536 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4537 ]) 4538 def f(x, s): 4539 old_shape = array_ops.shape(x) 4540 new_shape = array_ops.stack([old_shape[0], s], axis=0) 4541 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4542 return y 4543 4544 @def_function.function(input_signature=[ 4545 tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) 4546 ]) 4547 def g(x): 4548 y = f(x, s=5) 4549 assert y.shape.as_list() == [3, 5], y.shape.as_list() 4550 return y 4551 4552 self.assertAllEqual( 4553 g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) 4554 4555 def testShapeInferencePropagateConstNestedUnstackStack(self): 4556 4557 @def_function.function(input_signature=[ 4558 tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), 4559 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4560 ]) 4561 def f(x, s): 4562 s0, _ = array_ops.unstack(array_ops.shape(x), axis=0) 4563 new_shape = array_ops.stack([s0, s], axis=0) 4564 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4565 return y 4566 4567 @def_function.function(input_signature=[ 4568 tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) 4569 ]) 4570 def g(x): 4571 y = f(x, s=5) 4572 assert y.shape.as_list() == [3, 5], y.shape.as_list() 4573 return y 4574 4575 self.assertAllEqual( 4576 g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) 4577 4578 def testShapeInferencePropagateConstNestedConcat(self): 4579 4580 @def_function.function(input_signature=[ 4581 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4582 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4583 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4584 ]) 4585 def f(d1, d2, d3): 4586 new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) 4587 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4588 return y 4589 4590 @def_function.function() 4591 def g(): 4592 y = f(1, 2, 3) 4593 assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() 4594 return y 4595 4596 self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) 4597 4598 def testShapeInferencePropagateConstDoubleNested(self): 4599 4600 @def_function.function(input_signature=[ 4601 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4602 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4603 tensor_spec.TensorSpec((), dtype=dtypes.int32), 4604 ]) 4605 def f(d1, d2, d3): 4606 new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) 4607 y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) 4608 return y 4609 4610 @def_function.function() 4611 def g(): 4612 y = def_function.function(f)(1, 2, 3) 4613 assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() 4614 return y 4615 4616 self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) 4617 4618 @test_util.run_v2_only 4619 def testControlDependencyAfterInline(self): 4620 v = variables.Variable(0.) 4621 4622 @def_function.function 4623 def assign(): 4624 return v.assign(1.) 4625 4626 @def_function.function 4627 def assign_add(): 4628 return v.assign_add(1.) 4629 4630 @def_function.function 4631 def f(): 4632 check_ops.assert_equal_v2(assign(), 1.) 4633 check_ops.assert_equal_v2(assign_add(), 2.) 4634 4635 # We don't have a way to inspect the inlined graph in Python, so we run it 4636 # multiple times to have more confidence the dependency is correct. 4637 for _ in range(30): 4638 f() 4639 4640 @test_util.run_v2_only 4641 def testReadInFuncWriteOutside(self): 4642 # Run many times since we are testing for a potential race condition. 4643 for _ in range(30): 4644 # pylint: disable=cell-var-from-loop 4645 v = variables.Variable(1.) 4646 4647 @def_function.function 4648 def add_one(): 4649 return v + 1. 4650 4651 @def_function.function 4652 def get_v_plus_one(): 4653 v_plus_one = add_one() 4654 v.assign_add(2.0) 4655 return v_plus_one 4656 4657 self.assertAllEqual(get_v_plus_one(), 2.0) 4658 4659 4660class MultiDeviceTest(test.TestCase, parameterized.TestCase): 4661 4662 @test_util.run_gpu_only 4663 def testMultiDeviceOutput(self): 4664 """Tests that functions can produce outputs on multiple devices.""" 4665 @function.defun 4666 def func(a, b, transpose_a): 4667 with ops.device('/device:CPU:0'): 4668 m1 = math_ops.matmul(a, b, transpose_a=transpose_a) 4669 with ops.device('/device:GPU:0'): 4670 m2 = math_ops.matmul(a, b, transpose_a=transpose_a) 4671 return m1, m2 4672 4673 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 4674 m1, m2 = func(t, t, transpose_a=True) 4675 self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]]) 4676 self.assertRegex(m1.backing_device, 'CPU') 4677 self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]]) 4678 self.assertRegex(m2.backing_device, 'GPU') 4679 4680 @test_util.run_gpu_only 4681 def testEmptyBody(self): 4682 @function.defun 4683 def func(a, b): 4684 return b, a 4685 4686 with ops.device('/device:CPU:0'): 4687 a = array_ops.identity(3.0) 4688 with ops.device('/device:GPU:0'): 4689 b = array_ops.identity(5.0) 4690 4691 m1, m2 = func(a, b) 4692 self.assertAllEqual(m1.numpy(), 5.0) 4693 self.assertRegex(m1.backing_device, 'GPU') 4694 self.assertAllEqual(m2.numpy(), 3.0) 4695 self.assertRegex(m2.backing_device, 'CPU') 4696 4697 @test_util.run_gpu_only 4698 def testMultiDeviceInt32(self): 4699 """Tests that multi-device functions can take and output INT32s. 4700 4701 When an INT32 device tensor is fed into a function, it is copied to CPU 4702 by the eager runtime. The function sees all INT32 inputs on CPU. 4703 4704 We set allocator attribute 'on_host' for INT32 outputs. They can be 4705 partitioned into the GPU component function, but will be allocated on 4706 CPU nevertheless. 4707 4708 There is experimental support for `ints_on_device` in 4709 FunctionLibraryRuntime now. We can try that. 4710 4711 """ 4712 with ops.device('/device:CPU:0'): 4713 int_cpu = constant_op.constant(3, dtype=dtypes.int32) 4714 resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32) 4715 with ops.device('/device:GPU:0'): 4716 int_gpu = constant_op.constant(7, dtype=dtypes.int32) 4717 4718 @function.defun 4719 def func(int_cpu, resource, int_gpu): 4720 with ops.device('/device:CPU:0'): 4721 m1 = int_cpu * resource + int_gpu 4722 with ops.device('/device:GPU:0'): 4723 # This computation will happen on GPU but m2 will be copied to CPU. 4724 m2 = int_gpu * resource + int_cpu + 1 4725 return m1, m2 4726 4727 m1, m2 = func(int_cpu, resource, int_gpu) 4728 self.assertAllEqual(m1.numpy(), 22) 4729 self.assertRegex(m1.backing_device, 'CPU') 4730 self.assertAllEqual(m2.numpy(), 39) 4731 self.assertRegex(m2.backing_device, 'CPU') 4732 4733 # flip arguments 4734 m1, m2 = func(int_gpu, resource, int_cpu) 4735 self.assertAllEqual(m1.numpy(), 38) 4736 self.assertRegex(m1.backing_device, 'CPU') 4737 self.assertAllEqual(m2.numpy(), 23) 4738 self.assertRegex(m2.backing_device, 'CPU') 4739 4740 @test_util.run_gpu_only 4741 def testMultiDeviceColocateWith(self): 4742 """Tests that function's outputs respect colocation constraints.""" 4743 @function.defun 4744 def func(a, b): 4745 with ops.colocate_with(a): 4746 ra = 2 * a 4747 with ops.colocate_with(b): 4748 rb = 3 * b 4749 return ra, rb 4750 4751 devices = ['/device:CPU:0', '/device:GPU:0'] 4752 for dev1, dev2 in itertools.product(devices, devices): 4753 with ops.device(dev1): 4754 a = array_ops.identity(1.0) 4755 with ops.device(dev2): 4756 b = array_ops.identity(10.0) 4757 4758 ra, rb = func(a, b) 4759 self.assertEqual(ra.numpy(), 2.0) 4760 self.assertRegex(ra.backing_device, dev1) 4761 self.assertEqual(rb.numpy(), 30.0) 4762 self.assertRegex(rb.backing_device, dev2) 4763 4764 @test_util.run_gpu_only 4765 def testMultiDeviceResources(self): 4766 with ops.device('/device:CPU:0'): 4767 c1 = resource_variable_ops.ResourceVariable(2.0) 4768 c2 = resource_variable_ops.ResourceVariable(7.0) 4769 with ops.device('/device:GPU:0'): 4770 g1 = resource_variable_ops.ResourceVariable(3.0) 4771 g2 = resource_variable_ops.ResourceVariable(5.0) 4772 4773 @function.defun 4774 def func(resource1, resource2): 4775 with ops.device('/device:CPU:0'): 4776 result1 = resource1 * g2 4777 with ops.device('/device:GPU:0'): 4778 result2 = resource2 * c2 4779 return result1, result2 4780 4781 r1, r2 = func(c1, g1) 4782 self.assertEqual(r1.numpy(), 10.0) 4783 self.assertRegex(r1.backing_device, 'CPU') 4784 self.assertEqual(r2.numpy(), 21.0) 4785 self.assertRegex(r2.backing_device, 'GPU') 4786 4787 # Call with flipped inputs. Check that we look at resource's 4788 # device and reinstantiates the function when inputs' devices change. 4789 r1, r2 = func(g1, c1) 4790 self.assertEqual(r1.numpy(), 15.0) 4791 self.assertRegex(r1.backing_device, 'CPU') 4792 self.assertEqual(r2.numpy(), 14.0) 4793 self.assertRegex(r2.backing_device, 'GPU') 4794 4795 @test_util.run_gpu_only 4796 def testOutputResources(self): 4797 with ops.device('/device:CPU:0'): 4798 c1 = resource_variable_ops.ResourceVariable(2.0) 4799 with ops.device('/device:GPU:0'): 4800 g1 = resource_variable_ops.ResourceVariable(3.0) 4801 4802 @function.defun 4803 def func(resource1, resource2): 4804 with ops.device('/device:CPU:0'): 4805 result1 = resource1 * 5 4806 with ops.device('/device:GPU:0'): 4807 result2 = resource2 * 7 4808 return result1, resource1.handle, result2, resource2.handle 4809 4810 r1, res1, r2, res2 = func(c1, g1) 4811 self.assertEqual(r1.numpy(), 10.0) 4812 self.assertRegex(r1.backing_device, 'CPU') 4813 self.assertEqual(r2.numpy(), 21.0) 4814 self.assertRegex(r2.backing_device, 'GPU') 4815 4816 def check_handle(handle, expected_value): 4817 self.assertRegex(handle.backing_device, 'CPU') 4818 tensor = gen_resource_variable_ops.read_variable_op( 4819 handle, dtypes.float32) 4820 self.assertEqual(tensor.numpy(), expected_value) 4821 4822 # Check that handles returned from functions are on CPU and an op using 4823 # the resource handle is correctly placed on the device backing the 4824 # resource. 4825 check_handle(res1, 2.0) 4826 check_handle(res2, 3.0) 4827 4828 # Call with flipped inputs to make sure the same the function is 4829 # reinstantiated and eager runtime does not mess up the device assignment 4830 # for ops consuming handles returned from defuns. 4831 r1, res1, r2, res2 = func(g1, c1) 4832 self.assertEqual(r1.numpy(), 15.0) 4833 self.assertRegex(r1.backing_device, 'CPU') 4834 self.assertEqual(r2.numpy(), 14.0) 4835 self.assertRegex(r2.backing_device, 'GPU') 4836 check_handle(res1, 3.0) 4837 check_handle(res2, 2.0) 4838 4839 @test_util.run_gpu_only 4840 def testPassResourceThroughNestedFunctionCall(self): 4841 """Test passing GPU resource to noinline function call placed on CPU. 4842 4843 PartitionedCallOp must not enforce any particular device assignment for the 4844 resource output. Inner function marked as `_nospecialize`, so Grappler would 4845 not prune unused function output. 4846 """ 4847 4848 with ops.device('/device:GPU:0'): 4849 g1 = resource_variable_ops.ResourceVariable(3.0) 4850 4851 @function.defun_with_attributes(attributes={ 4852 '_noinline': True, 4853 '_nospecialize': True 4854 }) 4855 def inner(resource1): 4856 return resource1 * 2, resource1.handle 4857 4858 @function.defun 4859 def outer(resource1): 4860 with ops.device('/device:CPU:0'): 4861 r1, _ = inner(resource1) 4862 return r1 4863 4864 r1 = outer(g1) 4865 4866 self.assertEqual(r1.numpy(), 6.0) 4867 self.assertRegex(r1.backing_device, 'CPU') 4868 4869 @test_util.run_gpu_only 4870 def testReturnResourceFromNestedFunctionCall(self): 4871 """Test returning GPU resource from noinline function call placed on CPU. 4872 4873 When inferring output devices for the return value, do not set a device for 4874 returns of DT_RESOURCE data type based on the device assignment of the node 4875 that produced that resource. As an example function call placed on CPU can 4876 return resources on GPU. 4877 """ 4878 4879 with ops.device('/device:GPU:0'): 4880 g1 = resource_variable_ops.ResourceVariable(3.0) 4881 4882 @function.defun_with_attributes(attributes={ 4883 '_noinline': True 4884 }) 4885 def inner(resource1): 4886 resource1.assign_add(2.0) 4887 return resource1 * 2, resource1.handle 4888 4889 @function.defun 4890 def outer(resource1): 4891 with ops.device('/device:CPU:0'): 4892 r1, res1 = inner(resource1) 4893 return r1, res1 4894 4895 r1, res1 = outer(g1) 4896 4897 self.assertEqual(r1.numpy(), 10.0) 4898 self.assertRegex(r1.backing_device, 'CPU') 4899 4900 def check_handle(handle, expected_value): 4901 self.assertRegex(handle.backing_device, 'CPU') 4902 tensor = gen_resource_variable_ops.read_variable_op( 4903 handle, dtypes.float32) 4904 self.assertEqual(tensor.numpy(), expected_value) 4905 4906 # Check that handles returned from functions are on CPU and an op using 4907 # the resource handle is correctly placed on the device backing the 4908 # resource. 4909 check_handle(res1, 5.0) 4910 4911 @test_util.run_gpu_only 4912 def testComplexInputOutputDevicePattern(self): 4913 """Tests input/output mapping logic in partitioning.""" 4914 with ops.device('/device:CPU:0'): 4915 rc0 = resource_variable_ops.ResourceVariable(2.0) 4916 rc1 = resource_variable_ops.ResourceVariable(3.0) 4917 cc0 = array_ops.identity(5.0) 4918 cc1 = array_ops.identity(7.0) 4919 with ops.device('/device:GPU:0'): 4920 rg0 = resource_variable_ops.ResourceVariable(11.0) 4921 rg1 = resource_variable_ops.ResourceVariable(13.0) 4922 cg0 = array_ops.identity(17.0) 4923 cg1 = array_ops.identity(19.0) 4924 4925 # Make sure tensors are on expected devices. 4926 for tensor in [cc0, cc1]: 4927 self.assertRegex(tensor.backing_device, 'CPU:0') 4928 for tensor in [cg0, cg1]: 4929 self.assertRegex(tensor.backing_device, 'GPU:0') 4930 4931 @function.defun 4932 def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1): 4933 with ops.device('/device:CPU:0'): 4934 m1 = rc0 * cg0 4935 with ops.device('/device:GPU:0'): 4936 m2 = rg0 * cc0 4937 4938 with ops.device('/device:CPU:0'): 4939 r1 = 1000.0 * m2 + rc1 * cg1 4940 with ops.device('/device:GPU:0'): 4941 r2 = 1000.0 * m1 + rg1 * cc1 4942 4943 return r1, r2, m2, m1 4944 4945 r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1) 4946 self.assertRegex(m1.backing_device, 'CPU') 4947 self.assertRegex(r1.backing_device, 'CPU') 4948 self.assertRegex(m2.backing_device, 'GPU') 4949 self.assertRegex(r2.backing_device, 'GPU') 4950 self.assertEqual(m1.numpy(), 34.0) 4951 self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0) 4952 self.assertEqual(m2.numpy(), 55.0) 4953 self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0) 4954 4955 @test_util.run_gpu_only 4956 def testArgumentPruning(self): 4957 """Tests functions taking unnecessary arguments.""" 4958 with ops.device('/device:CPU:0'): 4959 c1 = constant_op.constant(5.0) 4960 c2 = constant_op.constant(7.0) 4961 4962 with ops.device('/device:GPU:0'): 4963 g1 = constant_op.constant(11.0) 4964 g2 = constant_op.constant(13.0) 4965 g3 = constant_op.constant(17.0) 4966 4967 @function.defun 4968 def func(g1, g2, c1, g3, c2): # pylint: disable=unused-argument 4969 # arguments g1 and g2 are unused and can be pruned by grappler. 4970 return c1 * g3 * c2 4971 4972 result = func(g1, g2, c1, g3, c2) 4973 self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0) 4974 4975 def testNestedCallWatchedVariables(self): 4976 4977 v = variables.Variable(4.) 4978 4979 @def_function.function 4980 def f(): 4981 return v ** 2. 4982 4983 with backprop.GradientTape() as tape: 4984 f() 4985 4986 self.assertEqual((v,), tape.watched_variables()) 4987 4988 @def_function.function 4989 def g(): 4990 return f() 4991 4992 with backprop.GradientTape() as tape: 4993 g() 4994 4995 self.assertEqual((v,), tape.watched_variables()) 4996 4997 # f() can rely on the variable being read during its trace. g() checks that 4998 # variables from a function which knows about them are recorded on the 4999 # tape. h() tests that functions forward knowledge of variables to callers. 5000 5001 @def_function.function 5002 def h(): 5003 return g() 5004 5005 with backprop.GradientTape() as tape: 5006 h() 5007 5008 self.assertEqual((v,), tape.watched_variables()) 5009 5010 def testDeferredCapture(self): 5011 value = 1.0 5012 5013 @def_function.function 5014 def lazy_capture(x): 5015 y = ops.get_default_graph().capture_call_time_value( 5016 lambda: value, tensor_spec.TensorSpec(None)) 5017 return x + y 5018 5019 self.assertAllEqual(lazy_capture(2.0), 3.0) 5020 # After changing the value of `value` the function call should return a 5021 # different result. 5022 value = 2.0 5023 self.assertAllEqual(lazy_capture(2.0), 4.0) 5024 5025 def testDeferredCaptureWithKey(self): 5026 value0 = 1.0 5027 value1 = 2.0 5028 5029 @def_function.function 5030 def lazy_capture(x): 5031 w = ops.get_default_graph().capture_call_time_value( 5032 lambda: value0, tensor_spec.TensorSpec(None), key=0) 5033 y = ops.get_default_graph().capture_call_time_value( 5034 lambda: value1, tensor_spec.TensorSpec(None), key=1) 5035 def bad_closure(): 5036 raise ValueError('Should not run') 5037 z = ops.get_default_graph().capture_call_time_value( 5038 bad_closure, tensor_spec.TensorSpec(None), key=1) 5039 return x + y + w + z 5040 5041 self.assertAllEqual(lazy_capture(2.0), 7.0) 5042 value0 = 2.0 5043 value1 = 3.0 5044 self.assertAllEqual(lazy_capture(2.0), 10.0) 5045 5046 def testDeferredCaptureTypeError(self): 5047 value = constant_op.constant(1.0) 5048 5049 @def_function.function 5050 def lazy_capture(x): 5051 y = ops.get_default_graph().capture_call_time_value( 5052 lambda: value, tensor_spec.TensorSpec(())) 5053 return x + y 5054 5055 self.assertAllEqual(lazy_capture(2.0), 3.0) 5056 5057 # dtype mismatch 5058 value = constant_op.constant(1) 5059 with self.assertRaisesRegex(ValueError, 'Value .* to a tensor with dtype'): 5060 lazy_capture(2.0) 5061 5062 # shape mismatch 5063 value = constant_op.constant([1.0]) 5064 with self.assertRaisesRegex(ValueError, 'Value .* shape'): 5065 lazy_capture(2.0) 5066 5067 def testDeferredCaptureReturnNestWithCompositeTensor(self): 5068 i_s = indexed_slices.IndexedSlices( 5069 constant_op.constant([1, 2]), 5070 constant_op.constant([0, 1], dtype=dtypes.int64), 5071 constant_op.constant([2])) 5072 r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]) 5073 s_t = sparse_tensor.SparseTensor( 5074 values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20]) 5075 5076 @def_function.function 5077 def lazy_capture(): 5078 y = ops.get_default_graph().capture_call_time_value( 5079 lambda: {'i': i_s, 't': (r_t, s_t)}, 5080 {'i': indexed_slices.IndexedSlicesSpec( 5081 dtype=dtypes.int32, dense_shape_dtype=dtypes.int32), 5082 't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32), 5083 sparse_tensor.SparseTensorSpec([None], dtypes.int32))}) 5084 return y['i'], y['t'] 5085 5086 i, (r, s) = lazy_capture() 5087 self.assertAllEqual(i_s.values, i.values) 5088 self.assertAllEqual(i_s.indices, i.indices) 5089 self.assertAllEqual(i_s.dense_shape, i.dense_shape) 5090 self.assertAllEqual(r_t, r) 5091 self.assertAllEqual(s_t.indices, s.indices) 5092 self.assertAllEqual(s_t.values, s.values) 5093 self.assertAllEqual(s_t.dense_shape, s.dense_shape) 5094 5095 def testDeferredCaptureCompositeTensorSpecTypeMismatch(self): 5096 value = indexed_slices.IndexedSlices( 5097 constant_op.constant([1, 2]), 5098 constant_op.constant([0, 1], dtype=dtypes.int64)) 5099 5100 @def_function.function 5101 def lazy_capture(): 5102 return ops.get_default_graph().capture_call_time_value( 5103 lambda: value, 5104 indexed_slices.IndexedSlicesSpec(dtype=dtypes.int32)) 5105 5106 # Type matches spec. 5107 lazy_capture() 5108 5109 # Extra dense shape component. 5110 value = indexed_slices.IndexedSlices( 5111 constant_op.constant([1, 2]), 5112 constant_op.constant([0, 1], dtype=dtypes.int64), 5113 constant_op.constant([2])) 5114 with self.assertRaises(ValueError): 5115 lazy_capture() 5116 5117 # Index dtype mismatch int32 vs. int64. 5118 value = indexed_slices.IndexedSlices( 5119 constant_op.constant([1, 2]), 5120 constant_op.constant([0, 1])) 5121 with self.assertRaises(ValueError): 5122 lazy_capture() 5123 5124 5125if __name__ == '__main__': 5126 ops.enable_eager_execution() 5127 test.main() 5128