1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for eager execution using XLA.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.compiler.tests import xla_test 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.layers import convolutional 33from tensorflow.python.layers import pooling 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import embedding_ops 37from tensorflow.python.ops import functional_ops 38from tensorflow.python.ops import gen_random_ops 39from tensorflow.python.ops import init_ops 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import nn_ops 42from tensorflow.python.ops import resource_variable_ops 43from tensorflow.python.platform import googletest 44from tensorflow.python.training import adam 45 46 47class EagerTest(xla_test.XLATestCase): 48 49 def testBasic(self): 50 with self.test_scope(): 51 three = constant_op.constant(3) 52 five = constant_op.constant(5) 53 product = three * five 54 self.assertAllEqual(15, product) 55 56 def testGradientTape(self): 57 with self.test_scope(): 58 59 x = constant_op.constant(1.0) 60 y = constant_op.constant(10.0) 61 with backprop.GradientTape(persistent=True) as tape: 62 tape.watch(x) 63 tape.watch(y) 64 a = x + y + x * y 65 da_dx = tape.gradient(a, x) 66 da_dy = tape.gradient(a, y) 67 68 self.assertEqual(11.0, da_dx.numpy()) 69 self.assertEqual(2.0, da_dy.numpy()) 70 71 def testExecuteListOutputLen0(self): 72 with self.test_scope(): 73 empty = constant_op.constant([], dtype=dtypes.float32) 74 result = array_ops.unstack(empty, 0) 75 self.assertTrue(isinstance(result, list)) 76 self.assertEqual(0, len(result)) 77 78 def testExecuteListOutputLen1(self): 79 with self.test_scope(): 80 split_dim = constant_op.constant(1) 81 value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) 82 result = array_ops.split(value, 1, axis=split_dim) 83 self.assertTrue(isinstance(result, list)) 84 self.assertEqual(1, len(result)) 85 self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0]) 86 87 def testExecuteListOutputLen3(self): 88 with self.test_scope(): 89 split_dim = constant_op.constant(1) 90 value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) 91 result = array_ops.split(value, 3, axis=split_dim) 92 self.assertTrue(isinstance(result, list)) 93 self.assertEqual(3, len(result)) 94 self.assertAllEqual([[0], [3]], result[0]) 95 self.assertAllEqual([[1], [4]], result[1]) 96 self.assertAllEqual([[2], [5]], result[2]) 97 98 def testBasicGraph(self): 99 # Run some ops eagerly 100 with self.test_scope(): 101 three = constant_op.constant(3) 102 five = constant_op.constant(5) 103 product = three * five 104 self.assertAllEqual(15, product) 105 106 # Run some ops graphly 107 with context.graph_mode(), self.session(): 108 with self.test_scope(): 109 three = constant_op.constant(3) 110 five = constant_op.constant(5) 111 product = three * five 112 self.assertAllEqual(15, self.evaluate(product)) 113 114 def testDegenerateSlices(self): 115 with self.test_scope(): 116 npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3) 117 t = constant_op.constant(npt) 118 # degenerate by offering a forward interval with a negative stride 119 self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :]) 120 # degenerate with a reverse interval with a positive stride 121 self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :]) 122 # empty interval in every dimension 123 self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1]) 124 125 def testIdentity(self): 126 with self.test_scope(): 127 self.assertAllEqual(2, array_ops.identity(2)) 128 129 def testRandomOps(self): 130 with self.test_scope(): 131 tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32) 132 row0 = tensor[0].numpy() 133 row1 = tensor[1].numpy() 134 # It should be very unlikely to rng to generate two equal rows. 135 self.assertFalse((row0 == row1).all()) 136 137 def testIdentityOnVariable(self): 138 with self.test_scope(): 139 v = resource_variable_ops.ResourceVariable(True) 140 i = array_ops.identity(v) 141 self.assertAllEqual(True, i.numpy()) 142 143 def testAssignAddVariable(self): 144 with self.test_scope(): 145 v = resource_variable_ops.ResourceVariable(1.0) 146 v.assign_add(2.0) 147 self.assertEqual(3.0, v.numpy()) 148 149 def testReadAssignRead(self): 150 with self.test_scope(): 151 v = resource_variable_ops.ResourceVariable(1.0) 152 val1 = v.read_value() 153 v.assign_add(2.0) 154 val2 = v.read_value() 155 self.assertEqual(1.0, val1.numpy()) 156 self.assertEqual(3.0, val2.numpy()) 157 158 def testGradient(self): 159 def f(x): 160 return x 161 162 with self.test_scope(): 163 grad_fn = backprop.gradients_function(f) 164 self.assertAllEqual(2., grad_fn(1., dy=2.)[0]) 165 166 def testVariableGradient(self): 167 with self.test_scope(): 168 v0 = resource_variable_ops.ResourceVariable(1.0) 169 170 def f(): 171 x = v0 * v0 172 return x 173 174 grads = backprop.implicit_grad(f)() 175 self.assertEqual(2., grads[0][0].numpy()) 176 177 def testMultipleVariableReads(self): 178 # This test makes sure consecutive variable reads don't copy 179 # the underlying memory. 180 with self.test_scope(): 181 # Create 128MiB variables 182 var = resource_variable_ops.ResourceVariable( 183 array_ops.ones([32, 1024, 1024])) 184 185 # Read the same variable 100 times. If the underlying tensor 186 # is not copied, this is a trivial operation. If it is copied, 187 # this will eat over 13GB and OOM. 188 values = [] 189 for _ in range(100): 190 values.append(var.value()) 191 192 # The shape, shape_n, size, and rank are tested here because their 193 # execution kernels (as opposed to compilation only tf2xla kernels) 194 # are distincts from tf2xla kernels. 195 196 def testShape(self): 197 def const(value): 198 return array_ops.shape( 199 constant_op.constant(value)).numpy() 200 201 def ones(value): 202 return array_ops.shape( 203 array_ops.ones(value)).numpy() 204 205 with self.test_scope(): 206 # Shapes of directly constructed tensors 207 self.assertAllEqual([], const(3)) 208 self.assertAllEqual([3], const([1.0, 2.0, 3.0])) 209 self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]])) 210 self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]])) 211 212 # Shapes of tensors created by op running on device 213 # We make this distinction because directly constructed tensors 214 # are treated differently in a few places that can influence shape: 215 # - they always have on_host_tensor 216 # - they and their shapes can be cached 217 # - they end up on device via a copy, instead of as program output 218 self.assertAllEqual([], ones([])) 219 self.assertAllEqual([3], ones([3])) 220 self.assertAllEqual([2, 2], ones([2, 2])) 221 self.assertAllEqual([2, 1, 2], ones([2, 1, 2])) 222 223 def testShapeN(self): 224 with self.test_scope(): 225 # Shapes of directly constructed tensors 226 shapes = array_ops.shape_n([ 227 constant_op.constant(1.0), 228 constant_op.constant([1.0, 2.0, 3.0]), 229 constant_op.constant([[1.0, 2.0], [3.0, 4.0]])]) 230 self.assertAllEqual( 231 [[], [3], [2, 2]], 232 [x.numpy().tolist() for x in shapes]) 233 234 # Shapes of tensors created by op running on device 235 shapes = array_ops.shape_n([ 236 array_ops.ones([]), 237 array_ops.ones([3]), 238 array_ops.ones([2, 2])]) 239 self.assertAllEqual( 240 [[], [3], [2, 2]], 241 [x.numpy().tolist() for x in shapes]) 242 243 def testSize(self): 244 with self.test_scope(): 245 self.assertEqual( 246 1, array_ops.size(constant_op.constant(1.0)).numpy()) 247 self.assertEqual( 248 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy()) 249 self.assertEqual( 250 4, array_ops.size( 251 constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) 252 253 def testRank(self): 254 with self.test_scope(): 255 self.assertEqual( 256 0, array_ops.rank(constant_op.constant(1.0)).numpy()) 257 self.assertEqual( 258 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy()) 259 self.assertEqual( 260 2, array_ops.rank( 261 constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) 262 263 def testAdam(self): 264 with self.test_scope(): 265 optimizer = adam.AdamOptimizer(0.1) 266 x = resource_variable_ops.ResourceVariable(10.0) 267 with backprop.GradientTape() as tape: 268 y = x * x 269 dy_dx = tape.gradient(y, x) 270 optimizer.apply_gradients([(dy_dx, x)]) 271 self.assertAlmostEqual(9.9, x.numpy(), places=3) 272 273 def testAdamSparse(self): 274 with ops.device('/cpu:0'): 275 # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates 276 # are not implemented on TPU. 277 embedding_matrix = resource_variable_ops.ResourceVariable( 278 array_ops.ones([3, 2])) 279 280 with self.test_scope(): 281 with backprop.GradientTape() as tape: 282 embedding = embedding_ops.embedding_lookup(embedding_matrix, [1]) 283 y = math_ops.reduce_sum(embedding) 284 dy_dx = tape.gradient(y, embedding_matrix) 285 self.assertIsInstance(dy_dx, ops.IndexedSlices) 286 optimizer = adam.AdamOptimizer(0.1) 287 # The gradient application operations will run on CPU because optimizer 288 # updates are always collocated with the variable. 289 optimizer.apply_gradients([(dy_dx, embedding_matrix)]) 290 291 # This assign_add will run on CPU because when an input to an 292 # operation is a resource, this operation is placed on the resource's 293 # device by the eager runtime. 294 embedding_matrix.assign_add(array_ops.ones([3, 2])) 295 296 self.assertAllClose([[2.0, 2.0], 297 [1.9, 1.9], 298 [2.0, 2.0]], embedding_matrix.numpy()) 299 300 301class EagerFunctionTest(xla_test.XLATestCase): 302 303 def testBasic(self): 304 with self.test_scope(): 305 matmul = function.defun(math_ops.matmul) 306 t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) 307 sq = matmul(t, t, transpose_a=True) 308 self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) 309 310 def testConv(self): 311 if 'GPU' in self.device: 312 # TODO(b/32333178) 313 self.skipTest('Current implementation of RandomStandardNormal kernel ' 314 'is very slow on GPU, and has been denylisted.') 315 with self.test_scope(): 316 data_format = 'channels_last' 317 conv = convolutional.Conv2D( 318 filters=1, kernel_size=2, padding='VALID', 319 data_format=data_format, activation=nn_ops.relu, 320 kernel_initializer=init_ops.ones_initializer(), 321 bias_initializer=init_ops.zeros_initializer()) 322 pool = pooling.MaxPooling2D(2, 2, data_format=data_format) 323 324 def model(x): 325 x = conv(x) 326 return pool(x) 327 model = function.defun(model) 328 329 x = array_ops.ones([1, 4, 4, 1]) 330 y = model(x) 331 self.assertAllEqual(y.numpy(), [[[[4.]]]]) 332 333 def testReadVariable(self): 334 with self.test_scope(): 335 v = resource_variable_ops.ResourceVariable(1.0) 336 337 @function.defun 338 def f(): 339 return v.read_value() 340 341 var = f() 342 self.assertEqual(1.0, var.numpy()) 343 344 def testResourceVariableNoInlineReadWrite(self): 345 with self.test_scope(): 346 v = resource_variable_ops.ResourceVariable(1.0) 347 w = resource_variable_ops.ResourceVariable(0.0) 348 349 @function.defun_with_attributes(attributes={'_noinline': True}) 350 def g(x): 351 w.assign(w.read_value() + x) 352 return v.read_value() + x * w.read_value() 353 354 @function.defun_with_attributes(attributes={'_noinline': True}) 355 def f(): 356 return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0) 357 358 # 1 + 1*1 + 1 + 2*3 + 1 + 3*6 + 1 + 4*10 + 1 + 5*15 359 self.assertEqual(145.0, f().numpy()) 360 self.assertEqual(15.0, w.read_value().numpy()) 361 362 def testResourceVariableNoInlineReadOnly(self): 363 with self.test_scope(): 364 v = resource_variable_ops.ResourceVariable(10.0) 365 366 @function.defun_with_attributes(attributes={'_noinline': True}) 367 def g(): 368 return v.read_value() 369 370 @function.defun_with_attributes(attributes={'_noinline': True}) 371 def f(): 372 return g() + g() + g() + g() + g() 373 374 self.assertEqual(50.0, f().numpy()) 375 376 def testResourceVariableNoInlineWriteOnly(self): 377 with self.test_scope(): 378 v = resource_variable_ops.ResourceVariable(0.0) 379 380 @function.defun_with_attributes(attributes={'_noinline': True}) 381 def g(x): 382 v.assign(x) 383 384 @function.defun_with_attributes(attributes={'_noinline': True}) 385 def f(): 386 g(1.0) 387 g(2.0) 388 g(3.0) 389 g(4.0) 390 g(5.0) 391 392 f() 393 self.assertEqual(5.0, v.read_value().numpy()) 394 395 def testUpdateVariable(self): 396 with self.test_scope(): 397 v = resource_variable_ops.ResourceVariable(1.0) 398 399 def f(v): 400 v.assign_add(1.0) 401 return v 402 403 f = function.defun(f) 404 405 var = f(v) 406 self.assertEqual(2.0, var.numpy()) 407 408 def testReturnResourceHandle(self): 409 with self.test_scope(): 410 v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]]) 411 412 def f(v): 413 return v.handle 414 415 f = function.defun(f) 416 handle = f(v) 417 self.assertAllEqual(v.numpy(), 418 resource_variable_ops.read_variable_op( 419 handle, dtypes.float32).numpy()) 420 421 def testReturnMultipleResourceHandles(self): 422 with self.test_scope(): 423 v1 = resource_variable_ops.ResourceVariable(1.25) 424 v2 = resource_variable_ops.ResourceVariable(2.0) 425 426 def f(v): 427 return v.handle, 3.0 * v, v2.handle, v + v2 428 429 f = function.defun(f) 430 v1_handle, v1_times_3, v2_handle, variable_sum = f(v1) 431 self.assertAllEqual(v1.numpy(), 432 resource_variable_ops.read_variable_op( 433 v1_handle, dtypes.float32).numpy()) 434 self.assertEqual(3.75, v1_times_3.numpy()) 435 self.assertAllEqual(v2.numpy(), 436 resource_variable_ops.read_variable_op( 437 v2_handle, dtypes.float32).numpy()) 438 self.assertEqual(3.25, variable_sum.numpy()) 439 440 def testAllArgumentKinds(self): 441 """Test a complex function that takes different argument kinds. 442 443 tf2xla machinery that translates, compiles, and runs defuns 444 classifies arguments into: compile-time constants, regular tensors, 445 and resources. This test creates a function with a mix of all these 446 kinds. Moreover, the order of function arguments is intentionally mixed up. 447 448 This also tests the case when the same argument is a compile-time constant 449 as well as used in an operation that normally expects its inputs to be 450 in device memory - addition in this case. 451 """ 452 with self.test_scope(): 453 def foo(c1, r1, v1, c2, v2, r2): 454 # c1 and c2 are compile-time constants 455 # r1 and r2 are regular tensors 456 # v1 and v2 are resource variables 457 a = c1 + r1 458 b = math_ops.cast(c2, dtypes.float32) + v2 459 c = array_ops.slice(v1, c1, c2) 460 d = r2 * v2 461 return a, b, c, d 462 463 foo = function.defun(foo) 464 465 c1 = [0, 0] 466 c2 = array_ops.ones([2], dtype=dtypes.int32) 467 468 r1 = array_ops.ones([2]) 469 r2 = [[2., 2.], [3., 3.]] 470 471 v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]]) 472 v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]]) 473 474 a, b, c, d = foo(c1, r1, v1, c2, v2, r2) 475 476 self.assertAllEqual([1, 1], a.numpy()) 477 self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy()) 478 self.assertAllEqual([[1.]], c.numpy()) 479 self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) 480 481 def testDefunInGradientTape(self): 482 with self.test_scope(): 483 v0 = resource_variable_ops.ResourceVariable(5.0) 484 485 @function.defun 486 def f(x): 487 x = v0 * v0 * x 488 return x 489 490 x = constant_op.constant(3.0) 491 with backprop.GradientTape() as tape: 492 y = f(x) 493 dy = tape.gradient(y, v0) 494 495 self.assertEqual(75, y.numpy()) 496 self.assertEqual(30, dy.numpy()) 497 498 def testGradientTapeInDefun(self): 499 with self.test_scope(): 500 v0 = resource_variable_ops.ResourceVariable(5.0) 501 502 @function.defun 503 def f(): 504 x = constant_op.constant(1.0) 505 with backprop.GradientTape() as tape: 506 y = v0 * x 507 dy = tape.gradient(y, v0) 508 return dy 509 510 dy = f() 511 self.assertEqual(1.0, dy.numpy()) 512 513 def testSliceInDefun(self): 514 with self.test_scope(): 515 516 @function.defun 517 def f(x, y): 518 return x[0::2, y:, ...] 519 520 x = array_ops.ones([2, 3, 4], dtype=dtypes.float32) 521 y = array_ops.ones([], dtype=dtypes.int32) 522 with backprop.GradientTape() as tape: 523 tape.watch(x) 524 tape.watch(y) 525 z = f(x, y) 526 dz = tape.gradient(z, x) 527 528 self.assertAllEqual(np.ones([1, 2, 4]), z.numpy()) 529 self.assertAllEqual((2, 3, 4), dz.shape.as_list()) 530 531 def testNestedDefun(self): 532 with self.test_scope(): 533 534 @function.defun 535 def times_two(x): 536 return 2. * x 537 538 @function.defun 539 def two_x_plus_1(x): 540 return times_two(x) + 1. 541 542 x = constant_op.constant([2., 3., 4.]) 543 y = two_x_plus_1(x) 544 self.assertAllEqual([5., 7., 9.], y.numpy()) 545 546 def testNestedDefunWithVariable(self): 547 with self.test_scope(): 548 v0 = resource_variable_ops.ResourceVariable(5.0) 549 550 @function.defun 551 def g(x): 552 x = v0 * x 553 return x 554 555 @function.defun 556 def f(x): 557 x = g(v0 * x) 558 return x 559 560 x = constant_op.constant(3.0) 561 y = f(x) 562 563 self.assertEqual(75.0, y.numpy()) 564 565 def testNestedDefunInGradientTape(self): 566 with self.test_scope(): 567 v0 = resource_variable_ops.ResourceVariable(5.0) 568 569 @function.defun 570 def g(x): 571 x = v0 * x 572 return x 573 574 @function.defun 575 def f(x): 576 x = g(v0 * x) 577 return x 578 579 x = constant_op.constant(3.0) 580 with backprop.GradientTape() as tape: 581 y = f(x) 582 dy = tape.gradient(y, v0) 583 584 self.assertEqual(75, y.numpy()) 585 self.assertEqual(30, dy.numpy()) 586 587 def testNestedDefunInGradientTapeDifferentVars(self): 588 with self.test_scope(): 589 v0 = resource_variable_ops.ResourceVariable(5.0) 590 v1 = resource_variable_ops.ResourceVariable(3.0) 591 592 @function.defun 593 def g(x): 594 x = v1 * x 595 return x 596 597 @function.defun 598 def f(x): 599 x = g(v0 * x) 600 return x 601 602 x = constant_op.constant(3.0) 603 with backprop.GradientTape(persistent=True) as tape: 604 y = f(x) 605 dy_v0 = tape.gradient(y, v0) 606 dy_v1 = tape.gradient(y, v1) 607 608 self.assertEqual(45, y.numpy()) 609 self.assertEqual(9, dy_v0.numpy()) 610 self.assertEqual(15, dy_v1.numpy()) 611 612 def testWhileInDefun(self): 613 with self.test_scope(): 614 @def_function.function 615 def f(start): 616 c = lambda x: math_ops.less(x, 13.0) 617 b = lambda x: math_ops.add(x, 1.0) 618 return control_flow_ops.while_loop(c, b, [start]) 619 620 y = f(constant_op.constant(3.0)) 621 self.assertEqual(13.0, y.numpy()) 622 623 def testAutoGraphWhileInDefun(self): 624 with self.test_scope(): 625 @def_function.function 626 def f(start): 627 x = start 628 while x < 13.0: 629 x += 1.0 630 return x 631 632 y = f(constant_op.constant(3.0)) 633 self.assertEqual(13.0, y.numpy()) 634 635 def testCondInDefun(self): 636 with self.test_scope(): 637 @def_function.function 638 def f(pred, value): 639 fn1 = lambda: math_ops.add(value, 1.0) 640 fn2 = lambda: math_ops.subtract(value, 1.0) 641 return control_flow_ops.cond(pred, fn1, fn2) 642 643 plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) 644 minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) 645 self.assertEqual(11.0, plus_one.numpy()) 646 self.assertEqual(9.0, minus_one.numpy()) 647 648 def testAutoGraphCondInDefun(self): 649 with self.test_scope(): 650 @def_function.function 651 def f(pred, value): 652 if pred: 653 return value + 1.0 654 else: 655 return value - 1.0 656 657 plus_one = f(constant_op.constant(True), constant_op.constant(10.0)) 658 minus_one = f(constant_op.constant(False), constant_op.constant(10.0)) 659 self.assertEqual(11.0, plus_one.numpy()) 660 self.assertEqual(9.0, minus_one.numpy()) 661 662 def testScanInDefun(self): 663 with self.test_scope(): 664 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='data') 665 v = constant_op.constant(2.0, name='v') 666 667 @def_function.function 668 def f(y): 669 # pylint: disable=unnecessary-lambda 670 return functional_ops.scan( 671 lambda a, x: math_ops.multiply(a, x), y, initializer=v) 672 # pylint: enable=unnecessary-lambda 673 674 r = f(elems) 675 self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) 676 677 def testFeedDeviceMemoryToOpExpectingHostMemory(self): 678 @function.defun 679 def f(dims, value): 680 return array_ops.fill(dims, value) 681 682 with self.test_scope(): 683 x = constant_op.constant([4], dtype=dtypes.int64) 684 685 y = f(x, 3) 686 self.assertAllEqual([3, 3, 3, 3], y) 687 688 def testRequestNotToCompile(self): 689 with self.test_scope(): 690 def f(x): 691 with ops.device('device:CPU:0'): 692 y = 2.0 * x 693 return x, y 694 695 wholly_compiled_f = def_function.function(f) 696 op_by_op_f = def_function.function(f, jit_compile=False) 697 698 x = array_ops.identity([0.0, 2.0], name='data') 699 700 # When function is wholly compiled, all outputs will be on the 701 # device on which it is run. 702 r_x, r_y = wholly_compiled_f(x) 703 self.assertAllEqual([0.0, 2.0], r_x) 704 self.assertAllEqual([0.0, 4.0], r_y) 705 if context.executing_eagerly(): 706 # backing_device is only available for eager tensors. 707 self.assertRegex(r_x.backing_device, self.device) 708 self.assertRegex(r_y.backing_device, self.device) 709 710 # When function is executed op-by-op, requested devices will be 711 # respected. 712 r_x, r_y = op_by_op_f(x) 713 self.assertAllEqual([0.0, 2.0], r_x) 714 self.assertAllEqual([0.0, 4.0], r_y) 715 if context.executing_eagerly(): 716 # backing_device is only available for eager tensors. 717 self.assertRegex(r_x.backing_device, self.device) 718 self.assertRegex(r_y.backing_device, 'device:CPU:0') 719 720 721class ExcessivePaddingTest(xla_test.XLATestCase): 722 """Test that eager execution works with TPU flattened tensors. 723 724 Tensors that would normally be excessively padded when written 725 to TPU memory are reshaped to 1-D flat tensors. 726 727 This test case verifies that such tensors work with eager execution. 728 729 The flattening currently only happens on TPU, but tests should work 730 fine with all backends as flattening is transparent. 731 """ 732 733 def testFromConstant(self): 734 with self.test_scope(): 735 # Create constant of shape [100, 2, 1]. This tensor would be 736 # excessively padded on TPU. 737 tensor = constant_op.constant(100 * [[[10.0], [2.0]]]) 738 # Use reduce_sum since it requires correctly working with 739 # a particular dimension. 740 reduced = math_ops.reduce_sum(tensor, axis=1) 741 self.assertAllEqual(100 * [[12.0]], reduced) 742 743 def testFromOperation(self): 744 with self.test_scope(): 745 tensor = array_ops.ones([3, 100, 2, 2]) 746 reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3]) 747 self.assertAllEqual(100 * [12.0], reduced) 748 749 def testAsFunctionInput(self): 750 with self.test_scope(): 751 752 @function.defun 753 def f(x): 754 return math_ops.reduce_sum(x, axis=2) 755 756 tensor = constant_op.constant(100 * [[[10.0, 2.0]]]) 757 reduced = f(tensor) 758 self.assertAllEqual(100 * [[12.0]], reduced) 759 760 def testAsFunctionOutput(self): 761 with self.test_scope(): 762 763 @function.defun 764 def f(x): 765 return x * constant_op.constant(100 * [[[10.0, 2.0]]]) 766 767 y = f(3) 768 reduced = math_ops.reduce_sum(y, axis=2) 769 self.assertAllEqual(100 * [[36.0]], reduced) 770 771 772def multiple_tpus(): 773 devices = context.context().devices() 774 return len([d for d in devices if 'device:TPU:' in d]) > 1 775 776 777class MultiDeviceTest(xla_test.XLATestCase): 778 """Test running TPU computation on more than one core.""" 779 780 def testBasic(self): 781 if not multiple_tpus(): 782 self.skipTest('MultiDeviceTest requires multiple TPU devices.') 783 784 # Compute 10 on TPU core 0 785 with ops.device('device:TPU:0'): 786 two = constant_op.constant(2) 787 five = constant_op.constant(5) 788 ten = two * five 789 self.assertAllEqual(10, ten) 790 791 # Compute 6 on TPU core 1 792 with ops.device('device:TPU:1'): 793 two = constant_op.constant(2) 794 three = constant_op.constant(3) 795 six = two * three 796 self.assertAllEqual(6, six) 797 798 # Copy 10 and 6 to CPU and sum them 799 self.assertAllEqual(16, ten + six) 800 801 802if __name__ == '__main__': 803 ops.enable_eager_execution( 804 config=config_pb2.ConfigProto(log_device_placement=True)) 805 googletest.main() 806