1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for eager execution using XLA.""" 16 17import numpy as np 18 19from tensorflow.compiler.tests import xla_test 20from tensorflow.core.protobuf import config_pb2 21from tensorflow.python.eager import backprop 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.eager import function 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import indexed_slices 28from tensorflow.python.framework import ops 29from tensorflow.python.layers import convolutional 30from tensorflow.python.layers import pooling 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import embedding_ops 34from tensorflow.python.ops import functional_ops 35from tensorflow.python.ops import gen_random_ops 36from tensorflow.python.ops import init_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import nn_ops 39from tensorflow.python.ops import resource_variable_ops 40from tensorflow.python.platform import googletest 41from tensorflow.python.training import adam 42 43 44class EagerTest(xla_test.XLATestCase): 45 46 def testBasic(self): 47 with self.test_scope(): 48 three = constant_op.constant(3) 49 five = constant_op.constant(5) 50 product = three * five 51 self.assertAllEqual(15, product) 52 53 def testGradientTape(self): 54 with self.test_scope(): 55 56 x = constant_op.constant(1.0) 57 y = constant_op.constant(10.0) 58 with backprop.GradientTape(persistent=True) as tape: 59 tape.watch(x) 60 tape.watch(y) 61 a = x + y + x * y 62 da_dx = tape.gradient(a, x) 63 da_dy = tape.gradient(a, y) 64 65 self.assertEqual(11.0, da_dx.numpy()) 66 self.assertEqual(2.0, da_dy.numpy()) 67 68 def testExecuteListOutputLen0(self): 69 with self.test_scope(): 70 empty = constant_op.constant([], dtype=dtypes.float32) 71 result = array_ops.unstack(empty, 0) 72 self.assertTrue(isinstance(result, list)) 73 self.assertEqual(0, len(result)) 74 75 def testExecuteListOutputLen1(self): 76 with self.test_scope(): 77 split_dim = constant_op.constant(1) 78 value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) 79 result = array_ops.split(value, 1, axis=split_dim) 80 self.assertTrue(isinstance(result, list)) 81 self.assertEqual(1, len(result)) 82 self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0]) 83 84 def testExecuteListOutputLen3(self): 85 with self.test_scope(): 86 split_dim = constant_op.constant(1) 87 value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) 88 result = array_ops.split(value, 3, axis=split_dim) 89 self.assertTrue(isinstance(result, list)) 90 self.assertEqual(3, len(result)) 91 self.assertAllEqual([[0], [3]], result[0]) 92 self.assertAllEqual([[1], [4]], result[1]) 93 self.assertAllEqual([[2], [5]], result[2]) 94 95 def testBasicGraph(self): 96 # Run some ops eagerly 97 with self.test_scope(): 98 three = constant_op.constant(3) 99 five = constant_op.constant(5) 100 product = three * five 101 self.assertAllEqual(15, product) 102 103 # Run some ops graphly 104 with context.graph_mode(), self.session(): 105 with self.test_scope(): 106 three = constant_op.constant(3) 107 five = constant_op.constant(5) 108 product = three * five 109 self.assertAllEqual(15, self.evaluate(product)) 110 111 def testDegenerateSlices(self): 112 with self.test_scope(): 113 npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3) 114 t = constant_op.constant(npt) 115 # degenerate by offering a forward interval with a negative stride 116 self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :]) 117 # degenerate with a reverse interval with a positive stride 118 self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :]) 119 # empty interval in every dimension 120 self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1]) 121 122 def testIdentity(self): 123 with self.test_scope(): 124 self.assertAllEqual(2, array_ops.identity(2)) 125 126 def testRandomOps(self): 127 with self.test_scope(): 128 tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32) 129 row0 = tensor[0].numpy() 130 row1 = tensor[1].numpy() 131 # It should be very unlikely to rng to generate two equal rows. 132 self.assertFalse((row0 == row1).all()) 133 134 def testIdentityOnVariable(self): 135 with self.test_scope(): 136 v = resource_variable_ops.ResourceVariable(True) 137 i = array_ops.identity(v) 138 self.assertAllEqual(True, i.numpy()) 139 140 def testAssignAddVariable(self): 141 with self.test_scope(): 142 v = resource_variable_ops.ResourceVariable(1.0) 143 v.assign_add(2.0) 144 self.assertEqual(3.0, v.numpy()) 145 146 def testReadAssignRead(self): 147 with self.test_scope(): 148 v = resource_variable_ops.ResourceVariable(1.0) 149 val1 = v.read_value() 150 v.assign_add(2.0) 151 val2 = v.read_value() 152 self.assertEqual(1.0, val1.numpy()) 153 self.assertEqual(3.0, val2.numpy()) 154 155 def testGradient(self): 156 def f(x): 157 return x 158 159 with self.test_scope(): 160 grad_fn = backprop.gradients_function(f) 161 self.assertAllEqual(2., grad_fn(1., dy=2.)[0]) 162 163 def testVariableGradient(self): 164 with self.test_scope(): 165 v0 = resource_variable_ops.ResourceVariable(1.0) 166 167 def f(): 168 x = v0 * v0 169 return x 170 171 grads = backprop.implicit_grad(f)() 172 self.assertEqual(2., grads[0][0].numpy()) 173 174 def testMultipleVariableReads(self): 175 # This test makes sure consecutive variable reads don't copy 176 # the underlying memory. 177 with self.test_scope(): 178 # Create 128MiB variables 179 var = resource_variable_ops.ResourceVariable( 180 array_ops.ones([32, 1024, 1024])) 181 182 # Read the same variable 100 times. If the underlying tensor 183 # is not copied, this is a trivial operation. If it is copied, 184 # this will eat over 13GB and OOM. 185 values = [] 186 for _ in range(100): 187 values.append(var.value()) 188 189 # The shape, shape_n, size, and rank are tested here because their 190 # execution kernels (as opposed to compilation only tf2xla kernels) 191 # are distincts from tf2xla kernels. 192 193 def testShape(self): 194 def const(value): 195 return array_ops.shape( 196 constant_op.constant(value)).numpy() 197 198 def ones(value): 199 return array_ops.shape( 200 array_ops.ones(value)).numpy() 201 202 with self.test_scope(): 203 # Shapes of directly constructed tensors 204 self.assertAllEqual([], const(3)) 205 self.assertAllEqual([3], const([1.0, 2.0, 3.0])) 206 self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]])) 207 self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]])) 208 209 # Shapes of tensors created by op running on device 210 # We make this distinction because directly constructed tensors 211 # are treated differently in a few places that can influence shape: 212 # - they always have on_host_tensor 213 # - they and their shapes can be cached 214 # - they end up on device via a copy, instead of as program output 215 self.assertAllEqual([], ones([])) 216 self.assertAllEqual([3], ones([3])) 217 self.assertAllEqual([2, 2], ones([2, 2])) 218 self.assertAllEqual([2, 1, 2], ones([2, 1, 2])) 219 220 def testShapeN(self): 221 with self.test_scope(): 222 # Shapes of directly constructed tensors 223 shapes = array_ops.shape_n([ 224 constant_op.constant(1.0), 225 constant_op.constant([1.0, 2.0, 3.0]), 226 constant_op.constant([[1.0, 2.0], [3.0, 4.0]])]) 227 self.assertAllEqual( 228 [[], [3], [2, 2]], 229 [x.numpy().tolist() for x in shapes]) 230 231 # Shapes of tensors created by op running on device 232 shapes = array_ops.shape_n([ 233 array_ops.ones([]), 234 array_ops.ones([3]), 235 array_ops.ones([2, 2])]) 236 self.assertAllEqual( 237 [[], [3], [2, 2]], 238 [x.numpy().tolist() for x in shapes]) 239 240 def testSize(self): 241 with self.test_scope(): 242 self.assertEqual( 243 1, array_ops.size(constant_op.constant(1.0)).numpy()) 244 self.assertEqual( 245 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy()) 246 self.assertEqual( 247 4, array_ops.size( 248 constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) 249 250 def testRank(self): 251 with self.test_scope(): 252 self.assertEqual( 253 0, array_ops.rank(constant_op.constant(1.0)).numpy()) 254 self.assertEqual( 255 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy()) 256 self.assertEqual( 257 2, array_ops.rank( 258 constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) 259 260 def testAdam(self): 261 with self.test_scope(): 262 optimizer = adam.AdamOptimizer(0.1) 263 x = resource_variable_ops.ResourceVariable(10.0) 264 with backprop.GradientTape() as tape: 265 y = x * x 266 dy_dx = tape.gradient(y, x) 267 optimizer.apply_gradients([(dy_dx, x)]) 268 self.assertAlmostEqual(9.9, x.numpy(), places=3) 269 270 def testAdamSparse(self): 271 with ops.device('/cpu:0'): 272 # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates 273 # are not implemented on TPU. 274 embedding_matrix = resource_variable_ops.ResourceVariable( 275 array_ops.ones([3, 2])) 276 277 with self.test_scope(): 278 with backprop.GradientTape() as tape: 279 embedding = embedding_ops.embedding_lookup(embedding_matrix, [1]) 280 y = math_ops.reduce_sum(embedding) 281 dy_dx = tape.gradient(y, embedding_matrix) 282 self.assertIsInstance(dy_dx, indexed_slices.IndexedSlices) 283 optimizer = adam.AdamOptimizer(0.1) 284 # The gradient application operations will run on CPU because optimizer 285 # updates are always collocated with the variable. 286 optimizer.apply_gradients([(dy_dx, embedding_matrix)]) 287 288 # This assign_add will run on CPU because when an input to an 289 # operation is a resource, this operation is placed on the resource's 290 # device by the eager runtime. 291 embedding_matrix.assign_add(array_ops.ones([3, 2])) 292 293 self.assertAllClose([[2.0, 2.0], 294 [1.9, 1.9], 295 [2.0, 2.0]], embedding_matrix.numpy()) 296 297 298class EagerFunctionTest(xla_test.XLATestCase): 299 300 def testBasic(self): 301 with self.test_scope(): 302 matmul = function.defun(math_ops.matmul) 303 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 304 sq = matmul(t, t, transpose_a=True) 305 self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) 306 307 def testConv(self): 308 if 'GPU' in self.device: 309 # TODO(b/32333178) 310 self.skipTest('Current implementation of RandomStandardNormal kernel ' 311 'is very slow on GPU, and has been denylisted.') 312 with self.test_scope(): 313 data_format = 'channels_last' 314 conv = convolutional.Conv2D( 315 filters=1, kernel_size=2, padding='VALID', 316 data_format=data_format, activation=nn_ops.relu, 317 kernel_initializer=init_ops.ones_initializer(), 318 bias_initializer=init_ops.zeros_initializer()) 319 pool = pooling.MaxPooling2D(2, 2, data_format=data_format) 320 321 def model(x): 322 x = conv(x) 323 return pool(x) 324 model = function.defun(model) 325 326 x = array_ops.ones([1, 4, 4, 1]) 327 y = model(x) 328 self.assertAllEqual(y.numpy(), [[[[4.]]]]) 329 330 def testReadVariable(self): 331 with self.test_scope(): 332 v = resource_variable_ops.ResourceVariable(1.0) 333 334 @function.defun 335 def f(): 336 return v.read_value() 337 338 var = f() 339 self.assertEqual(1.0, var.numpy()) 340 341 def testResourceVariableNoInlineReadWrite(self): 342 with self.test_scope(): 343 v = resource_variable_ops.ResourceVariable(1.0) 344 w = resource_variable_ops.ResourceVariable(0.0) 345 346 @function.defun_with_attributes(attributes={'_noinline': True}) 347 def g(x): 348 w.assign(w.read_value() + x) 349 return v.read_value() + x * w.read_value() 350 351 @function.defun_with_attributes(attributes={'_noinline': True}) 352 def f(): 353 return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0) 354 355 # 1 + 1*1 + 1 + 2*3 + 1 + 3*6 + 1 + 4*10 + 1 + 5*15 356 self.assertEqual(145.0, f().numpy()) 357 self.assertEqual(15.0, w.read_value().numpy()) 358 359 def testResourceVariableNoInlineReadOnly(self): 360 with self.test_scope(): 361 v = resource_variable_ops.ResourceVariable(10.0) 362 363 @function.defun_with_attributes(attributes={'_noinline': True}) 364 def g(): 365 return v.read_value() 366 367 @function.defun_with_attributes(attributes={'_noinline': True}) 368 def f(): 369 return g() + g() + g() + g() + g() 370 371 self.assertEqual(50.0, f().numpy()) 372 373 def testResourceVariableNoInlineWriteOnly(self): 374 with self.test_scope(): 375 v = resource_variable_ops.ResourceVariable(0.0) 376 377 @function.defun_with_attributes(attributes={'_noinline': True}) 378 def g(x): 379 v.assign(x) 380 381 @function.defun_with_attributes(attributes={'_noinline': True}) 382 def f(): 383 g(1.0) 384 g(2.0) 385 g(3.0) 386 g(4.0) 387 g(5.0) 388 389 f() 390 self.assertEqual(5.0, v.read_value().numpy()) 391 392 def testUpdateVariable(self): 393 with self.test_scope(): 394 v = resource_variable_ops.ResourceVariable(1.0) 395 396 def f(v): 397 v.assign_add(1.0) 398 return v 399 400 f = function.defun(f) 401 402 var = f(v) 403 self.assertEqual(2.0, var.numpy()) 404 405 def testReturnResourceHandle(self): 406 with self.test_scope(): 407 v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]]) 408 409 def f(v): 410 return v.handle 411 412 f = function.defun(f) 413 handle = f(v) 414 self.assertAllEqual(v.numpy(), 415 resource_variable_ops.read_variable_op( 416 handle, dtypes.float32).numpy()) 417 418 def testReturnMultipleResourceHandles(self): 419 with self.test_scope(): 420 v1 = resource_variable_ops.ResourceVariable(1.25) 421 v2 = resource_variable_ops.ResourceVariable(2.0) 422 423 def f(v): 424 return v.handle, 3.0 * v, v2.handle, v + v2 425 426 f = function.defun(f) 427 v1_handle, v1_times_3, v2_handle, variable_sum = f(v1) 428 self.assertAllEqual(v1.numpy(), 429 resource_variable_ops.read_variable_op( 430 v1_handle, dtypes.float32).numpy()) 431 self.assertEqual(3.75, v1_times_3.numpy()) 432 self.assertAllEqual(v2.numpy(), 433 resource_variable_ops.read_variable_op( 434 v2_handle, dtypes.float32).numpy()) 435 self.assertEqual(3.25, variable_sum.numpy()) 436 437 def testAllArgumentKinds(self): 438 """Test a complex function that takes different argument kinds. 439 440 tf2xla machinery that translates, compiles, and runs defuns 441 classifies arguments into: compile-time constants, regular tensors, 442 and resources. This test creates a function with a mix of all these 443 kinds. Moreover, the order of function arguments is intentionally mixed up. 444 445 This also tests the case when the same argument is a compile-time constant 446 as well as used in an operation that normally expects its inputs to be 447 in device memory - addition in this case. 448 """ 449 with self.test_scope(): 450 def foo(c1, r1, v1, c2, v2, r2): 451 # c1 and c2 are compile-time constants 452 # r1 and r2 are regular tensors 453 # v1 and v2 are resource variables 454 a = c1 + r1 455 b = math_ops.cast(c2, dtypes.float32) + v2 456 c = array_ops.slice(v1, c1, c2) 457 d = r2 * v2 458 return a, b, c, d 459 460 foo = function.defun(foo) 461 462 c1 = [0, 0] 463 c2 = array_ops.ones([2], dtype=dtypes.int32) 464 465 r1 = array_ops.ones([2]) 466 r2 = [[2., 2.], [3., 3.]] 467 468 v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]]) 469 v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]]) 470 471 a, b, c, d = foo(c1, r1, v1, c2, v2, r2) 472 473 self.assertAllEqual([1, 1], a.numpy()) 474 self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy()) 475 self.assertAllEqual([[1.]], c.numpy()) 476 self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) 477 478 def testDefunInGradientTape(self): 479 with self.test_scope(): 480 v0 = resource_variable_ops.ResourceVariable(5.0) 481 482 @function.defun 483 def f(x): 484 x = v0 * v0 * x 485 return x 486 487 x = constant_op.constant(3.0) 488 with backprop.GradientTape() as tape: 489 y = f(x) 490 dy = tape.gradient(y, v0) 491 492 self.assertEqual(75, y.numpy()) 493 self.assertEqual(30, dy.numpy()) 494 495 def testGradientTapeInDefun(self): 496 with self.test_scope(): 497 v0 = resource_variable_ops.ResourceVariable(5.0) 498 499 @function.defun 500 def f(): 501 x = constant_op.constant(1.0) 502 with backprop.GradientTape() as tape: 503 y = v0 * x 504 dy = tape.gradient(y, v0) 505 return dy 506 507 dy = f() 508 self.assertEqual(1.0, dy.numpy()) 509 510 def testSliceInDefun(self): 511 with self.test_scope(): 512 513 @function.defun 514 def f(x, y): 515 return x[0::2, y:, ...] 516 517 x = array_ops.ones([2, 3, 4], dtype=dtypes.float32) 518 y = array_ops.ones([], dtype=dtypes.int32) 519 with backprop.GradientTape() as tape: 520 tape.watch(x) 521 tape.watch(y) 522 z = f(x, y) 523 dz = tape.gradient(z, x) 524 525 self.assertAllEqual(np.ones([1, 2, 4]), z.numpy()) 526 self.assertAllEqual((2, 3, 4), dz.shape.as_list()) 527 528 def testNestedDefun(self): 529 with self.test_scope(): 530 531 @function.defun 532 def times_two(x): 533 return 2. * x 534 535 @function.defun 536 def two_x_plus_1(x): 537 return times_two(x) + 1. 538 539 x = constant_op.constant([2., 3., 4.]) 540 y = two_x_plus_1(x) 541 self.assertAllEqual([5., 7., 9.], y.numpy()) 542 543 def testNestedDefunWithVariable(self): 544 with self.test_scope(): 545 v0 = resource_variable_ops.ResourceVariable(5.0) 546 547 @function.defun 548 def g(x): 549 x = v0 * x 550 return x 551 552 @function.defun 553 def f(x): 554 x = g(v0 * x) 555 return x 556 557 x = constant_op.constant(3.0) 558 y = f(x) 559 560 self.assertEqual(75.0, y.numpy()) 561 562 def testNestedDefunInGradientTape(self): 563 with self.test_scope(): 564 v0 = resource_variable_ops.ResourceVariable(5.0) 565 566 @function.defun 567 def g(x): 568 x = v0 * x 569 return x 570 571 @function.defun 572 def f(x): 573 x = g(v0 * x) 574 return x 575 576 x = constant_op.constant(3.0) 577 with backprop.GradientTape() as tape: 578 y = f(x) 579 dy = tape.gradient(y, v0) 580 581 self.assertEqual(75, y.numpy()) 582 self.assertEqual(30, dy.numpy()) 583 584 def testNestedDefunInGradientTapeDifferentVars(self): 585 with self.test_scope(): 586 v0 = resource_variable_ops.ResourceVariable(5.0) 587 v1 = resource_variable_ops.ResourceVariable(3.0) 588 589 @function.defun 590 def g(x): 591 x = v1 * x 592 return x 593 594 @function.defun 595 def f(x): 596 x = g(v0 * x) 597 return x 598 599 x = constant_op.constant(3.0) 600 with backprop.GradientTape(persistent=True) as tape: 601 y = f(x) 602 dy_v0 = tape.gradient(y, v0) 603 dy_v1 = tape.gradient(y, v1) 604 605 self.assertEqual(45, y.numpy()) 606 self.assertEqual(9, dy_v0.numpy()) 607 self.assertEqual(15, dy_v1.numpy()) 608 609 def testWhileInDefun(self): 610 with self.test_scope(): 611 @def_function.function 612 def f(start): 613 c = lambda x: math_ops.less(x, 13.0) 614 b = lambda x: math_ops.add(x, 1.0) 615 return control_flow_ops.while_loop(c, b, [start]) 616 617 y = f(constant_op.constant(3.0)) 618 self.assertEqual(13.0, y.numpy()) 619 620 def testAutoGraphWhileInDefun(self): 621 with self.test_scope(): 622 @def_function.function 623 def f(start): 624 x = start 625 while x < 13.0: 626 x += 1.0 627 return x 628 629 y = f(constant_op.constant(3.0)) 630 self.assertEqual(13.0, y.numpy()) 631 632 def testCondInDefun(self): 633 with self.test_scope(): 634 @def_function.function 635 def f(pred, value): 636 fn1 = lambda: math_ops.add(value, 1.0) 637 fn2 = lambda: math_ops.subtract(value, 1.0) 638 return control_flow_ops.cond(pred, fn1, fn2) 639 640 plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) 641 minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) 642 self.assertEqual(11.0, plus_one.numpy()) 643 self.assertEqual(9.0, minus_one.numpy()) 644 645 def testAutoGraphCondInDefun(self): 646 with self.test_scope(): 647 @def_function.function 648 def f(pred, value): 649 if pred: 650 return value + 1.0 651 else: 652 return value - 1.0 653 654 plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) 655 minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) 656 self.assertEqual(11.0, plus_one.numpy()) 657 self.assertEqual(9.0, minus_one.numpy()) 658 659 def testScanInDefun(self): 660 with self.test_scope(): 661 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='data') 662 v = constant_op.constant(2.0, name='v') 663 664 @def_function.function 665 def f(y): 666 # pylint: disable=unnecessary-lambda 667 return functional_ops.scan( 668 lambda a, x: math_ops.multiply(a, x), y, initializer=v) 669 # pylint: enable=unnecessary-lambda 670 671 r = f(elems) 672 self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) 673 674 def testFeedDeviceMemoryToOpExpectingHostMemory(self): 675 @function.defun 676 def f(dims, value): 677 return array_ops.fill(dims, value) 678 679 with self.test_scope(): 680 x = constant_op.constant([4], dtype=dtypes.int64) 681 682 y = f(x, 3) 683 self.assertAllEqual([3, 3, 3, 3], y) 684 685 def testRequestNotToCompile(self): 686 with self.test_scope(): 687 def f(x): 688 with ops.device('device:CPU:0'): 689 y = 2.0 * x 690 return x, y 691 692 wholly_compiled_f = def_function.function(f) 693 op_by_op_f = def_function.function(f, jit_compile=False) 694 695 x = array_ops.identity([0.0, 2.0], name='data') 696 697 # When function is wholly compiled, all outputs will be on the 698 # device on which it is run. 699 r_x, r_y = wholly_compiled_f(x) 700 self.assertAllEqual([0.0, 2.0], r_x) 701 self.assertAllEqual([0.0, 4.0], r_y) 702 if context.executing_eagerly(): 703 # backing_device is only available for eager tensors. 704 self.assertRegex(r_x.backing_device, self.device) 705 self.assertRegex(r_y.backing_device, self.device) 706 707 # When function is executed op-by-op, requested devices will be 708 # respected. 709 r_x, r_y = op_by_op_f(x) 710 self.assertAllEqual([0.0, 2.0], r_x) 711 self.assertAllEqual([0.0, 4.0], r_y) 712 if context.executing_eagerly(): 713 # backing_device is only available for eager tensors. 714 self.assertRegex(r_x.backing_device, self.device) 715 self.assertRegex(r_y.backing_device, 'device:CPU:0') 716 717 718class ExcessivePaddingTest(xla_test.XLATestCase): 719 """Test that eager execution works with TPU flattened tensors. 720 721 Tensors that would normally be excessively padded when written 722 to TPU memory are reshaped to 1-D flat tensors. 723 724 This test case verifies that such tensors work with eager execution. 725 726 The flattening currently only happens on TPU, but tests should work 727 fine with all backends as flattening is transparent. 728 """ 729 730 def testFromConstant(self): 731 with self.test_scope(): 732 # Create constant of shape [100, 2, 1]. This tensor would be 733 # excessively padded on TPU. 734 tensor = constant_op.constant(100 * [[[10.0], [2.0]]]) 735 # Use reduce_sum since it requires correctly working with 736 # a particular dimension. 737 reduced = math_ops.reduce_sum(tensor, axis=1) 738 self.assertAllEqual(100 * [[12.0]], reduced) 739 740 def testFromOperation(self): 741 with self.test_scope(): 742 tensor = array_ops.ones([3, 100, 2, 2]) 743 reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3]) 744 self.assertAllEqual(100 * [12.0], reduced) 745 746 def testAsFunctionInput(self): 747 with self.test_scope(): 748 749 @function.defun 750 def f(x): 751 return math_ops.reduce_sum(x, axis=2) 752 753 tensor = constant_op.constant(100 * [[[10.0, 2.0]]]) 754 reduced = f(tensor) 755 self.assertAllEqual(100 * [[12.0]], reduced) 756 757 def testAsFunctionOutput(self): 758 with self.test_scope(): 759 760 @function.defun 761 def f(x): 762 return x * constant_op.constant(100 * [[[10.0, 2.0]]]) 763 764 y = f(3) 765 reduced = math_ops.reduce_sum(y, axis=2) 766 self.assertAllEqual(100 * [[36.0]], reduced) 767 768 769def multiple_tpus(): 770 devices = context.context().devices() 771 return len([d for d in devices if 'device:TPU:' in d]) > 1 772 773 774class MultiDeviceTest(xla_test.XLATestCase): 775 """Test running TPU computation on more than one core.""" 776 777 def testBasic(self): 778 if not multiple_tpus(): 779 self.skipTest('MultiDeviceTest requires multiple TPU devices.') 780 781 # Compute 10 on TPU core 0 782 with ops.device('device:TPU:0'): 783 two = constant_op.constant(2) 784 five = constant_op.constant(5) 785 ten = two * five 786 self.assertAllEqual(10, ten) 787 788 # Compute 6 on TPU core 1 789 with ops.device('device:TPU:1'): 790 two = constant_op.constant(2) 791 three = constant_op.constant(3) 792 six = two * three 793 self.assertAllEqual(6, six) 794 795 # Copy 10 and 6 to CPU and sum them 796 self.assertAllEqual(16, ten + six) 797 798 799if __name__ == '__main__': 800 ops.enable_eager_execution( 801 config=config_pb2.ConfigProto(log_device_placement=True)) 802 googletest.main() 803