• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Backend-dependent tests for the Python XLA client."""
16
17import functools
18import itertools
19import re
20import threading
21import unittest
22
23from absl import flags
24from absl import logging
25from absl.testing import absltest
26from absl.testing import parameterized
27import numpy as np
28
29from tensorflow.compiler.xla.python import xla_client
30
31# pylint: disable=g-import-not-at-top
32try:
33  # This import is only used for GPU; the dependency is incompatible with TPU
34  # so it results in an import error.
35  from tensorflow.python.framework import test_util
36except ImportError:
37  test_util = None
38
39# pylint: disable=g-import-not-at-top
40try:
41  from tensorflow.compiler.xla.python import custom_call_for_test
42except ImportError:
43  custom_call_for_test = None
44
45bfloat16 = xla_client.bfloat16
46ops = xla_client.ops
47
48FLAGS = flags.FLAGS
49
50# We choose to ignore pylint's complaints about complex comprehensions, which we
51# use widely for parameterizing tests.
52# pylint: disable=g-complex-comprehension
53
54
55def TestFactory(xla_backend,
56                cloud_tpu=False,
57                tfrt_tpu=False,
58                external_tpu=False):
59  tests = []
60
61  if not cloud_tpu:
62    int_dtypes = [np.int32, np.int64, np.uint32, np.uint64]
63    # TODO(phawkins): test np.float16, where supported.
64    float_dtypes = [bfloat16, np.float32, np.float64]
65    complex_dtypes = [np.complex64, np.complex128]
66    standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
67  else:
68    int_dtypes = [np.int32, np.uint32]
69    float_dtypes = [np.float32]
70    complex_dtypes = [np.complex64]
71    standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
72  dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_]
73
74  class ComputationTest(parameterized.TestCase):
75    """Base class for running an XLA Computation through the local client."""
76
77    def setUp(self):
78      super(ComputationTest, self).setUp()
79      self.backend = xla_backend()
80
81    def _NewComputation(self, name=None):
82      if name is None:
83        name = self.id()
84      return xla_client.XlaBuilder(name)
85
86    def _Execute(self, c, arguments):
87      compiled_c = self.backend.compile(c.build())
88      return xla_client.execute_with_python_values(
89          compiled_c, arguments, backend=self.backend)
90
91    def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
92      assert expected is not None
93      results = self._Execute(c, arguments)
94      self.assertLen(results, len(expected))
95      for result, e in zip(results, expected):
96        # Numpy's comparison methods are a bit too lenient by treating inputs as
97        # "array-like", meaning that scalar 4 will be happily compared equal to
98        # [[4]]. We'd like to be more strict so assert shapes as well.
99        self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape)
100        assert_func(result, e)
101
102    def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
103      self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments,
104                                 expected)
105
106    def _ExecuteAndCompareClose(self,
107                                c,
108                                arguments=(),
109                                expected=None,
110                                rtol=1e-4,
111                                atol=0):
112      self._ExecuteAndAssertWith(
113          functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol),
114          c, arguments, expected)
115
116  def NumpyArrayF32(*args, **kwargs):
117    """Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
118    return np.array(*args, dtype=np.float32, **kwargs)
119
120  def NumpyArrayF64(*args, **kwargs):
121    """Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
122    return np.array(*args, dtype=np.float64, **kwargs)
123
124  def NumpyArrayS32(*args, **kwargs):
125    """Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
126    return np.array(*args, dtype=np.int32, **kwargs)
127
128  def NumpyArrayBool(*args, **kwargs):
129    """Convenience wrapper to create Numpy arrays with a np.bool_ dtype."""
130    return np.array(*args, dtype=np.bool_, **kwargs)
131
132  class ComputationPrinting(absltest.TestCase):
133
134    def setUp(self):
135      super(ComputationPrinting, self).setUp()
136      self.backend = xla_backend()
137
138    def ExampleComputation(self):
139      builder = xla_client.XlaBuilder("acomputation")
140      p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0)))
141      p1 = ops.Parameter(
142          builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32)))
143      x = ops.Mul(p0, p1)
144      ops.Add(x, x)
145      return builder.build()
146
147    @unittest.skipIf(cloud_tpu, "not implemented")
148    def testCompiledHloModuleToHloText(self):
149      computation = self.ExampleComputation()
150      executable = self.backend.compile(computation)
151      hlo_modules = executable.hlo_modules()
152      self.assertLen(hlo_modules, 1)
153      hlo_text = hlo_modules[0].to_string()
154      self.assertTrue(hlo_text.startswith("HloModule acomputation"))
155      self.assertIn("fusion", hlo_text)
156
157    @unittest.skipIf(cloud_tpu, "not implemented")
158    def testCompiledHloModuleAsSerializedProto(self):
159      computation = self.ExampleComputation()
160      executable = self.backend.compile(computation)
161      hlo_modules = executable.hlo_modules()
162      self.assertLen(hlo_modules, 1)
163      hlo_text = hlo_modules[0].to_string()
164      proto = hlo_modules[0].as_serialized_hlo_module_proto()
165      hlo_module_roundtrip = xla_client.XlaComputation(proto).get_hlo_module()
166      hlo_text_roundtrip = hlo_module_roundtrip.to_string()
167      self.assertEqual(hlo_text, hlo_text_roundtrip)
168
169    @unittest.skipIf(cloud_tpu, "not implemented")
170    def testStableComputationSerialization(self):
171      # Ideally we would test identical computations produced in different
172      # processes. For now we have this limited smoke test.
173      computation = self.ExampleComputation()
174      ref = computation.as_serialized_hlo_module_proto()
175      for _ in range(10):
176        self.assertEqual(computation.as_serialized_hlo_module_proto(), ref)
177
178    @unittest.skipIf(cloud_tpu, "not implemented")
179    def testFlopEstimate(self):
180      computation = self.ExampleComputation()
181      properties = xla_client._xla.hlo_module_cost_analysis(
182          self.backend, computation.as_hlo_module())
183      self.assertEqual(properties["flops"], 8.0)
184
185    def testFingerprint(self):
186      computation = self.ExampleComputation()
187      executable = self.backend.compile(computation)
188      fingerprint = executable.fingerprint
189      if self.backend.platform == "tpu" and not cloud_tpu:
190        logging.info("fingerprint: %s", fingerprint)
191        self.assertNotEmpty(fingerprint)
192      else:
193        self.assertIsNone(fingerprint)
194
195  tests.append(ComputationPrinting)
196
197  class ComputationsWithConstantsTest(ComputationTest):
198    """Tests focusing on Constant ops."""
199
200    @parameterized.named_parameters({
201        "testcase_name": "_{}".format(dtype.__name__),
202        "dtype": dtype,
203    } for dtype in int_dtypes + float_dtypes)
204    def testConstantScalarSum(self, dtype):
205      if dtype == np.int8 and self.backend.platform == "tpu":
206        self.skipTest("TPU doesn't support int8")
207      c = self._NewComputation()
208      ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14)))
209      self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)])
210
211    @parameterized.named_parameters({
212        "testcase_name": "_{}".format(dtype.__name__),
213        "dtype": dtype,
214    } for dtype in float_dtypes)
215    def testConstantVectorMul(self, dtype):
216      c = self._NewComputation()
217      ops.Mul(
218          ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)),
219          ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype)))
220      self._ExecuteAndCompareClose(
221          c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3)
222
223    @parameterized.named_parameters({
224        "testcase_name": "_{}".format(dtype.__name__),
225        "dtype": dtype,
226    } for dtype in float_dtypes)
227    def testConstantVectorScalarDiv(self, dtype):
228      c = self._NewComputation()
229      ops.Div(
230          ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)),
231          ops.Constant(c, dtype(2.0)))
232      self._ExecuteAndCompareClose(
233          c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3)
234
235    @parameterized.named_parameters({
236        "testcase_name": "_{}".format(dtype.__name__),
237        "dtype": dtype,
238    } for dtype in float_dtypes)
239    def testConstantVectorScalarPow(self, dtype):
240      c = self._NewComputation()
241      ops.Pow(
242          ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)),
243          ops.Constant(c, dtype(2.)))
244      self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]])
245
246    def testIota(self):
247      c = self._NewComputation()
248      ops.Iota(c, xla_client.PrimitiveType.F32, 10)
249      self._ExecuteAndCompareExact(
250          c, expected=[np.arange(10, dtype=np.float32)])
251
252    @parameterized.named_parameters({
253        "testcase_name": "_{}".format(dtype.__name__),
254        "dtype": dtype,
255    } for dtype in int_dtypes)
256    def testBroadcastedIota(self, dtype):
257      c = self._NewComputation()
258      shape = xla_client.Shape.array_shape(
259          xla_client.dtype_to_etype(dtype), (2, 3))
260      ops.Iota(c, shape, 1)
261      expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype)
262      self._ExecuteAndCompareExact(c, expected=[expected])
263
264    def testBooleanAnd(self):
265      c = self._NewComputation()
266      ops.And(
267          ops.Constant(c, NumpyArrayBool([True, False, True, False])),
268          ops.Constant(c, NumpyArrayBool([True, True, False, False])))
269      self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]])
270
271    def testBooleanOr(self):
272      c = self._NewComputation()
273      ops.Or(
274          ops.Constant(c, NumpyArrayBool([True, False, True, False])),
275          ops.Constant(c, NumpyArrayBool([True, True, False, False])))
276      self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]])
277
278    def testBooleanXor(self):
279      c = self._NewComputation()
280      ops.Xor(
281          ops.Constant(c, NumpyArrayBool([True, False, True, False])),
282          ops.Constant(c, NumpyArrayBool([True, True, False, False])))
283      self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]])
284
285    @parameterized.named_parameters({
286        "testcase_name": "_{}".format(dtype.__name__),
287        "dtype": dtype,
288    } for dtype in float_dtypes)
289    def testSum2D(self, dtype):
290      c = self._NewComputation()
291      ops.Add(
292          ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)),
293          ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype)))
294      self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]])
295
296    def testShiftLeft(self):
297      c = self._NewComputation()
298      ops.ShiftLeft(
299          ops.Constant(c, NumpyArrayS32([3])),
300          ops.Constant(c, NumpyArrayS32([2])))
301      self._ExecuteAndCompareClose(c, expected=[[12]])
302
303    def testShiftRightArithmetic(self):
304      c = self._NewComputation()
305      ops.ShiftRightArithmetic(
306          ops.Constant(c, NumpyArrayS32([-2])),
307          ops.Constant(c, NumpyArrayS32([1])))
308      self._ExecuteAndCompareClose(c, expected=[[-1]])
309
310    def testShiftRightLogical(self):
311      c = self._NewComputation()
312      ops.ShiftRightLogical(
313          ops.Constant(c, NumpyArrayS32([-1])),
314          ops.Constant(c, NumpyArrayS32([1])))
315      self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]])
316
317    @parameterized.named_parameters({
318        "testcase_name": "_{}".format(dtype.__name__),
319        "dtype": dtype,
320    } for dtype in float_dtypes)
321    def testSum2DWith1DBroadcastDim0(self, dtype):
322      # sum of a 2D array with a 1D array where the latter is replicated across
323      # dimension 0 to match the former's shape.
324      c = self._NewComputation()
325      ops.Add(
326          ops.Constant(c,
327                       np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
328                                dtype=dtype)),
329          ops.Constant(c, np.array([10, 20, 30], dtype=dtype)),
330          broadcast_dimensions=(0,))
331      self._ExecuteAndCompareClose(
332          c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]])
333
334    @parameterized.named_parameters({
335        "testcase_name": "_{}".format(dtype.__name__),
336        "dtype": dtype,
337    } for dtype in float_dtypes)
338    def testSum2DWith1DBroadcastDim1(self, dtype):
339      # sum of a 2D array with a 1D array where the latter is replicated across
340      # dimension 1 to match the former's shape.
341      c = self._NewComputation()
342      ops.Add(
343          ops.Constant(c,
344                       np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
345                                dtype=dtype)),
346          ops.Constant(c, np.array([10, 20, 30], dtype=dtype)),
347          broadcast_dimensions=(1,))
348      self._ExecuteAndCompareClose(
349          c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]])
350
351    @parameterized.named_parameters({
352        "testcase_name": "_{}".format(dtype.__name__),
353        "dtype": dtype,
354    } for dtype in float_dtypes)
355    def testConstantAxpy(self, dtype):
356      c = self._NewComputation()
357      ops.Add(
358          ops.Mul(
359              ops.Constant(c, dtype(2)),
360              ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))),
361          ops.Constant(c, np.array([100, -100, 200, -200], dtype)))
362      self._ExecuteAndCompareClose(
363          c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3)
364
365    def testCustomCall(self):
366      if self.backend.platform != "cpu":
367        self.skipTest("Test requires cpu platform")
368      c = self._NewComputation()
369      for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
370        xla_client.register_custom_call_target(name, fn, platform="cpu")
371      ops.CustomCallWithLayout(
372          c,
373          b"test_subtract_f32",
374          operands=[
375              ops.Constant(c, np.float32(1.25)),
376              ops.Constant(c, np.float32(0.5))
377          ],
378          shape_with_layout=xla_client.Shape.array_shape(
379              np.dtype(np.float32), (), ()),
380          operand_shapes_with_layout=[
381              xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
382              xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
383          ])
384      self._ExecuteAndCompareClose(c, expected=[0.75])
385
386  tests.append(ComputationsWithConstantsTest)
387
388  class PythonCallbackTest(ComputationTest):
389
390    def testPythonCallback(self):
391      if self.backend.platform != "cpu":
392        self.skipTest("Test requires cpu platform")
393      c = self._NewComputation()
394
395      f = lambda x, y: (x + y, x - y)
396
397      arg0 = np.array([9, 43, -101, 22], dtype=np.int32)
398      arg1 = np.array([10, 15, -2, 7], dtype=np.int32)
399      shape = xla_client.shape_from_pyval(arg0)
400      shape = shape.with_major_to_minor_layout_if_absent()
401      p0 = ops.Parameter(c, 0, shape)
402      p1 = ops.Parameter(c, 1, shape)
403      out, keepalive = self.backend.emit_python_callback(
404          f, c, [p0, p1], [shape, shape])
405      self._ExecuteAndCompareExact(
406          c, arguments=[arg0, arg1], expected=[arg0 + arg1, arg0 - arg1])
407      del out, keepalive
408
409    def testTokens(self):
410      if self.backend.platform != "cpu":
411        self.skipTest("Test requires cpu platform")
412      c = self._NewComputation()
413
414      def _Callback(x, y):
415        assert y is None, y
416        return None, x + 1
417
418      arg0 = np.array([9, 43, -101, 22], dtype=np.int32)
419      shape = xla_client.shape_from_pyval(arg0)
420      token_shape = xla_client.Shape.token_shape()
421      p0 = ops.Parameter(c, 0, shape)
422      token = ops.CreateToken(c)
423      out, keepalive = self.backend.emit_python_callback(
424          _Callback, c, [p0, token], [token_shape, shape])
425      out = ops.GetTupleElement(out, 1)
426      self._ExecuteAndCompareExact(c, arguments=[arg0], expected=[arg0 + 1])
427      del out, keepalive
428
429    def testStriding(self):
430      if self.backend.platform != "cpu":
431        self.skipTest("Test requires cpu platform")
432      c = self._NewComputation()
433
434      def _Callback(x):
435        assert x.flags.f_contiguous, x.strides
436        # Force the output array to have C layout, which will require a
437        # transpose back to the expected Fortran layout.
438        return np.ascontiguousarray(x * 2),
439
440      arg0 = np.arange(12, dtype=np.int16).reshape(3, 4)
441      shape_f_layout = xla_client.Shape.array_shape(
442          arg0.dtype, arg0.shape, layout=(0, 1))
443      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
444      out, keepalive = self.backend.emit_python_callback(
445          _Callback, c, [p0], [shape_f_layout], [shape_f_layout])
446      self._ExecuteAndCompareExact(c, arguments=[arg0], expected=[arg0 * 2])
447      del out, keepalive
448
449  tests.append(PythonCallbackTest)
450
451  class ComputationFromProtoTest(absltest.TestCase):
452    """Test computation execution from HLO proto."""
453
454    def setUp(self):
455      super(ComputationFromProtoTest, self).setUp()
456      self.backend = xla_backend()
457
458    def testExecuteFromProto(self):
459      # Build the HLO proto
460      b = xla_client.XlaBuilder("computation")
461      ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
462      serialized_proto = b.build().as_serialized_hlo_module_proto()
463
464      # Load and execute the proto
465      c = xla_client.XlaComputation(serialized_proto)
466      ans, = xla_client.execute_with_python_values(
467          self.backend.compile(c), (), backend=self.backend)
468      np.testing.assert_equal(ans, np.int32(3))
469
470  tests.append(ComputationFromProtoTest)
471
472  class ParametersTest(ComputationTest):
473    """Tests focusing on Parameter ops and argument-passing."""
474
475    @parameterized.named_parameters({
476        "testcase_name": "_{}".format(dtype.__name__),
477        "dtype": dtype,
478    } for dtype in int_dtypes)
479    def testScalarTimesVector(self, dtype):
480      c = self._NewComputation()
481      arg0 = np.array(3, dtype=dtype)
482      arg1 = np.array([10, 15, -2, 7], dtype=dtype)
483      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
484      p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
485      ops.Mul(p0, p1)
486      self._ExecuteAndCompareExact(
487          c, arguments=[arg0, arg1], expected=[arg0 * arg1])
488
489    # TODO(phawkins): test comparison harness doesn't support bfloat16
490    @parameterized.named_parameters({
491        "testcase_name": "_{}".format(dtype.__name__),
492        "dtype": dtype,
493    } for dtype in float_dtypes if dtype != bfloat16)
494    def testScalarMinusVectorExplicitNumbering(self, dtype):
495      # Use explicit numbering and pass parameter_num first. Sub is used since
496      # it's not commutative and can help catch parameter reversal within the
497      # computation.
498      c = self._NewComputation()
499      arg0 = np.array(2.0, dtype=dtype)
500      arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype)
501      p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
502      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
503      ops.Sub(p1, p0)
504      self._ExecuteAndCompareClose(
505          c, arguments=[arg0, arg1], expected=[arg1 - arg0])
506
507  tests.append(ParametersTest)
508
509  class BufferTest(ComputationTest):
510    """Tests focusing on execution with Buffers."""
511
512    def testConstantSum(self):
513      c = self._NewComputation()
514      ops.Add(
515          ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14)))
516      self._ExecuteAndCompareClose(c, expected=[4.25])
517
518    def testOneParameterSum(self):
519      c = self._NewComputation()
520      ops.Add(
521          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
522          ops.Constant(c, np.float32(3.14)))
523      self._ExecuteAndCompareClose(
524          c, arguments=[NumpyArrayF32(1.11)], expected=[4.25])
525
526    def testTwoParameterSum(self):
527      c = self._NewComputation()
528      ops.Add(
529          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
530          ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.))))
531      self._ExecuteAndCompareClose(
532          c,
533          arguments=[NumpyArrayF32(1.11),
534                     NumpyArrayF32(3.14)],
535          expected=[4.25])
536
537    @unittest.skipIf(cloud_tpu, "not implemented")
538    def testCannotCallWithDeletedBuffers(self):
539      c = self._NewComputation()
540      ops.Add(
541          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))),
542          ops.Constant(c, np.float32(3.14)))
543      arg = NumpyArrayF32(1.11)
544      compiled_c = self.backend.compile(c.build())
545      arg_buffer = self.backend.buffer_from_pyval(arg)
546      arg_buffer.delete()
547      with self.assertRaises(RuntimeError):
548        compiled_c.execute([arg_buffer])
549
550    def testXlaShape(self):
551      pyval = np.array([[1., 2.]], np.float32)
552      local_buffer = self.backend.buffer_from_pyval(pyval)
553      xla_shape = local_buffer.xla_shape()
554      self.assertEqual(xla_shape.dimensions(), (1, 2))
555      self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
556
557    def testXlaShapeIndex(self):
558      a = xla_client.ShapeIndex((1, 2))
559      b = xla_client.ShapeIndex((1, 2))
560      c = xla_client.ShapeIndex((2, 3))
561      self.assertEqual(a, b)
562      self.assertNotEqual(b, c)
563
564    def testBlockHostUntilReadyWorks(self):
565      arg = np.array([[1., 2.]], np.float32)
566      arg_buffer = self.backend.buffer_from_pyval(arg)
567      arg_buffer.block_host_until_ready()
568      # This test merely checks that nothing goes awry when we call
569      # block_host_until_ready(); it's difficult to test anything else.
570
571    def testBlockHostUntilReadyRaisesOnDeletedBuffer(self):
572      arg = np.array([[1., 2.]], np.float32)
573      buffer = self.backend.buffer_from_pyval(arg)
574      buffer.delete()
575      with self.assertRaisesRegex(
576          RuntimeError,
577          re.escape(
578              "BlockHostUntilReady() called on deleted or donated buffer")):
579        buffer.block_host_until_ready()
580
581    def testDeviceArrayBaseSignatures(self):
582      # When extending `DeviceArrayBase`, the object behaves as a `DeviceArray`
583      # and thus needs to correctly implement the following methods.
584      arg = np.array([[1., 2., 3.]], np.float32)
585      buffer = self.backend.buffer_from_pyval(arg)
586      if not isinstance(buffer, xla_client.DeviceArrayBase):
587        raise unittest.SkipTest(
588            "The objectof type {} do not extend DeviceArrayBase".format(
589                type(buffer)))
590
591      self.assertEqual(buffer.__array_priority__, 100)
592      self.assertEqual(buffer.shape, (1, 3))
593      self.assertEqual(buffer.dtype, np.float32)
594      self.assertEqual(buffer.size, 3)
595      self.assertEqual(buffer.ndim, 2)
596
597      self.assertIs(buffer, buffer.block_until_ready())
598      buffer.delete()
599      with self.assertRaises(RuntimeError):
600        buffer.block_until_ready()
601
602    def testOnDeviceSizeInBytes(self):
603      if not isinstance(self.backend, xla_client.Client):
604        self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.")
605      arg0 = np.array([])
606      arg1 = np.array([[0., 1., 2.]], np.float32)
607      arg2 = np.array([[3., 4., 5.]], bfloat16)
608      arg0_buffer = self.backend.buffer_from_pyval(arg0)
609      arg1_buffer = self.backend.buffer_from_pyval(arg1)
610      arg2_buffer = self.backend.buffer_from_pyval(arg2)
611      self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0)
612      # OnDeviceSizeInBytes varies depending on the platform. Confirm there's
613      # a reasonable value.
614      self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0)
615      self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0)
616
617    def testLiveBuffers(self):
618      if not isinstance(self.backend, xla_client.Client):
619        self.skipTest("TPU Driver doesn't support LiveBuffers().")
620      self.assertEmpty(self.backend.live_buffers())
621      arg0 = np.array([])
622      arg1 = np.array([[0., 1., 2.]], np.float32)
623      arg2 = np.array([[3., 4., 5.]], bfloat16)
624      arg0_buffer = self.backend.buffer_from_pyval(arg0)
625      arg1_buffer = self.backend.buffer_from_pyval(arg1)
626      arg2_buffer = self.backend.buffer_from_pyval(arg2)
627      self.assertLen(self.backend.live_buffers(), 3)
628      self.assertIs(self.backend.live_buffers()[0], arg2_buffer)
629      self.assertIs(self.backend.live_buffers()[1], arg1_buffer)
630      self.assertIs(self.backend.live_buffers()[2], arg0_buffer)
631      self.assertEqual(self.backend.devices()[0].live_buffers(),
632                       self.backend.live_buffers())
633
634      arg1_buffer.delete()
635      self.assertLen(self.backend.live_buffers(), 2)
636      self.assertIs(self.backend.live_buffers()[0], arg2_buffer)
637      self.assertIs(self.backend.live_buffers()[1], arg0_buffer)
638
639      arg0_buffer.delete()
640      arg2_buffer.delete()
641      self.assertEmpty(self.backend.live_buffers())
642
643    def testCopyToHost(self):
644      arg0 = np.array([[1., 2.]], np.float32)
645      arg1 = np.array([[3., 4.]], np.float32)
646      arg0_buffer = self.backend.buffer_from_pyval(arg0)
647      arg1_buffer = self.backend.buffer_from_pyval(arg1)
648      # Prefetch two buffers using copy_to_host_async, and then retrieve their
649      # values using to_py.
650      arg0_buffer.copy_to_host_async()
651      arg0_buffer.copy_to_host_async()  # Duplicate calls don't do anything.
652      arg1_buffer.copy_to_host_async()
653      np.testing.assert_equal(arg0, arg0_buffer.to_py())
654      np.testing.assert_equal(arg1, arg1_buffer.to_py())
655      # copy_to_host_async does nothing after to_py is called.
656      arg0_buffer.copy_to_host_async()
657      np.testing.assert_equal(arg0, arg0_buffer.to_py())
658
659    def testDevice(self):
660      x = np.arange(8, dtype=np.int32)
661      for device in self.backend.local_devices():
662        buf = self.backend.buffer_from_pyval(x, device=device)
663        self.assertEqual(buf.device(), device)
664        np.testing.assert_equal(x, buf.to_py())
665
666    def testStandardTypes(self):
667      for dtype in standard_dtypes:
668        if dtype == bfloat16 or dtype == np.complex128:
669          continue
670        arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype))
671        arr = arr.to_py()
672        self.assertEqual(dtype, type(arr[0]))
673
674    def testUnsafeBufferPointer(self):
675      if not isinstance(self.backend, xla_client.Client):
676        self.skipTest("TPU Driver doesn't support UnsafeBufferPointer().")
677      arg0 = np.array([])
678      arg1 = np.array([[0., 1., 2.]], np.float32)
679      arg2 = np.array([[3., 4., 5.]], bfloat16)
680      arg0_buffer = self.backend.buffer_from_pyval(arg0)
681      arg1_buffer = self.backend.buffer_from_pyval(arg1)
682      arg2_buffer = self.backend.buffer_from_pyval(arg2)
683      self.assertGreaterEqual(arg0_buffer.unsafe_buffer_pointer(), 0)
684      self.assertGreaterEqual(arg1_buffer.unsafe_buffer_pointer(), 0)
685      self.assertGreaterEqual(arg2_buffer.unsafe_buffer_pointer(), 0)
686
687    @unittest.skipIf(cloud_tpu, "not implemented")
688    def testClone(self):
689      x = np.array([[3., 4., 5.]], np.float32)
690      y = self.backend.buffer_from_pyval(x)
691      z = y.clone()
692      self.assertNotEqual(id(x), id(y))
693      np.testing.assert_array_equal(y.to_py(), z.to_py())
694      self.assertEqual(y.unsafe_buffer_pointer(), z.unsafe_buffer_pointer())
695
696    @unittest.skipIf(cloud_tpu, "not implemented")
697    def testJaxAttributesHaveCorrectDefaults(self):
698      x = np.array([[3., 4., 5.]], np.float32)
699      y = self.backend.buffer_from_pyval(x)
700      self.assertIsNone(y.aval)
701      self.assertIsNone(y._device)
702
703  tests.append(BufferTest)
704
705  class SingleOpTest(ComputationTest):
706    """Tests for single ops.
707
708    The goal here is smoke testing - to exercise the most basic functionality of
709    single XLA ops. As minimal as possible number of additional ops are added
710    around the op being tested.
711    """
712
713    @parameterized.named_parameters({
714        "testcase_name": "_{}".format(dtype.__name__),
715        "dtype": dtype,
716    } for dtype in float_dtypes)
717    def testConcatenate(self, dtype):
718      c = self._NewComputation()
719      args = (
720          ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)),
721          ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)),
722      )
723      ops.ConcatInDim(c, args, dimension=0)
724      self._ExecuteAndCompareExact(
725          c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)])
726
727    # pyformat: disable
728    @parameterized.named_parameters({
729        "testcase_name": "_{}_{}".format(src_dtype.__name__,
730                                         dst_dtype.__name__),
731        "src_dtype": src_dtype,
732        "dst_dtype": dst_dtype,
733    } for src_dtype, dst_dtype in itertools.permutations(
734        [np.bool_, np.int32, np.int64, np.float32, np.float64], 2))
735    # pyformat: enable
736    def testConvertElementType(self, src_dtype, dst_dtype):
737      if ((src_dtype in [np.int64, np.float64] or
738           dst_dtype in [np.int64, np.float64]) and
739          self.backend.platform == "tpu"):
740        self.skipTest("TPU doesn't support float64")
741      c = self._NewComputation()
742      x = np.array([0, 1, 0, 0, 1], dtype=src_dtype)
743      ops.ConvertElementType(
744          ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
745
746      result = xla_client.execute_with_python_values(
747          self.backend.compile(c.build()), (), backend=self.backend)
748      self.assertLen(result, 1)
749      expected = np.array(x, dtype=dst_dtype)
750
751      self.assertEqual(result[0].shape, expected.shape)
752      self.assertEqual(result[0].dtype, expected.dtype)
753      np.testing.assert_equal(result[0], expected)
754
755    # pyformat: disable
756    @parameterized.named_parameters(
757        {
758            "testcase_name": "_{}_{}".format(src_dtype.__name__,
759                                             dst_dtype.__name__),
760            "src_dtype": src_dtype,
761            "dst_dtype": dst_dtype,
762        }
763        for dtypes in [[np.int32, np.float32], [np.int64, np.float64]]
764        for src_dtype, dst_dtype in itertools.permutations(dtypes, 2))
765    # pyformat: enable
766    def testBitcastConvertType(self, src_dtype, dst_dtype):
767      if (np.float64 in (src_dtype, dst_dtype) and
768          self.backend.platform == "tpu"):
769        self.skipTest("TPU doesn't support float64")
770      c = self._NewComputation()
771      x = np.array([0, 1, 0, 0, 1], dtype=src_dtype)
772      ops.BitcastConvertType(
773          ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype))
774
775      result = xla_client.execute_with_python_values(
776          self.backend.compile(c.build()), (), backend=self.backend)
777      self.assertLen(result, 1)
778      expected = x.view(dst_dtype)
779
780      self.assertEqual(result[0].shape, expected.shape)
781      self.assertEqual(result[0].dtype, expected.dtype)
782      np.testing.assert_equal(result[0], expected)
783
784    # TODO(b/123523486) implement AllToAll on CPU
785    def DISABLED_testAllToAllOneReplica(self):
786      samples = [
787          NumpyArrayF32([97.0]),
788          NumpyArrayF32([64.0, 117.0]),
789          NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
790      ]
791      for lhs in samples[:1]:
792        c = self._NewComputation()
793        ops.AllToAll(ops.Constant(c, lhs), 0, 0)
794        self._ExecuteAndCompareExact(c, expected=[lhs])
795
796    def testCrossReplicaSumOneReplica(self):
797      samples = [
798          NumpyArrayF32(42.0),
799          NumpyArrayF32([97.0]),
800          NumpyArrayF32([64.0, 117.0]),
801          NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
802      ]
803      for lhs in samples:
804        c = self._NewComputation()
805        ops.CrossReplicaSum(ops.Constant(c, lhs))
806        self._ExecuteAndCompareExact(c, expected=[lhs])
807
808    def testReplicaId(self):
809      c = self._NewComputation()
810      _ = ops.ReplicaId(c)
811      self._ExecuteAndCompareExact(c, expected=[0])
812
813    def testCrossReplicaSumOneReplicaWithSingletonGroup(self):
814      samples = [
815          NumpyArrayF32(42.0),
816          NumpyArrayF32([97.0]),
817          NumpyArrayF32([64.0, 117.0]),
818          NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
819      ]
820      for lhs in samples:
821        c = self._NewComputation()
822        ops.CrossReplicaSum(
823            ops.Constant(c, lhs), xla_client.make_replica_groups([[0]]))
824        self._ExecuteAndCompareExact(c, expected=[lhs])
825
826    # TODO(phawkins): np.dot implementation doesn't support bfloat16
827    @parameterized.named_parameters({
828        "testcase_name": "_{}".format(dtype.__name__),
829        "dtype": dtype,
830    } for dtype in float_dtypes if dtype != bfloat16)
831    def testDotMatrixVector(self, dtype):
832      c = self._NewComputation()
833      lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype)
834      rhs = np.array([[10.0], [20.0]], dtype=dtype)
835      ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs))
836      self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)])
837
838    # TODO(phawkins): np.dot implementation doesn't support bfloat16
839    @parameterized.named_parameters({
840        "testcase_name": "_{}".format(dtype.__name__),
841        "dtype": dtype,
842    } for dtype in float_dtypes if dtype != bfloat16)
843    def testDotMatrixMatrix(self, dtype):
844      c = self._NewComputation()
845      lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype)
846      rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype)
847      ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs))
848      self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)])
849
850    def testDotGeneral(self):
851      c = self._NewComputation()
852      rng = np.random.RandomState(0)
853      lhs = NumpyArrayF32(rng.randn(10, 3, 4))
854      rhs = NumpyArrayF32(rng.randn(10, 4, 5))
855      dimension_numbers = xla_client.make_dot_dimension_numbers(
856          (([2], [1]), ([0], [0])))
857      ops.DotGeneral(
858          ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers)
859      self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6)
860
861    def testDotGeneralWithDotDimensionNumbersProto(self):
862      c = self._NewComputation()
863      rng = np.random.RandomState(0)
864      lhs = NumpyArrayF32(rng.randn(10, 3, 4))
865      rhs = NumpyArrayF32(rng.randn(10, 4, 5))
866
867      dimension_numbers = xla_client.DotDimensionNumbers()
868      dimension_numbers.lhs_contracting_dimensions.append(2)
869      dimension_numbers.rhs_contracting_dimensions.append(1)
870      dimension_numbers.lhs_batch_dimensions.append(0)
871      dimension_numbers.rhs_batch_dimensions.append(0)
872
873      ops.DotGeneral(
874          ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers)
875      self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6)
876
877    def testDotGeneralWithPrecisionConfig(self):
878      c = self._NewComputation()
879      rng = np.random.RandomState(0)
880      lhs = NumpyArrayF32(rng.randn(10, 3, 4))
881      rhs = NumpyArrayF32(rng.randn(10, 4, 5))
882      dimension_numbers = xla_client.make_dot_dimension_numbers(
883          (([2], [1]), ([0], [0])))
884      config = xla_client.PrecisionConfig()
885      config.operand_precision.append(config.Precision.HIGH)
886      config.operand_precision.append(config.Precision.HIGHEST)
887      ops.DotGeneral(
888          ops.Constant(c, lhs),
889          ops.Constant(c, rhs),
890          dimension_numbers,
891          precision_config=config)
892      self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6)
893
894    def testConvGeneralDilatedF32(self):
895      c = self._NewComputation()
896      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
897      lhs = a(1, 1, 2, 3)
898      rhs = a(1, 1, 1, 2) * 10
899      strides = [1, 1]
900      pads = [(1, 0), (0, 1)]
901      lhs_dilation = (2, 1)
902      rhs_dilation = (1, 1)
903      dimension_numbers = xla_client.make_convolution_dimension_numbers(
904          ("NCHW", "OIHW", "NCHW"), 2)
905      ops.ConvGeneralDilated(
906          ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads,
907          lhs_dilation, rhs_dilation, dimension_numbers)
908      result = np.array([[[
909          [0., 0., 0.],
910          [10., 20., 0.],
911          [0., 0., 0.],
912          [40., 50., 0.],
913      ]]])
914      self._ExecuteAndCompareClose(c, expected=[result])
915
916    def testConvGeneralDilatedF32WithPrecisionConfig(self):
917      c = self._NewComputation()
918      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
919      lhs = a(1, 1, 2, 3)
920      rhs = a(1, 1, 1, 2) * 10
921      strides = [1, 1]
922      pads = [(1, 0), (0, 1)]
923      lhs_dilation = (2, 1)
924      rhs_dilation = (1, 1)
925      dimension_numbers = xla_client.make_convolution_dimension_numbers(
926          ("NCHW", "OIHW", "NCHW"), 2)
927      config = xla_client.PrecisionConfig()
928      config.operand_precision.append(config.Precision.HIGHEST)
929      config.operand_precision.append(config.Precision.DEFAULT)
930      ops.ConvGeneralDilated(
931          ops.Constant(c, lhs),
932          ops.Constant(c, rhs),
933          strides,
934          pads,
935          lhs_dilation,
936          rhs_dilation,
937          dimension_numbers,
938          precision_config=config)
939      result = np.array([[[
940          [0., 0., 0.],
941          [10., 20., 0.],
942          [0., 0., 0.],
943          [40., 50., 0.],
944      ]]])
945      self._ExecuteAndCompareClose(c, expected=[result])
946
947    def testConvGeneralDilatedPermutedF32(self):
948      c = self._NewComputation()
949      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
950      lhs = a(1, 1, 2, 3)
951      rhs = a(1, 1, 1, 2) * 10
952      strides = [1, 1]
953      pads = [(1, 0), (0, 1)]
954      lhs_dilation = (2, 1)
955      rhs_dilation = (1, 1)
956
957      dimension_numbers = xla_client.make_convolution_dimension_numbers(
958          ("NHWC", "OIHW", "CWNH"), 2)
959      ops.ConvGeneralDilated(
960          ops.Constant(c, np.transpose(lhs,
961                                       (0, 2, 3, 1))), ops.Constant(c, rhs),
962          strides, pads, lhs_dilation, rhs_dilation, dimension_numbers)
963      result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.],
964                           [40., 50., 0.]]]])
965      self._ExecuteAndCompareClose(
966          c, expected=[np.transpose(result, (1, 3, 0, 2))])
967
968    def testConvGeneralDilatedGroupedConvolutionF32(self):
969      c = self._NewComputation()
970      a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
971      lhs = a(1, 2, 2, 3)
972      rhs = a(2, 1, 1, 2) * 10
973      strides = [1, 1]
974      pads = [(1, 0), (0, 1)]
975      lhs_dilation = (2, 1)
976      rhs_dilation = (1, 1)
977      dimension_numbers = xla_client.make_convolution_dimension_numbers(
978          ("NCHW", "OIHW", "NCHW"), 2)
979      feature_group_count = 2
980      ops.ConvGeneralDilated(
981          ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads,
982          lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count)
983      result = np.array([[[
984          [0., 0., 0.],
985          [10., 20., 0.],
986          [0., 0., 0.],
987          [40., 50., 0.],
988      ], [
989          [0., 0., 0.],
990          [330., 380., 160.],
991          [0., 0., 0.],
992          [480., 530., 220.],
993      ]]])
994      self._ExecuteAndCompareClose(c, expected=[result])
995
996    def testBooleanNot(self):
997      c = self._NewComputation()
998      arr = NumpyArrayBool([True, False, True])
999      ops.Not(ops.Constant(c, arr))
1000      self._ExecuteAndCompareClose(c, expected=[~arr])
1001
1002    def testPopulationCount(self):
1003      c = self._NewComputation()
1004      arr = NumpyArrayS32([3, 0, 1])
1005      ops.PopulationCount(ops.Constant(c, arr))
1006      self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])])
1007
1008    def testCountLeadingZeros(self):
1009      c = self._NewComputation()
1010      arr = NumpyArrayS32([0x7FFF, 0x12345678])
1011      ops.Clz(ops.Constant(c, arr))
1012      self._ExecuteAndCompareClose(c, expected=[[17, 3]])
1013
1014    def testExp(self):
1015      c = self._NewComputation()
1016      arr = NumpyArrayF32([3.3, 12.1])
1017      ops.Exp(ops.Constant(c, arr))
1018      self._ExecuteAndCompareClose(c, expected=[np.exp(arr)])
1019
1020    def testExpm1(self):
1021      c = self._NewComputation()
1022      arr = NumpyArrayF32([3.3, 12.1])
1023      ops.Expm1(ops.Constant(c, arr))
1024      self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)])
1025
1026    def testRound(self):
1027      c = self._NewComputation()
1028      arr = NumpyArrayF32([3.3, 12.1])
1029      ops.Round(ops.Constant(c, arr))
1030      self._ExecuteAndCompareClose(c, expected=[np.round(arr)])
1031
1032    def testLog(self):
1033      c = self._NewComputation()
1034      arr = NumpyArrayF32([3.3, 12.1])
1035      ops.Log(ops.Constant(c, arr))
1036      self._ExecuteAndCompareClose(c, expected=[np.log(arr)])
1037
1038    def testLog1p(self):
1039      c = self._NewComputation()
1040      arr = NumpyArrayF32([3.3, 12.1])
1041      ops.Log1p(ops.Constant(c, arr))
1042      self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)])
1043
1044    def testNeg(self):
1045      c = self._NewComputation()
1046      arr = NumpyArrayF32([3.3, 12.1])
1047      ops.Neg(ops.Constant(c, arr))
1048      self._ExecuteAndCompareClose(c, expected=[-arr])
1049
1050    def testFloor(self):
1051      c = self._NewComputation()
1052      arr = NumpyArrayF32([3.3, 12.1])
1053      ops.Floor(ops.Constant(c, arr))
1054      self._ExecuteAndCompareClose(c, expected=[np.floor(arr)])
1055
1056    def testCeil(self):
1057      c = self._NewComputation()
1058      arr = NumpyArrayF32([3.3, 12.1])
1059      ops.Ceil(ops.Constant(c, arr))
1060      self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)])
1061
1062    def testAbs(self):
1063      c = self._NewComputation()
1064      arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
1065      ops.Abs(ops.Constant(c, arr))
1066      self._ExecuteAndCompareClose(c, expected=[np.abs(arr)])
1067
1068    def testTanhF32(self):
1069      c = self._NewComputation()
1070      arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001])
1071      ops.Tanh(ops.Constant(c, arr))
1072      self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)])
1073
1074    def testTanhF64(self):
1075      if self.backend.platform == "tpu":
1076        self.skipTest("TPU doesn't support 64bit tanh")
1077      c = self._NewComputation()
1078      arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001])
1079      ops.Tanh(ops.Constant(c, arr))
1080      self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12)
1081
1082    def testTranspose(self):
1083
1084      def _TransposeAndTest(array, permutation):
1085        c = self._NewComputation()
1086        ops.Transpose(ops.Constant(c, array), permutation)
1087        expected = np.transpose(array, permutation)
1088        self._ExecuteAndCompareClose(c, expected=[expected])
1089
1090      _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
1091      _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
1092      _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
1093      _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
1094
1095      arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
1096      for permutation in itertools.permutations(range(arr.ndim)):
1097        _TransposeAndTest(arr, permutation)
1098        _TransposeAndTest(np.asfortranarray(arr), permutation)
1099
1100    def testEq(self):
1101      c = self._NewComputation()
1102      ops.Eq(
1103          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])),
1104          ops.Constant(c, NumpyArrayS32([4, 2, 3, 1])))
1105      self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]])
1106
1107    def testNe(self):
1108      c = self._NewComputation()
1109      ops.Ne(
1110          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])),
1111          ops.Constant(c, NumpyArrayS32([4, 2, 3, 1])))
1112      self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]])
1113
1114      ops.Ne(
1115          ops.Constant(c, NumpyArrayF32([-2.0, 0.0,
1116                                         float("nan"),
1117                                         float("nan")])),
1118          ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0,
1119                                         float("nan")])))
1120      self._ExecuteAndAssertWith(
1121          np.testing.assert_allclose,
1122          c, (),
1123          expected=[[True, False, True, True]])
1124
1125    def testGt(self):
1126      c = self._NewComputation()
1127      ops.Gt(
1128          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1129          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1130      self._ExecuteAndCompareExact(
1131          c, expected=[[False, True, True, False, False]])
1132
1133    def testGe(self):
1134      c = self._NewComputation()
1135      ops.Ge(
1136          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1137          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1138      self._ExecuteAndCompareExact(
1139          c, expected=[[True, True, True, False, False]])
1140
1141    def testLt(self):
1142      c = self._NewComputation()
1143      ops.Lt(
1144          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1145          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1146      self._ExecuteAndCompareExact(
1147          c, expected=[[False, False, False, True, True]])
1148
1149    def testLe(self):
1150      c = self._NewComputation()
1151      ops.Le(
1152          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])),
1153          ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12])))
1154      self._ExecuteAndCompareExact(
1155          c, expected=[[True, False, False, True, True]])
1156
1157    def testMax(self):
1158      c = self._NewComputation()
1159      ops.Max(
1160          ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
1161          ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
1162      self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]])
1163
1164    def testMaxExplicitBroadcastDim0(self):
1165      c = self._NewComputation()
1166      ops.Max(
1167          ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1168          ops.Constant(c, NumpyArrayF32([3, 4, 5])),
1169          broadcast_dimensions=(0,))
1170      self._ExecuteAndCompareExact(
1171          c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]])
1172
1173    def testMaxExplicitBroadcastDim1(self):
1174      c = self._NewComputation()
1175      ops.Max(
1176          ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1177          ops.Constant(c, NumpyArrayF32([3, 4, 5])),
1178          broadcast_dimensions=(1,))
1179      self._ExecuteAndCompareExact(
1180          c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]])
1181
1182    def testMin(self):
1183      c = self._NewComputation()
1184      ops.Min(
1185          ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
1186          ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
1187      self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]])
1188
1189    def testPad(self):
1190      c = self._NewComputation()
1191      ops.Pad(
1192          ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
1193          ops.Constant(c, NumpyArrayF32(0.0)),
1194          xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)]))
1195      self._ExecuteAndCompareClose(
1196          c,
1197          expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
1198                     [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]])
1199
1200    def testPadWithPaddingConfig(self):
1201      c = self._NewComputation()
1202      padding_config = xla_client.PaddingConfig()
1203      for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]:
1204        dimension = xla_client.PaddingConfigDimension()
1205        dimension.edge_padding_low = lo
1206        dimension.edge_padding_high = hi
1207        dimension.interior_padding = interior
1208        padding_config.dimensions.append(dimension)
1209      ops.Pad(
1210          ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
1211          ops.Constant(c, NumpyArrayF32(0.0)), padding_config)
1212      self._ExecuteAndCompareClose(
1213          c,
1214          expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
1215                     [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]])
1216
1217    def testReshape(self):
1218      c = self._NewComputation()
1219      ops.Reshape(
1220          ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
1221          dimensions=[0, 1],
1222          new_sizes=[2, 3])
1223      self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]])
1224
1225    def testCollapse(self):
1226      c = self._NewComputation()
1227      ops.Collapse(
1228          ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1229          dimensions=[1, 2])
1230      self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]])
1231
1232    def testRev(self):
1233      c = self._NewComputation()
1234      ops.Rev(
1235          ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1236          dimensions=[0, 2])
1237      self._ExecuteAndCompareExact(
1238          c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]])
1239
1240    def testReducePrecision(self):
1241      c = self._NewComputation()
1242      ops.ReducePrecision(
1243          ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])),
1244          exponent_bits=8,
1245          mantissa_bits=7)
1246      self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]])
1247
1248    def testClampF32(self):
1249      c = self._NewComputation()
1250      ops.Clamp(
1251          ops.Constant(c, NumpyArrayF32(-1)),
1252          ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
1253          ops.Constant(c, NumpyArrayF32(2)))
1254      self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]])
1255
1256    def testClampS32(self):
1257      c = self._NewComputation()
1258      ops.Clamp(
1259          ops.Constant(c, NumpyArrayS32(-1)),
1260          ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
1261          ops.Constant(c, NumpyArrayS32(2)))
1262      self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]])
1263
1264    def testSelect(self):
1265      c = self._NewComputation()
1266      ops.Select(
1267          ops.Constant(c, NumpyArrayBool([True, False, False, True, False])),
1268          ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])),
1269          ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5])))
1270      self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]])
1271
1272    def testSlice(self):
1273      c = self._NewComputation()
1274      ops.Slice(
1275          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1276          [1, 0], [3, 2], [1, 1])
1277      self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]])
1278
1279    def testSliceInDim(self):
1280      c = self._NewComputation()
1281      ops.SliceInDim(
1282          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1283          start_index=1,
1284          limit_index=2,
1285          stride=1,
1286          dimno=1)
1287      self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]])
1288      ops.SliceInDim(
1289          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1290          start_index=0,
1291          limit_index=3,
1292          stride=2,
1293          dimno=0)
1294      self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]])
1295
1296    def testDynamicSlice(self):
1297      c = self._NewComputation()
1298      ops.DynamicSlice(
1299          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1300          [ops.Constant(c, NumpyArrayS32([1, 0]))], [2, 2])
1301      self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]])
1302
1303    def testDynamicUpdateSlice(self):
1304      c = self._NewComputation()
1305      ops.DynamicUpdateSlice(
1306          ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1307          ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])),
1308          [ops.Constant(c, NumpyArrayS32([1, 1]))])
1309      self._ExecuteAndCompareExact(
1310          c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]])
1311
1312    def testTuple(self):
1313      c = self._NewComputation()
1314      ops.Tuple(c, [
1315          ops.Constant(c, np.int32(42)),
1316          ops.Constant(c, NumpyArrayF32([1.0, 2.0])),
1317          ops.Constant(c, NumpyArrayBool([True, False, False, True]))
1318      ])
1319      result = xla_client.execute_with_python_values(
1320          self.backend.compile(c.build()), (), backend=self.backend)
1321      self.assertLen(result, 3)
1322      np.testing.assert_equal(result[0], 42)
1323      np.testing.assert_allclose(result[1], [1.0, 2.0])
1324      np.testing.assert_equal(result[2], [True, False, False, True])
1325
1326    def testGetTupleElement(self):
1327      c = self._NewComputation()
1328      ops.GetTupleElement(
1329          ops.Tuple(c, [
1330              ops.Constant(c, np.int32(42)),
1331              ops.Constant(c, NumpyArrayF32([1.0, 2.0])),
1332              ops.Constant(c, NumpyArrayBool([True, False, False, True]))
1333          ]), 1)
1334      self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]])
1335
1336    def testBroadcast(self):
1337      c = self._NewComputation()
1338      ops.Broadcast(
1339          ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
1340      self._ExecuteAndCompareExact(
1341          c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]])
1342
1343    def testBroadcastInDim(self):
1344      c = self._NewComputation()
1345      ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0])
1346      self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]])
1347      ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1])
1348      self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]])
1349
1350    def testRngNormal(self):
1351      shape = (2, 3)
1352      c = self._NewComputation()
1353      ops.RngNormal(
1354          ops.Constant(c, NumpyArrayF32(0.)),
1355          ops.Constant(c, NumpyArrayF32(1.)),
1356          shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
1357                                             shape))
1358      result = xla_client.execute_with_python_values(
1359          self.backend.compile(c.build()), (), backend=self.backend)
1360      # since the result is random, we just check shape and uniqueness
1361      self.assertLen(result, 1)
1362      self.assertEqual(result[0].shape, shape)
1363      self.assertLen(np.unique(result[0]), np.prod(shape))
1364
1365    def testRngUniformF32(self):
1366      lo, hi = 2., 4.
1367      shape = (2, 3)
1368      c = self._NewComputation()
1369      ops.RngUniform(
1370          ops.Constant(c, NumpyArrayF32(lo)),
1371          ops.Constant(c, NumpyArrayF32(hi)),
1372          shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32,
1373                                             shape))
1374      result = xla_client.execute_with_python_values(
1375          self.backend.compile(c.build()), (), backend=self.backend)
1376      # since the result is random, we just check shape, uniqueness, and range
1377      self.assertLen(result, 1)
1378      self.assertEqual(result[0].shape, shape)
1379      self.assertLen(np.unique(result[0]), np.prod(shape))
1380      self.assertTrue(np.all(lo <= result[0]))
1381      self.assertTrue(np.all(result[0] < hi))
1382
1383    def testRngUniformS32(self):
1384      lo, hi = 2, 4
1385      shape = (2, 3)
1386      c = self._NewComputation()
1387      ops.RngUniform(
1388          ops.Constant(c, NumpyArrayS32(lo)),
1389          ops.Constant(c, NumpyArrayS32(hi)),
1390          shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32,
1391                                             shape))
1392      result = xla_client.execute_with_python_values(
1393          self.backend.compile(c.build()), (), backend=self.backend)
1394      # since the result is random, we just check shape, integrality, and range
1395      self.assertLen(result, 1)
1396      self.assertEqual(result[0].shape, shape)
1397      self.assertEqual(result[0].dtype, np.int32)
1398      self.assertTrue(np.all(lo <= result[0]))
1399      self.assertTrue(np.all(result[0] < hi))
1400
1401    def testCholesky(self):
1402      l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]],
1403                   dtype=np.float32)
1404      c = self._NewComputation()
1405      ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T))))
1406      self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4)
1407
1408    def testSort(self):
1409      keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
1410      c = self._NewComputation()
1411      ops.Sort(c, [ops.Constant(c, keys)], is_stable=True)
1412      self._ExecuteAndCompareClose(
1413          c,
1414          expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)])
1415
1416    def testSortKeyVal(self):
1417      keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
1418      values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
1419      c = self._NewComputation()
1420      ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0)
1421      result = xla_client.execute_with_python_values(
1422          self.backend.compile(c.build()), (), backend=self.backend)
1423      self.assertLen(result, 2)
1424      np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
1425      np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
1426
1427    def testSortCustomComparator(self):
1428      b = self._NewComputation("comparator")
1429      p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0)))
1430      q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))
1431      p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1432      q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1433      ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1)))
1434      comparator = b.build()
1435
1436      keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32)
1437      values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
1438      c = self._NewComputation()
1439      ops.Sort(
1440          c, (ops.Constant(c, keys), ops.Constant(c, values)),
1441          dimension=1,
1442          comparator=comparator)
1443      result = xla_client.execute_with_python_values(
1444          self.backend.compile(c.build()), (), backend=self.backend)
1445      self.assertLen(result, 2)
1446      np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]])
1447      np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]])
1448
1449    def testQR(self):
1450      a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166],
1451                    [10, 63, 166, 310]],
1452                   dtype=np.float32)
1453      c = self._NewComputation()
1454      ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True))
1455      q, r = self._Execute(c, ())
1456      np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4)
1457
1458    def testEigh(self):
1459      a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166],
1460                    [10, 63, 166, 310]],
1461                   dtype=np.float32)
1462      a = (a + a.T) / 2
1463
1464      c = self._NewComputation()
1465      ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True))
1466      # TODO(b/129396575): Turn this test back on when it passes without
1467      # fastmath.
1468      # v, w = self._Execute(c, ())
1469      # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3)
1470
1471    def testSVD(self):
1472      a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166],
1473                    [10, 63, 166, 310]],
1474                   dtype=np.float32)
1475      c = self._NewComputation()
1476      ops.Tuple(c, ops.SVD(ops.Constant(c, a)))
1477      u, d, v = self._Execute(c, ())
1478      self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3)
1479
1480    def testTriangularSolve(self):
1481      a_vals = np.array(
1482          [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]],
1483          dtype=np.float32)
1484      b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
1485                        dtype=np.float32)
1486
1487      c = self._NewComputation()
1488      ops.TriangularSolve(
1489          ops.Constant(c, a_vals),
1490          ops.Constant(c, b_vals),
1491          left_side=False,
1492          lower=True,
1493          transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE,
1494          unit_diagonal=False)
1495      self._ExecuteAndCompareClose(
1496          c,
1497          expected=[
1498              np.array([
1499                  [0.5, 0.08333334, 0.04629629, 0.03367003],
1500                  [2.5, -0.25, -0.1388889, -0.1010101],
1501                  [4.5, -0.58333331, -0.32407406, -0.23569024],
1502              ],
1503                       dtype=np.float32)
1504          ],
1505          rtol=1e-4)
1506
1507    def testIsConstant(self):
1508      c = self._NewComputation()
1509      a = ops.Constant(c, np.int32(3))
1510      b = ops.Constant(c, np.int32(1))
1511      x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0)))
1512      const_expr = ops.Sub(b, a)
1513      non_const_expr = ops.Mul(const_expr, x)
1514      self.assertTrue(c.is_constant(const_expr))
1515      self.assertFalse(c.is_constant(non_const_expr))
1516
1517    def testGather(self):
1518      a = np.arange(9).astype(np.int32).reshape((3, 3))
1519      indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32)
1520      dnums = xla_client.GatherDimensionNumbers()
1521      dnums.offset_dims.append(1)
1522      dnums.offset_dims.append(2)
1523      dnums.start_index_map.append(0)
1524      dnums.start_index_map.append(1)
1525      dnums.index_vector_dim = 2
1526      c = self._NewComputation()
1527      ops.Gather(
1528          ops.Constant(c, a),
1529          ops.Constant(c, indices),
1530          dnums,
1531          slice_sizes=[1, 1])
1532      g, = self._Execute(c, ())
1533      expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32)
1534      np.testing.assert_allclose(g, expected, rtol=1e-4)
1535
1536    def testFft(self):
1537      if self.backend.platform == "tpu":
1538        self.skipTest("TPU only supports 1D FFT")
1539      shape = [2, 3, 4, 5]
1540      rng = np.random.RandomState(0)
1541      a = rng.randn(*shape) + 1.0j * rng.randn(*shape)
1542      a = a.astype(np.complex64)
1543      # FFT
1544      c = self._NewComputation()
1545      ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:])
1546      self._ExecuteAndCompareClose(
1547          c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4)
1548      # IFFT
1549      c = self._NewComputation()
1550      ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:])
1551      self._ExecuteAndCompareClose(
1552          c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4)
1553      # RFFT
1554      b = rng.randn(*shape).astype(np.float32)
1555      c = self._NewComputation()
1556      ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:])
1557      self._ExecuteAndCompareClose(
1558          c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4)
1559      # IRFFT
1560      c = self._NewComputation()
1561      ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8])
1562      self._ExecuteAndCompareClose(
1563          c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4)
1564
1565    def testNextAfter(self):
1566      c = self._NewComputation()
1567      ops.NextAfter(
1568          ops.Constant(c, np.array([1, 2], dtype=np.float32)),
1569          ops.Constant(c, np.array([2, 1], dtype=np.float32)))
1570      out, = self._Execute(c, ())
1571      eps = np.finfo(np.float32).eps
1572      np.testing.assert_equal(
1573          np.array([eps + 1, 2 - eps], dtype=np.float32), out)
1574
1575    @parameterized.named_parameters({
1576        "testcase_name": "_{}".format(dtype.__name__),
1577        "dtype": dtype,
1578    } for dtype in float_dtypes)
1579    def testRegularizedIncompleteBeta(self, dtype):
1580      x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538],
1581                   dtype=dtype)
1582      a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606],
1583                   dtype=dtype)
1584      b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677],
1585                   dtype=dtype)
1586      c = self._NewComputation()
1587      ops.RegularizedIncompleteBeta(
1588          ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x))
1589      expected = np.array(
1590          [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155])
1591      self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2)
1592
1593  tests.append(SingleOpTest)
1594
1595  class EmbeddedComputationsTest(ComputationTest):
1596    """Tests for XLA graphs with embedded computations (such as maps)."""
1597
1598    def _CreateConstantComputation(self, in_dtype, out_dtype):
1599      """Computation (A) -> B that returns a constant 1 for any input."""
1600      c = self._NewComputation("constant_{}_{}_one".format(
1601          in_dtype.__name__, out_dtype.__name__))
1602      ops.Parameter(
1603          c, 0,
1604          xla_client.shape_from_pyval(np.array(
1605              0, dtype=in_dtype)).with_major_to_minor_layout_if_absent())
1606      ops.Constant(c, out_dtype(1))
1607      return c.build()
1608
1609    def _CreateMulBy2Computation(self, dtype):
1610      """Computation (dtype) -> dtype that multiplies its parameter by 2."""
1611      c = self._NewComputation("mul_f32_by2")
1612      ops.Mul(
1613          ops.Parameter(
1614              c, 0,
1615              xla_client.shape_from_pyval(np.array(
1616                  0, dtype=dtype)).with_major_to_minor_layout_if_absent()),
1617          ops.Constant(c, dtype(2.0)))
1618      return c.build()
1619
1620    def _CreateMulF32ByParamComputation(self):
1621      """Computation (f32) -> f32 that multiplies one parameter by the other."""
1622      c = self._NewComputation("mul_f32_by_param")
1623      ops.Mul(
1624          ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))),
1625          ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))))
1626      return c.build()
1627
1628    def _CreateBinaryAddComputation(self, dtype):
1629      """Computation (dtype, dtype) -> dtype that adds its two parameters."""
1630      c = self._NewComputation("add_param0_by_param1")
1631      shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1632      shape = shape.with_major_to_minor_layout_if_absent()
1633      ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
1634      return c.build()
1635
1636    def _CreateBinaryGeComputation(self, dtype):
1637      """Computation (dtype, dtype) -> bool that tests param0 >= param1."""
1638      c = self._NewComputation("param0_lt_param1")
1639      shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1640      shape = shape.with_major_to_minor_layout_if_absent()
1641      ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
1642      return c.build()
1643
1644    def _MakeSample3DArray(self, dtype):
1645      return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1646                       [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]],
1647                      dtype=dtype)
1648
1649    @parameterized.named_parameters({
1650        "testcase_name": "_{}".format(dtype.__name__),
1651        "dtype": dtype,
1652    } for dtype in float_dtypes)
1653    def testCall(self, dtype):
1654      c = self._NewComputation()
1655      ops.Call(
1656          c,
1657          self._CreateMulBy2Computation(dtype),
1658          operands=(ops.Constant(c, dtype(5.0)),))
1659      self._ExecuteAndCompareClose(c, expected=[10.0])
1660
1661    @parameterized.named_parameters({
1662        "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__),
1663        "in_dtype": in_dtype,
1664        "out_dtype": out_dtype,
1665    } for in_dtype, out_dtype in [[np.float32, np.int32]])
1666    def testMapEachElementToConstant(self, in_dtype, out_dtype):
1667      c = self._NewComputation()
1668      ops.Map(c,
1669              [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))],
1670              self._CreateConstantComputation(in_dtype, out_dtype), [0])
1671      self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]])
1672
1673    @parameterized.named_parameters({
1674        "testcase_name": "_{}".format(dtype.__name__),
1675        "dtype": dtype,
1676    } for dtype in float_dtypes)
1677    def testMapMulBy2(self, dtype):
1678      if dtype == np.float64 and self.backend.platform == "tpu":
1679        self.skipTest("TPU doesn't support float64")
1680      c = self._NewComputation()
1681      ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))],
1682              self._CreateMulBy2Computation(dtype), [0])
1683      self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]])
1684
1685    @parameterized.named_parameters({
1686        "testcase_name": "_{}".format(dtype.__name__),
1687        "dtype": dtype,
1688    } for dtype in float_dtypes)
1689    def testSimpleMapChain(self, dtype):
1690      if dtype == np.float64 and self.backend.platform == "tpu":
1691        self.skipTest("TPU doesn't support float64")
1692      # Chains a map of constant-out with a map of mul-by-2
1693      c = self._NewComputation()
1694      const = ops.Map(
1695          c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))],
1696          self._CreateConstantComputation(dtype, dtype), [0])
1697      ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0])
1698      self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]])
1699
1700    # TODO(b/154752816): bfloat16 crashes in evaluator.
1701    @parameterized.named_parameters({
1702        "testcase_name": "_{}".format(dtype.__name__),
1703        "dtype": dtype,
1704    } for dtype in float_dtypes if dtype != bfloat16)
1705    def testDivVectorsWithMap(self, dtype):
1706
1707      def DivComputation():
1708        c = self._NewComputation("div_param0_by_param1")
1709        shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1710        ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape))
1711        return c.build()
1712
1713      c = self._NewComputation()
1714      ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)),
1715                  ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))),
1716              DivComputation(), [0])
1717      self._ExecuteAndCompareClose(
1718          c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3)
1719
1720    @parameterized.named_parameters({
1721        "testcase_name": "_{}".format(dtype.__name__),
1722        "dtype": dtype,
1723    } for dtype in float_dtypes)
1724    def testSelectAndScatter(self, dtype):
1725      if dtype == np.float64 and self.backend.platform == "tpu":
1726        self.skipTest("TPU doesn't support float64")
1727      c = self._NewComputation()
1728      operand = ops.Constant(
1729          c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype))
1730      window_dimensions = (2, 1)
1731      window_strides = (1, 2)
1732      padding = xla_client.window_padding_type_to_pad_values(
1733          xla_client.PaddingType.VALID,
1734          c.get_shape(operand).dimensions(), window_dimensions, window_strides)
1735      ops.SelectAndScatterWithGeneralPadding(
1736          operand,
1737          select=self._CreateBinaryGeComputation(dtype),
1738          window_dimensions=window_dimensions,
1739          window_strides=window_strides,
1740          padding=padding,
1741          source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)),
1742          init_value=ops.Constant(c, np.array(1, dtype=dtype)),
1743          scatter=self._CreateBinaryAddComputation(dtype))
1744      self._ExecuteAndCompareClose(
1745          c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3)
1746
1747    @parameterized.named_parameters({
1748        "testcase_name": "_{}".format(dtype.__name__),
1749        "dtype": dtype,
1750    } for dtype in float_dtypes)
1751    def testReduce1DtoScalar(self, dtype):
1752      c = self._NewComputation()
1753      ops.Reduce(
1754          c,
1755          operands=[
1756              ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))
1757          ],
1758          init_values=[ops.Constant(c, dtype(0))],
1759          computation=self._CreateBinaryAddComputation(dtype),
1760          dimensions_to_reduce=[0])
1761      self._ExecuteAndCompareClose(c, expected=[10])
1762
1763    # TODO(phawkins): test comparison harness doesn't support bfloat16
1764    @parameterized.named_parameters({
1765        "testcase_name": "_{}_dim{}".format(dtype.__name__, dim),
1766        "dtype": dtype,
1767        "dim": dim,
1768    } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2))
1769    def testReduce2DTo1D(self, dtype, dim):
1770      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1771      c = self._NewComputation()
1772      ops.Reduce(
1773          c,
1774          operands=[ops.Constant(c, input_array)],
1775          init_values=[ops.Constant(c, dtype(0))],
1776          computation=self._CreateBinaryAddComputation(dtype),
1777          dimensions_to_reduce=[dim])
1778      self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)])
1779
1780    @parameterized.named_parameters({
1781        "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims),
1782        "dtype": dtype,
1783        "dims": tuple(dims)
1784    } for dtype in float_dtypes for dims in itertools.permutations(range(3)))
1785    def testReduce3DAllPossibleWaysF32(self, dtype, dims):
1786      input_array = self._MakeSample3DArray(dtype)
1787      c = self._NewComputation()
1788      ops.Reduce(
1789          c,
1790          operands=[ops.Constant(c, input_array)],
1791          init_values=[ops.Constant(c, dtype(0))],
1792          computation=self._CreateBinaryAddComputation(dtype),
1793          dimensions_to_reduce=dims)
1794      self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)])
1795
1796    @parameterized.named_parameters({
1797        "testcase_name": "_{}".format(dtype.__name__),
1798        "dtype": dtype,
1799    } for dtype in float_dtypes)
1800    def testReduceWindowValidUnitStrides(self, dtype):
1801      if dtype == np.float64 and self.backend.platform == "tpu":
1802        self.skipTest("TPU doesn't support float64")
1803      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1804      c = self._NewComputation()
1805      window_dimensions = (2, 1)
1806      window_strides = (1, 1)
1807      padding = xla_client.window_padding_type_to_pad_values(
1808          xla_client.PaddingType.VALID, input_array.shape, window_dimensions,
1809          window_strides)
1810      ops.ReduceWindowWithGeneralPadding(
1811          operand=ops.Constant(c, input_array),
1812          init_value=ops.Constant(c, dtype(0)),
1813          computation=self._CreateBinaryAddComputation(dtype),
1814          window_dimensions=window_dimensions,
1815          window_strides=window_strides,
1816          base_dilations=[],
1817          window_dilations=[],
1818          padding=padding)
1819      self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]])
1820
1821    @parameterized.named_parameters({
1822        "testcase_name": "_{}".format(dtype.__name__),
1823        "dtype": dtype,
1824    } for dtype in float_dtypes)
1825    def testReduceWindowSameUnitStrides(self, dtype):
1826      if dtype == np.float64 and self.backend.platform == "tpu":
1827        self.skipTest("TPU doesn't support float64")
1828      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1829      c = self._NewComputation()
1830      window_dimensions = (2, 1)
1831      window_strides = (1, 1)
1832      padding = xla_client.window_padding_type_to_pad_values(
1833          xla_client.PaddingType.SAME, input_array.shape, window_dimensions,
1834          window_strides)
1835      ops.ReduceWindowWithGeneralPadding(
1836          operand=ops.Constant(c, input_array),
1837          init_value=ops.Constant(c, dtype(0)),
1838          computation=self._CreateBinaryAddComputation(dtype),
1839          window_dimensions=window_dimensions,
1840          window_strides=window_strides,
1841          base_dilations=[],
1842          window_dilations=[],
1843          padding=padding)
1844      self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]])
1845
1846    @parameterized.named_parameters({
1847        "testcase_name": "_{}".format(dtype.__name__),
1848        "dtype": dtype,
1849    } for dtype in float_dtypes)
1850    def testReduceWindowValidGeneralStrides(self, dtype):
1851      if dtype == np.float64 and self.backend.platform == "tpu":
1852        self.skipTest("TPU doesn't support float64")
1853      input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype)
1854      c = self._NewComputation()
1855      window_dimensions = (2, 1)
1856      window_strides = (1, 2)
1857      padding = xla_client.window_padding_type_to_pad_values(
1858          xla_client.PaddingType.VALID, input_array.shape, window_dimensions,
1859          window_strides)
1860      ops.ReduceWindowWithGeneralPadding(
1861          operand=ops.Constant(c, input_array),
1862          init_value=ops.Constant(c, dtype(0)),
1863          computation=self._CreateBinaryAddComputation(dtype),
1864          window_dimensions=window_dimensions,
1865          window_strides=window_strides,
1866          base_dilations=[],
1867          window_dilations=[],
1868          padding=padding)
1869      self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]])
1870
1871    def testReduceWindowVariadic(self):
1872      c = self._NewComputation("reducer")
1873      shape = xla_client.shape_from_pyval(np.array(0, dtype=np.int32))
1874      shape = shape.with_major_to_minor_layout_if_absent()
1875      ps = [ops.Parameter(c, i, shape) for i in range(4)]
1876      which = ops.Ge(ps[0], ps[2])
1877      ops.Tuple(
1878          c, [ops.Select(which, ps[0], ps[2]),
1879              ops.Select(which, ps[1], ps[3])])
1880      reducer = c.build()
1881
1882      key_array = np.array([[1, 5, 6], [4, 2, 3]], dtype=np.int32)
1883      val_array = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.int32)
1884      c = self._NewComputation()
1885      window_dimensions = (2, 1)
1886      window_strides = (1, 1)
1887      padding = xla_client.window_padding_type_to_pad_values(
1888          xla_client.PaddingType.VALID, key_array.shape, window_dimensions,
1889          window_strides)
1890      ops.ReduceWindowWithGeneralPadding(
1891          operands=[ops.Constant(c, key_array),
1892                    ops.Constant(c, val_array)],
1893          init_values=[
1894              ops.Constant(c, np.int32(0)),
1895              ops.Constant(c, np.int32(0))
1896          ],
1897          computation=reducer,
1898          window_dimensions=window_dimensions,
1899          window_strides=window_strides,
1900          base_dilations=[],
1901          window_dilations=[],
1902          padding=padding)
1903      self._ExecuteAndCompareClose(c, expected=[[[4, 5, 6]], [[10, 8, 9]]])
1904
1905    @parameterized.named_parameters({
1906        "testcase_name": "_{}".format(dtype.__name__),
1907        "dtype": dtype,
1908    } for dtype in float_dtypes)
1909    def testWhile(self, dtype):
1910
1911      def LessThan10Cond():
1912        c = self._NewComputation("test_lt_10")
1913        shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype))
1914        ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.)))
1915        return c.build()
1916
1917      cond = LessThan10Cond()
1918      body = self._CreateMulBy2Computation(dtype)
1919      c = self._NewComputation()
1920      init = ops.Constant(c, dtype(1.))
1921      ops.While(cond, body, init)
1922      self._ExecuteAndCompareClose(c, expected=[16.])
1923
1924    def testConditionalTrue(self):
1925      c = self._NewComputation()
1926      pred = ops.Constant(c, np.bool_(True))
1927      true_operand = ops.Constant(c, np.float32(3.))
1928      true_computation = self._CreateMulBy2Computation(np.float32)
1929      false_operand = ops.Constant(c, np.float32(2.))
1930      false_computation = self._CreateConstantComputation(
1931          np.float32, np.float32)
1932      ops.Conditional(pred, true_operand, true_computation, false_operand,
1933                      false_computation)
1934      self._ExecuteAndCompareClose(c, expected=[6.])
1935
1936    def testConditionalFalse(self):
1937      c = self._NewComputation()
1938      pred = ops.Constant(c, np.bool_(False))
1939      true_operand = ops.Constant(c, np.float32(3.))
1940      true_computation = self._CreateMulBy2Computation(np.float32)
1941      false_operand = ops.Constant(c, np.float32(2.))
1942      false_computation = self._CreateConstantComputation(
1943          np.float32, np.float32)
1944      ops.Conditional(pred, true_operand, true_computation, false_operand,
1945                      false_computation)
1946      self._ExecuteAndCompareClose(c, expected=[1.])
1947
1948    @unittest.skipIf(cloud_tpu, "not implemented")
1949    def testInfeedS32Values(self):
1950      to_infeed = NumpyArrayS32([1, 2, 3, 4])
1951      c = self._NewComputation()
1952      ops.GetTupleElement(
1953          ops.InfeedWithToken(
1954              ops.CreateToken(c),
1955              xla_client.shape_from_pyval(
1956                  to_infeed[0]).with_major_to_minor_layout_if_absent()), 0)
1957      compiled_c = self.backend.compile(c.build())
1958      device = self.backend.local_devices()[0]
1959      for item in to_infeed:
1960        device.transfer_to_infeed(item)
1961
1962      for item in to_infeed:
1963        result, = xla_client.execute_with_python_values(
1964            compiled_c, (), backend=self.backend)
1965        self.assertEqual(result, item)
1966
1967    @unittest.skipIf(cloud_tpu, "not implemented")
1968    def testInfeedTuple(self):
1969      to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]]))
1970      c = self._NewComputation()
1971      ops.GetTupleElement(
1972          ops.InfeedWithToken(
1973              ops.CreateToken(c),
1974              xla_client.shape_from_pyval(
1975                  to_infeed).with_major_to_minor_layout_if_absent()), 0)
1976      compiled_c = self.backend.compile(c.build())
1977      device = self.backend.local_devices()[0]
1978      device.transfer_to_infeed(to_infeed)
1979
1980      result = xla_client.execute_with_python_values(
1981          compiled_c, (), backend=self.backend)
1982      self.assertLen(result, 2)
1983      np.testing.assert_equal(result[0], to_infeed[0])
1984      np.testing.assert_equal(result[1], to_infeed[1])
1985
1986    @unittest.skipIf(cloud_tpu, "not implemented")
1987    def testInfeedThenOutfeedS32(self):
1988      to_round_trip = NumpyArrayS32([1, 2, 3, 4])
1989      c = self._NewComputation()
1990      x_and_token = ops.InfeedWithToken(
1991          ops.CreateToken(c),
1992          xla_client.shape_from_pyval(
1993              to_round_trip[0]).with_major_to_minor_layout_if_absent())
1994      x = ops.GetTupleElement(x_and_token, 0)
1995      token = ops.GetTupleElement(x_and_token, 1)
1996      outfeed_shape = xla_client.shape_from_pyval(
1997          to_round_trip[0]).with_major_to_minor_layout_if_absent()
1998      ops.OutfeedWithToken(x, token, outfeed_shape)
1999
2000      compiled_c = self.backend.compile(c.build())
2001      device = self.backend.local_devices()[0]
2002
2003      for want in to_round_trip:
2004        execution = threading.Thread(target=lambda: compiled_c.execute([]))
2005        execution.start()
2006        device.transfer_to_infeed(want)
2007        got = device.transfer_from_outfeed(outfeed_shape)
2008        execution.join()
2009        self.assertEqual(want, got)
2010
2011    def testScatter(self):
2012      a = np.arange(9).astype(np.int32).reshape((3, 3))
2013      scatter_indices = np.array([0, 2], dtype=np.int32)
2014      updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32)
2015
2016      dnums = xla_client.ScatterDimensionNumbers()
2017      dnums.update_window_dims.append(1)
2018      dnums.inserted_window_dims.append(0)
2019      dnums.scatter_dims_to_operand_dims.append(0)
2020      dnums.index_vector_dim = 1
2021
2022      c = self._NewComputation()
2023      ops.Scatter(
2024          ops.Constant(c, a), ops.Constant(c, scatter_indices),
2025          ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32),
2026          dnums)
2027      expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]],
2028                          dtype=np.int32)
2029      self._ExecuteAndCompareClose(c, expected=[expected])
2030
2031  class DeviceTest(ComputationTest):
2032
2033    def testPlatform(self):
2034      for device in self.backend.local_devices():
2035        self.assertEqual(device.platform, self.backend.platform)
2036
2037  tests.append(DeviceTest)
2038
2039  class ErrorTest(ComputationTest):
2040
2041    def setUp(self):
2042      super(ErrorTest, self).setUp()
2043      self.f32_scalar_2 = NumpyArrayF32(2.0)
2044      self.s32_scalar_2 = NumpyArrayS32(2)
2045
2046    def testCompileWithWrongElementTypeInLayout(self):
2047      c = self._NewComputation()
2048      c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
2049      ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
2050      c.clear_op_metadata()
2051
2052      options = xla_client.CompileOptions()
2053      options.argument_layouts = [
2054          xla_client.Shape.array_shape(np.dtype(np.float32), [])
2055      ]
2056
2057      def TestFun():
2058        return self.backend.compile(c.build(), compile_options=options)
2059
2060      self.assertRaisesRegex(
2061          RuntimeError, r".*Invalid argument shape.*"
2062          r"expected s32\[\], got f32\[\].*", TestFun)
2063
2064    def testInvokeWithWrongElementType(self):
2065      c = self._NewComputation()
2066      c.set_op_metadata(xla_client.CurrentSourceInfoMetadata())
2067      ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2))
2068      c.clear_op_metadata()
2069
2070      def TestFun():
2071        return xla_client.execute_with_python_values(
2072            self.backend.compile(c.build()), [self.f32_scalar_2], self.backend)
2073
2074      self.assertRaisesRegex(
2075          RuntimeError, r"Invalid argument: Argument does not match.*"
2076          r"want s32\[\], got f32\[\].*", TestFun)
2077
2078  tests.append(EmbeddedComputationsTest)
2079
2080  class ComputationRootTest(ComputationTest):
2081    """Tests related to setting the root of the computation."""
2082
2083    def testComputationRootDifferentFromLastOp(self):
2084      c = self._NewComputation()
2085      x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0)))
2086      result = ops.Add(x, ops.Constant(c, np.float32(3.14)))
2087      ops.Add(result, ops.Constant(c, np.float32(1.618)))
2088
2089      arg = NumpyArrayF32(1.0)
2090      compiled_c = self.backend.compile(c.build(result))
2091      ans, = xla_client.execute_with_python_values(
2092          compiled_c, [arg], backend=self.backend)
2093      np.testing.assert_allclose(ans, 4.14)
2094
2095  tests.append(ComputationRootTest)
2096
2097  class SetShardingTest(ComputationTest):
2098    """Tests related to set OpSharding."""
2099
2100    def testSetSharding(self):
2101      c = self._NewComputation()
2102      sharding = xla_client.OpSharding()
2103      sharding.type = sharding.type.REPLICATED
2104      sharding.tile_assignment_dimensions.extend([1])
2105      sharding.tile_assignment_devices.extend([0])
2106      c.set_sharding(sharding)
2107      x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0)))
2108      c.clear_sharding()
2109
2110      result = ops.Add(x, ops.Constant(c, np.float32(3.14)))
2111      ops.Add(result, ops.Constant(c, np.float32(1.618)))
2112      arg = NumpyArrayF32(1.0)
2113      compiled_c = self.backend.compile(c.build(result))
2114      ans, = xla_client.execute_with_python_values(
2115          compiled_c, [arg], backend=self.backend)
2116      np.testing.assert_allclose(ans, 4.14)
2117
2118  tests.append(SetShardingTest)
2119
2120  testcase_shapes = [
2121      (),
2122      (1,),
2123      (2, 3),
2124      (2, 0),
2125      (0, 7),
2126      (4, 1, 2),
2127      (2, 1, 3),
2128      (2, 4, 1),
2129      (3, 1),
2130      (1, 3),
2131  ]
2132
2133  def FormatShapeAndDtype(shape, dtype):
2134    return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape)))
2135
2136  class DLPackTest(parameterized.TestCase):
2137
2138    def setUp(self):
2139      super(DLPackTest, self).setUp()
2140      self.backend = xla_backend()
2141      if self.backend.platform not in ("cpu", "gpu"):
2142        self.skipTest("DLPack requires CPU or GPU")
2143      self.cpu_backend = (
2144          self.backend
2145          if self.backend.platform == "cpu" else xla_client.make_cpu_client())
2146      self.gpu_backend = (
2147          self.backend if self.backend.platform == "gpu" else None)
2148
2149    # pylint: disable=g-complex-comprehension
2150    # pyformat: disable
2151    @parameterized.named_parameters({
2152        "testcase_name": "{}_own={}_gpu={}".format(
2153            FormatShapeAndDtype(shape, dtype), take_ownership, gpu),
2154        "dtype": dtype,
2155        "shape": shape,
2156        "take_ownership": take_ownership,
2157        "gpu": gpu
2158    } for dtype in dlpack_dtypes for shape in testcase_shapes
2159                                    for take_ownership in [False, True]
2160                                    for gpu in [False, True])
2161    # pyformat: enable
2162    def testRoundTrip(self, dtype, shape, take_ownership, gpu):
2163      if gpu and self.gpu_backend is None:
2164        raise unittest.SkipTest("Test not running with GPU support")
2165      backend = self.gpu_backend if gpu else self.cpu_backend
2166      if dtype == np.bool_:
2167        x = np.random.randint(0, 2, size=shape).astype(np.bool_)
2168      else:
2169        x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
2170      buffer = backend.buffer_from_pyval(x)
2171      dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(
2172          buffer, take_ownership=take_ownership)
2173      del buffer  # Free "buffer" to make sure dlt retains ownership.
2174      self.assertEqual(type(dlt).__name__, "PyCapsule")
2175      y = xla_client._xla.dlpack_managed_tensor_to_buffer(
2176          dlt, self.cpu_backend, self.gpu_backend)
2177      np.testing.assert_array_equal(
2178          x.astype(np.uint8) if dtype == np.bool_ else x, y.to_py())
2179
2180    def testTensorsCanBeConsumedOnceOnly(self):
2181      x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
2182      buffer = self.backend.buffer_from_pyval(x)
2183      dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(
2184          buffer, take_ownership=True)
2185
2186      def ConsumeDLPackTensor():
2187        _ = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend)
2188
2189      ConsumeDLPackTensor()
2190      self.assertRaisesRegex(
2191          RuntimeError, ".*a DLPack tensor may be consumed at most once.*",
2192          ConsumeDLPackTensor)
2193
2194    def testTensorsCanBeOwnedOnceOnly(self):
2195      x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
2196      buffer = self.backend.buffer_from_pyval(x)
2197      _ = xla_client._xla.buffer_to_dlpack_managed_tensor(
2198          buffer, take_ownership=True)
2199      self.assertTrue(buffer.is_deleted())
2200      with self.assertRaisesRegex(
2201          RuntimeError,
2202          "Cannot convert deleted/invalid buffer to DLPack tensor.*"):
2203        _ = xla_client._xla.buffer_to_dlpack_managed_tensor(
2204            buffer, take_ownership=True)
2205
2206    def testNonOwnedDlpackCanBeViewedTwice(self):
2207      x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
2208      buffer = self.backend.buffer_from_pyval(x)
2209      d1 = xla_client._xla.buffer_to_dlpack_managed_tensor(
2210          buffer, take_ownership=False)
2211      d2 = xla_client._xla.buffer_to_dlpack_managed_tensor(
2212          buffer, take_ownership=False)
2213
2214      y = xla_client._xla.dlpack_managed_tensor_to_buffer(d1, self.backend)
2215      z = xla_client._xla.dlpack_managed_tensor_to_buffer(d2, self.backend)
2216      del d1, d2
2217      np.testing.assert_array_equal(x, buffer.to_py())
2218      np.testing.assert_array_equal(x, y.to_py())
2219      np.testing.assert_array_equal(x, z.to_py())
2220
2221  tests.append(DLPackTest)
2222
2223  class BufferProtocolTest(parameterized.TestCase):
2224
2225    def setUp(self):
2226      super(BufferProtocolTest, self).setUp()
2227      self.backend = xla_backend()
2228      if self.backend.platform != "cpu":
2229        self.skipTest("Test requires CPU")
2230
2231    # pylint: disable=g-complex-comprehension
2232    @parameterized.named_parameters({
2233        "testcase_name": FormatShapeAndDtype(shape, dtype),
2234        "dtype": dtype,
2235        "shape": shape
2236    } for dtype in standard_dtypes if dtype != bfloat16
2237                                    for shape in testcase_shapes)
2238    def testRoundTrip(self, dtype, shape):
2239      x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
2240      x_ptr = x.__array_interface__["data"][0]
2241      buffer = self.backend.buffer_from_pyval(
2242          x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY)
2243      y = np.array(buffer, copy=False)
2244      y_ptr = y.__array_interface__["data"][0]
2245      np.testing.assert_array_equal(x, y)
2246      # If the input was sufficiently aligned, the input and output should
2247      # alias.
2248      self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
2249      self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
2250
2251      during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL
2252      buffer2 = self.backend.buffer_from_pyval(
2253          x, host_buffer_semantics=during_call)
2254      z = np.array(buffer2, copy=False)
2255      self.assertNotEqual(x.__array_interface__["data"][0],
2256                          z.__array_interface__["data"][0])
2257
2258    def testDeleteWithActiveView(self):
2259      x = np.random.randn(20, 10)
2260      buffer = self.backend.buffer_from_pyval(x)
2261      buffer_ptr = buffer.unsafe_buffer_pointer()
2262      y = np.array(buffer, copy=False)
2263      buffer.delete()
2264      # It is still legal to access `y`; the array view must keep it alive.
2265      np.testing.assert_array_equal(x, y)
2266      self.assertEqual(y.__array_interface__["data"][0], buffer_ptr)
2267
2268  tests.append(BufferProtocolTest)
2269
2270  class TracebackTest(absltest.TestCase):
2271
2272    def setUp(self):
2273      super(TracebackTest, self).setUp()
2274      self.backend = xla_backend()
2275
2276    def testNoTracebacksIfDisabled(self):
2277      with xla_client.tracebacks(enabled=False):
2278        self.assertEqual(None, xla_client.Traceback.get_traceback())
2279        buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
2280        self.assertEqual(None, buffer.traceback)
2281
2282        b = xla_client.XlaBuilder("computation")
2283        ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
2284        e = self.backend.compile(b.build())
2285        self.assertEqual(None, e.traceback)
2286
2287    def assertIsTracebackContaining(self, tb, function):
2288      self.assertIsInstance(tb, xla_client.Traceback)
2289      self.assertIn(function, str(tb))
2290      self.assertTrue(any(f.function_name == function for f in tb.frames))
2291
2292    def testTracebacks(self):
2293      with xla_client.tracebacks(enabled=True):
2294        tb = xla_client.Traceback.get_traceback()
2295        self.assertIsTracebackContaining(tb, "testTracebacks")
2296
2297        # Tracebacks are not implemented on the TPU driver extension's variant
2298        # of buffers and executables.
2299        if not isinstance(self.backend, xla_client.Client):
2300          return
2301
2302        buffer = self.backend.buffer_from_pyval(np.array(7, np.int32))
2303        self.assertIsTracebackContaining(buffer.traceback, "testTracebacks")
2304
2305        b = xla_client.XlaBuilder("computation")
2306        ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2)))
2307        e = self.backend.compile(b.build())
2308        self.assertIsTracebackContaining(e.traceback, "testTracebacks")
2309
2310    def testNestedFunction(self):
2311
2312      def AFunction():
2313
2314        def AnotherFunction():
2315          return xla_client.Traceback.get_traceback()
2316
2317        return AnotherFunction()
2318
2319      with xla_client.tracebacks(enabled=True):
2320        tb = AFunction()
2321        self.assertIsInstance(tb, xla_client.Traceback)
2322        frames = tb.frames
2323        i = next(
2324            i for (i, f) in enumerate(frames) if f.function_name == "AFunction")
2325        self.assertEqual(frames[i - 1].function_name, "AnotherFunction")
2326        self.assertEqual(frames[i + 1].function_name, "testNestedFunction")
2327
2328  tests.append(TracebackTest)
2329
2330  class ClientTest(ComputationTest):
2331
2332    def setUp(self):
2333      super(ClientTest, self).setUp()
2334      self.backend = xla_backend()
2335
2336    def testPlatformVersion(self):
2337      version = self.backend.platform_version
2338      if self.backend.platform == "cpu":
2339        self.assertEqual(version, "<unknown>")
2340      elif self.backend.platform == "gpu":
2341        # Following is false if not built with --config=cuda
2342        if test_util.is_gpu_available(cuda_only=True):
2343          self.assertTrue(
2344              re.match(r"^cuda \d{4,}$", version),
2345              msg=f"Expected CUDA version string; got {repr(version)}")
2346        else:
2347          self.assertEqual(version, "<unknown>")
2348
2349    @unittest.skipIf(cloud_tpu or tfrt_tpu, "not implemented")
2350    def testExecutableSerialization(self):
2351      if self.backend.platform != "tpu":
2352        self.skipTest("Test requires tpu platform")
2353
2354      c = self._NewComputation()
2355      ops.Add(
2356          ops.Constant(c, NumpyArrayS32([1, 2])),
2357          ops.Constant(c, NumpyArrayS32([3, 4])))
2358
2359      options = xla_client.CompileOptions()
2360      executable = self.backend.compile(c.build(), options)
2361      self.assertLen(executable.hlo_modules(), 1)
2362
2363      serialized = self.backend.serialize_executable(executable)
2364      deserialized = self.backend.deserialize_executable(
2365          serialized,
2366          executable.hlo_modules()[0], options)
2367
2368      expected, = xla_client.execute_with_python_values(executable, (),
2369                                                        self.backend)
2370      actual, = xla_client.execute_with_python_values(deserialized, (),
2371                                                      self.backend)
2372      self.assertTrue(np.all(actual == expected))
2373
2374  tests.append(ClientTest)
2375
2376  # TODO(b/182461453): Add TFRT and cloud TPU implementation of
2377  # ReadDynamicShapes
2378  class DynamicReshapeTest(ComputationTest):
2379    """Tests related to DynamicReshape."""
2380
2381    def _CompareToPyAndBufferProtocol(self, builder, args, expected_results,
2382                                      test_fn):
2383      compiled = self.backend.compile(builder.build())
2384      output_buffers = compiled.execute([
2385          self.backend.buffer_from_pyval(
2386              arg, device=compiled.local_devices()[0]) for arg in args
2387      ])
2388      self.assertLen(output_buffers, len(expected_results))
2389      for buf, expected in zip(output_buffers, expected_results):
2390        to_py_result = buf.to_py()
2391        self.assertEqual(expected.shape, to_py_result.shape)
2392        test_fn(expected, to_py_result)
2393        if self.backend.platform == "cpu" and buf.dtype != bfloat16:
2394          mview = memoryview(buf)
2395          self.assertEqual(expected.shape, mview.shape)
2396          test_fn(expected, np.asarray(mview))
2397        else:
2398          # Buffer protocol expected to fail on non-cpu platforms and bfloat16
2399          # Note that np.asarray(buf) doesn't throw an exception. To test if the
2400          # error was thrown properly we must use memoryview(buf).
2401          with self.assertRaises(BufferError):
2402            memoryview(buf)
2403
2404    # 1D reshape of full size, half size, and size of 0.
2405    @unittest.skipIf(cloud_tpu or tfrt_tpu or external_tpu, "not implemented")
2406    @parameterized.parameters((5), (3), (0))
2407    def testReshape1D(self, reshape_size):
2408      full_size = 5
2409      c = self._NewComputation()
2410      arg = np.array(reshape_size, dtype=np.int32)
2411      expected = np.array(range(reshape_size), dtype=np.int32)
2412      p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg))
2413      ops.DynamicReshape(
2414          ops.Constant(c, NumpyArrayS32(range(full_size))), [p], [full_size],
2415          [True])
2416      self._CompareToPyAndBufferProtocol(c, [arg], [expected],
2417                                         np.testing.assert_equal)
2418
2419    # 2D reshape with an slice on the minor dimension.  We test different types
2420    # where the strides may differ between the host and devices. The reshaped
2421    # physical memory layout is not consecutive, and we test if the program can
2422    # return the correct logical view of the data.
2423    @unittest.skipIf(cloud_tpu or tfrt_tpu or external_tpu, "not implemented")
2424    @parameterized.named_parameters({
2425        "testcase_name": "_{}".format(dtype.__name__),
2426        "dtype": dtype,
2427    } for dtype in int_dtypes + float_dtypes)
2428    def testReshape2D(self, dtype):
2429      arg0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
2430      arg1 = np.array(2, dtype=np.int32)
2431      expected = np.array([[1, 2], [4, 5]], dtype=np.int32)
2432      c = self._NewComputation()
2433      p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0))
2434      p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1))
2435      ops.DynamicReshape(p0, [p1, p1], [2, 3], [False, True])
2436      self._CompareToPyAndBufferProtocol(c, [arg0, arg1], [expected],
2437                                         np.testing.assert_equal)
2438
2439    @unittest.skipIf(cloud_tpu or tfrt_tpu, "not implemented")
2440    @parameterized.named_parameters({
2441        "testcase_name": "_{}".format(dtype.__name__),
2442        "dtype": dtype,
2443    } for dtype in int_dtypes + float_dtypes)
2444    def testDynamicShapeArgs(self, dtype):
2445      full_size = 10
2446      dynamic_shape_size = 4
2447      # subcomputation 1
2448      binary_add_builder = self._NewComputation()
2449      scalar_shape = xla_client.Shape.scalar_shape(np.dtype(dtype))
2450      ops.Add(
2451          ops.Parameter(binary_add_builder, 0, scalar_shape),
2452          ops.Parameter(binary_add_builder, 1, scalar_shape))
2453      # subcomputation 2
2454      reshape_reduce_builder = self._NewComputation()
2455      dshape = xla_client.Shape.array_shape(
2456          np.dtype(dtype), dims=[full_size], dynamic_dimensions=[True])
2457      reshape_reduce_p = ops.Parameter(reshape_reduce_builder, 0, dshape)
2458      ops.Reduce(
2459          reshape_reduce_builder,
2460          operands=[reshape_reduce_p],
2461          init_values=[ops.Constant(reshape_reduce_builder, dtype(0))],
2462          computation=binary_add_builder.build(),
2463          dimensions_to_reduce=[0])
2464      # main computation: sum(range(full_size)[:dynamic_shape_size])
2465      c = self._NewComputation()
2466      arg = np.array(dynamic_shape_size, dtype=np.int32)
2467      p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg))
2468      reshaped = ops.DynamicReshape(
2469          ops.Constant(c, np.array(range(full_size), dtype=dtype)), [p],
2470          [full_size], [True])
2471      ops.Call(c, reshape_reduce_builder.build(), operands=(reshaped,))
2472      self._ExecuteAndCompareClose(c, [arg], [dtype(6)])
2473
2474  tests.append(DynamicReshapeTest)
2475
2476  class DeviceAssignmentTest(ComputationTest):
2477
2478    def testSerialize(self):
2479      shape = (3, 4)
2480      device_assignment = xla_client.DeviceAssignment.create(
2481          np.arange(np.prod(shape)).reshape(*shape))
2482      self.assertEqual(device_assignment.replica_count(), shape[0])
2483      self.assertEqual(device_assignment.computation_count(), shape[1])
2484      serialized = device_assignment.serialize()
2485      self.assertIsInstance(serialized, bytes)
2486      self.assertNotEmpty(serialized)
2487
2488  tests.append(DeviceAssignmentTest)
2489
2490  return tests
2491
2492
2493def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw):
2494  # Avoid creating a new backend per test (this causes GPU OOM, and is probably
2495  # inefficient).
2496  backend_fn = functools.lru_cache(maxsize=None)(backend_fn)
2497  for klass in TestFactory(backend_fn, **kw):
2498    test = type(test_prefix + klass.__name__, (klass,), {})
2499    # Clean up the qualified names of the tests to not include the test factory.
2500    test.__qualname__ = test.__name__
2501    globals_dict[test.__name__] = test
2502
2503
2504backends = {
2505    "cpu": xla_client.make_cpu_client,
2506    "gpu": xla_client.make_gpu_client,
2507}
2508
2509if __name__ == "__main__":
2510  flags.DEFINE_string("backend", "cpu", "Target platform.")
2511  # pylint: disable=unnecessary-lambda
2512  InstantiateTests(globals(), lambda: backends[FLAGS.backend]())
2513  # pylint: enable=unnecessary-lambda
2514  absltest.main()
2515