1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import functools 22import itertools 23from multiprocessing.pool import ThreadPool 24import sys 25import weakref 26 27from absl.testing import parameterized 28import numpy 29 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.protobuf import rewriter_config_pb2 32from tensorflow.python import keras 33from tensorflow.python.eager import context 34from tensorflow.python.eager import def_function 35from tensorflow.python.eager import function 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import dtypes 38from tensorflow.python.framework import errors 39from tensorflow.python.framework import function as tf_function 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import random_seed 42from tensorflow.python.framework import tensor_shape 43from tensorflow.python.framework import tensor_spec 44from tensorflow.python.framework import test_ops 45from tensorflow.python.framework import test_util 46from tensorflow.python.keras.engine import training as keras_training 47from tensorflow.python.layers import convolutional 48from tensorflow.python.ops import array_ops 49from tensorflow.python.ops import check_ops 50from tensorflow.python.ops import clip_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import gen_functional_ops 53from tensorflow.python.ops import gen_random_ops 54from tensorflow.python.ops import gen_resource_variable_ops 55from tensorflow.python.ops import init_ops 56from tensorflow.python.ops import list_ops 57from tensorflow.python.ops import math_ops 58from tensorflow.python.ops import random_ops 59from tensorflow.python.ops import resource_variable_ops 60from tensorflow.python.ops import variable_scope 61from tensorflow.python.ops import variables 62from tensorflow.python.platform import test 63from tensorflow.python.training import training_ops 64from tensorflow.python.util import compat 65from tensorflow.python.util import nest 66from tensorflow.python.util import tf_inspect 67 68 69def total_function_cache(defined): 70 # pylint: disable=protected-access 71 return (set(defined._function_cache.primary) 72 | set(defined._function_cache.arg_relaxed)) 73 # pylint: enable=protected-access 74 75 76class MiniModel(keras_training.Model): 77 """Minimal model for mnist. 78 79 Useful for testing and debugging on slow TPU simulators. 80 """ 81 82 def __init__(self): 83 super(MiniModel, self).__init__(name='') 84 self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones', 85 bias_initializer='ones') 86 87 def call(self, inputs, training=True): 88 return self.fc(inputs) 89 90 91class DefunnedMiniModel(MiniModel): 92 93 @function.defun 94 def call(self, inputs, training=True): 95 return super(DefunnedMiniModel, self).call(inputs, training=training) 96 97 98class FunctionTest(test.TestCase, parameterized.TestCase): 99 100 def testBasic(self): 101 # TODO(b/121134877): Remove the autograph override. 102 matmul = def_function.function(math_ops.matmul, autograph=False) 103 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 104 sq = matmul(t, t, transpose_a=True) 105 sq2 = matmul(sq, t, transpose_a=True) 106 self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) 107 self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108]) 108 109 def testVariable(self): 110 v1 = variables.Variable(1.0) 111 add = def_function.function(lambda x, v: x + v1 + v) 112 v2 = variables.Variable(1.0) 113 x = constant_op.constant(1.0) 114 r = add(x, v2) 115 self.assertEqual(3.0, self.evaluate(r)) 116 117 def testExternalControlDependency(self): 118 with ops.Graph().as_default(), self.test_session(): 119 v = variables.Variable(1.0) 120 v.initializer.run() 121 122 op = v.assign_add(1.0) 123 124 @function.defun 125 def f(): 126 with ops.control_dependencies([op]): 127 return 1.0 128 129 self.evaluate(f()) 130 self.assertAllEqual(self.evaluate(v), 2.0) 131 132 def testInputShapeFunctionRelaxation(self): 133 unknown_dim = [False] 134 135 @function.defun 136 def func(a): 137 if a._shape_tuple()[0] is None: 138 unknown_dim[0] = True 139 return a + 1 140 141 func(constant_op.constant([])) 142 self.assertFalse(unknown_dim[0]) 143 self.assertLen(total_function_cache(func), 1) 144 145 func(constant_op.constant([1.0])) 146 self.assertFalse(unknown_dim[0]) 147 self.assertLen(total_function_cache(func), 2) 148 149 func(constant_op.constant([1.0, 2.0])) 150 self.assertTrue(unknown_dim[0]) 151 self.assertLen(total_function_cache(func), 2) 152 153 def testNestedInputShapeFunctionRelaxation(self): 154 unknown_dim = [False] 155 156 @function.defun 157 def func(a_, b_=None): 158 del a_ # Only used to check which cache is used. 159 self.assertEqual(b_[0]._shape_tuple(), ()) 160 if b_[1]._shape_tuple()[0] is None: 161 unknown_dim[0] = True 162 return b_[0] + 1 163 164 a = 'hi' 165 b0 = constant_op.constant(1.0) 166 func(a, b_=[b0, constant_op.constant([])]) 167 self.assertFalse(unknown_dim[0]) 168 self.assertLen(total_function_cache(func), 1) 169 170 func(a, b_=[b0, constant_op.constant([1.0])]) 171 self.assertFalse(unknown_dim[0]) 172 self.assertLen(total_function_cache(func), 2) 173 174 func(a, b_=[b0, constant_op.constant([1.0, 1.0])]) 175 self.assertTrue(unknown_dim[0]) 176 self.assertLen(total_function_cache(func), 2) 177 178 unknown_dim[0] = False 179 180 # Now do the same except with a new a which is not a tensor; this should 181 # change the cache key. 182 a = 'bye' 183 func(a, b_=[b0, constant_op.constant([])]) 184 self.assertFalse(unknown_dim[0]) 185 self.assertLen(total_function_cache(func), 3) 186 187 # Since we already marked a cache miss for a function with the same 188 # non-input signatures, here we will immediately start relaxing shapes. 189 func(a, b_=[b0, constant_op.constant([1.0])]) 190 self.assertTrue(unknown_dim[0]) 191 self.assertLen(total_function_cache(func), 3) 192 193 def testFunctionRelaxationLosesInnerDimWithKerasLayer(self): 194 layer = keras.layers.Dense(1) 195 fn = def_function.function()(layer) 196 197 with self.captureWritesToStream(sys.stderr) as printed: 198 fn(array_ops.ones((3, 2))) 199 self.assertNotIn('ValueError', printed.contents()) 200 with self.captureWritesToStream(sys.stderr) as printed: 201 # Use batch size 2 to trigger a second cache miss on the shape. 202 fn(array_ops.ones((2, 2))) 203 self.assertNotIn('ValueError', printed.contents()) 204 205 # Shape relaxation passes TensorShape([None, None]), which causes layer 206 # matmul to fail, due to incompatible dims. What would have been a graph 207 # build time error (layer would complain about the inner dim being 4). 208 with self.captureWritesToStream(sys.stderr) as printed: 209 with self.assertRaisesRegexp(errors.InvalidArgumentError, r'MatMul'): 210 fn(array_ops.ones((3, 4))) 211 212 def testNestedShapeFunctionRelaxation(self): 213 214 got_shape = [None] 215 216 # The inner function will go through shape relaxation because the shapes it 217 # receives will be [1], [2], [3], ... 218 @def_function.function 219 def bar(x_shape): 220 got_shape[0] = x_shape._shape_tuple() 221 return x_shape 222 223 # The outer function will not go through shape relaxation because the shapes 224 # it receives will be [1], [[1]], [[[1]]], ... 225 @def_function.function 226 def foo(ones): 227 return bar(array_ops.shape(ones)) 228 229 for rank in range(1, 6): 230 x_shape = self.evaluate(foo(array_ops.ones([1] * rank))) 231 self.assertAllEqual(x_shape, [1] * rank) 232 if rank < 3: 233 self.assertEqual(got_shape[0], (rank,)) 234 else: 235 self.assertEqual(got_shape[0], (None,)) 236 237 def testNoHash(self): 238 239 @def_function.function() 240 def f(_): 241 return 1.0 242 243 with self.assertRaisesRegexp(TypeError, 'set'): 244 f(set([])) 245 246 def testFuncName(self): 247 248 @function.defun_with_attributes(attributes={'func_name': 'multiply'}) 249 def add(x, y): 250 _ = x * y 251 return x + y 252 253 @function.defun 254 def add_2(x, y): 255 _ = x * y 256 return x + y 257 258 self.assertEqual(add._name, 'multiply') 259 self.assertEqual(add_2._name, 'add_2') 260 261 def testBasicGraphMode(self): 262 # TODO(b/121134877): Remove the autograph override. 263 matmul = def_function.function(math_ops.matmul, autograph=False) 264 265 @def_function.function 266 def sq(a): 267 return matmul(a, a) 268 269 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 270 out = sq(t) 271 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 272 273 def testNestedInputsGraphMode(self): 274 # TODO(b/121134877): Remove the autograph override. 275 matmul = def_function.function(math_ops.matmul, autograph=False) 276 277 pair = collections.namedtuple('pair', ['a', 'b']) 278 279 @def_function.function 280 def a_times_b(inputs): 281 return matmul(inputs.a['a'], inputs.b['b']) 282 283 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 284 285 out = a_times_b(pair({'a': t}, {'b': t})) 286 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 287 288 def testNestedOutputsGraphMode(self): 289 # TODO(b/121134877): Remove the autograph override. 290 matmul = def_function.function(math_ops.matmul, autograph=False) 291 292 pair = collections.namedtuple('pair', ['a', 'b']) 293 294 @def_function.function() 295 def pairs_mul(pair_a, pair_b): 296 return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b)) 297 298 a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]]) 299 b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]]) 300 301 out = pairs_mul(pair(a, b), pair(b, a)) 302 expected = pair(math_ops.matmul(a, b).numpy(), 303 math_ops.matmul(b, a).numpy()) 304 self.assertAllClose(out, expected) 305 306 def testGraphEagerIsolation(self): 307 308 @function.defun 309 def f(): 310 self.v = variables.Variable(1.0) 311 return self.v.read_value() 312 313 self.assertAllEqual(f(), 1.0) 314 315 with ops.Graph().as_default(): 316 self.assertEqual(f().shape, ()) 317 318 def testBasicGraphFunction(self): 319 # TODO(b/121134877): Remove the autograph override. 320 matmul = def_function.function(math_ops.matmul, autograph=False) 321 322 @def_function.function 323 def sq(a): 324 return matmul(a, a) 325 326 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 327 328 sq_op = sq.get_concrete_function(t) 329 self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) 330 out = sq_op(t) 331 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 332 333 def testInputSpecGraphFunction(self): 334 # TODO(b/121134877): Remove the autograph override. 335 matmul = def_function.function(math_ops.matmul, autograph=False) 336 337 @def_function.function 338 def sq(a): 339 return matmul(a, a) 340 341 sq_op = sq.get_concrete_function( 342 tensor_spec.TensorSpec((None, None), dtypes.float32)) 343 self.assertEqual([None, None], sq_op.output_shapes.as_list()) 344 345 t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 346 out1 = sq_op(t1) 347 self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy()) 348 349 t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 350 out2 = sq_op(t2) 351 self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy()) 352 353 def testNestedInputSpecGraphFunction(self): 354 # TODO(b/121134877): Remove the autograph override. 355 matmul = def_function.function(math_ops.matmul, autograph=False) 356 357 @def_function.function 358 def sq(mats): 359 ((a, b),) = mats 360 return matmul(a, b) 361 362 with self.assertRaisesRegexp(ValueError, "two arguments named 'mats'"): 363 sq.get_concrete_function( 364 [(tensor_spec.TensorSpec((None, None), dtypes.float32), 365 tensor_spec.TensorSpec((None, None), dtypes.float32))]) 366 sq_op = sq.get_concrete_function( 367 [(tensor_spec.TensorSpec((None, None), dtypes.float32, 368 name='first_mat'), 369 tensor_spec.TensorSpec((None, None), dtypes.float32, 370 name='second_mat'))]) 371 self.assertEqual([None, None], sq_op.output_shapes.as_list()) 372 373 t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 374 t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]]) 375 with self.assertRaisesRegexp( 376 TypeError, 'bound to Tensors within nested structures'): 377 sq_op(t1, t2) 378 out = sq_op(first_mat=t1, second_mat=t2) 379 self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy()) 380 381 def testExecutingStatelessDefunConcurrently(self): 382 383 @def_function.function 384 def stateless(x): 385 return math_ops.multiply(2.0, x) 386 387 pool = ThreadPool() 388 inputs = [constant_op.constant(1.0 * x) for x in range(100)] 389 outputs = [float(out) for out in pool.map(stateless, inputs)] 390 expected = [float(2.0 * x) for x in inputs] 391 self.assertSequenceEqual(outputs, expected) 392 393 def testExecutingManyStatelessDefunsConcurrently(self): 394 395 @def_function.function 396 def stateless(x): 397 del x 398 return math_ops.multiply(2.0, 2.0) 399 400 pool = ThreadPool() 401 # `pool.map` below instantiates 100 functions, one for each object. 402 outputs = [ 403 float(out) 404 for out in pool.map(stateless, [object() for _ in range(100)]) 405 ] 406 expected = [4.0] * 100 407 self.assertSequenceEqual(outputs, expected) 408 409 def testExecutingStatefulDefunConcurrently(self): 410 411 v = resource_variable_ops.ResourceVariable(1.0) 412 413 @def_function.function 414 def stateful(x): 415 v.assign(x) 416 417 pool = ThreadPool() 418 inputs = [constant_op.constant(0.0)] * 100 419 pool.map(stateful, inputs) 420 self.assertEqual(float(v.read_value()), 0.0) 421 422 def testExecutingManyStatefulDefunsConcurrently(self): 423 424 v = resource_variable_ops.ResourceVariable(1.0) 425 426 @def_function.function 427 def stateful(x): 428 del x 429 return v.assign(0.0) 430 431 pool = ThreadPool() 432 # `pool.map` below instantiates 100 functions, one for each object. 433 pool.map(stateful, [object() for _ in range(100)]) 434 self.assertEqual(float(v.read_value()), 0.0) 435 436 def disabled_testRandomSeed(self): 437 438 @def_function.function 439 def f(): 440 return random_ops.random_normal(()) 441 442 random_seed.set_random_seed(1) 443 x = f() 444 self.assertNotEqual(x, f()) 445 random_seed.set_random_seed(1) 446 self.assertAllEqual(f(), x) 447 448 def testNestedInputsGraphFunction(self): 449 # TODO(b/121134877): Remove the autograph override. 450 matmul = def_function.function(math_ops.matmul, autograph=False) 451 452 pair = collections.namedtuple('pair', ['a', 'b']) 453 454 @def_function.function 455 def a_times_b(inputs): 456 return matmul(inputs.a['a'], inputs.b['b']) 457 458 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 459 sq_op = a_times_b.get_concrete_function( 460 pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')), 461 dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b')))) 462 self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) 463 out = sq_op(a=t, b=t) 464 self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) 465 466 def testNestedOutputGraphFunction(self): 467 # TODO(b/121134877): Remove the autograph override. 468 matmul = def_function.function(math_ops.matmul, autograph=False) 469 470 @def_function.function 471 def sq(a): 472 return (matmul(a, a), {'b': constant_op.constant(1.0)}) 473 474 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 475 476 sq_op = sq.get_concrete_function(t) 477 self.assertEqual(sq_op.output_shapes, 478 (tensor_shape.TensorShape([2, 2]), 479 {'b': tensor_shape.TensorShape([])})) 480 self.assertEqual(sq_op.output_dtypes, 481 (dtypes.float32, {'b': dtypes.float32})) 482 (a, b) = sq_op(t) 483 self.assertAllEqual(a, math_ops.matmul(t, t).numpy()) 484 self.assertAllEqual(b['b'].numpy(), 1.0) 485 486 def testGraphFunctionNoneOutput(self): 487 @def_function.function 488 def fn(unused_a, unused_b): 489 return None 490 491 x = constant_op.constant(1) 492 fn_op = fn.get_concrete_function(x, x) 493 self.assertEqual(fn_op.output_dtypes, None) 494 self.assertEqual(fn_op.output_shapes, None) 495 self.assertAllEqual(fn_op(x, x), None) 496 497 def testDefunNumpyArraysConvertedToTensors(self): 498 499 def f(x): 500 self.assertIsInstance(x, ops.Tensor) 501 return x 502 503 x = random_ops.random_uniform([2, 2]).numpy() 504 defined = function.defun(f) 505 defined(x) 506 self.assertLen(total_function_cache(defined), 1) 507 508 x = random_ops.random_uniform([2, 2]).numpy() 509 defined(x) 510 # A NumPy array with different values but the same shape and dtype 511 # shouldn't trigger another function definition. 512 self.assertLen(total_function_cache(defined), 1) 513 514 # Test that the numpy array is properly an argument to the graph function. 515 self.assertEqual(1., defined(numpy.ones([])).numpy()) 516 self.assertEqual(0., defined(numpy.zeros([])).numpy()) 517 self.assertEqual(1., defined(array_ops.ones([])).numpy()) 518 self.assertEqual(0., defined(array_ops.zeros([])).numpy()) 519 520 def testDefunCapturedInt32(self): 521 x = constant_op.constant(1, dtype=dtypes.int32) 522 523 @def_function.function 524 def add_int32s(): 525 return x + x 526 527 self.assertEqual(2, int(add_int32s())) 528 529 def testDefunReadVariable(self): 530 v = resource_variable_ops.ResourceVariable(1.0) 531 532 @def_function.function 533 def f(): 534 return v.read_value() 535 536 self.assertEqual(1.0, float(f())) 537 538 def testDefunAssignAddVariable(self): 539 v = resource_variable_ops.ResourceVariable(1.0) 540 x = constant_op.constant(2.0) 541 542 @def_function.function 543 def test_assign_add(): 544 v.assign_add(x) 545 return v.read_value() 546 547 self.assertEqual(3.0, float(test_assign_add())) 548 549 @test_util.run_in_graph_and_eager_modes 550 def testTensorInitializationInFunctionRaisesError(self): 551 error_msg = ('Tensor-typed variable initializers must either be ' 552 'wrapped in an init_scope or callable.*') 553 554 @def_function.function 555 def tensor_init(): 556 with self.assertRaisesRegexp(ValueError, error_msg): 557 resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) 558 559 tensor_init() 560 561 @test_util.run_in_graph_and_eager_modes 562 def testCallableTensorInitializationInFunction(self): 563 564 @def_function.function 565 def tensor_init(): 566 self.v = resource_variable_ops.ResourceVariable( 567 lambda: constant_op.constant(2.0)) 568 return self.v.read_value() 569 570 value = tensor_init() 571 if not context.executing_eagerly(): 572 self.evaluate(variables.global_variables_initializer()) 573 self.assertEqual(self.evaluate(value), 2.0) 574 575 @test_util.also_run_as_tf_function 576 def testInitScopeTensorInitializationInFunction(self): 577 578 @def_function.function 579 def tensor_init(): 580 with ops.init_scope(): 581 const = constant_op.constant(2.0) 582 # Note: this variable bypasses tf.function's variable creation 583 # requirements by bypassing variable_creator_scope by using 584 # ResourceVariable instead of Variable. 585 self.v = resource_variable_ops.ResourceVariable(const) 586 return self.v.read_value() 587 588 value = tensor_init() 589 self.assertAllEqual(value, 2.0) 590 591 @test_util.run_in_graph_and_eager_modes 592 def testGetConcreteFunctionCreatesVariables(self): 593 594 v_holder = [] 595 596 @def_function.function 597 def tensor_init(): 598 if not v_holder: 599 v_holder.append(variables.Variable(5.)) 600 return v_holder[0].read_value() 601 602 concrete = tensor_init.get_concrete_function() 603 self.evaluate(variables.global_variables_initializer()) 604 self.assertAllEqual(5., self.evaluate(concrete())) 605 self.assertAllEqual(5., self.evaluate(tensor_init())) 606 607 def testFuncGraphCaptureByValue(self): 608 v = variables.Variable(1.0) 609 610 def trivial_function(): 611 return v.read_value() 612 613 graph_function = function.Function( 614 trivial_function, 'test', capture_by_value=True) 615 616 self.assertAllEqual(graph_function(), 1.0) 617 v.assign(2.0) 618 self.assertAllEqual(graph_function(), 1.0) 619 620 def testFuncGraphCaptureByValueNested(self): 621 v = variables.Variable(1.0) 622 623 def trivial_function(): 624 return control_flow_ops.cond( 625 array_ops.placeholder_with_default(True, ()), 626 v.read_value, v.read_value) 627 628 graph_function = function.Function( 629 trivial_function, 'test', capture_by_value=True) 630 631 self.assertAllEqual(graph_function(), 1.0) 632 v.assign(2.0) 633 self.assertAllEqual(graph_function(), 1.0) 634 635 def testDefunShapeInferenceWithCapturedResourceVariable(self): 636 v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) 637 638 def f(): 639 x = constant_op.constant([[1, 2], [3, 4]]) 640 out = math_ops.matmul(v, x) 641 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 642 # We do not return v directly since the tensor conversion function of 643 # ResourceVariable returns the read value and not the resource itself. 644 return v._handle 645 646 compiled = def_function.function(f) 647 var_handle = compiled() 648 self.assertEqual(var_handle.dtype, dtypes.resource) 649 self.assertEqual(var_handle.shape, tensor_shape.scalar()) 650 var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) 651 self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) 652 653 def testShapeInferenceForMoreSpecificInput(self): 654 self.skipTest('b/124219898') 655 656 def f(a): 657 return array_ops.reshape(a, [-1, 3]) 658 659 signature = [tensor_spec.TensorSpec(None, dtypes.float32)] 660 compiled = def_function.function(f, input_signature=signature) 661 662 with ops.Graph().as_default(): 663 inputs = array_ops.zeros([10, 10, 3]) 664 self.assertAllEqual(f(inputs).shape, compiled(inputs).shape) 665 666 def testFuncListAttr(self): 667 668 @function.defun 669 def test_function(val): 670 671 def fn1(): 672 return array_ops.ones([10]) 673 674 fn2 = lambda: array_ops.ones([10]) * 2 675 676 def fn3(x=2): 677 return array_ops.ones([10]) * x 678 fn3 = functools.partial(fn3, x=3) 679 680 return gen_functional_ops.case(val, [], [dtypes.float32], 681 [function.defun(f).get_concrete_function() 682 for f in (fn1, fn2, fn3)]) 683 684 ones = array_ops.ones([10]) 685 self.assertAllEqual([ones], test_function(0)) 686 self.assertAllEqual([ones * 2], test_function(1)) 687 self.assertAllEqual([ones * 3], test_function(2)) 688 self.assertAllEqual([ones * 3], test_function(22)) # default branch 689 690 @test_util.enable_control_flow_v2 691 def testVariableInLoopInFunction(self): 692 693 @function.defun 694 def test_function(): 695 696 def loop_test(_): 697 return False 698 699 def loop_body(_): 700 return variable_scope.get_variable('a', shape=()) 701 702 return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) 703 704 self.assertEqual(test_function().shape, []) 705 706 def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): 707 with context.graph_mode(): 708 v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) 709 710 def f(): 711 x = constant_op.constant([[1, 2], [3, 4]]) 712 out = math_ops.matmul(v, x) 713 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 714 # We do not return v directly since the tensor conversion function of 715 # ResourceVariable returns the read value and not the resource itself. 716 return v._handle 717 718 compiled = def_function.function(f) 719 var_handle = compiled() 720 self.assertEqual(var_handle.dtype, dtypes.resource) 721 self.assertEqual(var_handle.shape, tensor_shape.scalar()) 722 var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) 723 self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) 724 725 def testDefunShapeInferenceWithCapturedVariableInGraphMode(self): 726 with context.graph_mode(): 727 v = variables.Variable([[1, 2], [3, 4]]) 728 729 def f(): 730 x = constant_op.constant([[1, 2], [3, 4]]) 731 out = math_ops.matmul(v, x) 732 self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) 733 734 # Check that shape inference works while creating the defun 735 compiled = def_function.function(f) 736 compiled() 737 738 def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self): 739 with context.graph_mode(): 740 tensor_list = list_ops.empty_tensor_list( 741 element_dtype=dtypes.float32, 742 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 743 tensor_list = list_ops.tensor_list_push_back(tensor_list, 744 constant_op.constant(1.0)) 745 tensor_list = list_ops.tensor_list_push_back(tensor_list, 746 constant_op.constant(2.0)) 747 748 def f(): 749 tl, value = list_ops.tensor_list_pop_back( 750 tensor_list, element_dtype=dtypes.float32) 751 self.assertEqual(value.shape, tensor_shape.scalar()) 752 return tl 753 754 compiled = def_function.function(f) 755 output_tensor_list = compiled() 756 _, value = list_ops.tensor_list_pop_back( 757 output_tensor_list, element_dtype=dtypes.float32) 758 self.assertEqual(value.shape, tensor_shape.scalar()) 759 760 @test_util.run_in_graph_and_eager_modes 761 def testDefunForcesResourceVariables(self): 762 763 def variable_creator(): 764 self.v = variables.Variable(0.0) 765 return self.v.read_value() 766 767 self.v = None 768 defined = function.defun(variable_creator) 769 defined() # Create the variable. 770 self.assertIsInstance( 771 self.v, resource_variable_ops.ResourceVariable) 772 773 def testRunMetadata(self): 774 775 @def_function.function 776 def f(x): 777 return x * x 778 779 with ops.device('cpu:0'): 780 context.enable_run_metadata() 781 f(constant_op.constant(1.0)) 782 run_metadata = context.export_run_metadata() 783 context.disable_run_metadata() 784 step_stats = run_metadata.step_stats 785 self.assertNotEmpty(step_stats.dev_stats) 786 cpu_stats = step_stats.dev_stats[0] 787 self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', 788 cpu_stats.device) 789 # Testing for at least 2 because the function call should generate at most 790 # one entry in the step_stats; the ops inside function can generate 791 # arbitrarily many (placeholders, return identities, etc, might be included 792 # or not in the future, so shouldn't be tested for exactly. 793 self.assertGreaterEqual(len(cpu_stats.node_stats), 2) 794 self.assertLen(run_metadata.partition_graphs, 1) 795 796 def testGraphModeCaptureVariable(self): 797 with context.graph_mode(), self.cached_session(): 798 799 class HasAVar(object): 800 801 def __init__(self): 802 self.v = resource_variable_ops.ResourceVariable(1.0) 803 804 def call(self): 805 return self.v * 2 806 807 o = HasAVar() 808 self.evaluate(variables.global_variables_initializer()) 809 call = def_function.function(o.call) 810 op = call() 811 self.assertAllEqual(self.evaluate(op), 2.0) 812 813 def testGraphModeManyFunctions(self): 814 with ops.Graph().as_default(), self.cached_session(): 815 816 @def_function.function 817 def f(x): 818 return x * x 819 820 @def_function.function 821 def g(x): 822 return f(x) + 1 823 824 self.assertAllEqual(g(constant_op.constant(2.0)).eval(), 5.0) 825 826 def testDict(self): 827 828 @def_function.function 829 def f(x): 830 return {'name': x + 1} 831 832 self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0) 833 834 def testTensorConversionWithDefun(self): 835 836 @def_function.function 837 def f(x): 838 return math_ops.add(x, constant_op.constant(3)) 839 840 self.assertAllEqual(5, f(constant_op.constant(2))) 841 842 def testTensorConversionCall(self): 843 844 @def_function.function 845 def f(x): 846 return math_ops.add(x, constant_op.constant(3)) 847 848 @def_function.function 849 def g(x): 850 return f(f(x)) 851 852 self.assertAllEqual(8, g(constant_op.constant(2))) 853 854 def testCallShape(self): 855 856 @def_function.function 857 def f(x): 858 return x + 1 859 860 @def_function.function 861 def g(x): 862 x = f(x) 863 self.assertEqual(x.shape.as_list(), []) 864 return None 865 866 g(constant_op.constant(1.0)) 867 868 def testNestedDefunWithNoOutputAndTapedInput(self): 869 three = resource_variable_ops.ResourceVariable(3.0, name='v') 870 871 @def_function.function 872 def f(x): 873 # This function intentionally takes a taped variable as input, 874 # but does not return any values 875 math_ops.add(x, three) 876 877 @def_function.function 878 def g(x): 879 y = math_ops.add(x, three) 880 f(y) 881 882 g(three) 883 884 def testGatherResourceWithDefun(self): 885 with ops.device('cpu:0'): 886 v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 887 888 def sum_gather(): 889 return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) 890 891 defined = def_function.function(sum_gather) 892 self.assertAllEqual(sum_gather(), defined()) 893 894 def testReturningIndexedSlicesWithDefun(self): 895 896 def validate(indexed_slice): 897 @def_function.function 898 def f(): 899 return indexed_slice 900 901 output = f() 902 self.assertIsInstance(output, ops.IndexedSlices) 903 self.assertAllEqual(indexed_slice.values, output.values) 904 self.assertAllEqual(indexed_slice.indices, output.indices) 905 self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape) 906 907 self.assertEqual( 908 f.get_concrete_function().output_shapes, 909 indexed_slice.values.shape) 910 911 arg = ops.IndexedSlices( 912 values=constant_op.constant([1, 2]), 913 indices=constant_op.constant([0, 1]), 914 dense_shape=constant_op.constant([2])) 915 validate(arg) 916 917 arg = ops.IndexedSlices( 918 values=constant_op.constant([1, 2]), 919 indices=constant_op.constant([0, 1]), 920 dense_shape=None) 921 validate(arg) 922 923 def testIndexedSliceAsArgumentWithDefun(self): 924 925 @def_function.function 926 def f(indexed_slice): 927 return indexed_slice 928 929 def validate(arg): 930 output = f(arg) 931 self.assertIsInstance(output, ops.IndexedSlices) 932 self.assertAllEqual(arg.values, output.values) 933 self.assertAllEqual(arg.indices, output.indices) 934 self.assertAllEqual(arg.dense_shape, output.dense_shape) 935 936 indexed_slice = ops.IndexedSlices( 937 values=constant_op.constant([1]), 938 indices=constant_op.constant([0]), 939 dense_shape=constant_op.constant([1])) 940 validate(indexed_slice) 941 942 # Test that `f` works even when `dense_shape` is None. 943 indexed_slice = ops.IndexedSlices( 944 values=constant_op.constant([1]), 945 indices=constant_op.constant([0]), 946 dense_shape=None) 947 validate(indexed_slice) 948 949 def testFunctionOnDevice(self): 950 if not context.context().num_gpus(): 951 self.skipTest('No GPUs found') 952 953 x = constant_op.constant([1.]).gpu() 954 # TODO(b/121134877): Remove the autograph override. 955 f = def_function.function(math_ops.add, autograph=False) 956 y = f(x, x).cpu() 957 self.assertAllEqual(y, [2.]) 958 959 @test_util.run_in_graph_and_eager_modes 960 def testFunctionWithResourcesOnDifferentDevices(self): 961 if not context.context().num_gpus(): 962 self.skipTest('No GPUs found.') 963 964 with ops.device('/cpu:0'): 965 v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 966 967 with ops.device('/gpu:0'): 968 v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) 969 970 def sum_gather(): 971 cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2])) 972 gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) 973 return cpu_result, gpu_result 974 975 defined = function.defun(sum_gather) 976 if not context.executing_eagerly(): 977 self.evaluate(variables.global_variables_initializer()) 978 expected = self.evaluate(sum_gather()) 979 self.assertAllEqual(expected, self.evaluate(defined())) 980 981 @test_util.run_in_graph_and_eager_modes 982 def testOpInFunctionWithConflictingResourceInputs(self): 983 if not context.context().num_gpus(): 984 self.skipTest('No GPUs found.') 985 986 with ops.device('/cpu:0'): 987 v_cpu = resource_variable_ops.ResourceVariable( 988 [0.0, 1.0, 2.0], name='cpu') 989 v_also_cpu = resource_variable_ops.ResourceVariable( 990 [0.0, 1.0, 2.0], name='also_cpu') 991 992 with ops.device('/gpu:0'): 993 v_gpu = resource_variable_ops.ResourceVariable( 994 [0.0, 1.0, 2.0], name='gpu') 995 996 @def_function.function 997 def resource_apply_adam(): 998 training_ops.resource_apply_adam( 999 v_cpu.handle, 1000 v_gpu.handle, 1001 v_also_cpu.handle, 1002 1.0, # beta1_power 1003 1.0, # beta2_power 1004 1.0, # learning_rate 1005 1.0, # beta1 1006 1.0, # beta2 1007 1.0, # epsilon, 1008 [1.0, 1.0, 1.0], # grad 1009 False) # use_locking 1010 return None 1011 1012 with self.assertRaisesRegexp( 1013 errors.InvalidArgumentError, 1014 'Cannot place the graph because a reference or resource edge connects ' 1015 'colocation groups with incompatible assigned devices'): 1016 if not context.executing_eagerly(): 1017 self.evaluate(variables.global_variables_initializer()) 1018 self.evaluate(resource_apply_adam()) 1019 1020 def testFunctionHandlesInputsOnDifferentDevices(self): 1021 if not context.context().num_gpus(): 1022 self.skipTest('No GPUs found') 1023 1024 # The Reshape op requires the shape tensor to be placed in host memory. 1025 # TODO(b/121134877): Remove the autograph override. 1026 reshape = def_function.function(array_ops.reshape, autograph=False) 1027 value = constant_op.constant([1., 2.]).gpu() 1028 shape = constant_op.constant([2, 1]) 1029 reshaped = reshape(value, shape).cpu() 1030 self.assertAllEqual(reshaped, [[1], [2]]) 1031 1032 def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self): 1033 if not context.context().num_gpus(): 1034 self.skipTest('No GPUs found') 1035 1036 # The Reshape op requires the shape tensor to be placed in host memory. 1037 # TODO(b/121134877): Remove the autograph override. 1038 reshape = def_function.function(array_ops.reshape, autograph=False) 1039 value = constant_op.constant([1., 2.]) 1040 shape = constant_op.constant([2, 1]).gpu() 1041 reshape(value, shape) # No error is raised 1042 1043 def testNoneOutput(self): 1044 1045 @def_function.function 1046 def my_function(_): 1047 return None 1048 1049 self.assertAllEqual(my_function(1), None) 1050 1051 def testNestedFunctions(self): 1052 # TensorFlow function (which is what would be used in TensorFlow graph 1053 # construction). 1054 @tf_function.Defun(dtypes.int32, dtypes.int32) 1055 def add(a, b): 1056 return math_ops.add(a, b) 1057 1058 @def_function.function 1059 def add_one(x): 1060 return add(x, 1) 1061 1062 self.assertAllEqual(3, add_one(constant_op.constant(2))) 1063 1064 def testVariableCaptureInNestedFunctions(self): 1065 v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32) 1066 1067 @def_function.function 1068 def inner_read(): 1069 return v.read_value() 1070 1071 @def_function.function 1072 def outer(): 1073 return inner_read() 1074 1075 self.assertEqual(1, int(outer())) 1076 1077 def testReturnCapturedEagerTensor(self): 1078 t = constant_op.constant(1) 1079 1080 @def_function.function 1081 def read(): 1082 return t 1083 1084 self.assertEqual(1, int(read())) 1085 1086 def testReturnCapturedGraphTensor(self): 1087 with context.graph_mode(), self.cached_session(): 1088 t = constant_op.constant(1) 1089 1090 @def_function.function 1091 def read(): 1092 return t 1093 1094 self.assertEqual(1, int(self.evaluate(read()))) 1095 1096 def testSequenceInputs(self): 1097 # TODO(b/121134877): Remove the autograph override. 1098 clip_by_global_norm = def_function.function( 1099 clip_ops.clip_by_global_norm, autograph=False) 1100 t_list = [constant_op.constant(1.0), constant_op.constant(2.0)] 1101 clipped_list, global_norm = clip_by_global_norm(t_list, 1102 constant_op.constant(.2)) 1103 for t in clipped_list: 1104 self.assertIsInstance(t, ops.Tensor) 1105 self.assertIsInstance(global_norm, ops.Tensor) 1106 1107 def testNestedSequenceInputs(self): 1108 1109 def my_op(inputs): 1110 a, b, c = inputs 1111 e, f = b 1112 g, h = e 1113 return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c 1114 1115 my_eager_op = def_function.function(my_op) 1116 ret = my_eager_op([ 1117 constant_op.constant(1), [(constant_op.constant(2), 1118 constant_op.constant(3)), 1119 constant_op.constant(4)], 1120 constant_op.constant(5) 1121 ]) 1122 self.assertLen(ret, 2) 1123 self.assertAllEqual(ret[0][0], 2) 1124 self.assertAllEqual(ret[0][1][0][0], 8) 1125 self.assertAllEqual(ret[0][1][0][1], 4) 1126 self.assertIsInstance(ret[0][1][0], tuple) 1127 self.assertAllEqual(ret[0][1][1], 6) 1128 self.assertAllEqual(ret[0][2], 10) 1129 self.assertAllEqual(ret[1], 15) 1130 1131 def testVariableNamesRespectNameScopesWithDefun(self): 1132 @def_function.function 1133 def create_variable(): 1134 with ops.name_scope('foo'): 1135 v = resource_variable_ops.ResourceVariable(0.0, name='bar') 1136 self.assertEqual(v.name, 'foo/bar:0') 1137 1138 create_variable() 1139 1140 def testVariableNamesRespectNameScopesWithDefunInGraph(self): 1141 with context.graph_mode(): 1142 @def_function.function 1143 def create_variable(): 1144 with ops.name_scope('foo'): 1145 v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar') 1146 self.assertEqual(v.name, 'foo/bar:0') 1147 1148 with ops.get_default_graph().as_default(): 1149 create_variable() 1150 1151 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) 1152 def testLayerInDefun(self): 1153 conv = convolutional.Conv2D( 1154 filters=1, 1155 kernel_size=2, 1156 kernel_initializer=init_ops.ones_initializer(), 1157 bias_initializer=init_ops.zeros_initializer()) 1158 1159 @function.defun 1160 def model(x): 1161 return conv(x) 1162 1163 x = array_ops.ones([1, 2, 2, 1]) 1164 y = model(x) 1165 1166 if not context.executing_eagerly(): 1167 self.evaluate(variables.global_variables_initializer()) 1168 1169 self.assertAllClose([[[[4.0]]]], self.evaluate(y)) 1170 1171 # Variable lifting is somewhat different between defun/tf.function, so testing 1172 # device placement on both makes sense. 1173 @parameterized.named_parameters( 1174 dict(testcase_name='Defun', 1175 function_decorator=function.defun), 1176 dict(testcase_name='DefFunction', 1177 function_decorator=def_function.function)) 1178 @test_util.run_in_graph_and_eager_modes 1179 def testVariablesPlacedOnOutsideDevice(self, function_decorator): 1180 1181 class _Obj(object): 1182 1183 def __init__(self): 1184 self.v = None 1185 1186 @function_decorator 1187 def f(self): 1188 if self.v is None: 1189 self.v = variables.Variable(1.) 1190 return self.v + 1. 1191 1192 has_device = _Obj() 1193 with ops.device('cpu:0'): 1194 has_device.f() 1195 self.assertIn('CPU', has_device.v.device) 1196 1197 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) 1198 def testDefunKerasModelCall(self): 1199 model = MiniModel() 1200 model.call = function.defun(model.call) 1201 1202 x = array_ops.ones([1, 2]) 1203 y = model(x) 1204 1205 if not context.executing_eagerly(): 1206 self.evaluate(variables.global_variables_initializer()) 1207 1208 self.assertAllEqual([[3.0]], self.evaluate(y)) 1209 1210 # Break the reference cycle between the MiniModel and the defun: 1211 # `MiniModel` --(through its `call` method)--> `Function` 1212 # `Function` --(instancemethod on `MiniModel`)--> `MiniModel` 1213 del model.call 1214 1215 # Note: The ConfigProto below unfortunately only configures graph 1216 # construction. Eager's configuration is controlled in `__main__`. 1217 @test_util.run_in_graph_and_eager_modes( 1218 config=config_pb2.ConfigProto(device_count={'CPU': 4})) 1219 @test_util.run_v1_only('b/120545219') 1220 def testDeviceAnnotationsRespected(self): 1221 1222 def multi_device_fn(): 1223 with ops.device('/cpu:0'): 1224 s0 = test_ops.device_placement_op() 1225 with ops.device('/cpu:1'): 1226 s1 = test_ops.device_placement_op() 1227 with ops.device('/cpu:2'): 1228 s2 = test_ops.device_placement_op() 1229 s3 = test_ops.device_placement_op() 1230 return s0, s1, s2, s3 1231 1232 defined = function.defun(multi_device_fn) 1233 outputs = self.evaluate(defined()) 1234 self.assertLen(total_function_cache(defined), 1) 1235 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1236 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1237 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1238 1239 with ops.device('/cpu:3'): 1240 outputs = self.evaluate(defined()) 1241 # All function definitions are agnostic to call site devices. 1242 self.assertLen(total_function_cache(defined), 1) 1243 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1244 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1245 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1246 self.assertIn(compat.as_bytes('CPU:3'), outputs[3]) 1247 1248 with ops.device('/cpu:0'): 1249 outputs = self.evaluate(defined()) 1250 self.assertLen(total_function_cache(defined), 1) 1251 self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) 1252 self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) 1253 self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) 1254 self.assertIn(compat.as_bytes('CPU:0'), outputs[3]) 1255 1256 @test_util.run_in_graph_and_eager_modes( 1257 config=config_pb2.ConfigProto(device_count={'CPU': 2})) 1258 @test_util.run_v1_only('b/120545219') 1259 def testCallingGraphFunctionOnDifferentDevice(self): 1260 1261 def func(): 1262 return constant_op.constant(0) 1263 1264 defined = def_function.function(func) 1265 with ops.device('cpu:0'): 1266 cpu_graph_function = defined.get_concrete_function() 1267 1268 with ops.device('cpu:0'): 1269 self.assertEqual( 1270 self.evaluate(cpu_graph_function()), self.evaluate(func())) 1271 1272 with ops.device('cpu:1'): 1273 self.assertEqual(0., self.evaluate(cpu_graph_function())) 1274 1275 with ops.device(None): 1276 self.assertEqual(0., self.evaluate(cpu_graph_function())) 1277 1278 default_graph_function = defined.get_concrete_function() 1279 self.assertEqual( 1280 self.evaluate(default_graph_function()), self.evaluate(func())) 1281 1282 with ops.device('cpu:1'): 1283 self.assertEqual(0., self.evaluate(default_graph_function())) 1284 1285 @test_util.run_in_graph_and_eager_modes 1286 def testColocateWithRespected(self): 1287 # TODO(b/113291792): Use multiple CPUs instead of a GPU. 1288 if not context.context().num_gpus(): 1289 self.skipTest('No GPUs found.') 1290 1291 with ops.device('cpu:0'): 1292 x = constant_op.constant(1.0) 1293 1294 with ops.device('gpu:0'): 1295 y = constant_op.constant(1.0) 1296 1297 @def_function.function 1298 def foo(): 1299 return test_ops.device_placement_op() 1300 1301 with ops.colocate_with(x): 1302 self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo())) 1303 1304 with ops.colocate_with(y): 1305 self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo())) 1306 1307 def testVariablesAreTracked(self): 1308 v = resource_variable_ops.ResourceVariable(1.0) 1309 1310 def foo(x): 1311 return v * x 1312 1313 defined = def_function.function(foo) 1314 1315 x = constant_op.constant([1.0]) 1316 self.assertEqual(1., self.evaluate(defined(x))) 1317 v.assign(2.) 1318 1319 x = constant_op.constant([1.0, 2.0]) 1320 self.assertAllEqual([2., 4.], self.evaluate(defined(x))) 1321 1322 def testCacheObjectHashCollisions(self): 1323 1324 class Foo(object): 1325 1326 def __hash__(self): 1327 return 42 1328 1329 def func(foo): 1330 del foo 1331 return 1332 1333 defined = function.defun(func) 1334 defined(Foo()) 1335 self.assertLen(total_function_cache(defined), 1) 1336 1337 defined(Foo()) 1338 self.assertLen(total_function_cache(defined), 2) 1339 1340 def testCacheTensorDtypeCollision(self): 1341 1342 def func(t): 1343 return t + t 1344 1345 defined = function.defun(func) 1346 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1347 defined(t) 1348 self.assertLen(total_function_cache(defined), 1) 1349 1350 t = constant_op.constant([[1.0]], dtype=dtypes.complex128) 1351 defined(t) 1352 self.assertLen(total_function_cache(defined), 2) 1353 1354 def testCacheTensorShapeCollision(self): 1355 1356 def func(t): 1357 return t + t 1358 1359 defined = function.defun(func) 1360 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1361 defined(t) 1362 self.assertLen(total_function_cache(defined), 1) 1363 1364 t = constant_op.constant([1.0], dtype=dtypes.complex64) 1365 defined(t) 1366 self.assertLen(total_function_cache(defined), 2) 1367 1368 def testCacheTensorShapeDtypeCollision(self): 1369 1370 def func(t): 1371 return t + t 1372 1373 defined = function.defun(func) 1374 t = constant_op.constant([[1.0]], dtype=dtypes.complex64) 1375 defined(t) 1376 self.assertLen(total_function_cache(defined), 1) 1377 1378 t = constant_op.constant([1.0], dtype=dtypes.complex128) 1379 defined(t) 1380 self.assertLen(total_function_cache(defined), 2) 1381 1382 def testCacheTensorUnknownShapesCollision(self): 1383 1384 def func(t): 1385 return t + t 1386 1387 with context.graph_mode(), self.cached_session(): 1388 defined = function.defun(func) 1389 1390 p = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 1391 defined(p) 1392 self.assertLen(total_function_cache(defined), 1) 1393 1394 p = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) 1395 defined(p) 1396 self.assertLen(total_function_cache(defined), 2) 1397 1398 p = array_ops.placeholder(dtype=dtypes.float32, shape=[2]) 1399 defined(p) 1400 # Gradual shape relaxation is performed; and the common shape between 1401 # [1] and [2] is one containing unknown dimensions. 1402 self.assertLen(total_function_cache(defined), 2) 1403 1404 # pylint: disable=protected-access 1405 self.assertLen(defined._function_cache.arg_relaxed_shapes, 1) 1406 relaxed_shapes = ( 1407 list(defined._function_cache.arg_relaxed_shapes.values())[0]) 1408 self.assertEqual(len(relaxed_shapes), 1) 1409 relaxed_shape = relaxed_shapes[0] 1410 # pylint: enable=protected-access 1411 self.assertEqual(relaxed_shape.rank, 1) 1412 self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None) 1413 1414 t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32) 1415 defined(t) 1416 # Shape (3,) matches the relaxed shape TensorShape([None]) 1417 self.assertLen(total_function_cache(defined), 2) 1418 1419 def testPythonFunctionWithDefaultArgs(self): 1420 1421 def func(foo, bar=1, baz=2): 1422 del foo 1423 del bar 1424 del baz 1425 return 1426 1427 defined = function.defun(func) 1428 defined(0, baz=20) 1429 1430 def cache_keys(): 1431 """Sanitizes cache keys of non-input metadata.""" 1432 return tuple(key[0] for key in total_function_cache(defined)) 1433 1434 # `True` corresponds to the fact that we're executing eagerly 1435 self.assertIn(('URRRu', (0, 1, 20)), cache_keys()) 1436 1437 defined(1) # bar=1, baz=2 1438 self.assertIn(('URRRu', (1, 1, 2)), cache_keys()) 1439 1440 # This matches the previous call. 1441 defined(foo=1) 1442 self.assertLen(total_function_cache(defined), 2) 1443 1444 defined(1, 2, 3) 1445 self.assertLen(total_function_cache(defined), 3) 1446 self.assertIn(('URRRu', (1, 2, 3)), cache_keys()) 1447 1448 # This matches the previous call. 1449 defined(1, bar=2, baz=3) 1450 self.assertLen(total_function_cache(defined), 3) 1451 1452 # This matches the previous call. 1453 defined(1, baz=3, bar=2) 1454 self.assertLen(total_function_cache(defined), 3) 1455 1456 def testFunctoolsPartialUnwrappedCorrectly(self): 1457 1458 def full_function(a, b, c=3): 1459 return a, b, c 1460 1461 partial = functools.partial(full_function, 1, c=4) 1462 a, b, c = partial(2) 1463 1464 defined = function.defun(partial) 1465 func_a, func_b, func_c = defined(2) 1466 self.assertEqual(func_a.numpy(), a) 1467 self.assertEqual(func_b.numpy(), b) 1468 self.assertEqual(func_c.numpy(), c) 1469 1470 def testInputSignatureWithMatchingInputs(self): 1471 1472 def foo(a): 1473 self.assertEqual(a.shape, (2,)) 1474 return a 1475 1476 signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] 1477 defined = function.defun(foo, input_signature=signature) 1478 a = array_ops.ones([2]) 1479 self.assertAllEqual(a, defined(a)) 1480 self.assertLen(total_function_cache(defined), 1) 1481 self.assertAllEqual(a, defined.get_concrete_function()(a)) 1482 self.assertAllEqual(a, defined.get_concrete_function(a)(a)) 1483 self.assertAllEqual(a, defined.get_concrete_function( 1484 tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a)) 1485 self.assertLen(total_function_cache(defined), 1) 1486 1487 def bar(a): 1488 self.assertEqual(a._shape_tuple(), (2, None)) 1489 return a 1490 1491 signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)] 1492 defined = function.defun(bar, input_signature=signature) 1493 a = array_ops.ones([2, 1]) 1494 out = defined(a) 1495 self.assertLen(total_function_cache(defined), 1) 1496 self.assertAllEqual(out, a) 1497 1498 # Changing the second dimension shouldn't create a new function. 1499 b = array_ops.ones([2, 3]) 1500 out = defined(b) 1501 self.assertLen(total_function_cache(defined), 1) 1502 self.assertAllEqual(out, b) 1503 1504 def testInputSignatureWithCompatibleInputs(self): 1505 1506 rank2_spec = tensor_spec.TensorSpec(shape=(None, None), 1507 dtype=dtypes.float32) 1508 1509 @function.defun(input_signature=[rank2_spec]) 1510 def func(a): 1511 self.assertEqual([None, None], a.shape.as_list()) 1512 return array_ops.shape(a) 1513 1514 self.assertAllEqual([3, 1], func([[0], [1.0], [1]])) 1515 self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]]))) 1516 1517 with self.assertRaisesRegexp(ValueError, 'incompatible'): 1518 func([0.0, 1.0, 2.0]) # Wrong shape. 1519 1520 with self.assertRaisesRegexp(ValueError, 'incompatible'): 1521 func([['wrong dtype']]) 1522 1523 def testNestedInputSignatures(self): 1524 1525 def expected_foo(a, b): 1526 return [a, b] 1527 1528 @function.defun(input_signature=[ 1529 [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2, 1530 tensor_spec.TensorSpec((1,), dtypes.float32), 1531 ]) 1532 def foo(a, b): 1533 self.assertEqual(a[0]._shape_tuple(), (2, None)) 1534 self.assertEqual(a[1]._shape_tuple(), (2, None)) 1535 self.assertEqual(b._shape_tuple(), (1,)) 1536 return [a, b] 1537 1538 a = array_ops.ones([2, 1]) 1539 b = array_ops.ones([1]) 1540 expected = expected_foo([a, a], b) 1541 out = foo([a, a], b) 1542 self.assertLen(total_function_cache(foo), 1) 1543 nest.assert_same_structure(out, expected) 1544 self.assertAllEqual(out[0][0], a) 1545 self.assertAllEqual(out[0][1], a) 1546 self.assertAllEqual(out[1], b) 1547 1548 # Changing the unspecified dimensions shouldn't create a new function. 1549 a = array_ops.ones([2, 3]) 1550 b = array_ops.ones([2, 5]) 1551 c = array_ops.ones([1]) 1552 expected = expected_foo([a, b], c) 1553 out = foo([a, b], c) 1554 self.assertLen(total_function_cache(foo), 1) 1555 nest.assert_same_structure(out, expected) 1556 self.assertAllEqual(out[0][0], a) 1557 self.assertAllEqual(out[0][1], b) 1558 self.assertAllEqual(out[1], c) 1559 1560 # Passing compatible inputs should work. 1561 a = a.numpy().tolist() 1562 b = b.numpy().tolist() 1563 c = c.numpy().tolist() 1564 out = foo([a, b], c) 1565 self.assertLen(total_function_cache(foo), 1) 1566 nest.assert_same_structure(out, expected) 1567 self.assertAllEqual(out[0][0], a) 1568 self.assertAllEqual(out[0][1], b) 1569 self.assertAllEqual(out[1], c) 1570 1571 def testNestedInputSignaturesWithDict(self): 1572 def expected_bar(a): 1573 return a 1574 1575 @function.defun(input_signature=[{ 1576 'a': tensor_spec.TensorSpec((2, None), dtypes.float32), 1577 'b': tensor_spec.TensorSpec((2, None), dtypes.float32), 1578 'c': tensor_spec.TensorSpec((1,), dtypes.float32)}]) 1579 def bar(a): 1580 self.assertEqual(a['a']._shape_tuple(), (2, None)) 1581 self.assertEqual(a['b']._shape_tuple(), (2, None)) 1582 self.assertEqual(a['c']._shape_tuple(), (1,)) 1583 return a 1584 1585 a = array_ops.ones([2, 3]) 1586 b = array_ops.ones([1]) 1587 inputs = {'a': a, 'b': a, 'c': b} 1588 expected = expected_bar(inputs) 1589 out = bar(inputs) 1590 nest.assert_same_structure(out, expected) 1591 self.assertAllEqual(out['a'], expected['a']) 1592 self.assertAllEqual(out['b'], expected['b']) 1593 self.assertAllEqual(out['c'], expected['c']) 1594 1595 # Passing compatible inputs should work. 1596 a = a.numpy().tolist() 1597 b = b.numpy().tolist() 1598 inputs = {'a': a, 'b': a, 'c': b} 1599 out = bar(inputs) 1600 nest.assert_same_structure(out, expected) 1601 self.assertAllEqual(out['a'], expected['a']) 1602 self.assertAllEqual(out['b'], expected['b']) 1603 self.assertAllEqual(out['c'], expected['c']) 1604 1605 def testInputSignatureMustBeSequenceOfTensorSpecs(self): 1606 1607 def foo(a, b): 1608 del a 1609 del b 1610 1611 # Signatures must consist exclusively of `TensorSpec` objects. 1612 signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] 1613 with self.assertRaisesRegexp(TypeError, 'Invalid input_signature.*'): 1614 def_function.function(foo, input_signature=signature) 1615 1616 # Signatures must be either lists or tuples on their outermost levels. 1617 signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} 1618 with self.assertRaisesRegexp(TypeError, 'input_signature must be either a ' 1619 'tuple or a list.*'): 1620 function.defun(foo, input_signature=signature) 1621 1622 @test_util.run_in_graph_and_eager_modes 1623 def testInputsIncompatibleWithSignatureRaisesError(self): 1624 1625 def foo(a): 1626 return a 1627 1628 signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] 1629 defined = def_function.function(foo, input_signature=signature) 1630 1631 # Invalid shapes. 1632 with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'): 1633 defined(array_ops.ones([3])) 1634 1635 with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'): 1636 defined(array_ops.ones([2, 1])) 1637 1638 # Wrong number of arguments. 1639 with self.assertRaisesRegexp(TypeError, 'Received 2 argument\(s\)'): 1640 defined(array_ops.ones([2]), array_ops.ones([2])) 1641 with self.assertRaisesRegexp(ValueError, 1642 'Structure of Python function inputs.*'): 1643 defined() 1644 1645 with self.assertRaisesRegexp(ValueError, 1646 'inputs incompatible with input_signature'): 1647 defined.get_concrete_function( 1648 tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32)) 1649 1650 def testInputsIncompatibleWithNestedSignatureRaisesError(self): 1651 1652 def foo(a, b): 1653 return [a, b] 1654 1655 signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2, 1656 [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2] 1657 defined = function.defun(foo, input_signature=signature) 1658 a = array_ops.ones([1]) 1659 1660 with self.assertRaisesRegexp(ValueError, 1661 'Structure of Python function inputs.*'): 1662 defined([a, a, a], [a]) 1663 1664 with self.assertRaisesRegexp(ValueError, 1665 'Structure of Python function inputs.*'): 1666 defined([a], [a, a, a]) 1667 defined([a, a], [a, a]) 1668 1669 def testUnderspecifiedInputSignature(self): 1670 @function.defun(input_signature=[ 1671 tensor_spec.TensorSpec([], dtypes.float32), 1672 ]) 1673 def foo(a, training=True): 1674 if training: 1675 return a 1676 else: 1677 return -1.0 * a 1678 1679 x = constant_op.constant(1.0) 1680 with self.assertRaisesRegexp(TypeError, 'only pass arguments'): 1681 foo(x, training=True) 1682 1683 with self.assertRaisesRegexp(TypeError, 'only pass arguments'): 1684 foo(x, training=False) 1685 1686 self.assertAllEqual(x.numpy(), foo(x).numpy()) 1687 1688 def testInputSignatureWithPartialFunction(self): 1689 self.skipTest('b/124441704') 1690 def full_function(a, b, c=3.0): 1691 return a, b, c 1692 1693 partial = functools.partial(full_function, 1, c=4) 1694 a, b, c = partial(2.0) 1695 signature = [tensor_spec.TensorSpec([], dtypes.float32)] 1696 defined = function.defun(partial, input_signature=signature) 1697 x = constant_op.constant(2.0) 1698 func_a, func_b, func_c = defined(x) 1699 self.assertEqual(func_a.numpy(), a) 1700 self.assertEqual(func_b.numpy(), b) 1701 self.assertEqual(func_c.numpy(), c) 1702 1703 def testInputSignatureConversionWithDefaultArg(self): 1704 1705 def foo(a, training=True): 1706 if training: 1707 return a 1708 else: 1709 return -1.0 * a 1710 1711 signature = [ 1712 tensor_spec.TensorSpec([], dtypes.float32), 1713 tensor_spec.TensorSpec([], dtypes.bool), 1714 ] 1715 defined = def_function.function(foo, input_signature=signature) 1716 a = constant_op.constant(1.0) 1717 self.assertAllEqual(a.numpy(), defined(a)) 1718 self.assertAllEqual(a.numpy(), defined(a, training=True)) 1719 self.assertAllEqual(-a.numpy(), defined(a, training=False)) 1720 1721 def testInputSignatureWithKeywordPositionalArgs(self): 1722 1723 @function.defun(input_signature=[ 1724 tensor_spec.TensorSpec([], dtypes.float32), 1725 tensor_spec.TensorSpec([], dtypes.int64) 1726 ]) 1727 def foo(flt, integer): 1728 return flt, integer 1729 1730 flt = constant_op.constant(1.0) 1731 integer = constant_op.constant(2, dtypes.int64) 1732 1733 out1, out2 = foo(flt, integer) 1734 self.assertLen(total_function_cache(foo), 1) 1735 self.assertEqual(out1.numpy(), 1.0) 1736 self.assertEqual(out2.numpy(), 2) 1737 1738 out1, out2 = foo(flt=flt, integer=integer) 1739 self.assertLen(total_function_cache(foo), 1) 1740 self.assertEqual(out1.numpy(), 1.0) 1741 self.assertEqual(out2.numpy(), 2) 1742 1743 out1, out2 = foo(integer=integer, flt=flt) 1744 self.assertLen(total_function_cache(foo), 1) 1745 self.assertEqual(out1.numpy(), 1.0) 1746 self.assertEqual(out2.numpy(), 2) 1747 1748 out1, out2 = foo(flt, integer=integer) 1749 self.assertLen(total_function_cache(foo), 1) 1750 self.assertEqual(out1.numpy(), 1.0) 1751 self.assertEqual(out2.numpy(), 2) 1752 1753 def testInputSignatureWithKeywordArgsFails(self): 1754 1755 def foo(a, **kwargs): 1756 del a 1757 del kwargs 1758 1759 with self.assertRaisesRegexp( 1760 ValueError, 'Cannot define a TensorFlow function from a Python ' 1761 'function with keyword arguments when input_signature.*'): 1762 function.defun( 1763 foo, 1764 input_signature=[ 1765 tensor_spec.TensorSpec([], dtypes.float32), 1766 tensor_spec.TensorSpec([], dtypes.int64) 1767 ]) 1768 1769 def testTensorKeywordArguments(self): 1770 1771 def foo(a, b): 1772 del a 1773 return b 1774 1775 defined = function.defun(foo) 1776 a = constant_op.constant(2.0) 1777 b = constant_op.constant([1.0, 2.0]) 1778 one = defined(a, b) 1779 self.assertLen(total_function_cache(defined), 1) 1780 1781 two = defined(a=a, b=b) 1782 self.assertLen(total_function_cache(defined), 1) 1783 1784 three = defined(b=b, a=a) 1785 self.assertLen(total_function_cache(defined), 1) 1786 1787 four = defined(a, b=b) 1788 self.assertLen(total_function_cache(defined), 1) 1789 1790 # The next call corresponds to a new input signature, hence 1791 # we expect another function to be defined. 1792 five = defined(b, a) 1793 self.assertLen(total_function_cache(defined), 2) 1794 1795 six = defined(a=b, b=a) 1796 self.assertLen(total_function_cache(defined), 2) 1797 1798 seven = defined(b=a, a=b) 1799 self.assertLen(total_function_cache(defined), 2) 1800 1801 self.assertAllEqual(one, [1.0, 2.0]) 1802 self.assertAllEqual(two, [1.0, 2.0]) 1803 self.assertAllEqual(three, [1.0, 2.0]) 1804 self.assertAllEqual(four, [1.0, 2.0]) 1805 self.assertAllEqual(five, 2.0) 1806 self.assertAllEqual(six, 2.0) 1807 self.assertAllEqual(seven, 2.0) 1808 1809 def testDefuningInstanceMethod(self): 1810 1811 integer = constant_op.constant(2, dtypes.int64) 1812 1813 class Foo(object): 1814 1815 def one(self, tensor): 1816 return tensor 1817 1818 @def_function.function 1819 def two(self, tensor, other=integer): 1820 return self.one(tensor), other 1821 1822 foo = Foo() 1823 t = constant_op.constant(1.0) 1824 one, two = foo.two(t) 1825 self.assertEqual(one.numpy(), 1.0) 1826 self.assertEqual(two.numpy(), 2) 1827 1828 def testDefuningInstanceMethodWithDefaultArgument(self): 1829 1830 integer = constant_op.constant(2, dtypes.int64) 1831 1832 class Foo(object): 1833 1834 @def_function.function 1835 def func(self, other=integer): 1836 return other 1837 1838 foo = Foo() 1839 self.assertEqual(foo.func().numpy(), int(integer)) 1840 1841 def testPythonCallWithSideEffects(self): 1842 state = [] 1843 1844 @def_function.function 1845 def side_effecting_function(): 1846 state.append(0) 1847 1848 side_effecting_function() 1849 self.assertAllEqual(state, [0]) 1850 1851 # The second invocation should call the graph function, which shouldn't 1852 # trigger the list append. 1853 side_effecting_function() 1854 self.assertAllEqual(state, [0]) 1855 1856 # Whereas calling the python function directly should create a side-effect. 1857 side_effecting_function.python_function() 1858 self.assertAllEqual(state, [0, 0]) 1859 1860 def testFunctionWithNestedFunctionCallAndSideEffects(self): 1861 v1 = variables.Variable(1.0) 1862 v2 = variables.Variable(1.0) 1863 1864 @def_function.function 1865 def add_one(a): 1866 a.assign_add(1.0) 1867 1868 # Grappler will inline calls to `add_one` into the function body, we check 1869 # that all side-effects were executed. 1870 @def_function.function 1871 def side_effecting_function(a, b): 1872 add_one(a) 1873 add_one(b) 1874 return a + b 1875 1876 result = side_effecting_function(v1, v2) 1877 self.assertEqual(result.numpy(), 4.0) 1878 1879 def testFunctionWithExtraAttributes(self): 1880 @function.defun_with_attributes(attributes={'experimental_1': 'value1', 1881 'experimental_2': 2}) 1882 def matmul(x, y): 1883 return math_ops.matmul(x, y) 1884 1885 def add(x, y): 1886 return math_ops.add(x, y) 1887 defun_add = function.defun_with_attributes( 1888 add, attributes={'experimental_3': True, 'experimental_4': 1.0}) 1889 1890 with context.graph_mode(), self.cached_session(): 1891 with ops.get_default_graph().as_default(): 1892 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 1893 sq = matmul(t, t) 1894 double = defun_add(t, t) 1895 self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) 1896 self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) 1897 1898 graph = ops.get_default_graph() 1899 # pylint: disable=protected-access 1900 self.assertLen(graph._functions, 2) 1901 functions = list(graph._functions.values()) 1902 self.assertRegexpMatches( 1903 functions[0].definition.signature.name, '.*matmul.*') 1904 attrs = functions[0].definition.attr 1905 self.assertLen(attrs, 2) 1906 self.assertEqual(attrs['experimental_1'].s, b'value1') 1907 self.assertEqual(attrs['experimental_2'].i, 2) 1908 1909 self.assertRegexpMatches( 1910 functions[1].definition.signature.name, '.*add.*') 1911 attrs = functions[1].definition.attr 1912 self.assertLen(attrs, 2) 1913 self.assertEqual(attrs['experimental_3'].b, True) 1914 self.assertEqual(attrs['experimental_4'].f, 1.0) 1915 # pylint: enable=protected-access 1916 1917 def testFunctionWithInvalidAttribute(self): 1918 @function.defun_with_attributes(attributes={'experimental_1': ['value1']}) 1919 def add(x, y): 1920 return math_ops.add(x, y) 1921 1922 with self.assertRaisesRegexp(ValueError, 1923 '.*Unsupported attribute type.*'): 1924 with context.graph_mode(), self.cached_session(): 1925 with ops.get_default_graph().as_default(): 1926 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 1927 add(t, t) 1928 1929 def testRegisterFunction(self): 1930 1931 @function.defun 1932 def add(x, y): 1933 return math_ops.add(x, y) 1934 1935 def matmul(x, y): 1936 return math_ops.matmul(x, y) 1937 defun_matmul = function.defun(matmul) 1938 1939 with context.graph_mode(), self.cached_session(): 1940 with ops.get_default_graph().as_default(): 1941 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 1942 function.register(defun_matmul, t, t) 1943 function.register(add, t, t) 1944 1945 graph = ops.get_default_graph() 1946 # pylint: disable=protected-access 1947 self.assertLen(graph._functions, 6) 1948 # two sets of functions, each of them are (inference, forward, backward) 1949 functions = list(graph._functions.values()) 1950 captured_function_names = [ 1951 f.definition.signature.name for f in functions 1952 ] 1953 expected_func_name_regex = [ 1954 '.*inference.*matmul.*', 1955 '.*forward.*matmul.*', 1956 '.*inference.*backward.*matmul.*', 1957 '.*inference.*add.*', 1958 '.*forward.*add.*', 1959 '.*inference.*backward.*add.*', 1960 ] 1961 for i in range(len(functions)): 1962 self.assertRegexpMatches(captured_function_names[i], 1963 expected_func_name_regex[i]) 1964 1965 # Check the forward and backward function has the correct attributes. 1966 self.assertEqual( 1967 functions[1].definition.attr['backward_function_name'].s, 1968 functions[2].name) 1969 self.assertEqual( 1970 functions[2].definition.attr['forward_function_name'].s, 1971 functions[1].name) 1972 1973 self.assertEqual( 1974 functions[4].definition.attr['backward_function_name'].s, 1975 functions[5].name) 1976 self.assertEqual( 1977 functions[5].definition.attr['forward_function_name'].s, 1978 functions[4].name) 1979 1980 sq = defun_matmul(t, t) 1981 double = add(t, t) 1982 self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) 1983 self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) 1984 # Make sure the pre registered function is used, and no other function 1985 # is added. 1986 self.assertLen(graph._functions, 6) 1987 functions = list(graph._functions.values()) 1988 for i in range(len(functions)): 1989 self.assertEqual(captured_function_names[i], 1990 functions[i].definition.signature.name) 1991 1992 @parameterized.named_parameters( 1993 dict(testcase_name='Defun', 1994 function_decorator=function.defun), 1995 dict(testcase_name='DefFunction', 1996 function_decorator=def_function.function)) 1997 def testRegisterConcreteFunction(self, function_decorator): 1998 @function_decorator 1999 def py_add(x, y): 2000 return math_ops.add(x, y) 2001 2002 py_add(array_ops.ones([]), array_ops.ones([])) 2003 add = py_add.get_concrete_function( 2004 tensor_spec.TensorSpec(None, dtypes.float32), 2005 tensor_spec.TensorSpec(None, dtypes.float32)) 2006 2007 @function_decorator 2008 def py_composite(x, y): 2009 return x, add(x, y) 2010 2011 py_composite(array_ops.ones([]), array_ops.ones([])) 2012 composite = py_composite.get_concrete_function( 2013 tensor_spec.TensorSpec(None, dtypes.float32), 2014 tensor_spec.TensorSpec(None, dtypes.float32)) 2015 2016 with context.graph_mode(), self.cached_session(): 2017 with ops.get_default_graph().as_default(): 2018 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2019 composite.add_to_graph(register_gradient_functions=True) 2020 2021 graph = ops.get_default_graph() 2022 # pylint: disable=protected-access 2023 self.assertLen(graph._functions, 6) 2024 # two sets of functions, each of them are (inference, forward, backward) 2025 functions = list(graph._functions.values()) 2026 captured_function_names = [ 2027 f.definition.signature.name for f in functions 2028 ] 2029 expected_func_name_regex = [ 2030 '.*inference.*py_composite.*', 2031 '.*inference.*py_add.*', 2032 '.*forward.*py_composite.*', 2033 '.*forward.*py_add.*', 2034 '.*inference.*backward.*py_composite.*', 2035 '.*inference.*backward.*py_add.*', 2036 ] 2037 for expected, found in zip( 2038 expected_func_name_regex, 2039 captured_function_names): 2040 self.assertRegexpMatches(found, expected) 2041 2042 composite_t, composite_double = composite(t, t) 2043 double = add(t, t) 2044 self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double)) 2045 self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double)) 2046 self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t)) 2047 # Make sure the pre registered function is used, and no other function 2048 # is added. 2049 self.assertLen(graph._functions, 6) 2050 2051 def testRegisterFunctionWithInputSignature(self): 2052 def matmul(x, y): 2053 return math_ops.matmul(x, y) 2054 defun_matmul = function.defun( 2055 matmul, 2056 input_signature=[ 2057 tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), 2058 tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) 2059 ]) 2060 with context.graph_mode(), self.cached_session(): 2061 with ops.get_default_graph().as_default(): 2062 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2063 function.register(defun_matmul, t, t) 2064 2065 graph = ops.get_default_graph() 2066 # pylint: disable=protected-access 2067 self.assertLen(graph._functions, 3) 2068 2069 # Test register function with cache, note inputs are ignored. 2070 function.register(defun_matmul) 2071 graph = ops.get_default_graph() 2072 self.assertLen(graph._functions, 3) 2073 2074 def testRegisterFunctionWithCache(self): 2075 def matmul(x, y): 2076 return math_ops.matmul(x, y) 2077 defun_matmul = function.defun(matmul) 2078 2079 with context.graph_mode(), self.cached_session(): 2080 with ops.get_default_graph().as_default(): 2081 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2082 t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]]) 2083 function.register(defun_matmul, t, t) 2084 function.register(defun_matmul, t2, t2) 2085 2086 graph = ops.get_default_graph() 2087 # Only one function is registered since the input param are in same type 2088 # pylint: disable=protected-access 2089 self.assertLen(graph._functions, 3) 2090 2091 def testCallingFunctionWithDifferentVariables(self): 2092 2093 @function.defun 2094 def foo(v): 2095 v.assign_add(1.0) 2096 return v.read_value() 2097 2098 v = resource_variable_ops.ResourceVariable(0.0) 2099 graph_function = foo.get_concrete_function(v) 2100 self.assertLen(graph_function.inputs, 1) 2101 self.assertEmpty(graph_function.captured_inputs) 2102 2103 self.assertEqual(float(graph_function(v)), 1.0) 2104 self.assertEqual(float(graph_function(v)), 2.0) 2105 2106 w = resource_variable_ops.ResourceVariable(0.0) 2107 2108 @function.defun 2109 def bar(v): 2110 del v 2111 return constant_op.constant(1.0) 2112 2113 graph_function = bar.get_concrete_function(v) 2114 self.assertEqual(float(graph_function(v)), 1.0) 2115 self.assertEqual(float(graph_function(w)), 1.0) 2116 2117 def testCallingFunctionWithNonTensorsFails(self): 2118 2119 @function.defun 2120 def foo(x): 2121 return x 2122 2123 graph_function = foo.get_concrete_function(constant_op.constant(1.0)) 2124 with self.assertRaisesRegexp( 2125 ValueError, 'All inputs to `ConcreteFunction`s must be Tensors;.*'): 2126 graph_function('Not a Tensor.') 2127 2128 def testSwapImplementationWithGrapplerPlugin(self): 2129 # Set the min_graph_nodes to -1 since the graph in this test is too small, 2130 # and will be ignored by grappler if don't set this. 2131 rewrites = rewriter_config_pb2.RewriterConfig() 2132 rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON 2133 rewrites.min_graph_nodes = -1 2134 graph_options = config_pb2.GraphOptions( 2135 rewrite_options=rewrites, build_cost_model=1) 2136 config = config_pb2.ConfigProto(graph_options=graph_options) 2137 2138 with context.graph_mode(), self.cached_session( 2139 config=config, graph=ops.Graph(), use_gpu=True): 2140 2141 @function.defun_with_attributes( 2142 attributes={ 2143 'api_implements': 'random_boost', 2144 'api_preferred_device': 'CPU' 2145 }) 2146 def cpu_boost(x): 2147 return math_ops.add(x, 2.0) 2148 2149 @function.defun_with_attributes( 2150 attributes={ 2151 'api_implements': 'random_boost', 2152 'api_preferred_device': 'GPU' 2153 }) 2154 def gpu_boost(x): 2155 return math_ops.add(x, 4.0) 2156 2157 x = constant_op.constant(1.0) 2158 2159 function.register(cpu_boost, x) 2160 y = gpu_boost(x) 2161 y_value = self.evaluate(y) 2162 2163 if test.is_gpu_available(): 2164 self.assertEqual(y_value, 5.0) 2165 else: 2166 # Grappler fallback to use the CPU impl even called with GPU function. 2167 self.assertEqual(y_value, 3.0) 2168 2169 def testDefunFunctionSeparateGraphs(self): 2170 with context.graph_mode(): 2171 2172 @function.defun 2173 def add(x): 2174 return x + 5 2175 2176 @function.defun 2177 def maybe_add(x, should_add): 2178 if should_add: 2179 return add(x) 2180 else: 2181 return x 2182 2183 with ops.Graph().as_default(): 2184 x = constant_op.constant(11) 2185 maybe_add(x, True) 2186 self.assertLen(total_function_cache(maybe_add), 1) 2187 self.assertLen(total_function_cache(add), 1) 2188 2189 maybe_add(x, False) 2190 self.assertLen(total_function_cache(maybe_add), 2) 2191 self.assertLen(total_function_cache(add), 1) 2192 2193 with ops.Graph().as_default(): 2194 x = constant_op.constant(11) 2195 maybe_add(x, True) 2196 self.assertLen(total_function_cache(maybe_add), 3) 2197 self.assertLen(total_function_cache(add), 2) 2198 2199 def testCacheKeyOverlappingShapes(self): 2200 @function.defun 2201 def defined(t): 2202 return t 2203 2204 defined(array_ops.zeros([12, 1])) 2205 self.assertLen(total_function_cache(defined), 1) 2206 2207 defined(array_ops.zeros([1, 21])) 2208 self.assertLen(total_function_cache(defined), 2) 2209 2210 def testCacheKeyNestedLists(self): 2211 @function.defun 2212 def defined(l): 2213 return l 2214 2215 a = constant_op.constant(1.) 2216 b = constant_op.constant(2.) 2217 c = constant_op.constant(3.) 2218 defined([[a], b, c]) 2219 self.assertLen(total_function_cache(defined), 1) 2220 2221 defined([[a, b], c]) 2222 self.assertLen(total_function_cache(defined), 2) 2223 2224 def testDecoratedMethod(self): 2225 m = DefunnedMiniModel() 2226 instance_call_one = m.call(array_ops.ones([1, 2]), training=True) 2227 instance_call_two = m.call( 2228 inputs=array_ops.ones([1, 2]), training=True) 2229 class_call = DefunnedMiniModel.call(m, array_ops.ones([1, 2]), 2230 training=True) 2231 self.assertAllEqual(instance_call_one, instance_call_two) 2232 self.assertAllEqual(instance_call_one, class_call) 2233 2234 def testDecoratedMethodUniqueFunctionPerInstance(self): 2235 m = DefunnedMiniModel() 2236 n = DefunnedMiniModel() 2237 2238 class_method_one = DefunnedMiniModel.call 2239 class_method_two = DefunnedMiniModel.call 2240 2241 m_method_one = m.call 2242 m_method_two = m.call 2243 2244 n_method_one = n.call 2245 n_method_two = n.call 2246 2247 self.assertEqual(class_method_one, class_method_two) 2248 self.assertEqual(m_method_one, m_method_two) 2249 self.assertEqual(n_method_one, n_method_two) 2250 self.assertNotEqual(m.call, n.call) 2251 2252 def testDecoratedMethodInspect(self): 2253 m = DefunnedMiniModel() 2254 fullargspec = tf_inspect.getfullargspec(m.call) 2255 self.assertIn('training', fullargspec.args) 2256 2257 def testDecoratedMethodGetConcreteFunction(self): 2258 m = DefunnedMiniModel() 2259 instance_call_one = m.call.get_concrete_function( 2260 array_ops.ones([1, 2]), training=False) 2261 instance_call_two = m.call.get_concrete_function( 2262 inputs=array_ops.ones([1, 2]), training=False) 2263 self.assertAllEqual(instance_call_one(array_ops.ones([1, 2])), 2264 instance_call_two(array_ops.ones([1, 2]))) 2265 2266 # Also make sure get_concrete_function works on the class method 2267 DefunnedMiniModel.call.get_concrete_function( 2268 m, array_ops.ones([1, 2]), training=False) 2269 DefunnedMiniModel.call.get_concrete_function( 2270 m, inputs=array_ops.ones([1, 2]), training=True) 2271 2272 def testFunctionModifiesInputList(self): 2273 # Tests on `list` methods that do in place modification, except `list.sort` 2274 # since it cannot even be "defunned" in the first place 2275 2276 def get_list(): 2277 return [constant_op.constant(0.), constant_op.constant(1.)] 2278 2279 expected_msg = ( 2280 'Function to be traced should not modify structure of input ' 2281 'arguments. Check if your function has list and dictionary ' 2282 'operations that alter input arguments, ' 2283 'such as `list.pop`, `list.append`') 2284 2285 with self.assertRaisesRegexp(ValueError, expected_msg): 2286 2287 @def_function.function 2288 def append(l): 2289 l.append(constant_op.constant(0.)) 2290 2291 append(get_list()) 2292 2293 with self.assertRaisesRegexp(ValueError, expected_msg): 2294 2295 @def_function.function 2296 def extend(l): 2297 l.extend([constant_op.constant(0.)]) 2298 2299 extend(get_list()) 2300 2301 with self.assertRaisesRegexp(ValueError, expected_msg): 2302 2303 @def_function.function 2304 def insert(l): 2305 l.insert(0, constant_op.constant(0.)) 2306 2307 insert(get_list()) 2308 2309 with self.assertRaisesRegexp(ValueError, expected_msg): 2310 2311 @def_function.function 2312 def pop(l): 2313 l.pop() 2314 2315 pop(get_list()) 2316 2317 with self.assertRaisesRegexp(ValueError, expected_msg): 2318 2319 @def_function.function 2320 def reverse(l): 2321 l.reverse() 2322 2323 reverse(get_list()) 2324 2325 with self.assertRaisesRegexp(ValueError, expected_msg): 2326 2327 @def_function.function 2328 def remove(l): 2329 l.remove(l[0]) 2330 2331 remove(get_list()) 2332 2333 # `list.clear` is a method that is in Py3 but not Py2 2334 if sys.version.startswith('3'): 2335 2336 with self.assertRaisesRegexp(ValueError, expected_msg): 2337 2338 @def_function.function 2339 def clear(l): 2340 l.clear() 2341 2342 clear(get_list()) 2343 2344 # One last test for keyword arguments 2345 with self.assertRaisesRegexp(ValueError, expected_msg): 2346 2347 @def_function.function 2348 def kwdappend(**kwargs): 2349 l = kwargs['l'] 2350 l.append(constant_op.constant(0.)) 2351 2352 kwdappend(l=get_list()) 2353 2354 def testFunctionModifiesInputDict(self): 2355 2356 def get_dict(): 2357 return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)} 2358 2359 expected_msg = ( 2360 'Function to be traced should not modify structure of input ' 2361 'arguments. Check if your function has list and dictionary ' 2362 'operations that alter input arguments, ' 2363 'such as `list.pop`, `list.append`') 2364 2365 with self.assertRaisesRegexp(ValueError, expected_msg): 2366 2367 @def_function.function 2368 def clear(m): 2369 m.clear() 2370 2371 clear(get_dict()) 2372 2373 with self.assertRaisesRegexp(ValueError, expected_msg): 2374 2375 @def_function.function 2376 def pop(m): 2377 m.pop('t1') 2378 2379 pop(get_dict()) 2380 2381 with self.assertRaisesRegexp(ValueError, expected_msg): 2382 2383 @def_function.function 2384 def popitem(m): 2385 m.popitem() 2386 2387 popitem(get_dict()) 2388 2389 with self.assertRaisesRegexp(ValueError, expected_msg): 2390 2391 @def_function.function 2392 def update(m): 2393 m.update({'t1': constant_op.constant(3.)}) 2394 2395 update(get_dict()) 2396 2397 with self.assertRaisesRegexp(ValueError, expected_msg): 2398 2399 @def_function.function 2400 def setdefault(m): 2401 m.setdefault('t3', constant_op.constant(3.)) 2402 2403 setdefault(get_dict()) 2404 2405 def testFunctionModifiesInputNest(self): 2406 # Test on functions that modify structure of nested input arguments 2407 expected_msg = ( 2408 'Function to be traced should not modify structure of input ' 2409 'arguments. Check if your function has list and dictionary ' 2410 'operations that alter input arguments, ' 2411 'such as `list.pop`, `list.append`') 2412 2413 with self.assertRaisesRegexp(ValueError, expected_msg): 2414 2415 @def_function.function 2416 def modify(n): 2417 n[0]['t1'].append(constant_op.constant(1.)) 2418 2419 nested_input = [{ 2420 't1': [constant_op.constant(0.), 2421 constant_op.constant(1.)], 2422 }, 2423 constant_op.constant(2.)] 2424 2425 modify(nested_input) 2426 2427 with self.assertRaisesRegexp(ValueError, expected_msg): 2428 2429 # The flat list doesn't change whereas the true structure changes 2430 @def_function.function 2431 def modify_same_flat(n): 2432 n[0].append(n[1].pop(0)) 2433 2434 nested_input = [[constant_op.constant(0.)], 2435 [constant_op.constant(1.), 2436 constant_op.constant(2.)]] 2437 2438 modify_same_flat(nested_input) 2439 2440 def testDecoratedMethodVariableCleanup(self): 2441 m = DefunnedMiniModel() 2442 m(array_ops.ones([1, 2])) 2443 weak_variables = weakref.WeakSet(m.variables) 2444 self.assertLen(weak_variables, 2) 2445 del m 2446 self.assertEqual([], list(weak_variables)) 2447 2448 def testExecutorType(self): 2449 @function.defun 2450 def add_five(x): 2451 return x + 5 2452 2453 self.assertEqual( 2454 5, 2455 add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) 2456 2457 with self.assertRaisesRegexp(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'): 2458 with context.function_executor_type('NON_EXISTENT_EXECUTOR'): 2459 add_five(constant_op.constant(0, dtype=dtypes.int32)) 2460 2461 for executor_type in ('', 'DEFAULT', None): 2462 with context.function_executor_type(executor_type): 2463 self.assertAllEqual( 2464 5, 2465 add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) 2466 2467 @test_util.assert_no_garbage_created 2468 def testReferenceCycles(self): 2469 2470 fn = function.defun(lambda x: 2. * x) 2471 2472 fn(constant_op.constant(4.0)) 2473 weak_fn = weakref.ref(fn) 2474 del fn 2475 # Tests that the weak reference we made to the function is now dead, which 2476 # means the object has been deleted. This should be true as long as the 2477 # function itself is not involved in a reference cycle. 2478 self.assertIs(None, weak_fn()) 2479 2480 def testFunctionStackInErrorMessage(self): 2481 if context.executing_eagerly(): 2482 # TODO(b/122736651): Remove this skipTest once fixed. 2483 self.skipTest('Error interpolation is not working when function is ' 2484 'invoked without PartitionedCallOp.') 2485 2486 @def_function.function() 2487 def fn3(x): 2488 return x + 2 2489 2490 @def_function.function() 2491 def fn2(x): 2492 check_ops.assert_equal(fn3(x), 3) 2493 return 2 2494 2495 @def_function.function() 2496 def fn(x): 2497 return fn2(x) 2498 2499 with self.assertRaises(errors.InvalidArgumentError) as cm: 2500 fn(2) 2501 e = cm.exception 2502 self.assertIn('fn -> fn2', e.message) 2503 self.assertIn('node assert_equal/Assert/Assert (defined at', e.message) 2504 self.assertNotIn('fn3', e.message) 2505 2506 def testFunctionIsNotPinned(self): 2507 """Tests that functions aren't pinned to the CPU by the eager runtime.""" 2508 if not context.context().num_gpus(): 2509 self.skipTest('No GPUs found.') 2510 seed1, seed2 = 79, 25 2511 shape = constant_op.constant([4, 7]) 2512 dtype = dtypes.float32 2513 2514 @def_function.function 2515 def func(): 2516 with ops.device('GPU:0'): 2517 return gen_random_ops.random_standard_normal( 2518 shape, dtype=dtype, seed=seed1, seed2=seed2) 2519 2520 with ops.device('GPU:0'): 2521 x = func() 2522 self.assertRegexpMatches(x.device, 'GPU') 2523 2524 @test_util.run_in_graph_and_eager_modes 2525 def testShapeCaching(self): 2526 2527 @function.defun 2528 def func(x): 2529 return array_ops.shape(x) 2530 2531 @function.defun( 2532 input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)]) 2533 def calls_func(x): 2534 return func(x) 2535 2536 self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1])))) 2537 self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2])))) 2538 self.assertAllEqual( 2539 [3, 3], 2540 self.evaluate(calls_func(array_ops.zeros([3, 3])))) 2541 2542 def testLimitedRetracing(self): 2543 trace_count = [0] 2544 @function.defun 2545 def func(x): 2546 trace_count[0] += 1 2547 return x 2548 2549 for _ in range(50): 2550 func(constant_op.constant(3.)) 2551 func(constant_op.constant(4.)) 2552 func(constant_op.constant([[1., 2.]])) 2553 func(constant_op.constant([[]])) 2554 func(constant_op.constant([[3., 4.], [5., 6.]])) 2555 func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]])) 2556 # Tracing more than twice per input doesn't make sense. 2557 self.assertLess(trace_count[0], 13) 2558 2559 2560class MultiDeviceTest(test.TestCase, parameterized.TestCase): 2561 2562 def testMultiDeviceOutput(self): 2563 """Tests that functions can produce outputs on multiple devices.""" 2564 if not context.context().num_gpus(): 2565 self.skipTest('No GPUs found.') 2566 2567 @function.defun 2568 def func(a, b, transpose_a): 2569 with ops.device('/device:CPU:0'): 2570 m1 = math_ops.matmul(a, b, transpose_a=transpose_a) 2571 with ops.device('/device:GPU:0'): 2572 m2 = math_ops.matmul(a, b, transpose_a=transpose_a) 2573 return m1, m2 2574 2575 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 2576 m1, m2 = func(t, t, transpose_a=True) 2577 self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]]) 2578 self.assertRegexpMatches(m1.backing_device, 'CPU') 2579 self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]]) 2580 self.assertRegexpMatches(m2.backing_device, 'GPU') 2581 2582 def testEmptyBody(self): 2583 if not context.context().num_gpus(): 2584 self.skipTest('No GPUs found.') 2585 2586 @function.defun 2587 def func(a, b): 2588 return b, a 2589 2590 with ops.device('/device:CPU:0'): 2591 a = constant_op.constant(3.0) 2592 with ops.device('/device:GPU:0'): 2593 b = constant_op.constant(5.0) 2594 2595 m1, m2 = func(a, b) 2596 self.assertAllEqual(m1.numpy(), 5.0) 2597 self.assertRegexpMatches(m1.backing_device, 'GPU') 2598 self.assertAllEqual(m2.numpy(), 3.0) 2599 self.assertRegexpMatches(m2.backing_device, 'CPU') 2600 2601 def testMultiDeviceInt32(self): 2602 """Tests that multi-device functions can take and output INT32s. 2603 2604 When an INT32 device tensor is fed into a function, it is copied to CPU 2605 by the eager runtime. The function sees all INT32 inputs on CPU. 2606 2607 We set allocator attribute 'on_host' for INT32 outputs. They can be 2608 partitioned into the GPU component function, but will be allocated on 2609 CPU nevertheless. 2610 2611 There is experimental support for `ints_on_device` in 2612 FunctionLibraryRuntime now. We can try that. 2613 2614 """ 2615 if not context.context().num_gpus(): 2616 self.skipTest('No GPUs found.') 2617 2618 with ops.device('/device:CPU:0'): 2619 int_cpu = constant_op.constant(3, dtype=dtypes.int32) 2620 resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32) 2621 with ops.device('/device:GPU:0'): 2622 int_gpu = constant_op.constant(7, dtype=dtypes.int32) 2623 2624 @function.defun 2625 def func(int_cpu, resource, int_gpu): 2626 with ops.device('/device:CPU:0'): 2627 m1 = int_cpu * resource + int_gpu 2628 with ops.device('/device:GPU:0'): 2629 # This computation will happen on GPU but m2 will be copied to CPU. 2630 m2 = int_gpu * resource + int_cpu + 1 2631 return m1, m2 2632 2633 m1, m2 = func(int_cpu, resource, int_gpu) 2634 self.assertAllEqual(m1.numpy(), 22) 2635 self.assertRegexpMatches(m1.backing_device, 'CPU') 2636 self.assertAllEqual(m2.numpy(), 39) 2637 self.assertRegexpMatches(m2.backing_device, 'CPU') 2638 2639 # flip arguments 2640 m1, m2 = func(int_gpu, resource, int_cpu) 2641 self.assertAllEqual(m1.numpy(), 38) 2642 self.assertRegexpMatches(m1.backing_device, 'CPU') 2643 self.assertAllEqual(m2.numpy(), 23) 2644 self.assertRegexpMatches(m2.backing_device, 'CPU') 2645 2646 def testMultiDeviceColocateWith(self): 2647 """Tests that function's outputs respect colocation constraints.""" 2648 if not context.context().num_gpus(): 2649 self.skipTest('No GPUs found.') 2650 2651 @function.defun 2652 def func(a, b): 2653 with ops.colocate_with(a): 2654 ra = 2 * a 2655 with ops.colocate_with(b): 2656 rb = 3 * b 2657 return ra, rb 2658 2659 devices = ['/device:CPU:0', '/device:GPU:0'] 2660 for dev1, dev2 in itertools.product(devices, devices): 2661 with ops.device(dev1): 2662 a = constant_op.constant(1.0) 2663 with ops.device(dev2): 2664 b = constant_op.constant(10.0) 2665 2666 ra, rb = func(a, b) 2667 self.assertEqual(ra.numpy(), 2.0) 2668 self.assertRegexpMatches(ra.backing_device, dev1) 2669 self.assertEqual(rb.numpy(), 30.0) 2670 self.assertRegexpMatches(rb.backing_device, dev2) 2671 2672 def testMultiDeviceResources(self): 2673 if not context.context().num_gpus(): 2674 self.skipTest('No GPUs found.') 2675 2676 with ops.device('/device:CPU:0'): 2677 c1 = resource_variable_ops.ResourceVariable(2.0) 2678 c2 = resource_variable_ops.ResourceVariable(7.0) 2679 with ops.device('/device:GPU:0'): 2680 g1 = resource_variable_ops.ResourceVariable(3.0) 2681 g2 = resource_variable_ops.ResourceVariable(5.0) 2682 2683 @function.defun 2684 def func(resource1, resource2): 2685 with ops.device('/device:CPU:0'): 2686 result1 = resource1 * g2 2687 with ops.device('/device:GPU:0'): 2688 result2 = resource2 * c2 2689 return result1, result2 2690 2691 r1, r2 = func(c1, g1) 2692 self.assertEqual(r1.numpy(), 10.0) 2693 self.assertRegexpMatches(r1.backing_device, 'CPU') 2694 self.assertEqual(r2.numpy(), 21.0) 2695 self.assertRegexpMatches(r2.backing_device, 'GPU') 2696 2697 # Call with flipped inputs. Check that we look at resource's 2698 # device and reinstantiates the function when inputs' devices change. 2699 r1, r2 = func(g1, c1) 2700 self.assertEqual(r1.numpy(), 15.0) 2701 self.assertRegexpMatches(r1.backing_device, 'CPU') 2702 self.assertEqual(r2.numpy(), 14.0) 2703 self.assertRegexpMatches(r2.backing_device, 'GPU') 2704 2705 def testOutputResources(self): 2706 if not context.context().num_gpus(): 2707 self.skipTest('No GPUs found.') 2708 2709 with ops.device('/device:CPU:0'): 2710 c1 = resource_variable_ops.ResourceVariable(2.0) 2711 with ops.device('/device:GPU:0'): 2712 g1 = resource_variable_ops.ResourceVariable(3.0) 2713 2714 @function.defun 2715 def func(resource1, resource2): 2716 with ops.device('/device:CPU:0'): 2717 result1 = resource1 * 5 2718 with ops.device('/device:GPU:0'): 2719 result2 = resource2 * 7 2720 return result1, resource1.handle, result2, resource2.handle 2721 2722 r1, res1, r2, res2 = func(c1, g1) 2723 self.assertEqual(r1.numpy(), 10.0) 2724 self.assertRegexpMatches(r1.backing_device, 'CPU') 2725 self.assertEqual(r2.numpy(), 21.0) 2726 self.assertRegexpMatches(r2.backing_device, 'GPU') 2727 2728 def check_handle(handle, expected_value): 2729 self.assertRegexpMatches(handle.backing_device, 'CPU') 2730 tensor = gen_resource_variable_ops.read_variable_op( 2731 handle, dtypes.float32) 2732 self.assertEqual(tensor.numpy(), expected_value) 2733 2734 # Check that handles returned from functions are on CPU and an op using 2735 # the resource handle is correctly placed on the device backing the 2736 # resource. 2737 check_handle(res1, 2.0) 2738 check_handle(res2, 3.0) 2739 2740 # Call with flipped inputs to make sure the same the function is 2741 # reinstantiated and eager runtime does not mess up the device assignment 2742 # for ops consuming handles returned from defuns. 2743 r1, res1, r2, res2 = func(g1, c1) 2744 self.assertEqual(r1.numpy(), 15.0) 2745 self.assertRegexpMatches(r1.backing_device, 'CPU') 2746 self.assertEqual(r2.numpy(), 14.0) 2747 self.assertRegexpMatches(r2.backing_device, 'GPU') 2748 check_handle(res1, 3.0) 2749 check_handle(res2, 2.0) 2750 2751 def testComplexInputOutputDevicePattern(self): 2752 """Tests input/output mapping logic in partitioning.""" 2753 if not context.context().num_gpus(): 2754 self.skipTest('No GPUs found.') 2755 2756 with ops.device('/device:CPU:0'): 2757 rc0 = resource_variable_ops.ResourceVariable(2.0) 2758 rc1 = resource_variable_ops.ResourceVariable(3.0) 2759 cc0 = constant_op.constant(5.0) 2760 cc1 = constant_op.constant(7.0) 2761 with ops.device('/device:GPU:0'): 2762 rg0 = resource_variable_ops.ResourceVariable(11.0) 2763 rg1 = resource_variable_ops.ResourceVariable(13.0) 2764 cg0 = constant_op.constant(17.0) 2765 cg1 = constant_op.constant(19.0) 2766 2767 # Make sure tensors are on expected devices. 2768 for tensor in [cc0, cc1]: 2769 self.assertRegexpMatches(tensor.backing_device, 'CPU:0') 2770 for tensor in [cg0, cg1]: 2771 self.assertRegexpMatches(tensor.backing_device, 'GPU:0') 2772 2773 @function.defun 2774 def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1): 2775 with ops.device('/device:CPU:0'): 2776 m1 = rc0 * cg0 2777 with ops.device('/device:GPU:0'): 2778 m2 = rg0 * cc0 2779 2780 with ops.device('/device:CPU:0'): 2781 r1 = 1000.0 * m2 + rc1 * cg1 2782 with ops.device('/device:GPU:0'): 2783 r2 = 1000.0 * m1 + rg1 * cc1 2784 2785 return r1, r2, m2, m1 2786 2787 r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1) 2788 self.assertRegexpMatches(m1.backing_device, 'CPU') 2789 self.assertRegexpMatches(r1.backing_device, 'CPU') 2790 self.assertRegexpMatches(m2.backing_device, 'GPU') 2791 self.assertRegexpMatches(r2.backing_device, 'GPU') 2792 self.assertEqual(m1.numpy(), 34.0) 2793 self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0) 2794 self.assertEqual(m2.numpy(), 55.0) 2795 self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0) 2796 2797 def testArgumentPrunning(self): 2798 """Tests functions taking unnecessary arguments.""" 2799 if not context.context().num_gpus(): 2800 self.skipTest('No GPUs found.') 2801 2802 with ops.device('/device:CPU:0'): 2803 c1 = constant_op.constant(5.0) 2804 c2 = constant_op.constant(7.0) 2805 2806 with ops.device('/device:GPU:0'): 2807 g1 = constant_op.constant(11.0) 2808 g2 = constant_op.constant(13.0) 2809 g3 = constant_op.constant(17.0) 2810 2811 @function.defun 2812 def func(g1, g2, c1, g3, c2): # pylint: disable=unused-argument 2813 # arguments g1 and g2 are unused and can be pruned by grappler. 2814 return c1 * g3 * c2 2815 2816 result = func(g1, g2, c1, g3, c2) 2817 self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0) 2818 2819 2820if __name__ == '__main__': 2821 ops.enable_eager_execution( 2822 config=config_pb2.ConfigProto(device_count={'CPU': 4})) 2823 test.main() 2824