• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Backend-dependent tests for the Python XLA client."""
16
17import functools
18import itertools
19import re
20import threading
21import unittest
22
23from absl import flags
24from absl.testing import absltest
25from absl.testing import parameterized
26import numpy as np
27
28from tensorflow.compiler.xla.python import xla_client
29
30# pylint: disable=g-import-not-at-top
31try:
32  from tensorflow.compiler.xla.python import custom_call_for_test
33except ImportError:
34  custom_call_for_test = None
35
36bfloat16 = xla_client.bfloat16
37ops = xla_client.ops
38
39FLAGS = flags.FLAGS
40
41# We choose to ignore pylint's complaints about complex comprehensions, which we
42# use widely for parameterizing tests.
43# pylint: disable=g-complex-comprehension
44
45
46def TestFactory(xla_backend, cloud_tpu=False):
47  tests = []
48
49  if not cloud_tpu:
50    int_dtypes = [np.int32, np.int64, np.uint32, np.uint64]
51    # TODO(phawkins): test np.float16, where supported.
52    float_dtypes = [bfloat16, np.float32, np.float64]
53    complex_dtypes = [np.complex64, np.complex128]
54    standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
55  else:
56    int_dtypes = [np.int32, np.uint32]
57    float_dtypes = [np.float32]
58    complex_dtypes = [np.complex64]
59    standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
60  dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_]
61
62  class ComputationTest(parameterized.TestCase):
63    """Base class for running an XLA Computation through the local client."""
64
65    def setUp(self):
66      super(ComputationTest, self).setUp()
67      self.backend = xla_backend()
68
69    def _NewComputation(self, name=None):
70      if name is None:
71        name = self.id()
72      return xla_client.XlaBuilder(name)
73
74    def _Execute(self, c, arguments):
75      compiled_c = self.backend.compile(c.build())
76      return xla_client.execute_with_python_values(
77          compiled_c, arguments, backend=self.backend)
78
79    def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
80      assert expected is not None
81      results = self._Execute(c, arguments)
82      self.assertLen(results, len(expected))
83      for result, e in zip(results, expected):
84        # Numpy's comparison methods are a bit too lenient by treating inputs as
85        # "array-like", meaning that scalar 4 will be happily compared equal to
86        # [[4]]. We'd like to be more strict so assert shapes as well.
87        self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape)
88        assert_func(result, e)
89
90    def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
91      self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments,
92                                 expected)
93
94    def _ExecuteAndCompareClose(self,
95                                c,
96                                arguments=(),
97                                expected=None,
98                                rtol=1e-4,
99                                atol=0):
100      self._ExecuteAndAssertWith(
101          functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol),
102          c, arguments, expected)
103
104  def NumpyArrayF32(*args, **kwargs):
105    """Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
106    return np.array(*args, dtype=np.float32, **kwargs)
107
108  def NumpyArrayF64(*args, **kwargs):
109    """Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
110    return np.array(*args, dtype=np.float64, **kwargs)
111
112  def NumpyArrayS32(*args, **kwargs):
113    """Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
114    return np.array(*args, dtype=np.int32, **kwargs)
115
116  def NumpyArrayBool(*args, **kwargs):
117    """Convenience wrapper to create Numpy arrays with a np.bool dtype."""
118    return np.array(*args, dtype=np.bool, **kwargs)
119
120  class ComputationPrinting(absltest.TestCase):
121
122    def setUp(self):
123      super(ComputationPrinting, self).setUp()
124      self.backend = xla_backend()
125
126    def ExampleComputation(self):
127      builder = xla_client.XlaBuilder("acomputation")
128      p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0)))
129      p1 = ops.Parameter(
130          builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
131      x = ops.Mul(p0, p1)
132      ops.Add(x, x)
133      return builder.build()
134
135    @unittest.skipIf(cloud_tpu, "not implemented")
136    def testCompiledHloModuleToHloText(self):
137      computation = self.ExampleComputation()
138      executable = self.backend.compile(computation)
139      hlo_modules = executable.hlo_modules()
140      self.assertLen(hlo_modules, 1)
141      hlo_text = hlo_modules[0].to_string()
142      self.assertTrue(hlo_text.startswith("HloModule acomputation"))
143      self.assertIn("fusion", hlo_text)
144
145    @unittest.skipIf(cloud_tpu, "not implemented")
146    def testFlopEstimate(self):
147      computation = self.ExampleComputation()
148      properties = xla_client._xla.hlo_module_cost_analysis(
149          self.backend, computation.as_hlo_module())
150      self.assertEqual(properties["flops"], 8.0)
151
152  tests.append(ComputationPrinting)
153
154  class ComputationsWithConstantsTest(ComputationTest):
155    """Tests focusing on Constant ops."""
156
157    @parameterized.named_parameters({
158        "testcase_name": "_{}".format(dtype.__name__),
159        "dtype": dtype,
160    } for dtype in int_dtypes + float_dtypes)
161    def testConstantScalarSum(self, dtype):
162      if dtype == np.int8 and self.backend.platform == "tpu":
163        self.skipTest("TPU doesn't support int8")
164      c = self._NewComputation()
165      ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14)))
166      self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)])
167
168    @parameterized.named_parameters({
169        "testcase_name": "_{}".format(dtype.__name__),
170        "dtype": dtype,
171    } for dtype in float_dtypes)
172    def testConstantVectorMul(self, dtype):
173      c = self._NewComputation()
174      ops.Mul(
175          ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)),
176          ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype)))
177      self._ExecuteAndCompareClose(
178          c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3)
179
180    @parameterized.named_parameters({
181        "testcase_name": "_{}".format(dtype.__name__),
182        "dtype": dtype,
183    } for dtype in float_dtypes)
184    def testConstantVectorScalarDiv(self, dtype):
185      c = self._NewComputation()
186      ops.Div(
187          ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)),
188          ops.Constant(c, dtype(2.0)))
189      self._ExecuteAndCompareClose(
190          c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3)
191
192    @parameterized.named_parameters({
193        "testcase_name": "_{}".format(dtype.__name__),
194        "dtype": dtype,
195    } for dtype in float_dtypes)
196    def testConstantVectorScalarPow(self, dtype):
197      c = self._NewComputation()
198      ops.Pow(
199          ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)),
200          ops.Constant(c, dtype(2.)))
201      self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]])
202
203    def testIota(self):
204      c = self._NewComputation()
205      ops.Iota(c, xla_client.PrimitiveType.F32, 10)
206      self._ExecuteAndCompareExact(
207          c, expected=[np.arange(10, dtype=np.float32)])
208
209    @parameterized.named_parameters({
210        "testcase_name": "_{}".format(dtype.__name__),
211        "dtype": dtype,
212    } for dtype in int_dtypes)
213    def testBroadcastedIota(self, dtype):
214      c = self._NewComputation()
215      shape = xla_client.Shape.array_shape(
216          xla_client.dtype_to_etype(dtype), (2, 3))
217      ops.Iota(c, shape, 1)
218      expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype)
219      self._ExecuteAndCompareExact(c, expected=[expected])
220
221    def testBooleanAnd(self):
222      c = self._NewComputation()
223      ops.And(
224          ops.Constant(c, NumpyArrayBool([True, False, True, False])),
225          ops.Constant(c, NumpyArrayBool([True, True, False, False])))
226      self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]])
227
228    def testBooleanOr(self):
229      c = self._NewComputation()
230      ops.Or(
231          ops.Constant(c, NumpyArrayBool([True, False, True, False])),
232          ops.Constant(c, NumpyArrayBool([True, True, False, False])))
233      self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]])
234
235    def testBooleanXor(self):
236      c = self._NewComputation()
237      ops.Xor(
238          ops.Constant(c, NumpyArrayBool([True, False, True, False])),
239          ops.Constant(c, NumpyArrayBool([True, True, False, False])))
240      self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]])
241
242    @parameterized.named_parameters({
243        "testcase_name": "_{}".format(dtype.__name__),
244        "dtype": dtype,
245    } for dtype in float_dtypes)
246    def testSum2D(self, dtype):
247      c = self._NewComputation()
248      ops.Add(
249          ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)),
250          ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype)))
251      self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]])
252
253    def testShiftLeft(self):
254      c = self._NewComputation()
255      ops.ShiftLeft(
256          ops.Constant(c, NumpyArrayS32([3])),
257          ops.Constant(c, NumpyArrayS32([2])))
258      self._ExecuteAndCompareClose(c, expected=[[12]])
259
260    def testShiftRightArithmetic(self):
261      c = self._NewComputation()
262      ops.ShiftRightArithmetic(
263          ops.Constant(c, NumpyArrayS32([-2])),
264          ops.Constant(c, NumpyArrayS32([1])))
265      self._ExecuteAndCompareClose(c, expected=[[-1]])
266
267    def testShiftRightLogical(self):
268      c = self._NewComputation()
269      ops.ShiftRightLogical(
270          ops.Constant(c, NumpyArrayS32([-1])),
271          ops.Constant(c, NumpyArrayS32([1])))
272      self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]])
273
274    @parameterized.named_parameters({
275        "testcase_name": "_{}".format(dtype.__name__),
276        "dtype": dtype,
277    } for dtype in float_dtypes)
278    def testSum2DWith1DBroadcastDim0(self, dtype):
279      # sum of a 2D array with a 1D array where the latter is replicated across
280      # dimension 0 to match the former's shape.
281      c = self._NewComputation()
282      ops.Add(
283          ops.Constant(c,
284                       np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
285                                dtype=dtype)),
286          ops.Constant(c, np.array([10, 20, 30], dtype=dtype)),
287          broadcast_dimensions=(0,))
288      self._ExecuteAndCompareClose(
289          c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]])
290
291    @parameterized.named_parameters({
292        "testcase_name": "_{}".format(dtype.__name__),
293        "dtype": dtype,
294    } for dtype in float_dtypes)
295    def testSum2DWith1DBroadcastDim1(self, dtype):
296      # sum of a 2D array with a 1D array where the latter is replicated across
297      # dimension 1 to match the former's shape.
298      c = self._NewComputation()
299      ops.Add(
300          ops.Constant(c,
301                       np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
302                                dtype=dtype)),
303          ops.Constant(c, np.array([10, 20, 30], dtype=dtype)),
304          broadcast_dimensions=(1,))
305      self._ExecuteAndCompareClose(
306          c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]])
307
308    @parameterized.named_parameters({
309        "testcase_name": "_{}".format(dtype.__name__),
310        "dtype": dtype,
311    } for dtype in float_dtypes)
312    def testConstantAxpy(self, dtype):
313      c = self._NewComputation()
314      ops.Add(
315          ops.Mul(
316              ops.Constant(c, dtype(2)),
317              ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))),
318          ops.Constant(c, np.array([100, -100, 200, -200], dtype)))
319      self._ExecuteAndCompareClose(
320          c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3)
321
322    def testCustomCall(self):
323      if self.backend.platform != "cpu":
324        self.skipTest("Test requires cpu platform")
325      c = self._NewComputation()
326      for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
327        xla_client.register_custom_call_target(name, fn, platform="cpu")
328      ops.CustomCallWithLayout(
329          c,
330          b"test_subtract_f32",
331          operands=[
332              ops.Constant(c, np.float32(1.25)),
333              ops.Constant(c, np.float32(0.5))
334          ],
335          shape_with_layout=xla_client.Shape.array_shape(
336              np.dtype(np.float32), (), ()),
337          operand_shapes_with_layout=[
338              xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
339              xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
340          ])
341      self._ExecuteAndCompareClose(c, expected=[0.75])
342
343  tests.append(ComputationsWithConstantsTest)
344
345  class ComputationFromProtoTest(absltest.TestCase):
346    """Test computation execution from HLO proto."""
347
348    def setUp(self):
349      super(ComputationFromProtoTest, self).setUp()
350      self.backend = xla_backend()
351
352    def testExecuteFromProto(self):
353      # Build the HLO proto
354      b = xla_client.XlaBuilder("computation")
355      ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
356      serialized_proto = b.build().as_serialized_hlo_module_proto()
357
358      # Load and execute the proto
359      c = xla_client.XlaComputation(serialized_proto)
360      ans, = xla_client.execute_with_python_values(
361          self.backend.compile(c), (), backend=self.backend)
362      np.testing.assert_equal(ans, np.int32(3))
363
364  tests.append(ComputationFromProtoTest)
365
366  class ParametersTest(ComputationTest):
367    """Tests focusing on Parameter ops and argument-passing."""
368
369    @parameterized.named_parameters({
370        "testcase_name": "_{}".format(dtype.__name__),
371        "dtype": dtype,
372    } for dtype in int_dtypes)
373    def testScalarTimesVector(self, dtype):
374      c = self._NewComputation()
375      arg0 = np.array(3, dtype=dtype)
376      arg1 = np.array([10, 15, -2, 7], dtype=dtype)
377      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
378      p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
379      ops.Mul(p0, p1)
380      self._ExecuteAndCompareExact(
381          c, arguments=[arg0, arg1], expected=[arg0 * arg1])
382
383    # TODO(phawkins): test comparison harness doesn't support bfloat16
384    @parameterized.named_parameters({
385        "testcase_name": "_{}".format(dtype.__name__),
386        "dtype": dtype,
387    } for dtype in float_dtypes if dtype != bfloat16)
388    def testScalarMinusVectorExplicitNumbering(self, dtype):
389      # Use explicit numbering and pass parameter_num first. Sub is used since
390      # it's not commutative and can help catch parameter reversal within the
391      # computation.
392      c = self._NewComputation()
393      arg0 = np.array(2.0, dtype=dtype)
394      arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype)
395      p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
396      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
397      ops.Sub(p1, p0)
398      self._ExecuteAndCompareClose(
399          c, arguments=[arg0, arg1], expected=[arg1 - arg0])
400
401  tests.append(ParametersTest)
402
403  class BufferTest(ComputationTest):
404    """Tests focusing on execution with Buffers."""
405
406    def testConstantSum(self):
407      c = self._NewComputation()
408      ops.Add(
409          ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14)))
410      self._ExecuteAndCompareClose(c, expected=[4.25])
411
412    def testOneParameterSum(self):
413      c = self._NewComputation()
414      ops.Add(
415          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
416          ops.Constant(c, np.float32(3.14)))
417      self._ExecuteAndCompareClose(
418          c, arguments=[NumpyArrayF32(1.11)], expected=[4.25])
419
420    def testTwoParameterSum(self):
421      c = self._NewComputation()
422      ops.Add(
423          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
424          ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.))))
425      self._ExecuteAndCompareClose(
426          c,
427          arguments=[NumpyArrayF32(1.11),
428                     NumpyArrayF32(3.14)],
429          expected=[4.25])
430
431    @unittest.skipIf(cloud_tpu, "not implemented")
432    def testCannotCallWithDeletedBuffers(self):
433      c = self._NewComputation()
434      ops.Add(
435          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
436          ops.Constant(c, np.float32(3.14)))
437      arg = NumpyArrayF32(1.11)
438      compiled_c = self.backend.compile(c.build())
439      arg_buffer = self.backend.buffer_from_pyval(arg)
440      arg_buffer.delete()
441      with self.assertRaises(RuntimeError):
442        compiled_c.execute([arg_buffer])
443
444    def testXlaShape(self):
445      pyval = np.array([[1., 2.]], np.float32)
446      local_buffer = self.backend.buffer_from_pyval(pyval)
447      xla_shape = local_buffer.xla_shape()
448      self.assertEqual(xla_shape.dimensions(), (1, 2))
449      self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
450
451    def testBlockHostUntilReadyWorks(self):
452      arg = np.array([[1., 2.]], np.float32)
453      arg_buffer = self.backend.buffer_from_pyval(arg)
454      arg_buffer.block_host_until_ready()
455      # This test merely checks that nothing goes awry when we call
456      # block_host_until_ready(); it's difficult to test anything else.
457
458    def testBlockHostUntilReadyRaisesOnDeletedBuffer(self):
459      arg = np.array([[1., 2.]], np.float32)
460      buffer = self.backend.buffer_from_pyval(arg)
461      buffer.delete()
462      with self.assertRaisesRegex(
463          RuntimeError,
464          re.escape(
465              "BlockHostUntilReady() called on deleted or donated buffer")):
466        buffer.block_host_until_ready()
467
468    def testDeviceArrayBaseSignatures(self):
469      # When extending `DeviceArrayBase`, the object behaves as a `DeviceArray`
470      # and thus needs to correctly implement the following methods.
471      arg = np.array([[1., 2., 3.]], np.float32)
472      buffer = self.backend.buffer_from_pyval(arg)
473      if not isinstance(buffer, xla_client.DeviceArrayBase):
474        raise unittest.SkipTest(
475            "The objectof type {} do not extend DeviceArrayBase".format(
476                type(buffer)))
477
478      self.assertEqual(buffer.__array_priority__, 100)
479      self.assertEqual(buffer.shape, (1, 3))
480      self.assertEqual(buffer.dtype, np.float32)
481      self.assertEqual(buffer.size, 3)
482      self.assertEqual(buffer.ndim, 2)
483
484      self.assertIs(buffer, buffer.block_until_ready())
485      buffer.delete()
486      with self.assertRaises(RuntimeError):
487        buffer.block_until_ready()
488
489    def testOnDeviceSizeInBytes(self):
490      if not isinstance(self.backend, xla_client.Client):
491        self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.")
492      arg0 = np.array([])
493      arg1 = np.array([[0., 1., 2.]], np.float32)
494      arg2 = np.array([[3., 4., 5.]], bfloat16)
495      arg0_buffer = self.backend.buffer_from_pyval(arg0)
496      arg1_buffer = self.backend.buffer_from_pyval(arg1)
497      arg2_buffer = self.backend.buffer_from_pyval(arg2)
498      self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0)
499      # OnDeviceSizeInBytes varies depending on the platform. Confirm there's
500      # a reasonable value.
501      self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0)
502      self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0)
503
504    def testLiveBuffers(self):
505      if not isinstance(self.backend, xla_client.Client):
506        self.skipTest("TPU Driver doesn't support LiveBuffers().")
507      self.assertEmpty(self.backend.live_buffers())
508      arg0 = np.array([])
509      arg1 = np.array([[0., 1., 2.]], np.float32)
510      arg2 = np.array([[3., 4., 5.]], bfloat16)
511      arg0_buffer = self.backend.buffer_from_pyval(arg0)
512      arg1_buffer = self.backend.buffer_from_pyval(arg1)
513      arg2_buffer = self.backend.buffer_from_pyval(arg2)
514      self.assertLen(self.backend.live_buffers(), 3)
515      self.assertIs(self.backend.live_buffers()[0], arg2_buffer)
516      self.assertIs(self.backend.live_buffers()[1], arg1_buffer)
517      self.assertIs(self.backend.live_buffers()[2], arg0_buffer)
518
519      arg1_buffer.delete()
520      self.assertLen(self.backend.live_buffers(), 2)
521      self.assertIs(self.backend.live_buffers()[0], arg2_buffer)
522      self.assertIs(self.backend.live_buffers()[1], arg0_buffer)
523
524      arg0_buffer.delete()
525      arg2_buffer.delete()
526      self.assertEmpty(self.backend.live_buffers())
527
528    def testCopyToHost(self):
529      arg0 = np.array([[1., 2.]], np.float32)
530      arg1 = np.array([[3., 4.]], np.float32)
531      arg0_buffer = self.backend.buffer_from_pyval(arg0)
532      arg1_buffer = self.backend.buffer_from_pyval(arg1)
533      # Prefetch two buffers using copy_to_host_async, and then retrieve their
534      # values using to_py.
535      arg0_buffer.copy_to_host_async()
536      arg0_buffer.copy_to_host_async()  # Duplicate calls don't do anything.
537      arg1_buffer.copy_to_host_async()
538      np.testing.assert_equal(arg0, arg0_buffer.to_py())
539      np.testing.assert_equal(arg1, arg1_buffer.to_py())
540      # copy_to_host_async does nothing after to_py is called.
541      arg0_buffer.copy_to_host_async()
542      np.testing.assert_equal(arg0, arg0_buffer.to_py())
543
544    def testDevice(self):
545      x = np.arange(8, dtype=np.int32)
546      for device in self.backend.local_devices():
547        buf = self.backend.buffer_from_pyval(x, device=device)
548        self.assertEqual(buf.device(), device)
549        np.testing.assert_equal(x, buf.to_py())
550
551    def testStandardTypes(self):
552      for dtype in standard_dtypes:
553        if dtype == bfloat16 or dtype == np.complex128:
554          continue
555        arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype))
556        arr = arr.to_py()
557        self.assertEqual(dtype, type(arr[0]))
558
559  tests.append(BufferTest)
560
561  class SingleOpTest(ComputationTest):
562    """Tests for single ops.
563
564    The goal here is smoke testing - to exercise the most basic functionality of
565    single XLA ops. As minimal as possible number of additional ops are added
566    around the op being tested.
567    """
568
569    @parameterized.named_parameters({
570        "testcase_name": "_{}".format(dtype.__name__),
571        "dtype": dtype,
572    } for dtype in float_dtypes)
573    def testConcatenate(self, dtype):
574      c = self._NewComputation()
575      args = (
576          ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)),
577          ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)),
578      )
579      ops.ConcatInDim(c, args, dimension=0)
580      self._ExecuteAndCompareExact(
581          c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)])
582
583    # pyformat: disable
584    @parameterized.named_parameters({
585        "testcase_name": "_{}_{}".format(src_dtype.__name__,
586                                         dst_dtype.__name__),
587        "src_dtype": src_dtype,
588        "dst_dtype": dst_dtype,
589    } for src_dtype, dst_dtype in itertools.permutations(
590        [np.bool, np.int32, np.int64, np.float32, np.float64], 2))
591    # pyformat: enable
592    def testConvertElementType(self, src_dtype, dst_dtype):
593      if ((src_dtype in [np.int64, np.float64] or
594           dst_dtype in [np.int64, np.float64]) and
595          self.backend.platform == "tpu"):
596        self.skipTest("TPU doesn't support float64")
597      c = self._NewComputation()
598      x = np.array([0, 1, 0, 0, 1], dtype=src_dtype)
599      ops.ConvertElementType(
600          ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
601
602      result = xla_client.execute_with_python_values(
603          self.backend.compile(c.build()), (), backend=self.backend)
604      self.assertLen(result, 1)
605      expected = np.array(x, dtype=dst_dtype)
606
607      self.assertEqual(result[0].shape, expected.shape)
608      self.assertEqual(result[0].dtype, expected.dtype)
609      np.testing.assert_equal(result[0], expected)
610
611    # pyformat: disable
612    @parameterized.named_parameters(
613        {
614            "testcase_name": "_{}_{}".format(src_dtype.__name__,
615                                             dst_dtype.__name__),
616            "src_dtype": src_dtype,
617            "dst_dtype": dst_dtype,
618        }
619        for dtypes in [[np.int32, np.float32], [np.int64, np.float64]]
620        for src_dtype, dst_dtype in itertools.permutations(dtypes, 2))
621    # pyformat: enable
622    def testBitcastConvertType(self, src_dtype, dst_dtype):
623      if (np.float64 in (src_dtype, dst_dtype) and
624          self.backend.platform == "tpu"):
625        self.skipTest("TPU doesn't support float64")
626      c = self._NewComputation()
627      x = np.array([0, 1, 0, 0, 1], dtype=src_dtype)
628      ops.BitcastConvertType(
629          ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
630
631      result = xla_client.execute_with_python_values(
632          self.backend.compile(c.build()), (), backend=self.backend)
633      self.assertLen(result, 1)
634      expected = x.view(dst_dtype)
635
636      self.assertEqual(result[0].shape, expected.shape)
637      self.assertEqual(result[0].dtype, expected.dtype)
638      np.testing.assert_equal(result[0], expected)
639
640    # TODO(b/123523486) implement AllToAll on CPU
641    def DISABLED_testAllToAllOneReplica(self):
642      samples = [
643          NumpyArrayF32([97.0]),
644          NumpyArrayF32([64.0, 117.0]),
645          NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
646      ]
647      for lhs in samples[:1]:
648        c = self._NewComputation()
649        ops.AllToAll(ops.Constant(c, lhs), 0, 0)
650        self._ExecuteAndCompareExact(c, expected=[lhs])
651
652    def testCrossReplicaSumOneReplica(self):
653      samples = [
654          NumpyArrayF32(42.0),
655          NumpyArrayF32([97.0]),
656          NumpyArrayF32([64.0, 117.0]),
657          NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
658      ]
659      for lhs in samples:
660        c = self._NewComputation()
661        ops.CrossReplicaSum(ops.Constant(c, lhs))
662        self._ExecuteAndCompareExact(c, expected=[lhs])
663
664    def testReplicaId(self):
665      c = self._NewComputation()
666      _ = ops.ReplicaId(c)
667      self._ExecuteAndCompareExact(c, expected=[0])
668
669    def testCrossReplicaSumOneReplicaWithSingletonGroup(self):
670      samples = [
671          NumpyArrayF32(42.0),
672          NumpyArrayF32([97.0]),
673          NumpyArrayF32([64.0, 117.0]),
674          NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
675      ]
676      for lhs in samples:
677        c = self._NewComputation()
678        ops.CrossReplicaSum(
679            ops.Constant(c, lhs), xla_client.make_replica_groups([[0]]))
680        self._ExecuteAndCompareExact(c, expected=[lhs])
681
682    # TODO(phawkins): np.dot implementation doesn't support bfloat16
683    @parameterized.named_parameters({
684        "testcase_name": "_{}".format(dtype.__name__),
685        "dtype": dtype,
686    } for dtype in float_dtypes if dtype != bfloat16)
687    def testDotMatrixVector(self, dtype):
688      c = self._NewComputation()
689      lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype)
690      rhs = np.array([[10.0], [20.0]], dtype=dtype)
691      ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs))
692      self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)])
693
694    # TODO(phawkins): np.dot implementation doesn't support bfloat16
695    @parameterized.named_parameters({
696        "testcase_name": "_{}".format(dtype.__name__),
697        "dtype": dtype,
698    } for dtype in float_dtypes if dtype != bfloat16)
699    def testDotMatrixMatrix(self, dtype):
700      c = self._NewComputation()
701      lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype)
702      rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype)
703      ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs))
704      self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)])
705
706    def testDotGeneral(self):
707      c = self._NewComputation()
708      rng = np.random.RandomState(0)
709      lhs = NumpyArrayF32(rng.randn(10, 3, 4))
710      rhs = NumpyArrayF32(rng.randn(10, 4, 5))
711      dimension_numbers = xla_client.make_dot_dimension_numbers(
712          (([2], [1]), ([0], [0])))
713      ops.DotGeneral(
714          ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers)
715      self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6)
716
717    def testDotGeneralWithDotDimensionNumbersProto(self):
718      c = self._NewComputation()
719      rng = np.random.RandomState(0)
720      lhs = NumpyArrayF32(rng.randn(10, 3, 4))
721      rhs = NumpyArrayF32(rng.randn(10, 4, 5))
722
723      dimension_numbers = xla_client.DotDimensionNumbers()
724      dimension_numbers.lhs_contracting_dimensions.append(2)
725      dimension_numbers.rhs_contracting_dimensions.append(1)
726      dimension_numbers.lhs_batch_dimensions.append(0)
727      dimension_numbers.rhs_batch_dimensions.append(0)
728
729      ops.DotGeneral(
730          ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers)
731      self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6)
732
733    def testDotGeneralWithPrecisionConfig(self):
734      c = self._NewComputation()
735      rng = np.random.RandomState(0)
736      lhs = NumpyArrayF32(rng.randn(10, 3, 4))
737      rhs = NumpyArrayF32(rng.randn(10, 4, 5))
738      dimension_numbers = xla_client.make_dot_dimension_numbers(
739          (([2], [1]), ([0], [0])))
740      config = xla_client.PrecisionConfig()
741      config.operand_precision.append(config.Precision.HIGH)
742      config.operand_precision.append(config.Precision.HIGHEST)
743      ops.DotGeneral(
744          ops.Constant(c, lhs),
745          ops.Constant(c, rhs),
746          dimension_numbers,
747          precision_config=config)
748      self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6)
749
750    def testConvGeneralDilatedF32(self):
751      c = self._NewComputation()
752      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
753      lhs = a(1, 1, 2, 3)
754      rhs = a(1, 1, 1, 2) * 10
755      strides = [1, 1]
756      pads = [(1, 0), (0, 1)]
757      lhs_dilation = (2, 1)
758      rhs_dilation = (1, 1)
759      dimension_numbers = xla_client.make_convolution_dimension_numbers(
760          ("NCHW", "OIHW", "NCHW"), 2)
761      ops.ConvGeneralDilated(
762          ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads,
763          lhs_dilation, rhs_dilation, dimension_numbers)
764      result = np.array([[[
765          [0., 0., 0.],
766          [10., 20., 0.],
767          [0., 0., 0.],
768          [40., 50., 0.],
769      ]]])
770      self._ExecuteAndCompareClose(c, expected=[result])
771
772    def testConvGeneralDilatedF32WithPrecisionConfig(self):
773      c = self._NewComputation()
774      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
775      lhs = a(1, 1, 2, 3)
776      rhs = a(1, 1, 1, 2) * 10
777      strides = [1, 1]
778      pads = [(1, 0), (0, 1)]
779      lhs_dilation = (2, 1)
780      rhs_dilation = (1, 1)
781      dimension_numbers = xla_client.make_convolution_dimension_numbers(
782          ("NCHW", "OIHW", "NCHW"), 2)
783      config = xla_client.PrecisionConfig()
784      config.operand_precision.append(config.Precision.HIGHEST)
785      config.operand_precision.append(config.Precision.DEFAULT)
786      ops.ConvGeneralDilated(
787          ops.Constant(c, lhs),
788          ops.Constant(c, rhs),
789          strides,
790          pads,
791          lhs_dilation,
792          rhs_dilation,
793          dimension_numbers,
794          precision_config=config)
795      result = np.array([[[
796          [0., 0., 0.],
797          [10., 20., 0.],
798          [0., 0., 0.],
799          [40., 50., 0.],
800      ]]])
801      self._ExecuteAndCompareClose(c, expected=[result])
802
803    def testConvGeneralDilatedPermutedF32(self):
804      c = self._NewComputation()
805      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
806      lhs = a(1, 1, 2, 3)
807      rhs = a(1, 1, 1, 2) * 10
808      strides = [1, 1]
809      pads = [(1, 0), (0, 1)]
810      lhs_dilation = (2, 1)
811      rhs_dilation = (1, 1)
812
813      dimension_numbers = xla_client.make_convolution_dimension_numbers(
814          ("NHWC", "OIHW", "CWNH"), 2)
815      ops.ConvGeneralDilated(
816          ops.Constant(c, np.transpose(lhs,
817                                       (0, 2, 3, 1))), ops.Constant(c, rhs),
818          strides, pads, lhs_dilation, rhs_dilation, dimension_numbers)
819      result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.],
820                           [40., 50., 0.]]]])
821      self._ExecuteAndCompareClose(
822          c, expected=[np.transpose(result, (1, 3, 0, 2))])
823
824    def testConvGeneralDilatedGroupedConvolutionF32(self):
825      c = self._NewComputation()
826      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
827      lhs = a(1, 2, 2, 3)
828      rhs = a(2, 1, 1, 2) * 10
829      strides = [1, 1]
830      pads = [(1, 0), (0, 1)]
831      lhs_dilation = (2, 1)
832      rhs_dilation = (1, 1)
833      dimension_numbers = xla_client.make_convolution_dimension_numbers(
834          ("NCHW", "OIHW", "NCHW"), 2)
835      feature_group_count = 2
836      ops.ConvGeneralDilated(
837          ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads,
838          lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count)
839      result = np.array([[[
840          [0., 0., 0.],
841          [10., 20., 0.],
842          [0., 0., 0.],
843          [40., 50., 0.],
844      ], [
845          [0., 0., 0.],
846          [330., 380., 160.],
847          [0., 0., 0.],
848          [480., 530., 220.],
849      ]]])
850      self._ExecuteAndCompareClose(c, expected=[result])
851
852    def testBooleanNot(self):
853      c = self._NewComputation()
854      arr = NumpyArrayBool([True, False, True])
855      ops.Not(ops.Constant(c, arr))
856      self._ExecuteAndCompareClose(c, expected=[~arr])
857
858    def testPopulationCount(self):
859      c = self._NewComputation()
860      arr = NumpyArrayS32([3, 0, 1])
861      ops.PopulationCount(ops.Constant(c, arr))
862      self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])])
863
864    def testCountLeadingZeros(self):
865      c = self._NewComputation()
866      arr = NumpyArrayS32([0x7FFF, 0x12345678])
867      ops.Clz(ops.Constant(c, arr))
868      self._ExecuteAndCompareClose(c, expected=[[17, 3]])
869
870    def testExp(self):
871      c = self._NewComputation()
872      arr = NumpyArrayF32([3.3, 12.1])
873      ops.Exp(ops.Constant(c, arr))
874      self._ExecuteAndCompareClose(c, expected=[np.exp(arr)])
875
876    def testExpm1(self):
877      c = self._NewComputation()
878      arr = NumpyArrayF32([3.3, 12.1])
879      ops.Expm1(ops.Constant(c, arr))
880      self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)])
881
882    def testRound(self):
883      c = self._NewComputation()
884      arr = NumpyArrayF32([3.3, 12.1])
885      ops.Round(ops.Constant(c, arr))
886      self._ExecuteAndCompareClose(c, expected=[np.round(arr)])
887
888    def testLog(self):
889      c = self._NewComputation()
890      arr = NumpyArrayF32([3.3, 12.1])
891      ops.Log(ops.Constant(c, arr))
892      self._ExecuteAndCompareClose(c, expected=[np.log(arr)])
893
894    def testLog1p(self):
895      c = self._NewComputation()
896      arr = NumpyArrayF32([3.3, 12.1])
897      ops.Log1p(ops.Constant(c, arr))
898      self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)])
899
900    def testNeg(self):
901      c = self._NewComputation()
902      arr = NumpyArrayF32([3.3, 12.1])
903      ops.Neg(ops.Constant(c, arr))
904      self._ExecuteAndCompareClose(c, expected=[-arr])
905
906    def testFloor(self):
907      c = self._NewComputation()
908      arr = NumpyArrayF32([3.3, 12.1])
909      ops.Floor(ops.Constant(c, arr))
910      self._ExecuteAndCompareClose(c, expected=[np.floor(arr)])
911
912    def testCeil(self):
913      c = self._NewComputation()
914      arr = NumpyArrayF32([3.3, 12.1])
915      ops.Ceil(ops.Constant(c, arr))
916      self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)])
917
918    def testAbs(self):
919      c = self._NewComputation()
920      arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
921      ops.Abs(ops.Constant(c, arr))
922      self._ExecuteAndCompareClose(c, expected=[np.abs(arr)])
923
924    def testTanhF32(self):
925      c = self._NewComputation()
926      arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001])
927      ops.Tanh(ops.Constant(c, arr))
928      self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)])
929
930    def testTanhF64(self):
931      if self.backend.platform == "tpu":
932        self.skipTest("TPU doesn't support 64bit tanh")
933      c = self._NewComputation()
934      arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001])
935      ops.Tanh(ops.Constant(c, arr))
936      self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12)
937
938    def testTranspose(self):
939
940      def _TransposeAndTest(array, permutation):
941        c = self._NewComputation()
942        ops.Transpose(ops.Constant(c, array), permutation)
943        expected = np.transpose(array, permutation)
944        self._ExecuteAndCompareClose(c, expected=[expected])
945
946      _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
947      _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
948      _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
949      _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
950
951      arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
952      for permutation in itertools.permutations(range(arr.ndim)):
953        _TransposeAndTest(arr, permutation)
954        _TransposeAndTest(np.asfortranarray(arr), permutation)
955
956    def testEq(self):
957      c = self._NewComputation()
958      ops.Eq(
959          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])),
960          ops.Constant(c, NumpyArrayS32([4, 2, 3, 1])))
961      self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]])
962
963    def testNe(self):
964      c = self._NewComputation()
965      ops.Ne(
966          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])),
967          ops.Constant(c, NumpyArrayS32([4, 2, 3, 1])))
968      self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]])
969
970      ops.Ne(
971          ops.Constant(c, NumpyArrayF32([-2.0, 0.0,
972                                         float("nan"),
973                                         float("nan")])),
974          ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0,
975                                         float("nan")])))
976      self._ExecuteAndAssertWith(
977          np.testing.assert_allclose,
978          c, (),
979          expected=[[True, False, True, True]])
980
981    def testGt(self):
982      c = self._NewComputation()
983      ops.Gt(
984          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
985          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
986      self._ExecuteAndCompareExact(
987          c, expected=[[False, True, True, False, False]])
988
989    def testGe(self):
990      c = self._NewComputation()
991      ops.Ge(
992          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
993          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
994      self._ExecuteAndCompareExact(
995          c, expected=[[True, True, True, False, False]])
996
997    def testLt(self):
998      c = self._NewComputation()
999      ops.Lt(
1000          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1001          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1002      self._ExecuteAndCompareExact(
1003          c, expected=[[False, False, False, True, True]])
1004
1005    def testLe(self):
1006      c = self._NewComputation()
1007      ops.Le(
1008          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1009          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1010      self._ExecuteAndCompareExact(
1011          c, expected=[[True, False, False, True, True]])
1012
1013    def testMax(self):
1014      c = self._NewComputation()
1015      ops.Max(
1016          ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
1017          ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
1018      self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]])
1019
1020    def testMaxExplicitBroadcastDim0(self):
1021      c = self._NewComputation()
1022      ops.Max(
1023          ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1024          ops.Constant(c, NumpyArrayF32([3, 4, 5])),
1025          broadcast_dimensions=(0,))
1026      self._ExecuteAndCompareExact(
1027          c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]])
1028
1029    def testMaxExplicitBroadcastDim1(self):
1030      c = self._NewComputation()
1031      ops.Max(
1032          ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1033          ops.Constant(c, NumpyArrayF32([3, 4, 5])),
1034          broadcast_dimensions=(1,))
1035      self._ExecuteAndCompareExact(
1036          c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]])
1037
1038    def testMin(self):
1039      c = self._NewComputation()
1040      ops.Min(
1041          ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
1042          ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
1043      self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]])
1044
1045    def testPad(self):
1046      c = self._NewComputation()
1047      ops.Pad(
1048          ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
1049          ops.Constant(c, NumpyArrayF32(0.0)),
1050          xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)]))
1051      self._ExecuteAndCompareClose(
1052          c,
1053          expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
1054                     [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]])
1055
1056    def testPadWithPaddingConfig(self):
1057      c = self._NewComputation()
1058      padding_config = xla_client.PaddingConfig()
1059      for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]:
1060        dimension = xla_client.PaddingConfigDimension()
1061        dimension.edge_padding_low = lo
1062        dimension.edge_padding_high = hi
1063        dimension.interior_padding = interior
1064        padding_config.dimensions.append(dimension)
1065      ops.Pad(
1066          ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
1067          ops.Constant(c, NumpyArrayF32(0.0)), padding_config)
1068      self._ExecuteAndCompareClose(
1069          c,
1070          expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
1071                     [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]])
1072
1073    def testReshape(self):
1074      c = self._NewComputation()
1075      ops.Reshape(
1076          ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
1077          dimensions=[0, 1],
1078          new_sizes=[2, 3])
1079      self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]])
1080
1081    def testCollapse(self):
1082      c = self._NewComputation()
1083      ops.Collapse(
1084          ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1085          dimensions=[1, 2])
1086      self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]])
1087
1088    def testRev(self):
1089      c = self._NewComputation()
1090      ops.Rev(
1091          ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1092          dimensions=[0, 2])
1093      self._ExecuteAndCompareExact(
1094          c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]])
1095
1096    def testReducePrecision(self):
1097      c = self._NewComputation()
1098      ops.ReducePrecision(
1099          ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])),
1100          exponent_bits=8,
1101          mantissa_bits=7)
1102      self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]])
1103
1104    def testClampF32(self):
1105      c = self._NewComputation()
1106      ops.Clamp(
1107          ops.Constant(c, NumpyArrayF32(-1)),
1108          ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
1109          ops.Constant(c, NumpyArrayF32(2)))
1110      self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]])
1111
1112    def testClampS32(self):
1113      c = self._NewComputation()
1114      ops.Clamp(
1115          ops.Constant(c, NumpyArrayS32(-1)),
1116          ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
1117          ops.Constant(c, NumpyArrayS32(2)))
1118      self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]])
1119
1120    def testSelect(self):
1121      c = self._NewComputation()
1122      ops.Select(
1123          ops.Constant(c, NumpyArrayBool([True, False, False, True, False])),
1124          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])),
1125          ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5])))
1126      self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]])
1127
1128    def testSlice(self):
1129      c = self._NewComputation()
1130      ops.Slice(
1131          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1132          [1, 0], [3, 2], [1, 1])
1133      self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]])
1134
1135    def testSliceInDim(self):
1136      c = self._NewComputation()
1137      ops.SliceInDim(
1138          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1139          start_index=1,
1140          limit_index=2,
1141          stride=1,
1142          dimno=1)
1143      self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]])
1144      ops.SliceInDim(
1145          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1146          start_index=0,
1147          limit_index=3,
1148          stride=2,
1149          dimno=0)
1150      self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]])
1151
1152    def testDynamicSlice(self):
1153      c = self._NewComputation()
1154      ops.DynamicSlice(
1155          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1156          [ops.Constant(c, NumpyArrayS32([1, 0]))], [2, 2])
1157      self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]])
1158
1159    def testDynamicUpdateSlice(self):
1160      c = self._NewComputation()
1161      ops.DynamicUpdateSlice(
1162          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1163          ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])),
1164          [ops.Constant(c, NumpyArrayS32([1, 1]))])
1165      self._ExecuteAndCompareExact(
1166          c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]])
1167
1168    def testTuple(self):
1169      c = self._NewComputation()
1170      ops.Tuple(c, [
1171          ops.Constant(c, np.int32(42)),
1172          ops.Constant(c, NumpyArrayF32([1.0, 2.0])),
1173          ops.Constant(c, NumpyArrayBool([True, False, False, True]))
1174      ])
1175      result = xla_client.execute_with_python_values(
1176          self.backend.compile(c.build()), (), backend=self.backend)
1177      self.assertLen(result, 3)
1178      np.testing.assert_equal(result[0], 42)
1179      np.testing.assert_allclose(result[1], [1.0, 2.0])
1180      np.testing.assert_equal(result[2], [True, False, False, True])
1181
1182    def testGetTupleElement(self):
1183      c = self._NewComputation()
1184      ops.GetTupleElement(
1185          ops.Tuple(c, [
1186              ops.Constant(c, np.int32(42)),
1187              ops.Constant(c, NumpyArrayF32([1.0, 2.0])),
1188              ops.Constant(c, NumpyArrayBool([True, False, False, True]))
1189          ]), 1)
1190      self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]])
1191
1192    def testBroadcast(self):
1193      c = self._NewComputation()
1194      ops.Broadcast(
1195          ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
1196      self._ExecuteAndCompareExact(
1197          c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]])
1198
1199    def testBroadcastInDim(self):
1200      c = self._NewComputation()
1201      ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0])
1202      self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]])
1203      ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1])
1204      self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]])
1205
1206    def testRngNormal(self):
1207      shape = (2, 3)
1208      c = self._NewComputation()
1209      ops.RngNormal(
1210          ops.Constant(c, NumpyArrayF32(0.)),
1211          ops.Constant(c, NumpyArrayF32(1.)),
1212          shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
1213                                             shape))
1214      result = xla_client.execute_with_python_values(
1215          self.backend.compile(c.build()), (), backend=self.backend)
1216      # since the result is random, we just check shape and uniqueness
1217      self.assertLen(result, 1)
1218      self.assertEqual(result[0].shape, shape)
1219      self.assertLen(np.unique(result[0]), np.prod(shape))
1220
1221    def testRngUniformF32(self):
1222      lo, hi = 2., 4.
1223      shape = (2, 3)
1224      c = self._NewComputation()
1225      ops.RngUniform(
1226          ops.Constant(c, NumpyArrayF32(lo)),
1227          ops.Constant(c, NumpyArrayF32(hi)),
1228          shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
1229                                             shape))
1230      result = xla_client.execute_with_python_values(
1231          self.backend.compile(c.build()), (), backend=self.backend)
1232      # since the result is random, we just check shape, uniqueness, and range
1233      self.assertLen(result, 1)
1234      self.assertEqual(result[0].shape, shape)
1235      self.assertLen(np.unique(result[0]), np.prod(shape))
1236      self.assertTrue(np.all(lo <= result[0]))
1237      self.assertTrue(np.all(result[0] < hi))
1238
1239    def testRngUniformS32(self):
1240      lo, hi = 2, 4
1241      shape = (2, 3)
1242      c = self._NewComputation()
1243      ops.RngUniform(
1244          ops.Constant(c, NumpyArrayS32(lo)),
1245          ops.Constant(c, NumpyArrayS32(hi)),
1246          shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
1247                                             shape))
1248      result = xla_client.execute_with_python_values(
1249          self.backend.compile(c.build()), (), backend=self.backend)
1250      # since the result is random, we just check shape, integrality, and range
1251      self.assertLen(result, 1)
1252      self.assertEqual(result[0].shape, shape)
1253      self.assertEqual(result[0].dtype, np.int32)
1254      self.assertTrue(np.all(lo <= result[0]))
1255      self.assertTrue(np.all(result[0] < hi))
1256
1257    def testCholesky(self):
1258      l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]],
1259                   dtype=np.float32)
1260      c = self._NewComputation()
1261      ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T))))
1262      self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4)
1263
1264    def testSort(self):
1265      keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
1266      c = self._NewComputation()
1267      ops.Sort(c, [ops.Constant(c, keys)], is_stable=True)
1268      self._ExecuteAndCompareClose(
1269          c,
1270          expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)])
1271
1272    def testSortKeyVal(self):
1273      keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
1274      values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
1275      c = self._NewComputation()
1276      ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
1277      result = xla_client.execute_with_python_values(
1278          self.backend.compile(c.build()), (), backend=self.backend)
1279      self.assertLen(result, 2)
1280      np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
1281      np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
1282
1283    def testSortCustomComparator(self):
1284      b = self._NewComputation("comparator")
1285      p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0)))
1286      q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))
1287      p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1288      q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1289      ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1)))
1290      comparator = b.build()
1291
1292      keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32)
1293      values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
1294      c = self._NewComputation()
1295      ops.Sort(
1296          c, (ops.Constant(c, keys), ops.Constant(c, values)),
1297          dimension=1,
1298          comparator=comparator)
1299      result = xla_client.execute_with_python_values(
1300          self.backend.compile(c.build()), (), backend=self.backend)
1301      self.assertLen(result, 2)
1302      np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]])
1303      np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]])
1304
1305    def testQR(self):
1306      a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166],
1307                    [10, 63, 166, 310]],
1308                   dtype=np.float32)
1309      c = self._NewComputation()
1310      ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True))
1311      q, r = self._Execute(c, ())
1312      np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4)
1313
1314    def testEigh(self):
1315      a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166],
1316                    [10, 63, 166, 310]],
1317                   dtype=np.float32)
1318      a = (a + a.T) / 2
1319
1320      c = self._NewComputation()
1321      ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True))
1322      # TODO(b/129396575): Turn this test back on when it passes without
1323      # fastmath.
1324      # v, w = self._Execute(c, ())
1325      # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3)
1326
1327    def testSVD(self):
1328      a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166],
1329                    [10, 63, 166, 310]],
1330                   dtype=np.float32)
1331      c = self._NewComputation()
1332      ops.Tuple(c, ops.SVD(ops.Constant(c, a)))
1333      u, d, v = self._Execute(c, ())
1334      self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3)
1335
1336    def testTriangularSolve(self):
1337      a_vals = np.array(
1338          [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]],
1339          dtype=np.float32)
1340      b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
1341                        dtype=np.float32)
1342
1343      c = self._NewComputation()
1344      ops.TriangularSolve(
1345          ops.Constant(c, a_vals),
1346          ops.Constant(c, b_vals),
1347          left_side=False,
1348          lower=True,
1349          transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE,
1350          unit_diagonal=False)
1351      self._ExecuteAndCompareClose(
1352          c,
1353          expected=[
1354              np.array([
1355                  [0.5, 0.08333334, 0.04629629, 0.03367003],
1356                  [2.5, -0.25, -0.1388889, -0.1010101],
1357                  [4.5, -0.58333331, -0.32407406, -0.23569024],
1358              ],
1359                       dtype=np.float32)
1360          ],
1361          rtol=1e-4)
1362
1363    def testIsConstant(self):
1364      c = self._NewComputation()
1365      a = ops.Constant(c, np.int32(3))
1366      b = ops.Constant(c, np.int32(1))
1367      x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1368      const_expr = ops.Sub(b, a)
1369      non_const_expr = ops.Mul(const_expr, x)
1370      self.assertTrue(c.is_constant(const_expr))
1371      self.assertFalse(c.is_constant(non_const_expr))
1372
1373    def testGather(self):
1374      a = np.arange(9).astype(np.int32).reshape((3, 3))
1375      indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32)
1376      dnums = xla_client.GatherDimensionNumbers()
1377      dnums.offset_dims.append(1)
1378      dnums.offset_dims.append(2)
1379      dnums.start_index_map.append(0)
1380      dnums.start_index_map.append(1)
1381      dnums.index_vector_dim = 2
1382      c = self._NewComputation()
1383      ops.Gather(
1384          ops.Constant(c, a),
1385          ops.Constant(c, indices),
1386          dnums,
1387          slice_sizes=[1, 1])
1388      g, = self._Execute(c, ())
1389      expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32)
1390      np.testing.assert_allclose(g, expected, rtol=1e-4)
1391
1392    def testFft(self):
1393      if self.backend.platform == "tpu":
1394        self.skipTest("TPU only supports 1D FFT")
1395      shape = [2, 3, 4, 5]
1396      rng = np.random.RandomState(0)
1397      a = rng.randn(*shape) + 1.0j * rng.randn(*shape)
1398      a = a.astype(np.complex64)
1399      # FFT
1400      c = self._NewComputation()
1401      ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:])
1402      self._ExecuteAndCompareClose(
1403          c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4)
1404      # IFFT
1405      c = self._NewComputation()
1406      ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:])
1407      self._ExecuteAndCompareClose(
1408          c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4)
1409      # RFFT
1410      b = rng.randn(*shape).astype(np.float32)
1411      c = self._NewComputation()
1412      ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:])
1413      self._ExecuteAndCompareClose(
1414          c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4)
1415      # IRFFT
1416      c = self._NewComputation()
1417      ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8])
1418      self._ExecuteAndCompareClose(
1419          c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4)
1420
1421    def testNextAfter(self):
1422      c = self._NewComputation()
1423      ops.NextAfter(
1424          ops.Constant(c, np.array([1, 2], dtype=np.float32)),
1425          ops.Constant(c, np.array([2, 1], dtype=np.float32)))
1426      out, = self._Execute(c, ())
1427      eps = np.finfo(np.float32).eps
1428      np.testing.assert_equal(
1429          np.array([eps + 1, 2 - eps], dtype=np.float32), out)
1430
1431    @parameterized.named_parameters({
1432        "testcase_name": "_{}".format(dtype.__name__),
1433        "dtype": dtype,
1434    } for dtype in float_dtypes)
1435    def testRegularizedIncompleteBeta(self, dtype):
1436      x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538],
1437                   dtype=dtype)
1438      a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606],
1439                   dtype=dtype)
1440      b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677],
1441                   dtype=dtype)
1442      c = self._NewComputation()
1443      ops.RegularizedIncompleteBeta(
1444          ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x))
1445      expected = np.array(
1446          [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155])
1447      self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2)
1448
1449  tests.append(SingleOpTest)
1450
1451  class EmbeddedComputationsTest(ComputationTest):
1452    """Tests for XLA graphs with embedded computations (such as maps)."""
1453
1454    def _CreateConstantComputation(self, in_dtype, out_dtype):
1455      """Computation (A) -> B that returns a constant 1 for any input."""
1456      c = self._NewComputation("constant_{}_{}_one".format(
1457          in_dtype.__name__, out_dtype.__name__))
1458      ops.Parameter(c, 0,
1459                    xla_client.shape_from_pyval(np.array(0, dtype=in_dtype)))
1460      ops.Constant(c, out_dtype(1))
1461      return c.build()
1462
1463    def _CreateMulBy2Computation(self, dtype):
1464      """Computation (dtype) -> dtype that multiplies its parameter by 2."""
1465      c = self._NewComputation("mul_f32_by2")
1466      ops.Mul(
1467          ops.Parameter(
1468              c, 0,
1469              xla_client.shape_from_pyval(np.array(
1470                  0, dtype=dtype)).with_major_to_minor_layout_if_absent()),
1471          ops.Constant(c, dtype(2.0)))
1472      return c.build()
1473
1474    def _CreateMulF32ByParamComputation(self):
1475      """Computation (f32) -> f32 that multiplies one parameter by the other."""
1476      c = self._NewComputation("mul_f32_by_param")
1477      ops.Mul(
1478          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))),
1479          ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))))
1480      return c.build()
1481
1482    def _CreateBinaryAddComputation(self, dtype):
1483      """Computation (dtype, dtype) -> dtype that adds its two parameters."""
1484      c = self._NewComputation("add_param0_by_param1")
1485      shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1486      shape = shape.with_major_to_minor_layout_if_absent()
1487      ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
1488      return c.build()
1489
1490    def _CreateBinaryGeComputation(self, dtype):
1491      """Computation (dtype, dtype) -> bool that tests param0 >= param1."""
1492      c = self._NewComputation("param0_lt_param1")
1493      shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1494      shape = shape.with_major_to_minor_layout_if_absent()
1495      ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
1496      return c.build()
1497
1498    def _MakeSample3DArray(self, dtype):
1499      return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1500                       [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
1501                      dtype=dtype)
1502
1503    @parameterized.named_parameters({
1504        "testcase_name": "_{}".format(dtype.__name__),
1505        "dtype": dtype,
1506    } for dtype in float_dtypes)
1507    def testCall(self, dtype):
1508      c = self._NewComputation()
1509      ops.Call(
1510          c,
1511          self._CreateMulBy2Computation(dtype),
1512          operands=(ops.Constant(c, dtype(5.0)),))
1513      self._ExecuteAndCompareClose(c, expected=[10.0])
1514
1515    @parameterized.named_parameters({
1516        "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__),
1517        "in_dtype": in_dtype,
1518        "out_dtype": out_dtype,
1519    } for in_dtype, out_dtype in [[np.float32, np.int32]])
1520    def testMapEachElementToConstant(self, in_dtype, out_dtype):
1521      c = self._NewComputation()
1522      ops.Map(c,
1523              [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))],
1524              self._CreateConstantComputation(in_dtype, out_dtype), [0])
1525      self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]])
1526
1527    @parameterized.named_parameters({
1528        "testcase_name": "_{}".format(dtype.__name__),
1529        "dtype": dtype,
1530    } for dtype in float_dtypes)
1531    def testMapMulBy2(self, dtype):
1532      if dtype == np.float64 and self.backend.platform == "tpu":
1533        self.skipTest("TPU doesn't support float64")
1534      c = self._NewComputation()
1535      ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))],
1536              self._CreateMulBy2Computation(dtype), [0])
1537      self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]])
1538
1539    @parameterized.named_parameters({
1540        "testcase_name": "_{}".format(dtype.__name__),
1541        "dtype": dtype,
1542    } for dtype in float_dtypes)
1543    def testSimpleMapChain(self, dtype):
1544      if dtype == np.float64 and self.backend.platform == "tpu":
1545        self.skipTest("TPU doesn't support float64")
1546      # Chains a map of constant-out with a map of mul-by-2
1547      c = self._NewComputation()
1548      const = ops.Map(
1549          c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))],
1550          self._CreateConstantComputation(dtype, dtype), [0])
1551      ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0])
1552      self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]])
1553
1554    # TODO(b/154752816): bfloat16 crashes in evaluator.
1555    @parameterized.named_parameters({
1556        "testcase_name": "_{}".format(dtype.__name__),
1557        "dtype": dtype,
1558    } for dtype in float_dtypes if dtype != bfloat16)
1559    def testDivVectorsWithMap(self, dtype):
1560
1561      def DivComputation():
1562        c = self._NewComputation("div_param0_by_param1")
1563        shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1564        ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
1565        return c.build()
1566
1567      c = self._NewComputation()
1568      ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)),
1569                  ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))),
1570              DivComputation(), [0])
1571      self._ExecuteAndCompareClose(
1572          c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3)
1573
1574    @parameterized.named_parameters({
1575        "testcase_name": "_{}".format(dtype.__name__),
1576        "dtype": dtype,
1577    } for dtype in float_dtypes)
1578    def testSelectAndScatter(self, dtype):
1579      if dtype == np.float64 and self.backend.platform == "tpu":
1580        self.skipTest("TPU doesn't support float64")
1581      c = self._NewComputation()
1582      operand = ops.Constant(
1583          c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype))
1584      window_dimensions = (2, 1)
1585      window_strides = (1, 2)
1586      padding = xla_client.window_padding_type_to_pad_values(
1587          xla_client.PaddingType.VALID,
1588          c.get_shape(operand).dimensions(), window_dimensions, window_strides)
1589      ops.SelectAndScatterWithGeneralPadding(
1590          operand,
1591          select=self._CreateBinaryGeComputation(dtype),
1592          window_dimensions=window_dimensions,
1593          window_strides=window_strides,
1594          padding=padding,
1595          source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)),
1596          init_value=ops.Constant(c, np.array(1, dtype=dtype)),
1597          scatter=self._CreateBinaryAddComputation(dtype))
1598      self._ExecuteAndCompareClose(
1599          c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3)
1600
1601    @parameterized.named_parameters({
1602        "testcase_name": "_{}".format(dtype.__name__),
1603        "dtype": dtype,
1604    } for dtype in float_dtypes)
1605    def testReduce1DtoScalar(self, dtype):
1606      c = self._NewComputation()
1607      ops.Reduce(
1608          c,
1609          operands=[
1610              ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))
1611          ],
1612          init_values=[ops.Constant(c, dtype(0))],
1613          computation=self._CreateBinaryAddComputation(dtype),
1614          dimensions_to_reduce=[0])
1615      self._ExecuteAndCompareClose(c, expected=[10])
1616
1617    # TODO(phawkins): test comparison harness doesn't support bfloat16
1618    @parameterized.named_parameters({
1619        "testcase_name": "_{}_dim{}".format(dtype.__name__, dim),
1620        "dtype": dtype,
1621        "dim": dim,
1622    } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2))
1623    def testReduce2DTo1D(self, dtype, dim):
1624      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1625      c = self._NewComputation()
1626      ops.Reduce(
1627          c,
1628          operands=[ops.Constant(c, input_array)],
1629          init_values=[ops.Constant(c, dtype(0))],
1630          computation=self._CreateBinaryAddComputation(dtype),
1631          dimensions_to_reduce=[dim])
1632      self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)])
1633
1634    @parameterized.named_parameters({
1635        "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims),
1636        "dtype": dtype,
1637        "dims": tuple(dims)
1638    } for dtype in float_dtypes for dims in itertools.permutations(range(3)))
1639    def testReduce3DAllPossibleWaysF32(self, dtype, dims):
1640      input_array = self._MakeSample3DArray(dtype)
1641      c = self._NewComputation()
1642      ops.Reduce(
1643          c,
1644          operands=[ops.Constant(c, input_array)],
1645          init_values=[ops.Constant(c, dtype(0))],
1646          computation=self._CreateBinaryAddComputation(dtype),
1647          dimensions_to_reduce=dims)
1648      self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)])
1649
1650    @parameterized.named_parameters({
1651        "testcase_name": "_{}".format(dtype.__name__),
1652        "dtype": dtype,
1653    } for dtype in float_dtypes)
1654    def testReduceWindowValidUnitStrides(self, dtype):
1655      if dtype == np.float64 and self.backend.platform == "tpu":
1656        self.skipTest("TPU doesn't support float64")
1657      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1658      c = self._NewComputation()
1659      window_dimensions = (2, 1)
1660      window_strides = (1, 1)
1661      padding = xla_client.window_padding_type_to_pad_values(
1662          xla_client.PaddingType.VALID, input_array.shape, window_dimensions,
1663          window_strides)
1664      ops.ReduceWindowWithGeneralPadding(
1665          operand=ops.Constant(c, input_array),
1666          init_value=ops.Constant(c, dtype(0)),
1667          computation=self._CreateBinaryAddComputation(dtype),
1668          window_dimensions=window_dimensions,
1669          window_strides=window_strides,
1670          base_dilations=[],
1671          window_dilations=[],
1672          padding=padding)
1673      self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]])
1674
1675    @parameterized.named_parameters({
1676        "testcase_name": "_{}".format(dtype.__name__),
1677        "dtype": dtype,
1678    } for dtype in float_dtypes)
1679    def testReduceWindowSameUnitStrides(self, dtype):
1680      if dtype == np.float64 and self.backend.platform == "tpu":
1681        self.skipTest("TPU doesn't support float64")
1682      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1683      c = self._NewComputation()
1684      window_dimensions = (2, 1)
1685      window_strides = (1, 1)
1686      padding = xla_client.window_padding_type_to_pad_values(
1687          xla_client.PaddingType.SAME, input_array.shape, window_dimensions,
1688          window_strides)
1689      ops.ReduceWindowWithGeneralPadding(
1690          operand=ops.Constant(c, input_array),
1691          init_value=ops.Constant(c, dtype(0)),
1692          computation=self._CreateBinaryAddComputation(dtype),
1693          window_dimensions=window_dimensions,
1694          window_strides=window_strides,
1695          base_dilations=[],
1696          window_dilations=[],
1697          padding=padding)
1698      self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]])
1699
1700    @parameterized.named_parameters({
1701        "testcase_name": "_{}".format(dtype.__name__),
1702        "dtype": dtype,
1703    } for dtype in float_dtypes)
1704    def testReduceWindowValidGeneralStrides(self, dtype):
1705      if dtype == np.float64 and self.backend.platform == "tpu":
1706        self.skipTest("TPU doesn't support float64")
1707      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1708      c = self._NewComputation()
1709      window_dimensions = (2, 1)
1710      window_strides = (1, 2)
1711      padding = xla_client.window_padding_type_to_pad_values(
1712          xla_client.PaddingType.VALID, input_array.shape, window_dimensions,
1713          window_strides)
1714      ops.ReduceWindowWithGeneralPadding(
1715          operand=ops.Constant(c, input_array),
1716          init_value=ops.Constant(c, dtype(0)),
1717          computation=self._CreateBinaryAddComputation(dtype),
1718          window_dimensions=window_dimensions,
1719          window_strides=window_strides,
1720          base_dilations=[],
1721          window_dilations=[],
1722          padding=padding)
1723      self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]])
1724
1725    @parameterized.named_parameters({
1726        "testcase_name": "_{}".format(dtype.__name__),
1727        "dtype": dtype,
1728    } for dtype in float_dtypes)
1729    def testWhile(self, dtype):
1730
1731      def LessThan10Cond():
1732        c = self._NewComputation("test_lt_10")
1733        shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1734        ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.)))
1735        return c.build()
1736
1737      cond = LessThan10Cond()
1738      body = self._CreateMulBy2Computation(dtype)
1739      c = self._NewComputation()
1740      init = ops.Constant(c, dtype(1.))
1741      ops.While(cond, body, init)
1742      self._ExecuteAndCompareClose(c, expected=[16.])
1743
1744    def testConditionalTrue(self):
1745      c = self._NewComputation()
1746      pred = ops.Constant(c, np.bool_(True))
1747      true_operand = ops.Constant(c, np.float32(3.))
1748      true_computation = self._CreateMulBy2Computation(np.float32)
1749      false_operand = ops.Constant(c, np.float32(2.))
1750      false_computation = self._CreateConstantComputation(
1751          np.float32, np.float32)
1752      ops.Conditional(pred, true_operand, true_computation, false_operand,
1753                      false_computation)
1754      self._ExecuteAndCompareClose(c, expected=[6.])
1755
1756    def testConditionalFalse(self):
1757      c = self._NewComputation()
1758      pred = ops.Constant(c, np.bool_(False))
1759      true_operand = ops.Constant(c, np.float32(3.))
1760      true_computation = self._CreateMulBy2Computation(np.float32)
1761      false_operand = ops.Constant(c, np.float32(2.))
1762      false_computation = self._CreateConstantComputation(
1763          np.float32, np.float32)
1764      ops.Conditional(pred, true_operand, true_computation, false_operand,
1765                      false_computation)
1766      self._ExecuteAndCompareClose(c, expected=[1.])
1767
1768    @unittest.skipIf(cloud_tpu, "not implemented")
1769    def testInfeedS32Values(self):
1770      to_infeed = NumpyArrayS32([1, 2, 3, 4])
1771      c = self._NewComputation()
1772      ops.GetTupleElement(
1773          ops.InfeedWithToken(
1774              ops.CreateToken(c),
1775              xla_client.shape_from_pyval(
1776                  to_infeed[0]).with_major_to_minor_layout_if_absent()), 0)
1777      compiled_c = self.backend.compile(c.build())
1778      device = self.backend.local_devices()[0]
1779      for item in to_infeed:
1780        device.transfer_to_infeed(item)
1781
1782      for item in to_infeed:
1783        result, = xla_client.execute_with_python_values(
1784            compiled_c, (), backend=self.backend)
1785        self.assertEqual(result, item)
1786
1787    @unittest.skipIf(cloud_tpu, "not implemented")
1788    def testInfeedTuple(self):
1789      to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]]))
1790      c = self._NewComputation()
1791      ops.GetTupleElement(
1792          ops.InfeedWithToken(
1793              ops.CreateToken(c),
1794              xla_client.shape_from_pyval(
1795                  to_infeed).with_major_to_minor_layout_if_absent()), 0)
1796      compiled_c = self.backend.compile(c.build())
1797      device = self.backend.local_devices()[0]
1798      device.transfer_to_infeed(to_infeed)
1799
1800      result = xla_client.execute_with_python_values(
1801          compiled_c, (), backend=self.backend)
1802      self.assertLen(result, 2)
1803      np.testing.assert_equal(result[0], to_infeed[0])
1804      np.testing.assert_equal(result[1], to_infeed[1])
1805
1806    @unittest.skipIf(cloud_tpu, "not implemented")
1807    def testInfeedThenOutfeedS32(self):
1808      to_round_trip = NumpyArrayS32([1, 2, 3, 4])
1809      c = self._NewComputation()
1810      x_and_token = ops.InfeedWithToken(
1811          ops.CreateToken(c),
1812          xla_client.shape_from_pyval(
1813              to_round_trip[0]).with_major_to_minor_layout_if_absent())
1814      x = ops.GetTupleElement(x_and_token, 0)
1815      token = ops.GetTupleElement(x_and_token, 1)
1816      outfeed_shape = xla_client.shape_from_pyval(
1817          to_round_trip[0]).with_major_to_minor_layout_if_absent()
1818      ops.OutfeedWithToken(x, token, outfeed_shape)
1819
1820      compiled_c = self.backend.compile(c.build())
1821      device = self.backend.local_devices()[0]
1822
1823      for want in to_round_trip:
1824        execution = threading.Thread(target=lambda: compiled_c.execute([]))
1825        execution.start()
1826        device.transfer_to_infeed(want)
1827        got = device.transfer_from_outfeed(outfeed_shape)
1828        execution.join()
1829        self.assertEqual(want, got)
1830
1831    def testScatter(self):
1832      a = np.arange(9).astype(np.int32).reshape((3, 3))
1833      scatter_indices = np.array([0, 2], dtype=np.int32)
1834      updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32)
1835
1836      dnums = xla_client.ScatterDimensionNumbers()
1837      dnums.update_window_dims.append(1)
1838      dnums.inserted_window_dims.append(0)
1839      dnums.scatter_dims_to_operand_dims.append(0)
1840      dnums.index_vector_dim = 1
1841
1842      c = self._NewComputation()
1843      ops.Scatter(
1844          ops.Constant(c, a), ops.Constant(c, scatter_indices),
1845          ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32),
1846          dnums)
1847      expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]],
1848                          dtype=np.int32)
1849      self._ExecuteAndCompareClose(c, expected=[expected])
1850
1851  class DeviceTest(ComputationTest):
1852
1853    def testPlatform(self):
1854      for device in self.backend.local_devices():
1855        self.assertEqual(device.platform, self.backend.platform)
1856
1857  tests.append(DeviceTest)
1858
1859  class ErrorTest(ComputationTest):
1860
1861    def setUp(self):
1862      super(ErrorTest, self).setUp()
1863      self.f32_scalar_2 = NumpyArrayF32(2.0)
1864      self.s32_scalar_2 = NumpyArrayS32(2)
1865
1866    def testCompileWithWrongElementTypeInLayout(self):
1867      c = self._NewComputation()
1868      c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
1869      ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
1870      c.clear_op_metadata()
1871
1872      options = xla_client.CompileOptions()
1873      options.argument_layouts = [
1874          xla_client.Shape.array_shape(np.dtype(np.float32), [])
1875      ]
1876
1877      def TestFun():
1878        return self.backend.compile(c.build(), compile_options=options)
1879
1880      self.assertRaisesRegex(
1881          RuntimeError, r".*Invalid argument shape.*"
1882          r"expected s32\[\], got f32\[\].*", TestFun)
1883
1884    def testInvokeWithWrongElementType(self):
1885      c = self._NewComputation()
1886      c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
1887      ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
1888      c.clear_op_metadata()
1889
1890      def TestFun():
1891        return xla_client.execute_with_python_values(
1892            self.backend.compile(c.build()), [self.f32_scalar_2], self.backend)
1893
1894      self.assertRaisesRegex(
1895          RuntimeError, r"Invalid argument: Argument does not match.*"
1896          r"want s32\[\], got f32\[\].*", TestFun)
1897
1898  tests.append(EmbeddedComputationsTest)
1899
1900  class ComputationRootTest(ComputationTest):
1901    """Tests related to setting the root of the computation."""
1902
1903    def testComputationRootDifferentFromLastOp(self):
1904      c = self._NewComputation()
1905      x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0)))
1906      result = ops.Add(x, ops.Constant(c, np.float32(3.14)))
1907      ops.Add(result, ops.Constant(c, np.float32(1.618)))
1908
1909      arg = NumpyArrayF32(1.0)
1910      compiled_c = self.backend.compile(c.build(result))
1911      ans, = xla_client.execute_with_python_values(
1912          compiled_c, [arg], backend=self.backend)
1913      np.testing.assert_allclose(ans, 4.14)
1914
1915  tests.append(ComputationRootTest)
1916
1917  class SetShardingTest(ComputationTest):
1918    """Tests related to set OpSharding."""
1919
1920    def testSetSharding(self):
1921      c = self._NewComputation()
1922      sharding = xla_client.OpSharding()
1923      sharding.type = sharding.type.REPLICATED
1924      sharding.tile_assignment_dimensions.extend([1])
1925      sharding.tile_assignment_devices.extend([0])
1926      c.set_sharding(sharding)
1927      x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0)))
1928      c.clear_sharding()
1929
1930      result = ops.Add(x, ops.Constant(c, np.float32(3.14)))
1931      ops.Add(result, ops.Constant(c, np.float32(1.618)))
1932      arg = NumpyArrayF32(1.0)
1933      compiled_c = self.backend.compile(c.build(result))
1934      ans, = xla_client.execute_with_python_values(
1935          compiled_c, [arg], backend=self.backend)
1936      np.testing.assert_allclose(ans, 4.14)
1937
1938  tests.append(SetShardingTest)
1939
1940  testcase_shapes = [
1941      (),
1942      (1,),
1943      (2, 3),
1944      (2, 0),
1945      (0, 7),
1946      (4, 1, 2),
1947      (2, 1, 3),
1948      (2, 4, 1),
1949      (3, 1),
1950      (1, 3),
1951  ]
1952
1953  def FormatShapeAndDtype(shape, dtype):
1954    return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape)))
1955
1956  class DLPackTest(parameterized.TestCase):
1957
1958    def setUp(self):
1959      super(DLPackTest, self).setUp()
1960      self.backend = xla_backend()
1961      if self.backend.platform not in ("cpu", "gpu"):
1962        self.skipTest("DLPack requires CPU or GPU")
1963
1964    # pylint: disable=g-complex-comprehension
1965    # pyformat: disable
1966    @parameterized.named_parameters({
1967        "testcase_name": "{}_own={}".format(FormatShapeAndDtype(shape, dtype),
1968                                            take_ownership),
1969        "dtype": dtype,
1970        "shape": shape,
1971        "take_ownership": take_ownership
1972    } for dtype in dlpack_dtypes for shape in testcase_shapes
1973                                    for take_ownership in [False, True])
1974    # pyformat: enable
1975    def testRoundTrip(self, dtype, shape, take_ownership):
1976      if dtype == np.bool_:
1977        x = np.random.randint(0, 2, size=shape).astype(np.bool_)
1978      else:
1979        x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
1980      buffer = self.backend.buffer_from_pyval(x)
1981      dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(
1982          buffer, take_ownership=take_ownership)
1983      del buffer  # Free "buffer" to make sure dlt retains ownership.
1984      self.assertEqual(type(dlt).__name__, "PyCapsule")
1985      y = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend)
1986      np.testing.assert_array_equal(
1987          x.astype(np.uint8) if dtype == np.bool_ else x, y.to_py())
1988
1989    def testTensorsCanBeConsumedOnceOnly(self):
1990      x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
1991      buffer = self.backend.buffer_from_pyval(x)
1992      dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(
1993          buffer, take_ownership=True)
1994
1995      def ConsumeDLPackTensor():
1996        _ = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend)
1997
1998      ConsumeDLPackTensor()
1999      self.assertRaisesRegex(
2000          RuntimeError, ".*a DLPack tensor may be consumed at most once.*",
2001          ConsumeDLPackTensor)
2002
2003    def testTensorsCanBeOwnedOnceOnly(self):
2004      x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
2005      buffer = self.backend.buffer_from_pyval(x)
2006      _ = xla_client._xla.buffer_to_dlpack_managed_tensor(
2007          buffer, take_ownership=True)
2008      self.assertTrue(buffer.is_deleted())
2009      with self.assertRaisesRegex(
2010          RuntimeError,
2011          "Cannot convert deleted/invalid buffer to DLPack tensor.*"):
2012        _ = xla_client._xla.buffer_to_dlpack_managed_tensor(
2013            buffer, take_ownership=True)
2014
2015    def testNonOwnedDlpackCanBeViewedTwice(self):
2016      x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
2017      buffer = self.backend.buffer_from_pyval(x)
2018      d1 = xla_client._xla.buffer_to_dlpack_managed_tensor(
2019          buffer, take_ownership=False)
2020      d2 = xla_client._xla.buffer_to_dlpack_managed_tensor(
2021          buffer, take_ownership=False)
2022
2023      y = xla_client._xla.dlpack_managed_tensor_to_buffer(d1, self.backend)
2024      z = xla_client._xla.dlpack_managed_tensor_to_buffer(d2, self.backend)
2025      del d1, d2
2026      np.testing.assert_array_equal(x, buffer.to_py())
2027      np.testing.assert_array_equal(x, y.to_py())
2028      np.testing.assert_array_equal(x, z.to_py())
2029
2030  tests.append(DLPackTest)
2031
2032  class BufferProtocolTest(parameterized.TestCase):
2033
2034    def setUp(self):
2035      super(BufferProtocolTest, self).setUp()
2036      self.backend = xla_backend()
2037      if self.backend.platform != "cpu":
2038        self.skipTest("Test requires CPU")
2039
2040    # pylint: disable=g-complex-comprehension
2041    @parameterized.named_parameters({
2042        "testcase_name": FormatShapeAndDtype(shape, dtype),
2043        "dtype": dtype,
2044        "shape": shape
2045    } for dtype in standard_dtypes if dtype != bfloat16
2046                                    for shape in testcase_shapes)
2047    def testRoundTrip(self, dtype, shape):
2048      x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
2049      x_ptr = x.__array_interface__["data"][0]
2050      buffer = self.backend.buffer_from_pyval(
2051          x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY)
2052      y = np.array(buffer, copy=False)
2053      y_ptr = y.__array_interface__["data"][0]
2054      np.testing.assert_array_equal(x, y)
2055      # If the input was sufficiently aligned, the input and output should
2056      # alias.
2057      self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
2058      self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
2059
2060      during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL
2061      buffer2 = self.backend.buffer_from_pyval(
2062          x, host_buffer_semantics=during_call)
2063      z = np.array(buffer2, copy=False)
2064      self.assertNotEqual(x.__array_interface__["data"][0],
2065                          z.__array_interface__["data"][0])
2066
2067    def testDeleteWithActiveView(self):
2068      x = np.random.randn(20, 10)
2069      buffer = self.backend.buffer_from_pyval(x)
2070      buffer_ptr = buffer.unsafe_buffer_pointer()
2071      y = np.array(buffer, copy=False)
2072      buffer.delete()
2073      # It is still legal to access `y`; the array view must keep it alive.
2074      np.testing.assert_array_equal(x, y)
2075      self.assertEqual(y.__array_interface__["data"][0], buffer_ptr)
2076
2077  tests.append(BufferProtocolTest)
2078
2079  class TracebackTest(absltest.TestCase):
2080
2081    def setUp(self):
2082      super(TracebackTest, self).setUp()
2083      self.backend = xla_backend()
2084
2085    def testNoTracebacksIfDisabled(self):
2086      with xla_client.tracebacks(enabled=False):
2087        self.assertEqual(None, xla_client.Traceback.get_traceback())
2088        buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
2089        self.assertEqual(None, buffer.traceback)
2090
2091        b = xla_client.XlaBuilder("computation")
2092        ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
2093        e = self.backend.compile(b.build())
2094        self.assertEqual(None, e.traceback)
2095
2096    def assertIsTracebackContaining(self, tb, function):
2097      self.assertIsInstance(tb, xla_client.Traceback)
2098      self.assertIn(function, str(tb))
2099      self.assertTrue(any(f.function_name == function for f in tb.frames))
2100
2101    def testTracebacks(self):
2102      with xla_client.tracebacks(enabled=True):
2103        tb = xla_client.Traceback.get_traceback()
2104        self.assertIsTracebackContaining(tb, "testTracebacks")
2105
2106        # Tracebacks are not implemented on the TPU driver extension's variant
2107        # of buffers and executables.
2108        if not isinstance(self.backend, xla_client.Client):
2109          return
2110
2111        buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
2112        self.assertIsTracebackContaining(buffer.traceback, "testTracebacks")
2113
2114        b = xla_client.XlaBuilder("computation")
2115        ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
2116        e = self.backend.compile(b.build())
2117        self.assertIsTracebackContaining(e.traceback, "testTracebacks")
2118
2119    def testNestedFunction(self):
2120
2121      def AFunction():
2122
2123        def AnotherFunction():
2124          return xla_client.Traceback.get_traceback()
2125
2126        return AnotherFunction()
2127
2128      with xla_client.tracebacks(enabled=True):
2129        tb = AFunction()
2130        self.assertIsInstance(tb, xla_client.Traceback)
2131        frames = tb.frames
2132        i = next(
2133            i for (i, f) in enumerate(frames) if f.function_name == "AFunction")
2134        self.assertEqual(frames[i - 1].function_name, "AnotherFunction")
2135        self.assertEqual(frames[i + 1].function_name, "testNestedFunction")
2136
2137  tests.append(TracebackTest)
2138
2139  return tests
2140
2141
2142def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw):
2143  # Avoid creating a new backend per test (this causes GPU OOM, and is probably
2144  # inefficient).
2145  backend_fn = functools.lru_cache(maxsize=None)(backend_fn)
2146  for klass in TestFactory(backend_fn, **kw):
2147    test = type(test_prefix + klass.__name__, (klass,), {})
2148    # Clean up the qualified names of the tests to not include the test factory.
2149    test.__qualname__ = test.__name__
2150    globals_dict[test.__name__] = test
2151
2152
2153if __name__ == "__main__":
2154  flags.DEFINE_string("backend", "cpu", "Target backend.")
2155  InstantiateTests(globals(),
2156                   lambda: xla_client.get_local_backend(FLAGS.backend))
2157  absltest.main()
2158