• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for eager execution using XLA."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.compiler.tests import xla_test
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.python.eager import backprop
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.layers import convolutional
33from tensorflow.python.layers import pooling
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import embedding_ops
37from tensorflow.python.ops import functional_ops
38from tensorflow.python.ops import gen_random_ops
39from tensorflow.python.ops import init_ops
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import nn_ops
42from tensorflow.python.ops import resource_variable_ops
43from tensorflow.python.platform import googletest
44from tensorflow.python.training import adam
45
46
47class EagerTest(xla_test.XLATestCase):
48
49  def testBasic(self):
50    with self.test_scope():
51      three = constant_op.constant(3)
52      five = constant_op.constant(5)
53      product = three * five
54      self.assertAllEqual(15, product)
55
56  def testGradientTape(self):
57    with self.test_scope():
58
59      x = constant_op.constant(1.0)
60      y = constant_op.constant(10.0)
61      with backprop.GradientTape(persistent=True) as tape:
62        tape.watch(x)
63        tape.watch(y)
64        a = x + y + x * y
65      da_dx = tape.gradient(a, x)
66      da_dy = tape.gradient(a, y)
67
68    self.assertEqual(11.0, da_dx.numpy())
69    self.assertEqual(2.0, da_dy.numpy())
70
71  def testExecuteListOutputLen0(self):
72    with self.test_scope():
73      empty = constant_op.constant([], dtype=dtypes.float32)
74      result = array_ops.unstack(empty, 0)
75      self.assertTrue(isinstance(result, list))
76      self.assertEqual(0, len(result))
77
78  def testExecuteListOutputLen1(self):
79    with self.test_scope():
80      split_dim = constant_op.constant(1)
81      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
82      result = array_ops.split(value, 1, axis=split_dim)
83      self.assertTrue(isinstance(result, list))
84      self.assertEqual(1, len(result))
85      self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0])
86
87  def testExecuteListOutputLen3(self):
88    with self.test_scope():
89      split_dim = constant_op.constant(1)
90      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
91      result = array_ops.split(value, 3, axis=split_dim)
92      self.assertTrue(isinstance(result, list))
93      self.assertEqual(3, len(result))
94      self.assertAllEqual([[0], [3]], result[0])
95      self.assertAllEqual([[1], [4]], result[1])
96      self.assertAllEqual([[2], [5]], result[2])
97
98  def testBasicGraph(self):
99    # Run some ops eagerly
100    with self.test_scope():
101      three = constant_op.constant(3)
102      five = constant_op.constant(5)
103      product = three * five
104      self.assertAllEqual(15, product)
105
106    # Run some ops graphly
107    with context.graph_mode(), self.session():
108      with self.test_scope():
109        three = constant_op.constant(3)
110        five = constant_op.constant(5)
111        product = three * five
112        self.assertAllEqual(15, self.evaluate(product))
113
114  def testDegenerateSlices(self):
115    with self.test_scope():
116      npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
117      t = constant_op.constant(npt)
118      # degenerate by offering a forward interval with a negative stride
119      self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :])
120      # degenerate with a reverse interval with a positive stride
121      self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :])
122      # empty interval in every dimension
123      self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1])
124
125  def testIdentity(self):
126    with self.test_scope():
127      self.assertAllEqual(2, array_ops.identity(2))
128
129  def testRandomOps(self):
130    with self.test_scope():
131      tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32)
132      row0 = tensor[0].numpy()
133      row1 = tensor[1].numpy()
134      # It should be very unlikely to rng to generate two equal rows.
135      self.assertFalse((row0 == row1).all())
136
137  def testIdentityOnVariable(self):
138    with self.test_scope():
139      v = resource_variable_ops.ResourceVariable(True)
140      i = array_ops.identity(v)
141    self.assertAllEqual(True, i.numpy())
142
143  def testAssignAddVariable(self):
144    with self.test_scope():
145      v = resource_variable_ops.ResourceVariable(1.0)
146      v.assign_add(2.0)
147    self.assertEqual(3.0, v.numpy())
148
149  def testReadAssignRead(self):
150    with self.test_scope():
151      v = resource_variable_ops.ResourceVariable(1.0)
152      val1 = v.read_value()
153      v.assign_add(2.0)
154      val2 = v.read_value()
155    self.assertEqual(1.0, val1.numpy())
156    self.assertEqual(3.0, val2.numpy())
157
158  def testGradient(self):
159    def f(x):
160      return x
161
162    with self.test_scope():
163      grad_fn = backprop.gradients_function(f)
164      self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
165
166  def testVariableGradient(self):
167    with self.test_scope():
168      v0 = resource_variable_ops.ResourceVariable(1.0)
169
170      def f():
171        x = v0 * v0
172        return x
173
174      grads = backprop.implicit_grad(f)()
175    self.assertEqual(2., grads[0][0].numpy())
176
177  def testMultipleVariableReads(self):
178    # This test makes sure consecutive variable reads don't copy
179    # the underlying memory.
180    with self.test_scope():
181      # Create 128MiB variables
182      var = resource_variable_ops.ResourceVariable(
183          array_ops.ones([32, 1024, 1024]))
184
185      # Read the same variable 100 times. If the underlying tensor
186      # is not copied, this is a trivial operation. If it is copied,
187      # this will eat over 13GB and OOM.
188      values = []
189      for _ in range(100):
190        values.append(var.value())
191
192  # The shape, shape_n, size, and rank are tested here because their
193  # execution kernels (as opposed to compilation only tf2xla kernels)
194  # are distincts from tf2xla kernels.
195
196  def testShape(self):
197    def const(value):
198      return array_ops.shape(
199          constant_op.constant(value)).numpy()
200
201    def ones(value):
202      return array_ops.shape(
203          array_ops.ones(value)).numpy()
204
205    with self.test_scope():
206      # Shapes of directly constructed tensors
207      self.assertAllEqual([], const(3))
208      self.assertAllEqual([3], const([1.0, 2.0, 3.0]))
209      self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]]))
210      self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]]))
211
212      # Shapes of tensors created by op running on device
213      # We make this distinction because directly constructed tensors
214      # are treated differently in a few places that can influence shape:
215      #  - they always have on_host_tensor
216      #  - they and their shapes can be cached
217      #  - they end up on device via a copy, instead of as program output
218      self.assertAllEqual([], ones([]))
219      self.assertAllEqual([3], ones([3]))
220      self.assertAllEqual([2, 2], ones([2, 2]))
221      self.assertAllEqual([2, 1, 2], ones([2, 1, 2]))
222
223  def testShapeN(self):
224    with self.test_scope():
225      # Shapes of directly constructed tensors
226      shapes = array_ops.shape_n([
227          constant_op.constant(1.0),
228          constant_op.constant([1.0, 2.0, 3.0]),
229          constant_op.constant([[1.0, 2.0], [3.0, 4.0]])])
230      self.assertAllEqual(
231          [[], [3], [2, 2]],
232          [x.numpy().tolist() for x in shapes])
233
234      # Shapes of tensors created by op running on device
235      shapes = array_ops.shape_n([
236          array_ops.ones([]),
237          array_ops.ones([3]),
238          array_ops.ones([2, 2])])
239      self.assertAllEqual(
240          [[], [3], [2, 2]],
241          [x.numpy().tolist() for x in shapes])
242
243  def testSize(self):
244    with self.test_scope():
245      self.assertEqual(
246          1, array_ops.size(constant_op.constant(1.0)).numpy())
247      self.assertEqual(
248          3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy())
249      self.assertEqual(
250          4, array_ops.size(
251              constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
252
253  def testRank(self):
254    with self.test_scope():
255      self.assertEqual(
256          0, array_ops.rank(constant_op.constant(1.0)).numpy())
257      self.assertEqual(
258          1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy())
259      self.assertEqual(
260          2, array_ops.rank(
261              constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
262
263  def testAdam(self):
264    with self.test_scope():
265      optimizer = adam.AdamOptimizer(0.1)
266      x = resource_variable_ops.ResourceVariable(10.0)
267      with backprop.GradientTape() as tape:
268        y = x * x
269      dy_dx = tape.gradient(y, x)
270      optimizer.apply_gradients([(dy_dx, x)])
271      self.assertAlmostEqual(9.9, x.numpy(), places=3)
272
273  def testAdamSparse(self):
274    with ops.device('/cpu:0'):
275      # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates
276      # are not implemented on TPU.
277      embedding_matrix = resource_variable_ops.ResourceVariable(
278          array_ops.ones([3, 2]))
279
280    with self.test_scope():
281      with backprop.GradientTape() as tape:
282        embedding = embedding_ops.embedding_lookup(embedding_matrix, [1])
283        y = math_ops.reduce_sum(embedding)
284      dy_dx = tape.gradient(y, embedding_matrix)
285      self.assertIsInstance(dy_dx, ops.IndexedSlices)
286      optimizer = adam.AdamOptimizer(0.1)
287      # The gradient application operations will run on CPU because optimizer
288      # updates are always collocated with the variable.
289      optimizer.apply_gradients([(dy_dx, embedding_matrix)])
290
291      # This assign_add will run on CPU because when an input to an
292      # operation is a resource, this operation is placed on the resource's
293      # device by the eager runtime.
294      embedding_matrix.assign_add(array_ops.ones([3, 2]))
295
296    self.assertAllClose([[2.0, 2.0],
297                         [1.9, 1.9],
298                         [2.0, 2.0]], embedding_matrix.numpy())
299
300
301class EagerFunctionTest(xla_test.XLATestCase):
302
303  def testBasic(self):
304    with self.test_scope():
305      matmul = function.defun(math_ops.matmul)
306      t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
307      sq = matmul(t, t, transpose_a=True)
308      self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
309
310  def testConv(self):
311    if 'GPU' in self.device:
312      # TODO(b/32333178)
313      self.skipTest('Current implementation of RandomStandardNormal kernel '
314                    'is very slow on GPU, and has been denylisted.')
315    with self.test_scope():
316      data_format = 'channels_last'
317      conv = convolutional.Conv2D(
318          filters=1, kernel_size=2, padding='VALID',
319          data_format=data_format, activation=nn_ops.relu,
320          kernel_initializer=init_ops.ones_initializer(),
321          bias_initializer=init_ops.zeros_initializer())
322      pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
323
324      def model(x):
325        x = conv(x)
326        return pool(x)
327      model = function.defun(model)
328
329      x = array_ops.ones([1, 4, 4, 1])
330      y = model(x)
331      self.assertAllEqual(y.numpy(), [[[[4.]]]])
332
333  def testReadVariable(self):
334    with self.test_scope():
335      v = resource_variable_ops.ResourceVariable(1.0)
336
337      @function.defun
338      def f():
339        return v.read_value()
340
341      var = f()
342      self.assertEqual(1.0, var.numpy())
343
344  def testResourceVariableNoInlineReadWrite(self):
345    with self.test_scope():
346      v = resource_variable_ops.ResourceVariable(1.0)
347      w = resource_variable_ops.ResourceVariable(0.0)
348
349      @function.defun_with_attributes(attributes={'_noinline': True})
350      def g(x):
351        w.assign(w.read_value() + x)
352        return v.read_value() + x * w.read_value()
353
354      @function.defun_with_attributes(attributes={'_noinline': True})
355      def f():
356        return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0)
357
358      # 1 + 1*1 + 1 + 2*3 + 1 + 3*6 + 1 + 4*10 + 1 + 5*15
359      self.assertEqual(145.0, f().numpy())
360      self.assertEqual(15.0, w.read_value().numpy())
361
362  def testResourceVariableNoInlineReadOnly(self):
363    with self.test_scope():
364      v = resource_variable_ops.ResourceVariable(10.0)
365
366      @function.defun_with_attributes(attributes={'_noinline': True})
367      def g():
368        return v.read_value()
369
370      @function.defun_with_attributes(attributes={'_noinline': True})
371      def f():
372        return g() + g() + g() + g() + g()
373
374      self.assertEqual(50.0, f().numpy())
375
376  def testResourceVariableNoInlineWriteOnly(self):
377    with self.test_scope():
378      v = resource_variable_ops.ResourceVariable(0.0)
379
380      @function.defun_with_attributes(attributes={'_noinline': True})
381      def g(x):
382        v.assign(x)
383
384      @function.defun_with_attributes(attributes={'_noinline': True})
385      def f():
386        g(1.0)
387        g(2.0)
388        g(3.0)
389        g(4.0)
390        g(5.0)
391
392      f()
393      self.assertEqual(5.0, v.read_value().numpy())
394
395  def testUpdateVariable(self):
396    with self.test_scope():
397      v = resource_variable_ops.ResourceVariable(1.0)
398
399      def f(v):
400        v.assign_add(1.0)
401        return v
402
403      f = function.defun(f)
404
405      var = f(v)
406      self.assertEqual(2.0, var.numpy())
407
408  def testReturnResourceHandle(self):
409    with self.test_scope():
410      v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]])
411
412      def f(v):
413        return v.handle
414
415      f = function.defun(f)
416      handle = f(v)
417      self.assertAllEqual(v.numpy(),
418                          resource_variable_ops.read_variable_op(
419                              handle, dtypes.float32).numpy())
420
421  def testReturnMultipleResourceHandles(self):
422    with self.test_scope():
423      v1 = resource_variable_ops.ResourceVariable(1.25)
424      v2 = resource_variable_ops.ResourceVariable(2.0)
425
426      def f(v):
427        return v.handle, 3.0 * v, v2.handle, v + v2
428
429      f = function.defun(f)
430      v1_handle, v1_times_3, v2_handle, variable_sum = f(v1)
431      self.assertAllEqual(v1.numpy(),
432                          resource_variable_ops.read_variable_op(
433                              v1_handle, dtypes.float32).numpy())
434      self.assertEqual(3.75, v1_times_3.numpy())
435      self.assertAllEqual(v2.numpy(),
436                          resource_variable_ops.read_variable_op(
437                              v2_handle, dtypes.float32).numpy())
438      self.assertEqual(3.25, variable_sum.numpy())
439
440  def testAllArgumentKinds(self):
441    """Test a complex function that takes different argument kinds.
442
443    tf2xla machinery that translates, compiles, and runs defuns
444    classifies arguments into: compile-time constants, regular tensors,
445    and resources. This test creates a function with a mix of all these
446    kinds. Moreover, the order of function arguments is intentionally mixed up.
447
448    This also tests the case when the same argument is a compile-time constant
449    as well as used in an operation that normally expects its inputs to be
450    in device memory - addition in this case.
451    """
452    with self.test_scope():
453      def foo(c1, r1, v1, c2, v2, r2):
454        # c1 and c2 are compile-time constants
455        # r1 and r2 are regular tensors
456        # v1 and v2 are resource variables
457        a = c1 + r1
458        b = math_ops.cast(c2, dtypes.float32) + v2
459        c = array_ops.slice(v1, c1, c2)
460        d = r2 * v2
461        return a, b, c, d
462
463      foo = function.defun(foo)
464
465      c1 = [0, 0]
466      c2 = array_ops.ones([2], dtype=dtypes.int32)
467
468      r1 = array_ops.ones([2])
469      r2 = [[2., 2.], [3., 3.]]
470
471      v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
472      v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
473
474      a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
475
476      self.assertAllEqual([1, 1], a.numpy())
477      self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
478      self.assertAllEqual([[1.]], c.numpy())
479      self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
480
481  def testDefunInGradientTape(self):
482    with self.test_scope():
483      v0 = resource_variable_ops.ResourceVariable(5.0)
484
485      @function.defun
486      def f(x):
487        x = v0 * v0 * x
488        return x
489
490      x = constant_op.constant(3.0)
491      with backprop.GradientTape() as tape:
492        y = f(x)
493      dy = tape.gradient(y, v0)
494
495    self.assertEqual(75, y.numpy())
496    self.assertEqual(30, dy.numpy())
497
498  def testGradientTapeInDefun(self):
499    with self.test_scope():
500      v0 = resource_variable_ops.ResourceVariable(5.0)
501
502      @function.defun
503      def f():
504        x = constant_op.constant(1.0)
505        with backprop.GradientTape() as tape:
506          y = v0 * x
507        dy = tape.gradient(y, v0)
508        return dy
509
510      dy = f()
511      self.assertEqual(1.0, dy.numpy())
512
513  def testSliceInDefun(self):
514    with self.test_scope():
515
516      @function.defun
517      def f(x, y):
518        return x[0::2, y:, ...]
519
520      x = array_ops.ones([2, 3, 4], dtype=dtypes.float32)
521      y = array_ops.ones([], dtype=dtypes.int32)
522      with backprop.GradientTape() as tape:
523        tape.watch(x)
524        tape.watch(y)
525        z = f(x, y)
526      dz = tape.gradient(z, x)
527
528      self.assertAllEqual(np.ones([1, 2, 4]), z.numpy())
529      self.assertAllEqual((2, 3, 4), dz.shape.as_list())
530
531  def testNestedDefun(self):
532    with self.test_scope():
533
534      @function.defun
535      def times_two(x):
536        return 2. * x
537
538      @function.defun
539      def two_x_plus_1(x):
540        return times_two(x) + 1.
541
542      x = constant_op.constant([2., 3., 4.])
543      y = two_x_plus_1(x)
544      self.assertAllEqual([5., 7., 9.], y.numpy())
545
546  def testNestedDefunWithVariable(self):
547    with self.test_scope():
548      v0 = resource_variable_ops.ResourceVariable(5.0)
549
550      @function.defun
551      def g(x):
552        x = v0 * x
553        return x
554
555      @function.defun
556      def f(x):
557        x = g(v0 * x)
558        return x
559
560      x = constant_op.constant(3.0)
561      y = f(x)
562
563    self.assertEqual(75.0, y.numpy())
564
565  def testNestedDefunInGradientTape(self):
566    with self.test_scope():
567      v0 = resource_variable_ops.ResourceVariable(5.0)
568
569      @function.defun
570      def g(x):
571        x = v0 * x
572        return x
573
574      @function.defun
575      def f(x):
576        x = g(v0 * x)
577        return x
578
579      x = constant_op.constant(3.0)
580      with backprop.GradientTape() as tape:
581        y = f(x)
582      dy = tape.gradient(y, v0)
583
584    self.assertEqual(75, y.numpy())
585    self.assertEqual(30, dy.numpy())
586
587  def testNestedDefunInGradientTapeDifferentVars(self):
588    with self.test_scope():
589      v0 = resource_variable_ops.ResourceVariable(5.0)
590      v1 = resource_variable_ops.ResourceVariable(3.0)
591
592      @function.defun
593      def g(x):
594        x = v1 * x
595        return x
596
597      @function.defun
598      def f(x):
599        x = g(v0 * x)
600        return x
601
602      x = constant_op.constant(3.0)
603      with backprop.GradientTape(persistent=True) as tape:
604        y = f(x)
605      dy_v0 = tape.gradient(y, v0)
606      dy_v1 = tape.gradient(y, v1)
607
608    self.assertEqual(45, y.numpy())
609    self.assertEqual(9, dy_v0.numpy())
610    self.assertEqual(15, dy_v1.numpy())
611
612  def testWhileInDefun(self):
613    with self.test_scope():
614      @def_function.function
615      def f(start):
616        c = lambda x: math_ops.less(x, 13.0)
617        b = lambda x: math_ops.add(x, 1.0)
618        return control_flow_ops.while_loop(c, b, [start])
619
620      y = f(constant_op.constant(3.0))
621    self.assertEqual(13.0, y.numpy())
622
623  def testAutoGraphWhileInDefun(self):
624    with self.test_scope():
625      @def_function.function
626      def f(start):
627        x = start
628        while x < 13.0:
629          x += 1.0
630        return x
631
632      y = f(constant_op.constant(3.0))
633    self.assertEqual(13.0, y.numpy())
634
635  def testCondInDefun(self):
636    with self.test_scope():
637      @def_function.function
638      def f(pred, value):
639        fn1 = lambda: math_ops.add(value, 1.0)
640        fn2 = lambda: math_ops.subtract(value, 1.0)
641        return control_flow_ops.cond(pred, fn1, fn2)
642
643      plus_one = f(constant_op.constant(True), constant_op.constant(10.0))
644      minus_one = f(constant_op.constant(False), constant_op.constant(10.0))
645    self.assertEqual(11.0, plus_one.numpy())
646    self.assertEqual(9.0, minus_one.numpy())
647
648  def testAutoGraphCondInDefun(self):
649    with self.test_scope():
650      @def_function.function
651      def f(pred, value):
652        if pred:
653          return value + 1.0
654        else:
655          return value - 1.0
656
657      plus_one = f(constant_op.constant(True), constant_op.constant(10.0))
658      minus_one = f(constant_op.constant(False), constant_op.constant(10.0))
659    self.assertEqual(11.0, plus_one.numpy())
660    self.assertEqual(9.0, minus_one.numpy())
661
662  def testScanInDefun(self):
663    with self.test_scope():
664      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='data')
665      v = constant_op.constant(2.0, name='v')
666
667      @def_function.function
668      def f(y):
669        # pylint: disable=unnecessary-lambda
670        return functional_ops.scan(
671            lambda a, x: math_ops.multiply(a, x), y, initializer=v)
672        # pylint: enable=unnecessary-lambda
673
674      r = f(elems)
675      self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
676
677  def testFeedDeviceMemoryToOpExpectingHostMemory(self):
678    @function.defun
679    def f(dims, value):
680      return array_ops.fill(dims, value)
681
682    with self.test_scope():
683      x = constant_op.constant([4], dtype=dtypes.int64)
684
685    y = f(x, 3)
686    self.assertAllEqual([3, 3, 3, 3], y)
687
688  def testRequestNotToCompile(self):
689    with self.test_scope():
690      def f(x):
691        with ops.device('device:CPU:0'):
692          y = 2.0 * x
693        return x, y
694
695      wholly_compiled_f = def_function.function(f)
696      op_by_op_f = def_function.function(f, jit_compile=False)
697
698      x = array_ops.identity([0.0, 2.0], name='data')
699
700      # When function is wholly compiled, all outputs will be on the
701      # device on which it is run.
702      r_x, r_y = wholly_compiled_f(x)
703      self.assertAllEqual([0.0, 2.0], r_x)
704      self.assertAllEqual([0.0, 4.0], r_y)
705      if context.executing_eagerly():
706        # backing_device is only available for eager tensors.
707        self.assertRegex(r_x.backing_device, self.device)
708        self.assertRegex(r_y.backing_device, self.device)
709
710      # When function is executed op-by-op, requested devices will be
711      # respected.
712      r_x, r_y = op_by_op_f(x)
713      self.assertAllEqual([0.0, 2.0], r_x)
714      self.assertAllEqual([0.0, 4.0], r_y)
715      if context.executing_eagerly():
716        # backing_device is only available for eager tensors.
717        self.assertRegex(r_x.backing_device, self.device)
718        self.assertRegex(r_y.backing_device, 'device:CPU:0')
719
720
721class ExcessivePaddingTest(xla_test.XLATestCase):
722  """Test that eager execution works with TPU flattened tensors.
723
724  Tensors that would normally be excessively padded when written
725  to TPU memory are reshaped to 1-D flat tensors.
726
727  This test case verifies that such tensors work with eager execution.
728
729  The flattening currently only happens on TPU, but tests should work
730  fine with all backends as flattening is transparent.
731  """
732
733  def testFromConstant(self):
734    with self.test_scope():
735      # Create constant of shape [100, 2, 1]. This tensor would be
736      # excessively padded on TPU.
737      tensor = constant_op.constant(100 * [[[10.0], [2.0]]])
738      # Use reduce_sum since it requires correctly working with
739      # a particular dimension.
740      reduced = math_ops.reduce_sum(tensor, axis=1)
741      self.assertAllEqual(100 * [[12.0]], reduced)
742
743  def testFromOperation(self):
744    with self.test_scope():
745      tensor = array_ops.ones([3, 100, 2, 2])
746      reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3])
747      self.assertAllEqual(100 * [12.0], reduced)
748
749  def testAsFunctionInput(self):
750    with self.test_scope():
751
752      @function.defun
753      def f(x):
754        return math_ops.reduce_sum(x, axis=2)
755
756      tensor = constant_op.constant(100 * [[[10.0, 2.0]]])
757      reduced = f(tensor)
758      self.assertAllEqual(100 * [[12.0]], reduced)
759
760  def testAsFunctionOutput(self):
761    with self.test_scope():
762
763      @function.defun
764      def f(x):
765        return x * constant_op.constant(100 * [[[10.0, 2.0]]])
766
767      y = f(3)
768      reduced = math_ops.reduce_sum(y, axis=2)
769      self.assertAllEqual(100 * [[36.0]], reduced)
770
771
772def multiple_tpus():
773  devices = context.context().devices()
774  return len([d for d in devices if 'device:TPU:' in d]) > 1
775
776
777class MultiDeviceTest(xla_test.XLATestCase):
778  """Test running TPU computation on more than one core."""
779
780  def testBasic(self):
781    if not multiple_tpus():
782      self.skipTest('MultiDeviceTest requires multiple TPU devices.')
783
784    # Compute 10 on TPU core 0
785    with ops.device('device:TPU:0'):
786      two = constant_op.constant(2)
787      five = constant_op.constant(5)
788      ten = two * five
789      self.assertAllEqual(10, ten)
790
791    # Compute 6 on TPU core 1
792    with ops.device('device:TPU:1'):
793      two = constant_op.constant(2)
794      three = constant_op.constant(3)
795      six = two * three
796      self.assertAllEqual(6, six)
797
798    # Copy 10 and 6 to CPU and sum them
799    self.assertAllEqual(16, ten + six)
800
801
802if __name__ == '__main__':
803  ops.enable_eager_execution(
804      config=config_pb2.ConfigProto(log_device_placement=True))
805  googletest.main()
806