• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for eager execution using XLA."""
16
17import numpy as np
18
19from tensorflow.compiler.tests import xla_test
20from tensorflow.core.protobuf import config_pb2
21from tensorflow.python.eager import backprop
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import function
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import indexed_slices
28from tensorflow.python.framework import ops
29from tensorflow.python.layers import convolutional
30from tensorflow.python.layers import pooling
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import embedding_ops
34from tensorflow.python.ops import functional_ops
35from tensorflow.python.ops import gen_random_ops
36from tensorflow.python.ops import init_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import nn_ops
39from tensorflow.python.ops import resource_variable_ops
40from tensorflow.python.platform import googletest
41from tensorflow.python.training import adam
42
43
44class EagerTest(xla_test.XLATestCase):
45
46  def testBasic(self):
47    with self.test_scope():
48      three = constant_op.constant(3)
49      five = constant_op.constant(5)
50      product = three * five
51      self.assertAllEqual(15, product)
52
53  def testGradientTape(self):
54    with self.test_scope():
55
56      x = constant_op.constant(1.0)
57      y = constant_op.constant(10.0)
58      with backprop.GradientTape(persistent=True) as tape:
59        tape.watch(x)
60        tape.watch(y)
61        a = x + y + x * y
62      da_dx = tape.gradient(a, x)
63      da_dy = tape.gradient(a, y)
64
65    self.assertEqual(11.0, da_dx.numpy())
66    self.assertEqual(2.0, da_dy.numpy())
67
68  def testExecuteListOutputLen0(self):
69    with self.test_scope():
70      empty = constant_op.constant([], dtype=dtypes.float32)
71      result = array_ops.unstack(empty, 0)
72      self.assertTrue(isinstance(result, list))
73      self.assertEqual(0, len(result))
74
75  def testExecuteListOutputLen1(self):
76    with self.test_scope():
77      split_dim = constant_op.constant(1)
78      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
79      result = array_ops.split(value, 1, axis=split_dim)
80      self.assertTrue(isinstance(result, list))
81      self.assertEqual(1, len(result))
82      self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0])
83
84  def testExecuteListOutputLen3(self):
85    with self.test_scope():
86      split_dim = constant_op.constant(1)
87      value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]])
88      result = array_ops.split(value, 3, axis=split_dim)
89      self.assertTrue(isinstance(result, list))
90      self.assertEqual(3, len(result))
91      self.assertAllEqual([[0], [3]], result[0])
92      self.assertAllEqual([[1], [4]], result[1])
93      self.assertAllEqual([[2], [5]], result[2])
94
95  def testBasicGraph(self):
96    # Run some ops eagerly
97    with self.test_scope():
98      three = constant_op.constant(3)
99      five = constant_op.constant(5)
100      product = three * five
101      self.assertAllEqual(15, product)
102
103    # Run some ops graphly
104    with context.graph_mode(), self.session():
105      with self.test_scope():
106        three = constant_op.constant(3)
107        five = constant_op.constant(5)
108        product = three * five
109        self.assertAllEqual(15, self.evaluate(product))
110
111  def testDegenerateSlices(self):
112    with self.test_scope():
113      npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
114      t = constant_op.constant(npt)
115      # degenerate by offering a forward interval with a negative stride
116      self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :])
117      # degenerate with a reverse interval with a positive stride
118      self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :])
119      # empty interval in every dimension
120      self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1])
121
122  def testIdentity(self):
123    with self.test_scope():
124      self.assertAllEqual(2, array_ops.identity(2))
125
126  def testRandomOps(self):
127    with self.test_scope():
128      tensor = gen_random_ops.random_uniform((2, 2), dtypes.float32)
129      row0 = tensor[0].numpy()
130      row1 = tensor[1].numpy()
131      # It should be very unlikely to rng to generate two equal rows.
132      self.assertFalse((row0 == row1).all())
133
134  def testIdentityOnVariable(self):
135    with self.test_scope():
136      v = resource_variable_ops.ResourceVariable(True)
137      i = array_ops.identity(v)
138    self.assertAllEqual(True, i.numpy())
139
140  def testAssignAddVariable(self):
141    with self.test_scope():
142      v = resource_variable_ops.ResourceVariable(1.0)
143      v.assign_add(2.0)
144    self.assertEqual(3.0, v.numpy())
145
146  def testReadAssignRead(self):
147    with self.test_scope():
148      v = resource_variable_ops.ResourceVariable(1.0)
149      val1 = v.read_value()
150      v.assign_add(2.0)
151      val2 = v.read_value()
152    self.assertEqual(1.0, val1.numpy())
153    self.assertEqual(3.0, val2.numpy())
154
155  def testGradient(self):
156    def f(x):
157      return x
158
159    with self.test_scope():
160      grad_fn = backprop.gradients_function(f)
161      self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
162
163  def testVariableGradient(self):
164    with self.test_scope():
165      v0 = resource_variable_ops.ResourceVariable(1.0)
166
167      def f():
168        x = v0 * v0
169        return x
170
171      grads = backprop.implicit_grad(f)()
172    self.assertEqual(2., grads[0][0].numpy())
173
174  def testMultipleVariableReads(self):
175    # This test makes sure consecutive variable reads don't copy
176    # the underlying memory.
177    with self.test_scope():
178      # Create 128MiB variables
179      var = resource_variable_ops.ResourceVariable(
180          array_ops.ones([32, 1024, 1024]))
181
182      # Read the same variable 100 times. If the underlying tensor
183      # is not copied, this is a trivial operation. If it is copied,
184      # this will eat over 13GB and OOM.
185      values = []
186      for _ in range(100):
187        values.append(var.value())
188
189  # The shape, shape_n, size, and rank are tested here because their
190  # execution kernels (as opposed to compilation only tf2xla kernels)
191  # are distincts from tf2xla kernels.
192
193  def testShape(self):
194    def const(value):
195      return array_ops.shape(
196          constant_op.constant(value)).numpy()
197
198    def ones(value):
199      return array_ops.shape(
200          array_ops.ones(value)).numpy()
201
202    with self.test_scope():
203      # Shapes of directly constructed tensors
204      self.assertAllEqual([], const(3))
205      self.assertAllEqual([3], const([1.0, 2.0, 3.0]))
206      self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]]))
207      self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]]))
208
209      # Shapes of tensors created by op running on device
210      # We make this distinction because directly constructed tensors
211      # are treated differently in a few places that can influence shape:
212      #  - they always have on_host_tensor
213      #  - they and their shapes can be cached
214      #  - they end up on device via a copy, instead of as program output
215      self.assertAllEqual([], ones([]))
216      self.assertAllEqual([3], ones([3]))
217      self.assertAllEqual([2, 2], ones([2, 2]))
218      self.assertAllEqual([2, 1, 2], ones([2, 1, 2]))
219
220  def testShapeN(self):
221    with self.test_scope():
222      # Shapes of directly constructed tensors
223      shapes = array_ops.shape_n([
224          constant_op.constant(1.0),
225          constant_op.constant([1.0, 2.0, 3.0]),
226          constant_op.constant([[1.0, 2.0], [3.0, 4.0]])])
227      self.assertAllEqual(
228          [[], [3], [2, 2]],
229          [x.numpy().tolist() for x in shapes])
230
231      # Shapes of tensors created by op running on device
232      shapes = array_ops.shape_n([
233          array_ops.ones([]),
234          array_ops.ones([3]),
235          array_ops.ones([2, 2])])
236      self.assertAllEqual(
237          [[], [3], [2, 2]],
238          [x.numpy().tolist() for x in shapes])
239
240  def testSize(self):
241    with self.test_scope():
242      self.assertEqual(
243          1, array_ops.size(constant_op.constant(1.0)).numpy())
244      self.assertEqual(
245          3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy())
246      self.assertEqual(
247          4, array_ops.size(
248              constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
249
250  def testRank(self):
251    with self.test_scope():
252      self.assertEqual(
253          0, array_ops.rank(constant_op.constant(1.0)).numpy())
254      self.assertEqual(
255          1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy())
256      self.assertEqual(
257          2, array_ops.rank(
258              constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
259
260  def testAdam(self):
261    with self.test_scope():
262      optimizer = adam.AdamOptimizer(0.1)
263      x = resource_variable_ops.ResourceVariable(10.0)
264      with backprop.GradientTape() as tape:
265        y = x * x
266      dy_dx = tape.gradient(y, x)
267      optimizer.apply_gradients([(dy_dx, x)])
268      self.assertAlmostEqual(9.9, x.numpy(), places=3)
269
270  def testAdamSparse(self):
271    with ops.device('/cpu:0'):
272      # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates
273      # are not implemented on TPU.
274      embedding_matrix = resource_variable_ops.ResourceVariable(
275          array_ops.ones([3, 2]))
276
277    with self.test_scope():
278      with backprop.GradientTape() as tape:
279        embedding = embedding_ops.embedding_lookup(embedding_matrix, [1])
280        y = math_ops.reduce_sum(embedding)
281      dy_dx = tape.gradient(y, embedding_matrix)
282      self.assertIsInstance(dy_dx, indexed_slices.IndexedSlices)
283      optimizer = adam.AdamOptimizer(0.1)
284      # The gradient application operations will run on CPU because optimizer
285      # updates are always collocated with the variable.
286      optimizer.apply_gradients([(dy_dx, embedding_matrix)])
287
288      # This assign_add will run on CPU because when an input to an
289      # operation is a resource, this operation is placed on the resource's
290      # device by the eager runtime.
291      embedding_matrix.assign_add(array_ops.ones([3, 2]))
292
293    self.assertAllClose([[2.0, 2.0],
294                         [1.9, 1.9],
295                         [2.0, 2.0]], embedding_matrix.numpy())
296
297
298class EagerFunctionTest(xla_test.XLATestCase):
299
300  def testBasic(self):
301    with self.test_scope():
302      matmul = function.defun(math_ops.matmul)
303      t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
304      sq = matmul(t, t, transpose_a=True)
305      self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
306
307  def testConv(self):
308    if 'GPU' in self.device:
309      # TODO(b/32333178)
310      self.skipTest('Current implementation of RandomStandardNormal kernel '
311                    'is very slow on GPU, and has been denylisted.')
312    with self.test_scope():
313      data_format = 'channels_last'
314      conv = convolutional.Conv2D(
315          filters=1, kernel_size=2, padding='VALID',
316          data_format=data_format, activation=nn_ops.relu,
317          kernel_initializer=init_ops.ones_initializer(),
318          bias_initializer=init_ops.zeros_initializer())
319      pool = pooling.MaxPooling2D(2, 2, data_format=data_format)
320
321      def model(x):
322        x = conv(x)
323        return pool(x)
324      model = function.defun(model)
325
326      x = array_ops.ones([1, 4, 4, 1])
327      y = model(x)
328      self.assertAllEqual(y.numpy(), [[[[4.]]]])
329
330  def testReadVariable(self):
331    with self.test_scope():
332      v = resource_variable_ops.ResourceVariable(1.0)
333
334      @function.defun
335      def f():
336        return v.read_value()
337
338      var = f()
339      self.assertEqual(1.0, var.numpy())
340
341  def testResourceVariableNoInlineReadWrite(self):
342    with self.test_scope():
343      v = resource_variable_ops.ResourceVariable(1.0)
344      w = resource_variable_ops.ResourceVariable(0.0)
345
346      @function.defun_with_attributes(attributes={'_noinline': True})
347      def g(x):
348        w.assign(w.read_value() + x)
349        return v.read_value() + x * w.read_value()
350
351      @function.defun_with_attributes(attributes={'_noinline': True})
352      def f():
353        return g(1.0) + g(2.0) + g(3.0) + g(4.0) + g(5.0)
354
355      # 1 + 1*1 + 1 + 2*3 + 1 + 3*6 + 1 + 4*10 + 1 + 5*15
356      self.assertEqual(145.0, f().numpy())
357      self.assertEqual(15.0, w.read_value().numpy())
358
359  def testResourceVariableNoInlineReadOnly(self):
360    with self.test_scope():
361      v = resource_variable_ops.ResourceVariable(10.0)
362
363      @function.defun_with_attributes(attributes={'_noinline': True})
364      def g():
365        return v.read_value()
366
367      @function.defun_with_attributes(attributes={'_noinline': True})
368      def f():
369        return g() + g() + g() + g() + g()
370
371      self.assertEqual(50.0, f().numpy())
372
373  def testResourceVariableNoInlineWriteOnly(self):
374    with self.test_scope():
375      v = resource_variable_ops.ResourceVariable(0.0)
376
377      @function.defun_with_attributes(attributes={'_noinline': True})
378      def g(x):
379        v.assign(x)
380
381      @function.defun_with_attributes(attributes={'_noinline': True})
382      def f():
383        g(1.0)
384        g(2.0)
385        g(3.0)
386        g(4.0)
387        g(5.0)
388
389      f()
390      self.assertEqual(5.0, v.read_value().numpy())
391
392  def testUpdateVariable(self):
393    with self.test_scope():
394      v = resource_variable_ops.ResourceVariable(1.0)
395
396      def f(v):
397        v.assign_add(1.0)
398        return v
399
400      f = function.defun(f)
401
402      var = f(v)
403      self.assertEqual(2.0, var.numpy())
404
405  def testReturnResourceHandle(self):
406    with self.test_scope():
407      v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]])
408
409      def f(v):
410        return v.handle
411
412      f = function.defun(f)
413      handle = f(v)
414      self.assertAllEqual(v.numpy(),
415                          resource_variable_ops.read_variable_op(
416                              handle, dtypes.float32).numpy())
417
418  def testReturnMultipleResourceHandles(self):
419    with self.test_scope():
420      v1 = resource_variable_ops.ResourceVariable(1.25)
421      v2 = resource_variable_ops.ResourceVariable(2.0)
422
423      def f(v):
424        return v.handle, 3.0 * v, v2.handle, v + v2
425
426      f = function.defun(f)
427      v1_handle, v1_times_3, v2_handle, variable_sum = f(v1)
428      self.assertAllEqual(v1.numpy(),
429                          resource_variable_ops.read_variable_op(
430                              v1_handle, dtypes.float32).numpy())
431      self.assertEqual(3.75, v1_times_3.numpy())
432      self.assertAllEqual(v2.numpy(),
433                          resource_variable_ops.read_variable_op(
434                              v2_handle, dtypes.float32).numpy())
435      self.assertEqual(3.25, variable_sum.numpy())
436
437  def testAllArgumentKinds(self):
438    """Test a complex function that takes different argument kinds.
439
440    tf2xla machinery that translates, compiles, and runs defuns
441    classifies arguments into: compile-time constants, regular tensors,
442    and resources. This test creates a function with a mix of all these
443    kinds. Moreover, the order of function arguments is intentionally mixed up.
444
445    This also tests the case when the same argument is a compile-time constant
446    as well as used in an operation that normally expects its inputs to be
447    in device memory - addition in this case.
448    """
449    with self.test_scope():
450      def foo(c1, r1, v1, c2, v2, r2):
451        # c1 and c2 are compile-time constants
452        # r1 and r2 are regular tensors
453        # v1 and v2 are resource variables
454        a = c1 + r1
455        b = math_ops.cast(c2, dtypes.float32) + v2
456        c = array_ops.slice(v1, c1, c2)
457        d = r2 * v2
458        return a, b, c, d
459
460      foo = function.defun(foo)
461
462      c1 = [0, 0]
463      c2 = array_ops.ones([2], dtype=dtypes.int32)
464
465      r1 = array_ops.ones([2])
466      r2 = [[2., 2.], [3., 3.]]
467
468      v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]])
469      v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]])
470
471      a, b, c, d = foo(c1, r1, v1, c2, v2, r2)
472
473      self.assertAllEqual([1, 1], a.numpy())
474      self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy())
475      self.assertAllEqual([[1.]], c.numpy())
476      self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
477
478  def testDefunInGradientTape(self):
479    with self.test_scope():
480      v0 = resource_variable_ops.ResourceVariable(5.0)
481
482      @function.defun
483      def f(x):
484        x = v0 * v0 * x
485        return x
486
487      x = constant_op.constant(3.0)
488      with backprop.GradientTape() as tape:
489        y = f(x)
490      dy = tape.gradient(y, v0)
491
492    self.assertEqual(75, y.numpy())
493    self.assertEqual(30, dy.numpy())
494
495  def testGradientTapeInDefun(self):
496    with self.test_scope():
497      v0 = resource_variable_ops.ResourceVariable(5.0)
498
499      @function.defun
500      def f():
501        x = constant_op.constant(1.0)
502        with backprop.GradientTape() as tape:
503          y = v0 * x
504        dy = tape.gradient(y, v0)
505        return dy
506
507      dy = f()
508      self.assertEqual(1.0, dy.numpy())
509
510  def testSliceInDefun(self):
511    with self.test_scope():
512
513      @function.defun
514      def f(x, y):
515        return x[0::2, y:, ...]
516
517      x = array_ops.ones([2, 3, 4], dtype=dtypes.float32)
518      y = array_ops.ones([], dtype=dtypes.int32)
519      with backprop.GradientTape() as tape:
520        tape.watch(x)
521        tape.watch(y)
522        z = f(x, y)
523      dz = tape.gradient(z, x)
524
525      self.assertAllEqual(np.ones([1, 2, 4]), z.numpy())
526      self.assertAllEqual((2, 3, 4), dz.shape.as_list())
527
528  def testNestedDefun(self):
529    with self.test_scope():
530
531      @function.defun
532      def times_two(x):
533        return 2. * x
534
535      @function.defun
536      def two_x_plus_1(x):
537        return times_two(x) + 1.
538
539      x = constant_op.constant([2., 3., 4.])
540      y = two_x_plus_1(x)
541      self.assertAllEqual([5., 7., 9.], y.numpy())
542
543  def testNestedDefunWithVariable(self):
544    with self.test_scope():
545      v0 = resource_variable_ops.ResourceVariable(5.0)
546
547      @function.defun
548      def g(x):
549        x = v0 * x
550        return x
551
552      @function.defun
553      def f(x):
554        x = g(v0 * x)
555        return x
556
557      x = constant_op.constant(3.0)
558      y = f(x)
559
560    self.assertEqual(75.0, y.numpy())
561
562  def testNestedDefunInGradientTape(self):
563    with self.test_scope():
564      v0 = resource_variable_ops.ResourceVariable(5.0)
565
566      @function.defun
567      def g(x):
568        x = v0 * x
569        return x
570
571      @function.defun
572      def f(x):
573        x = g(v0 * x)
574        return x
575
576      x = constant_op.constant(3.0)
577      with backprop.GradientTape() as tape:
578        y = f(x)
579      dy = tape.gradient(y, v0)
580
581    self.assertEqual(75, y.numpy())
582    self.assertEqual(30, dy.numpy())
583
584  def testNestedDefunInGradientTapeDifferentVars(self):
585    with self.test_scope():
586      v0 = resource_variable_ops.ResourceVariable(5.0)
587      v1 = resource_variable_ops.ResourceVariable(3.0)
588
589      @function.defun
590      def g(x):
591        x = v1 * x
592        return x
593
594      @function.defun
595      def f(x):
596        x = g(v0 * x)
597        return x
598
599      x = constant_op.constant(3.0)
600      with backprop.GradientTape(persistent=True) as tape:
601        y = f(x)
602      dy_v0 = tape.gradient(y, v0)
603      dy_v1 = tape.gradient(y, v1)
604
605    self.assertEqual(45, y.numpy())
606    self.assertEqual(9, dy_v0.numpy())
607    self.assertEqual(15, dy_v1.numpy())
608
609  def testWhileInDefun(self):
610    with self.test_scope():
611      @def_function.function
612      def f(start):
613        c = lambda x: math_ops.less(x, 13.0)
614        b = lambda x: math_ops.add(x, 1.0)
615        return control_flow_ops.while_loop(c, b, [start])
616
617      y = f(constant_op.constant(3.0))
618    self.assertEqual(13.0, y.numpy())
619
620  def testAutoGraphWhileInDefun(self):
621    with self.test_scope():
622      @def_function.function
623      def f(start):
624        x = start
625        while x < 13.0:
626          x += 1.0
627        return x
628
629      y = f(constant_op.constant(3.0))
630    self.assertEqual(13.0, y.numpy())
631
632  def testCondInDefun(self):
633    with self.test_scope():
634      @def_function.function
635      def f(pred, value):
636        fn1 = lambda: math_ops.add(value, 1.0)
637        fn2 = lambda: math_ops.subtract(value, 1.0)
638        return control_flow_ops.cond(pred, fn1, fn2)
639
640      plus_one = f(constant_op.constant(True), constant_op.constant(10.0))
641      minus_one = f(constant_op.constant(False), constant_op.constant(10.0))
642    self.assertEqual(11.0, plus_one.numpy())
643    self.assertEqual(9.0, minus_one.numpy())
644
645  def testAutoGraphCondInDefun(self):
646    with self.test_scope():
647      @def_function.function
648      def f(pred, value):
649        if pred:
650          return value + 1.0
651        else:
652          return value - 1.0
653
654      plus_one = f(constant_op.constant(True), constant_op.constant(10.0))
655      minus_one = f(constant_op.constant(False), constant_op.constant(10.0))
656    self.assertEqual(11.0, plus_one.numpy())
657    self.assertEqual(9.0, minus_one.numpy())
658
659  def testScanInDefun(self):
660    with self.test_scope():
661      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='data')
662      v = constant_op.constant(2.0, name='v')
663
664      @def_function.function
665      def f(y):
666        # pylint: disable=unnecessary-lambda
667        return functional_ops.scan(
668            lambda a, x: math_ops.multiply(a, x), y, initializer=v)
669        # pylint: enable=unnecessary-lambda
670
671      r = f(elems)
672      self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
673
674  def testFeedDeviceMemoryToOpExpectingHostMemory(self):
675    @function.defun
676    def f(dims, value):
677      return array_ops.fill(dims, value)
678
679    with self.test_scope():
680      x = constant_op.constant([4], dtype=dtypes.int64)
681
682    y = f(x, 3)
683    self.assertAllEqual([3, 3, 3, 3], y)
684
685  def testRequestNotToCompile(self):
686    with self.test_scope():
687      def f(x):
688        with ops.device('device:CPU:0'):
689          y = 2.0 * x
690        return x, y
691
692      wholly_compiled_f = def_function.function(f)
693      op_by_op_f = def_function.function(f, jit_compile=False)
694
695      x = array_ops.identity([0.0, 2.0], name='data')
696
697      # When function is wholly compiled, all outputs will be on the
698      # device on which it is run.
699      r_x, r_y = wholly_compiled_f(x)
700      self.assertAllEqual([0.0, 2.0], r_x)
701      self.assertAllEqual([0.0, 4.0], r_y)
702      if context.executing_eagerly():
703        # backing_device is only available for eager tensors.
704        self.assertRegex(r_x.backing_device, self.device)
705        self.assertRegex(r_y.backing_device, self.device)
706
707      # When function is executed op-by-op, requested devices will be
708      # respected.
709      r_x, r_y = op_by_op_f(x)
710      self.assertAllEqual([0.0, 2.0], r_x)
711      self.assertAllEqual([0.0, 4.0], r_y)
712      if context.executing_eagerly():
713        # backing_device is only available for eager tensors.
714        self.assertRegex(r_x.backing_device, self.device)
715        self.assertRegex(r_y.backing_device, 'device:CPU:0')
716
717
718class ExcessivePaddingTest(xla_test.XLATestCase):
719  """Test that eager execution works with TPU flattened tensors.
720
721  Tensors that would normally be excessively padded when written
722  to TPU memory are reshaped to 1-D flat tensors.
723
724  This test case verifies that such tensors work with eager execution.
725
726  The flattening currently only happens on TPU, but tests should work
727  fine with all backends as flattening is transparent.
728  """
729
730  def testFromConstant(self):
731    with self.test_scope():
732      # Create constant of shape [100, 2, 1]. This tensor would be
733      # excessively padded on TPU.
734      tensor = constant_op.constant(100 * [[[10.0], [2.0]]])
735      # Use reduce_sum since it requires correctly working with
736      # a particular dimension.
737      reduced = math_ops.reduce_sum(tensor, axis=1)
738      self.assertAllEqual(100 * [[12.0]], reduced)
739
740  def testFromOperation(self):
741    with self.test_scope():
742      tensor = array_ops.ones([3, 100, 2, 2])
743      reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3])
744      self.assertAllEqual(100 * [12.0], reduced)
745
746  def testAsFunctionInput(self):
747    with self.test_scope():
748
749      @function.defun
750      def f(x):
751        return math_ops.reduce_sum(x, axis=2)
752
753      tensor = constant_op.constant(100 * [[[10.0, 2.0]]])
754      reduced = f(tensor)
755      self.assertAllEqual(100 * [[12.0]], reduced)
756
757  def testAsFunctionOutput(self):
758    with self.test_scope():
759
760      @function.defun
761      def f(x):
762        return x * constant_op.constant(100 * [[[10.0, 2.0]]])
763
764      y = f(3)
765      reduced = math_ops.reduce_sum(y, axis=2)
766      self.assertAllEqual(100 * [[36.0]], reduced)
767
768
769def multiple_tpus():
770  devices = context.context().devices()
771  return len([d for d in devices if 'device:TPU:' in d]) > 1
772
773
774class MultiDeviceTest(xla_test.XLATestCase):
775  """Test running TPU computation on more than one core."""
776
777  def testBasic(self):
778    if not multiple_tpus():
779      self.skipTest('MultiDeviceTest requires multiple TPU devices.')
780
781    # Compute 10 on TPU core 0
782    with ops.device('device:TPU:0'):
783      two = constant_op.constant(2)
784      five = constant_op.constant(5)
785      ten = two * five
786      self.assertAllEqual(10, ten)
787
788    # Compute 6 on TPU core 1
789    with ops.device('device:TPU:1'):
790      two = constant_op.constant(2)
791      three = constant_op.constant(3)
792      six = two * three
793      self.assertAllEqual(6, six)
794
795    # Copy 10 and 6 to CPU and sum them
796    self.assertAllEqual(16, ten + six)
797
798
799if __name__ == '__main__':
800  ops.enable_eager_execution(
801      config=config_pb2.ConfigProto(log_device_placement=True))
802  googletest.main()
803