• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for the Python extension-based XLA client."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import itertools
24import threading
25
26from absl.testing import absltest
27from absl.testing import parameterized
28import numpy as np
29
30from tensorflow.compiler.xla.python import custom_call_for_test
31from tensorflow.compiler.xla.python import xla_client
32
33bfloat16 = xla_client.bfloat16
34
35
36class ComputationTest(absltest.TestCase):
37  """Base class for running an XLA Computation through the local client."""
38
39  def _NewComputation(self, name=None):
40    if name is None:
41      name = self.id()
42    return xla_client.ComputationBuilder(name)
43
44  def _Execute(self, c, arguments):
45    compiled_c = c.Build().Compile()
46    return xla_client.execute_with_python_values(compiled_c, arguments)
47
48  def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected):
49    assert expected is not None
50    result = self._Execute(c, arguments)
51    # Numpy's comparison methods are a bit too lenient by treating inputs as
52    # "array-like", meaning that scalar 4 will be happily compared equal to
53    # [[4]]. We'd like to be more strict so assert shapes as well.
54    self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape)
55    assert_func(result, expected)
56
57  def _ExecuteAndCompareExact(self, c, arguments=(), expected=None):
58    self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected)
59
60  def _ExecuteAndCompareClose(self,
61                              c,
62                              arguments=(),
63                              expected=None,
64                              rtol=1e-7,
65                              atol=0):
66    self._ExecuteAndAssertWith(
67        functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), c,
68        arguments, expected)
69
70
71def NumpyArrayF32(*args, **kwargs):
72  """Convenience wrapper to create Numpy arrays with a np.float32 dtype."""
73  return np.array(*args, dtype=np.float32, **kwargs)
74
75
76def NumpyArrayF64(*args, **kwargs):
77  """Convenience wrapper to create Numpy arrays with a np.float64 dtype."""
78  return np.array(*args, dtype=np.float64, **kwargs)
79
80
81def NumpyArrayS32(*args, **kwargs):
82  """Convenience wrapper to create Numpy arrays with a np.int32 dtype."""
83  return np.array(*args, dtype=np.int32, **kwargs)
84
85
86def NumpyArrayS64(*args, **kwargs):
87  """Convenience wrapper to create Numpy arrays with a np.int64 dtype."""
88  return np.array(*args, dtype=np.int64, **kwargs)
89
90
91def NumpyArrayBool(*args, **kwargs):
92  """Convenience wrapper to create Numpy arrays with a np.bool dtype."""
93  return np.array(*args, dtype=np.bool, **kwargs)
94
95
96class ComputationPrinting(absltest.TestCase):
97
98  def ExampleComputation(self):
99    builder = xla_client.ComputationBuilder("acomputation")
100    p0 = builder.ParameterFromNumpy(np.float32(0))
101    p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32))
102    builder.Mul(p0, p1)
103    return builder.Build()
104
105  def testComputationToHloText(self):
106    computation = self.ExampleComputation()
107    hlo_text = computation.GetHloText()
108    self.assertTrue(hlo_text.startswith("HloModule acomputation"))
109
110  def testComputationToHloGraph(self):
111    computation = self.ExampleComputation()
112    hlo_dot_graph = computation.GetHloDotGraph()
113    self.assertTrue(hlo_dot_graph.startswith("digraph "))
114
115
116class ComputationHashTest(absltest.TestCase):
117
118  def testHash(self):
119    builder0 = xla_client.ComputationBuilder("computation0")
120    p0 = builder0.ParameterFromNumpy(np.float32(0))
121    p1 = builder0.ParameterFromNumpy(np.zeros((4,), np.float32))
122    builder0.Mul(p0, p1)
123    computation0 = builder0.Build()
124
125    builder1 = xla_client.ComputationBuilder("computation1")
126    p0 = builder1.ParameterFromNumpy(np.float32(0))
127    p1 = builder1.ParameterFromNumpy(np.zeros((4,), np.float32))
128    builder1.Mul(p0, p1)
129    computation1 = builder1.Build()
130
131    self.assertEqual(computation0.Hash(), computation1.Hash())
132
133
134class ComputationsWithConstantsTest(ComputationTest):
135  """Tests focusing on Constant ops."""
136
137  def testConstantScalarSumS8(self):
138    c = self._NewComputation()
139    c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2)))
140    self._ExecuteAndCompareExact(c, expected=np.int8(3))
141
142  def testConstantScalarSumBF16(self):
143    c = self._NewComputation()
144    c.Add(c.Constant(bfloat16(1.11)), c.Constant(bfloat16(3.14)))
145    self._ExecuteAndCompareClose(c, expected=bfloat16(4.25))
146
147  def testConstantScalarSumF32(self):
148    c = self._NewComputation()
149    c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
150    self._ExecuteAndCompareClose(c, expected=4.25)
151
152  def testConstantScalarSumF64(self):
153    c = self._NewComputation()
154    c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14))
155    self._ExecuteAndCompareClose(c, expected=4.25)
156
157  def testConstantScalarSumS32(self):
158    c = self._NewComputation()
159    c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2))
160    self._ExecuteAndCompareClose(c, expected=3)
161
162  def testConstantScalarSumS64(self):
163    c = self._NewComputation()
164    c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2))
165    self._ExecuteAndCompareClose(c, expected=3)
166
167  def testConstantVectorMulF16(self):
168    c = self._NewComputation()
169    c.Mul(
170        c.Constant(np.array([2.5, 3.3, -1.2, 0.7], np.float16)),
171        c.Constant(np.array([-1.2, 2, -2, -3], np.float16)))
172    self._ExecuteAndCompareClose(
173        c, expected=np.array([-3, 6.6, 2.4, -2.1], np.float16), rtol=2e-3)
174
175  def testConstantVectorMulF32(self):
176    c = self._NewComputation()
177    c.Mul(
178        c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])),
179        c.Constant(NumpyArrayF32([-1.2, 2, -2, -3])))
180    self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
181
182  def testConstantVectorMulF64(self):
183    c = self._NewComputation()
184    c.Mul(
185        c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])),
186        c.Constant(NumpyArrayF64([-1.2, 2, -2, -3])))
187    self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1])
188
189  def testConstantVectorScalarDivF32(self):
190    c = self._NewComputation()
191    c.Div(
192        c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])),
193        c.ConstantF32Scalar(2.0))
194    self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
195
196  def testConstantVectorScalarDivF64(self):
197    c = self._NewComputation()
198    c.Div(
199        c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])),
200        c.ConstantF64Scalar(2.0))
201    self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4])
202
203  def testConstantVectorScalarPowF32(self):
204    c = self._NewComputation()
205    c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.))
206    self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
207
208  def testConstantVectorScalarPowF64(self):
209    c = self._NewComputation()
210    c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.))
211    self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
212
213  def testIota(self):
214    c = self._NewComputation()
215    c.Iota(np.float32, 10)
216    self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32))
217
218  def testBroadcastedIota(self):
219    c = self._NewComputation()
220    c.BroadcastedIota(np.int64, (2, 3), 1)
221    expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64)
222    self._ExecuteAndCompareExact(c, expected=expected)
223
224  def testBooleanAnd(self):
225    c = self._NewComputation()
226    c.And(
227        c.Constant(NumpyArrayBool([True, False, True, False])),
228        c.Constant(NumpyArrayBool([True, True, False, False])))
229    self._ExecuteAndCompareExact(c, expected=[True, False, False, False])
230
231  def testBooleanOr(self):
232    c = self._NewComputation()
233    c.Or(
234        c.Constant(NumpyArrayBool([True, False, True, False])),
235        c.Constant(NumpyArrayBool([True, True, False, False])))
236    self._ExecuteAndCompareExact(c, expected=[True, True, True, False])
237
238  def testBooleanXor(self):
239    c = self._NewComputation()
240    c.Xor(
241        c.Constant(NumpyArrayBool([True, False, True, False])),
242        c.Constant(NumpyArrayBool([True, True, False, False])))
243    self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
244
245  def testSum2DF32(self):
246    c = self._NewComputation()
247    c.Add(
248        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])),
249        c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
250    self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
251
252  def testShiftLeft(self):
253    c = self._NewComputation()
254    c.ShiftLeft(c.Constant(NumpyArrayS32([3])), c.Constant(NumpyArrayS32([2])))
255    self._ExecuteAndCompareClose(c, expected=[12])
256
257  def testShiftRightArithmetic(self):
258    c = self._NewComputation()
259    c.ShiftRightArithmetic(
260        c.Constant(NumpyArrayS32([-2])), c.Constant(NumpyArrayS32([1])))
261    self._ExecuteAndCompareClose(c, expected=[-1])
262
263  def testShiftRightLogical(self):
264    c = self._NewComputation()
265    c.ShiftRightLogical(
266        c.Constant(NumpyArrayS32([-1])), c.Constant(NumpyArrayS32([1])))
267    self._ExecuteAndCompareClose(c, expected=[2**31 - 1])
268
269  def testSum2DF64(self):
270    c = self._NewComputation()
271    c.Add(
272        c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])),
273        c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]])))
274    self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
275
276  def testSum2DWith1DBroadcastDim0F32(self):
277    # sum of a 2D array with a 1D array where the latter is replicated across
278    # dimension 0 to match the former's shape.
279    c = self._NewComputation()
280    c.Add(
281        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
282        c.Constant(NumpyArrayF32([10, 20, 30])),
283        broadcast_dimensions=(0,))
284    self._ExecuteAndCompareClose(
285        c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
286
287  def testSum2DWith1DBroadcastDim0F64(self):
288    # sum of a 2D array with a 1D array where the latter is replicated across
289    # dimension 0 to match the former's shape.
290    c = self._NewComputation()
291    c.Add(
292        c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
293        c.Constant(NumpyArrayF64([10, 20, 30])),
294        broadcast_dimensions=(0,))
295    self._ExecuteAndCompareClose(
296        c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]])
297
298  def testSum2DWith1DBroadcastDim1F32(self):
299    # sum of a 2D array with a 1D array where the latter is replicated across
300    # dimension 1 to match the former's shape.
301    c = self._NewComputation()
302    c.Add(
303        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
304        c.Constant(NumpyArrayF32([10, 20, 30])),
305        broadcast_dimensions=(1,))
306    self._ExecuteAndCompareClose(
307        c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
308
309  def testSum2DWith1DBroadcastDim1F64(self):
310    # sum of a 2D array with a 1D array where the latter is replicated across
311    # dimension 1 to match the former's shape.
312    c = self._NewComputation()
313    c.Add(
314        c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
315        c.Constant(NumpyArrayF64([10, 20, 30])),
316        broadcast_dimensions=(1,))
317    self._ExecuteAndCompareClose(
318        c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]])
319
320  def testConstantAxpyF32(self):
321    c = self._NewComputation()
322    c.Add(
323        c.Mul(
324            c.ConstantF32Scalar(2),
325            c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))),
326        c.Constant(NumpyArrayF32([100, -100, 200, -200])))
327    self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
328
329  def testConstantAxpyF64(self):
330    c = self._NewComputation()
331    c.Add(
332        c.Mul(
333            c.ConstantF64Scalar(2),
334            c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))),
335        c.Constant(NumpyArrayF64([100, -100, 200, -200])))
336    self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189])
337
338  def testCustomCall(self):
339    c = self._NewComputation()
340    for name, fn in custom_call_for_test.cpu_custom_call_targets.items():
341      xla_client.register_custom_call_target(name, fn, platform="cpu")
342    c.CustomCall(
343        b"test_subtract_f32",
344        operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)),
345        shape_with_layout=xla_client.Shape.array_shape(
346            np.dtype(np.float32), (), ()),
347        operand_shapes_with_layout=(
348            xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
349            xla_client.Shape.array_shape(np.dtype(np.float32), (), ()),
350        ))
351    self._ExecuteAndCompareClose(c, expected=0.75)
352
353
354class ParametersTest(ComputationTest):
355  """Tests focusing on Parameter ops and argument-passing."""
356
357  def setUp(self):
358    self.f32_scalar_2 = NumpyArrayF32(2.0)
359    self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3])
360    self.f64_scalar_2 = NumpyArrayF64(2.0)
361    self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3])
362    self.s32_scalar_3 = NumpyArrayS32(3)
363    self.s32_4vector = NumpyArrayS32([10, 15, -2, 7])
364    self.s64_scalar_3 = NumpyArrayS64(3)
365    self.s64_4vector = NumpyArrayS64([10, 15, -2, 7])
366
367  def testScalarTimesVectorAutonumberF32(self):
368    c = self._NewComputation()
369    p0 = c.ParameterFromNumpy(self.f32_scalar_2)
370    p1 = c.ParameterFromNumpy(self.f32_4vector)
371    c.Mul(p0, p1)
372    self._ExecuteAndCompareClose(
373        c,
374        arguments=[self.f32_scalar_2, self.f32_4vector],
375        expected=[-4.6, 6.6, -8.6, 10.6])
376
377  def testScalarTimesVectorAutonumberF64(self):
378    c = self._NewComputation()
379    p0 = c.ParameterFromNumpy(self.f64_scalar_2)
380    p1 = c.ParameterFromNumpy(self.f64_4vector)
381    c.Mul(p0, p1)
382    self._ExecuteAndCompareClose(
383        c,
384        arguments=[self.f64_scalar_2, self.f64_4vector],
385        expected=[-4.6, 6.6, -8.6, 10.6])
386
387  def testScalarTimesVectorS32(self):
388    c = self._NewComputation()
389    p0 = c.ParameterFromNumpy(self.s32_scalar_3)
390    p1 = c.ParameterFromNumpy(self.s32_4vector)
391    c.Mul(p0, p1)
392    self._ExecuteAndCompareExact(
393        c,
394        arguments=[self.s32_scalar_3, self.s32_4vector],
395        expected=[30, 45, -6, 21])
396
397  def testScalarTimesVectorS64(self):
398    c = self._NewComputation()
399    p0 = c.ParameterFromNumpy(self.s64_scalar_3)
400    p1 = c.ParameterFromNumpy(self.s64_4vector)
401    c.Mul(p0, p1)
402    self._ExecuteAndCompareExact(
403        c,
404        arguments=[self.s64_scalar_3, self.s64_4vector],
405        expected=[30, 45, -6, 21])
406
407  def testScalarMinusVectorExplicitNumberingF32(self):
408    # Use explicit numbering and pass parameter_num first. Sub is used since
409    # it's not commutative and can help catch parameter reversal within the
410    # computation.
411    c = self._NewComputation()
412    p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1)
413    p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0)
414    c.Sub(p1, p0)
415    self._ExecuteAndCompareClose(
416        c,
417        arguments=[self.f32_scalar_2, self.f32_4vector],
418        expected=[-4.3, 1.3, -6.3, 3.3])
419
420  def testScalarMinusVectorExplicitNumberingF64(self):
421    # Use explicit numbering and pass parameter_num first. Sub is used since
422    # it's not commutative and can help catch parameter reversal within the
423    # computation.
424    c = self._NewComputation()
425    p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1)
426    p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0)
427    c.Sub(p1, p0)
428    self._ExecuteAndCompareClose(
429        c,
430        arguments=[self.f64_scalar_2, self.f64_4vector],
431        expected=[-4.3, 1.3, -6.3, 3.3])
432
433
434class BufferTest(ComputationTest):
435  """Tests focusing on execution with Buffers."""
436
437  def _Execute(self, c, arguments):
438    compiled_c = c.Build().Compile()
439    arg_buffers = [xla_client.Buffer.from_pyval(arg) for arg in arguments]
440    result_buffer = compiled_c.Execute(arg_buffers)
441    return result_buffer.to_py()
442
443  def testConstantSum(self):
444    c = self._NewComputation()
445    c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
446    self._ExecuteAndCompareClose(c, expected=4.25)
447
448  def testOneParameterSum(self):
449    c = self._NewComputation()
450    c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
451    self._ExecuteAndCompareClose(
452        c, arguments=[NumpyArrayF32(1.11)], expected=4.25)
453
454  def testTwoParameterSum(self):
455    c = self._NewComputation()
456    c.Add(
457        c.ParameterFromNumpy(NumpyArrayF32(0.)),
458        c.ParameterFromNumpy(NumpyArrayF32(0.)))
459    self._ExecuteAndCompareClose(
460        c, arguments=[NumpyArrayF32(1.11),
461                      NumpyArrayF32(3.14)], expected=4.25)
462
463  def testCannotCallWithDeletedBuffers(self):
464    c = self._NewComputation()
465    c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14))
466    arg = NumpyArrayF32(1.11)
467    compiled_c = c.Build().Compile()
468    arg_buffer = xla_client.Buffer.from_pyval(arg)
469    arg_buffer.delete()
470    with self.assertRaises(RuntimeError):
471      compiled_c.Execute([arg_buffer])
472
473  def testDestructureTupleEmpty(self):
474    device = xla_client.get_local_backend().devices()[0]
475    local_buffer = xla_client.Buffer.make_tuple((), device=device)
476    pieces = local_buffer.destructure()
477    self.assertFalse(local_buffer.is_deleted())
478    self.assertEmpty(pieces)
479
480  def testDestructureTupleOneArrayElement(self):
481    device = xla_client.get_local_backend().devices()[0]
482    t = xla_client.Buffer.from_pyval(np.array([1, 2, 3, 4], dtype=np.int32))
483    local_buffer = xla_client.Buffer.make_tuple((t,), device)
484    pieces = local_buffer.destructure()
485    self.assertFalse(local_buffer.is_deleted())
486    self.assertLen(pieces, 1)
487    array = pieces[0]
488    got = array.to_py()
489    want = NumpyArrayS32([1, 2, 3, 4])
490    np.testing.assert_equal(want, got)
491
492  def testDestructureTupleTwoArrayElementDifferentType(self):
493    device = xla_client.get_local_backend().devices()[0]
494    t = (
495        xla_client.Buffer.from_pyval(
496            np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)),
497        xla_client.Buffer.from_pyval(np.array([2, 3, 4, 5], dtype=np.int32)),
498    )
499    local_buffer = xla_client.Buffer.make_tuple(t, device)
500    # Run the test twice to verify that the original tuple buffer remains valid
501    # even after destructuring.
502    for _ in range(2):
503      pieces = local_buffer.destructure()
504      self.assertFalse(local_buffer.is_deleted())
505      self.assertLen(pieces, 2)
506      array0, array1 = pieces
507      got = array0.to_py()
508      want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0])
509      np.testing.assert_equal(want, got)
510      got = array1.to_py()
511      want = NumpyArrayS32([2, 3, 4, 5])
512      np.testing.assert_equal(want, got)
513
514  def testDestructureTupleNested(self):
515    device = xla_client.get_local_backend().devices()[0]
516    t = xla_client.Buffer.make_tuple(
517        (xla_client.Buffer.from_pyval(NumpyArrayF32([1.0, 2.0])),
518         xla_client.Buffer.from_pyval(NumpyArrayS32([3, 4]))), device)
519    local_buffer = xla_client.Buffer.make_tuple(
520        (t, xla_client.Buffer.from_pyval(NumpyArrayS32([5]))), device)
521    pieces = local_buffer.destructure()
522    self.assertFalse(local_buffer.is_deleted())
523    self.assertLen(pieces, 2)
524    tuple0, array1 = pieces
525    got = array1.to_py()
526    want = NumpyArrayS32([5])
527    np.testing.assert_equal(want, got)
528    got = tuple0.to_py()
529    self.assertEqual(type(got), tuple)
530    self.assertLen(got, 2)
531    np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0])
532    np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1])
533
534  def testMakeTuple(self):
535    t = (
536        np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
537        np.array([2, 3, 4, 5], dtype=np.int32),
538    )
539    b0 = xla_client.Buffer.from_pyval(t[0])
540    b1 = xla_client.Buffer.from_pyval(t[1])
541    device = xla_client.get_local_backend().local_devices()[0]
542    btup = xla_client.Buffer.make_tuple([b0, b1], device=device)
543    pieces = btup.destructure()
544    self.assertLen(pieces, 2)
545    array0, array1 = pieces
546    np.testing.assert_equal(
547        np.array([1, 2, 3, 4], dtype=np.float32), array0.to_py())
548    np.testing.assert_equal(
549        np.array([2, 3, 4, 5], dtype=np.int32), array1.to_py())
550
551  def testShape(self):
552    pyval = np.array([[1., 2.]], np.float32)
553    local_buffer = xla_client.Buffer.from_pyval(pyval)
554    xla_shape = local_buffer.shape()
555    self.assertEqual(xla_shape.dimensions(), (1, 2))
556    self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32))
557
558  def testTupleShape(self):
559    t = (
560        np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32),
561        np.array([2, 3, 4, 5], dtype=np.int32),
562    )
563    b0 = xla_client.Buffer.from_pyval(t[0])
564    b1 = xla_client.Buffer.from_pyval(t[1])
565    device = xla_client.get_local_backend().local_devices()[0]
566    tuple_buffer = xla_client.Buffer.make_tuple([b0, b1], device=device)
567    tuple_shape = tuple_buffer.shape()
568    self.assertEqual(tuple_shape.leaf_count(), 2)
569    shapes = tuple_shape.tuple_shapes()
570    self.assertLen(shapes, 2)
571    shape1, shape2 = shapes
572    self.assertEqual(shape1.dimensions(), (1, 4))
573    self.assertEqual(shape2.dimensions(), (4,))
574
575  def testBlockHostUntilReadyWorks(self):
576    arg = np.array([[1., 2.]], np.float32)
577    arg_buffer = xla_client.Buffer.from_pyval(arg)
578    arg_buffer.block_host_until_ready()
579    # This test merely checks that nothing goes awry when we call
580    # block_host_until_ready(); it's difficult to test anything else.
581
582  def testCopyToHost(self):
583    arg0 = np.array([[1., 2.]], np.float32)
584    arg1 = np.array([[3., 4.]], np.float32)
585    arg0_buffer = xla_client.Buffer.from_pyval(arg0)
586    arg1_buffer = xla_client.Buffer.from_pyval(arg1)
587    # Prefetch two buffers using copy_to_host_async, and then retrieve their
588    # values using to_py.
589    arg0_buffer.copy_to_host_async()
590    arg0_buffer.copy_to_host_async()  # Duplicate calls don't do anything.
591    arg1_buffer.copy_to_host_async()
592    np.testing.assert_equal(arg0, arg0_buffer.to_py())
593    np.testing.assert_equal(arg1, arg1_buffer.to_py())
594    # copy_to_host_async does nothing after to_py is called.
595    arg0_buffer.copy_to_host_async()
596    np.testing.assert_equal(arg0, arg0_buffer.to_py())
597
598  def testDevice(self):
599    x = np.arange(8)
600    for device in xla_client.get_local_backend().local_devices():
601      buf = xla_client.Buffer.from_pyval(x, device=device)
602      self.assertEqual(buf.device(), device)
603      np.testing.assert_equal(x, buf.to_py())
604
605
606class SingleOpTest(ComputationTest):
607  """Tests for single ops.
608
609  The goal here is smoke testing - to exercise the most basic functionality of
610  single XLA ops. As minimal as possible number of additional ops are added
611  around the op being tested.
612  """
613
614  def testConcatenateF32(self):
615    c = self._NewComputation()
616    args = (
617        c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])),
618        c.Constant(NumpyArrayF32([4.0, 5.0, 6.0])),
619    )
620    c.Concatenate(args, dimension=0)
621    self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
622
623  def testConcatenateF64(self):
624    c = self._NewComputation()
625    args = (
626        c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])),
627        c.Constant(NumpyArrayF64([4.0, 5.0, 6.0])),
628    )
629    c.Concatenate(args, dimension=0)
630    self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
631
632  def testConvertElementType(self):
633    xla_types = {
634        np.bool: xla_client.PrimitiveType.PRED,
635        np.int32: xla_client.PrimitiveType.S32,
636        np.int64: xla_client.PrimitiveType.S64,
637        np.float32: xla_client.PrimitiveType.F32,
638        np.float64: xla_client.PrimitiveType.F64,
639    }
640
641    def _ConvertAndTest(template, src_dtype, dst_dtype):
642      c = self._NewComputation()
643      x = c.Constant(np.array(template, dtype=src_dtype))
644      c.ConvertElementType(x, xla_types[dst_dtype])
645
646      result = xla_client.execute_with_python_values(c.Build().Compile())
647      expected = np.array(template, dtype=dst_dtype)
648
649      self.assertEqual(result.shape, expected.shape)
650      self.assertEqual(result.dtype, expected.dtype)
651      np.testing.assert_equal(result, expected)
652
653    x = [0, 1, 0, 0, 1]
654    for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
655      _ConvertAndTest(x, src_dtype, dst_dtype)
656
657  def testBitcastConvertType(self):
658    xla_x32_types = {
659        np.int32: xla_client.PrimitiveType.S32,
660        np.float32: xla_client.PrimitiveType.F32,
661    }
662
663    xla_x64_types = {
664        np.int64: xla_client.PrimitiveType.S64,
665        np.float64: xla_client.PrimitiveType.F64,
666    }
667
668    def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype):
669      c = self._NewComputation()
670      x = c.Constant(np.array(template, dtype=src_dtype))
671      c.BitcastConvertType(x, dst_etype)
672
673      result = xla_client.execute_with_python_values(c.Build().Compile())
674      expected = np.array(template, src_dtype).view(dst_dtype)
675
676      self.assertEqual(result.shape, expected.shape)
677      self.assertEqual(result.dtype, expected.dtype)
678      np.testing.assert_equal(result, expected)
679
680    x = [0, 1, 0, 0, 1]
681    for xla_types in [xla_x32_types, xla_x64_types]:
682      for src_dtype, dst_dtype in itertools.product(xla_types, xla_types):
683        _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype])
684
685  # TODO(b/123523486) implement AllToAll on CPU
686  def DISABLED_testAllToAllOneReplica(self):
687    samples = [
688        NumpyArrayF32([97.0]),
689        NumpyArrayF32([64.0, 117.0]),
690        NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
691    ]
692    for lhs in samples[:1]:
693      c = self._NewComputation()
694      c.AllToAll(c.Constant(lhs), 0, 0)
695      self._ExecuteAndCompareExact(c, expected=lhs)
696
697  def testCrossReplicaSumOneReplica(self):
698    samples = [
699        NumpyArrayF32(42.0),
700        NumpyArrayF32([97.0]),
701        NumpyArrayF32([64.0, 117.0]),
702        NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
703    ]
704    for lhs in samples:
705      c = self._NewComputation()
706      c.CrossReplicaSum(c.Constant(lhs))
707      self._ExecuteAndCompareExact(c, expected=lhs)
708
709  def testReplicaId(self):
710    c = self._NewComputation()
711    _ = c.ReplicaId()
712    self._ExecuteAndCompareExact(c, expected=0)
713
714  def testCrossReplicaSumOneReplicaWithSingletonGroup(self):
715    samples = [
716        NumpyArrayF32(42.0),
717        NumpyArrayF32([97.0]),
718        NumpyArrayF32([64.0, 117.0]),
719        NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]),
720    ]
721    for lhs in samples:
722      c = self._NewComputation()
723      c.CrossReplicaSum(c.Constant(lhs), [[0]])
724      self._ExecuteAndCompareExact(c, expected=lhs)
725
726  def testDotMatrixVectorF32(self):
727    c = self._NewComputation()
728    lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
729    rhs = NumpyArrayF32([[10.0], [20.0]])
730    c.Dot(c.Constant(lhs), c.Constant(rhs))
731    self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
732
733  def testDotMatrixVectorF64(self):
734    c = self._NewComputation()
735    lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
736    rhs = NumpyArrayF64([[10.0], [20.0]])
737    c.Dot(c.Constant(lhs), c.Constant(rhs))
738    self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
739
740  def testDotMatrixMatrixF32(self):
741    c = self._NewComputation()
742    lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]])
743    rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]])
744    c.Dot(c.Constant(lhs), c.Constant(rhs))
745    self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
746
747  def testDotMatrixMatrixF64(self):
748    c = self._NewComputation()
749    lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]])
750    rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]])
751    c.Dot(c.Constant(lhs), c.Constant(rhs))
752    self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
753
754  def testDotGeneral(self):
755    c = self._NewComputation()
756    rng = np.random.RandomState(0)
757    lhs = NumpyArrayF32(rng.randn(10, 3, 4))
758    rhs = NumpyArrayF32(rng.randn(10, 4, 5))
759    dimension_numbers = (([2], [1]), ([0], [0]))
760    c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
761    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6)
762
763  def testDotGeneralWithDotDimensionNumbersProto(self):
764    c = self._NewComputation()
765    rng = np.random.RandomState(0)
766    lhs = NumpyArrayF32(rng.randn(10, 3, 4))
767    rhs = NumpyArrayF32(rng.randn(10, 4, 5))
768
769    dimension_numbers = xla_client.DotDimensionNumbers()
770    dimension_numbers.lhs_contracting_dimensions.append(2)
771    dimension_numbers.rhs_contracting_dimensions.append(1)
772    dimension_numbers.lhs_batch_dimensions.append(0)
773    dimension_numbers.rhs_batch_dimensions.append(0)
774
775    c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
776    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6)
777
778  def testDotGeneralWithPrecisionConfig(self):
779    c = self._NewComputation()
780    rng = np.random.RandomState(0)
781    lhs = NumpyArrayF32(rng.randn(10, 3, 4))
782    rhs = NumpyArrayF32(rng.randn(10, 4, 5))
783    dimension_numbers = (([2], [1]), ([0], [0]))
784    config = xla_client.PrecisionConfig()
785    config.operand_precision.append(config.Precision.HIGH)
786    config.operand_precision.append(config.Precision.HIGHEST)
787    c.DotGeneral(
788        c.Constant(lhs),
789        c.Constant(rhs),
790        dimension_numbers,
791        precision_config=config)
792    self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6)
793
794  def testConvF32Same(self):
795    c = self._NewComputation()
796    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
797    lhs = a(1, 2, 3, 4)
798    rhs = a(1, 2, 1, 2) * 10
799    c.Conv(
800        c.Constant(lhs), c.Constant(rhs), [1, 1], xla_client.PaddingType.SAME)
801    result = np.array([[[
802        [640., 700., 760., 300.],
803        [880., 940., 1000., 380.],
804        [1120., 1180., 1240., 460.],
805    ]]])
806    self._ExecuteAndCompareClose(c, expected=result)
807
808  def testConvF32Valid(self):
809    c = self._NewComputation()
810    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
811    lhs = a(1, 2, 3, 4)
812    rhs = a(1, 2, 1, 2) * 10
813    c.Conv(
814        c.Constant(lhs), c.Constant(rhs), [2, 1], xla_client.PaddingType.VALID)
815    result = np.array([[[
816        [640., 700., 760.],
817        [1120., 1180., 1240.],
818    ]]])
819    self._ExecuteAndCompareClose(c, expected=result)
820
821  def testConvWithGeneralPaddingF32(self):
822    c = self._NewComputation()
823    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
824    lhs = a(1, 1, 2, 3)
825    rhs = a(1, 1, 1, 2) * 10
826    strides = [1, 1]
827    pads = [(1, 0), (0, 1)]
828    lhs_dilation = (2, 1)
829    rhs_dilation = (1, 1)
830    c.ConvWithGeneralPadding(
831        c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
832        rhs_dilation)
833    result = np.array([[[
834        [0., 0., 0.],
835        [10., 20., 0.],
836        [0., 0., 0.],
837        [40., 50., 0.],
838    ]]])
839    self._ExecuteAndCompareClose(c, expected=result)
840
841  def testConvGeneralDilatedF32(self):
842    c = self._NewComputation()
843    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
844    lhs = a(1, 1, 2, 3)
845    rhs = a(1, 1, 1, 2) * 10
846    strides = [1, 1]
847    pads = [(1, 0), (0, 1)]
848    lhs_dilation = (2, 1)
849    rhs_dilation = (1, 1)
850    dimension_numbers = ("NCHW", "OIHW", "NCHW")
851    c.ConvGeneralDilated(
852        c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
853        rhs_dilation, dimension_numbers)
854    result = np.array([[[
855        [0., 0., 0.],
856        [10., 20., 0.],
857        [0., 0., 0.],
858        [40., 50., 0.],
859    ]]])
860    self._ExecuteAndCompareClose(c, expected=result)
861
862  def testConvGeneralDilatedF32WithPrecisionConfig(self):
863    c = self._NewComputation()
864    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
865    lhs = a(1, 1, 2, 3)
866    rhs = a(1, 1, 1, 2) * 10
867    strides = [1, 1]
868    pads = [(1, 0), (0, 1)]
869    lhs_dilation = (2, 1)
870    rhs_dilation = (1, 1)
871    dimension_numbers = ("NCHW", "OIHW", "NCHW")
872    config = xla_client.PrecisionConfig()
873    config.operand_precision.append(config.Precision.HIGHEST)
874    config.operand_precision.append(config.Precision.DEFAULT)
875    c.ConvGeneralDilated(
876        c.Constant(lhs),
877        c.Constant(rhs),
878        strides,
879        pads,
880        lhs_dilation,
881        rhs_dilation,
882        dimension_numbers,
883        precision_config=config)
884    result = np.array([[[
885        [0., 0., 0.],
886        [10., 20., 0.],
887        [0., 0., 0.],
888        [40., 50., 0.],
889    ]]])
890    self._ExecuteAndCompareClose(c, expected=result)
891
892  def testConvGeneralDilatedPermutedF32(self):
893    c = self._NewComputation()
894    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
895    lhs = a(1, 1, 2, 3)
896    rhs = a(1, 1, 1, 2) * 10
897    strides = [1, 1]
898    pads = [(1, 0), (0, 1)]
899    lhs_dilation = (2, 1)
900    rhs_dilation = (1, 1)
901
902    dimension_numbers = ("NHWC", "OIHW", "CWNH")
903    c.ConvGeneralDilated(
904        c.Constant(np.transpose(lhs, (0, 2, 3, 1))), c.Constant(rhs), strides,
905        pads, lhs_dilation, rhs_dilation, dimension_numbers)
906    result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.],
907                         [40., 50., 0.]]]])
908    self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
909
910  def testConvGeneralDilatedGroupedConvolutionF32(self):
911    c = self._NewComputation()
912    a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
913    lhs = a(1, 2, 2, 3)
914    rhs = a(2, 1, 1, 2) * 10
915    strides = [1, 1]
916    pads = [(1, 0), (0, 1)]
917    lhs_dilation = (2, 1)
918    rhs_dilation = (1, 1)
919    dimension_numbers = ("NCHW", "OIHW", "NCHW")
920    feature_group_count = 2
921    c.ConvGeneralDilated(
922        c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation,
923        rhs_dilation, dimension_numbers, feature_group_count)
924    result = np.array([[[
925        [0., 0., 0.],
926        [10., 20., 0.],
927        [0., 0., 0.],
928        [40., 50., 0.],
929    ], [
930        [0., 0., 0.],
931        [330., 380., 160.],
932        [0., 0., 0.],
933        [480., 530., 220.],
934    ]]])
935    self._ExecuteAndCompareClose(c, expected=result)
936
937  def testBooleanNot(self):
938    c = self._NewComputation()
939    arr = NumpyArrayBool([True, False, True])
940    c.Not(c.Constant(arr))
941    self._ExecuteAndCompareClose(c, expected=~arr)
942
943  def testCountLeadingZeros(self):
944    c = self._NewComputation()
945    arr = NumpyArrayS32([0x7FFF, 0x12345678])
946    c.Clz(c.Constant(arr))
947    self._ExecuteAndCompareClose(c, expected=[17, 3])
948
949  def testExp(self):
950    c = self._NewComputation()
951    arr = NumpyArrayF32([3.3, 12.1])
952    c.Exp(c.Constant(arr))
953    self._ExecuteAndCompareClose(c, expected=np.exp(arr))
954
955  def testExpm1(self):
956    c = self._NewComputation()
957    arr = NumpyArrayF32([3.3, 12.1])
958    c.Expm1(c.Constant(arr))
959    self._ExecuteAndCompareClose(c, expected=np.expm1(arr))
960
961  def testRound(self):
962    c = self._NewComputation()
963    arr = NumpyArrayF32([3.3, 12.1])
964    c.Round(c.Constant(arr))
965    self._ExecuteAndCompareClose(c, expected=np.round(arr))
966
967  def testLog(self):
968    c = self._NewComputation()
969    arr = NumpyArrayF32([3.3, 12.1])
970    c.Log(c.Constant(arr))
971    self._ExecuteAndCompareClose(c, expected=np.log(arr))
972
973  def testLog1p(self):
974    c = self._NewComputation()
975    arr = NumpyArrayF32([3.3, 12.1])
976    c.Log1p(c.Constant(arr))
977    self._ExecuteAndCompareClose(c, expected=np.log1p(arr))
978
979  def testNeg(self):
980    c = self._NewComputation()
981    arr = NumpyArrayF32([3.3, 12.1])
982    c.Neg(c.Constant(arr))
983    self._ExecuteAndCompareClose(c, expected=-arr)
984
985  def testFloor(self):
986    c = self._NewComputation()
987    arr = NumpyArrayF32([3.3, 12.1])
988    c.Floor(c.Constant(arr))
989    self._ExecuteAndCompareClose(c, expected=np.floor(arr))
990
991  def testCeil(self):
992    c = self._NewComputation()
993    arr = NumpyArrayF32([3.3, 12.1])
994    c.Ceil(c.Constant(arr))
995    self._ExecuteAndCompareClose(c, expected=np.ceil(arr))
996
997  def testAbs(self):
998    c = self._NewComputation()
999    arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.])
1000    c.Abs(c.Constant(arr))
1001    self._ExecuteAndCompareClose(c, expected=np.abs(arr))
1002
1003  def testTanh(self):
1004    c = self._NewComputation()
1005    arr = NumpyArrayF32([3.3, 12.1])
1006    c.Tanh(c.Constant(arr))
1007    self._ExecuteAndCompareClose(c, expected=np.tanh(arr))
1008
1009  def testTrans(self):
1010
1011    def _TransposeAndTest(array):
1012      c = self._NewComputation()
1013      c.Trans(c.Constant(array))
1014      self._ExecuteAndCompareClose(c, expected=array.T)
1015
1016    # Test square and non-square matrices in both default (C) and F orders.
1017    for array_fun in [NumpyArrayF32, NumpyArrayF64]:
1018      _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]]))
1019      _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F"))
1020      _TransposeAndTest(array_fun([[1, 2], [4, 5]]))
1021      _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F"))
1022
1023  def testTranspose(self):
1024
1025    def _TransposeAndTest(array, permutation):
1026      c = self._NewComputation()
1027      c.Transpose(c.Constant(array), permutation)
1028      expected = np.transpose(array, permutation)
1029      self._ExecuteAndCompareClose(c, expected=expected)
1030
1031    _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1])
1032    _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0])
1033    _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1])
1034    _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0])
1035
1036    arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32)
1037    for permutation in itertools.permutations(range(arr.ndim)):
1038      _TransposeAndTest(arr, permutation)
1039      _TransposeAndTest(np.asfortranarray(arr), permutation)
1040
1041  def testEq(self):
1042    c = self._NewComputation()
1043    c.Eq(
1044        c.Constant(NumpyArrayS32([1, 2, 3, 4])),
1045        c.Constant(NumpyArrayS32([4, 2, 3, 1])))
1046    self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
1047
1048  def testNe(self):
1049    c = self._NewComputation()
1050    c.Ne(
1051        c.Constant(NumpyArrayS32([1, 2, 3, 4])),
1052        c.Constant(NumpyArrayS32([4, 2, 3, 1])))
1053    self._ExecuteAndCompareExact(c, expected=[True, False, False, True])
1054
1055    c.Ne(
1056        c.Constant(NumpyArrayF32([-2.0, 0.0,
1057                                  float("nan"),
1058                                  float("nan")])),
1059        c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")])))
1060    self._ExecuteAndAssertWith(
1061        np.testing.assert_allclose, c, (), expected=[True, False, True, True])
1062
1063  def testGt(self):
1064    c = self._NewComputation()
1065    c.Gt(
1066        c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
1067        c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
1068    self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False])
1069
1070  def testGe(self):
1071    c = self._NewComputation()
1072    c.Ge(
1073        c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
1074        c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
1075    self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False])
1076
1077  def testLt(self):
1078    c = self._NewComputation()
1079    c.Lt(
1080        c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
1081        c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
1082    self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True])
1083
1084  def testLe(self):
1085    c = self._NewComputation()
1086    c.Le(
1087        c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])),
1088        c.Constant(NumpyArrayS32([1, 0, 2, 7, 12])))
1089    self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True])
1090
1091  def testMax(self):
1092    c = self._NewComputation()
1093    c.Max(
1094        c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
1095        c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
1096    self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0])
1097
1098  def testMaxExplicitBroadcastDim0(self):
1099    c = self._NewComputation()
1100    c.Max(
1101        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1102        c.Constant(NumpyArrayF32([3, 4, 5])),
1103        broadcast_dimensions=(0,))
1104    self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]])
1105
1106  def testMaxExplicitBroadcastDim1(self):
1107    c = self._NewComputation()
1108    c.Max(
1109        c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1110        c.Constant(NumpyArrayF32([3, 4, 5])),
1111        broadcast_dimensions=(1,))
1112    self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]])
1113
1114  def testMin(self):
1115    c = self._NewComputation()
1116    c.Min(
1117        c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])),
1118        c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0])))
1119    self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0])
1120
1121  def testPad(self):
1122    c = self._NewComputation()
1123    c.Pad(
1124        c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
1125        c.Constant(NumpyArrayF32(0.0)), [(1, 2, 1), (0, 1, 0)])
1126    self._ExecuteAndCompareClose(
1127        c,
1128        expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
1129                  [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
1130
1131  def testPadWithPaddingConfig(self):
1132    c = self._NewComputation()
1133    padding_config = xla_client.PaddingConfig()
1134    for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]:
1135      dimension = xla_client.PaddingConfigDimension()
1136      dimension.edge_padding_low = lo
1137      dimension.edge_padding_high = hi
1138      dimension.interior_padding = interior
1139      padding_config.dimensions.append(dimension)
1140    c.Pad(
1141        c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])),
1142        c.Constant(NumpyArrayF32(0.0)), padding_config)
1143    self._ExecuteAndCompareClose(
1144        c,
1145        expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0],
1146                  [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
1147
1148  def testReshape(self):
1149    c = self._NewComputation()
1150    c.Reshape(
1151        c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])),
1152        dimensions=[0, 1],
1153        new_sizes=[2, 3])
1154    self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]])
1155
1156  def testCollapse(self):
1157    c = self._NewComputation()
1158    c.Collapse(
1159        c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1160        dimensions=[1, 2])
1161    self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]])
1162
1163  def testRev(self):
1164    c = self._NewComputation()
1165    c.Rev(
1166        c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])),
1167        dimensions=[0, 2])
1168    self._ExecuteAndCompareExact(
1169        c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]])
1170
1171  def testReducePrecision(self):
1172    c = self._NewComputation()
1173    c.ReducePrecision(
1174        c.Constant(NumpyArrayF32([float.fromhex("0x1.32fffep-3")])),
1175        exponent_bits=8,
1176        mantissa_bits=7)
1177    self._ExecuteAndCompareClose(c, expected=[float.fromhex("0x1.32p-3")])
1178
1179  def testClampF32(self):
1180    c = self._NewComputation()
1181    c.Clamp(
1182        c.Constant(NumpyArrayF32(-1)),
1183        c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
1184        c.Constant(NumpyArrayF32(2)))
1185    self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
1186
1187  def testClampS32(self):
1188    c = self._NewComputation()
1189    c.Clamp(
1190        c.Constant(NumpyArrayS32(-1)),
1191        c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
1192        c.Constant(NumpyArrayS32(2)))
1193    self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
1194
1195  def testSelect(self):
1196    c = self._NewComputation()
1197    c.Select(
1198        c.Constant(NumpyArrayBool([True, False, False, True, False])),
1199        c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])),
1200        c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5])))
1201    self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5])
1202
1203  def testSlice(self):
1204    c = self._NewComputation()
1205    c.Slice(
1206        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0],
1207        [3, 2])
1208    self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
1209
1210  def testSliceInDim(self):
1211    c = self._NewComputation()
1212    c.SliceInDim(
1213        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1214        start_index=1,
1215        limit_index=2,
1216        stride=1,
1217        dimno=1)
1218    self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]])
1219    c.SliceInDim(
1220        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1221        start_index=0,
1222        limit_index=3,
1223        stride=2,
1224        dimno=0)
1225    self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]])
1226
1227  def testDynamicSlice(self):
1228    c = self._NewComputation()
1229    c.DynamicSlice(
1230        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1231        c.Constant(NumpyArrayS32([1, 0])), [2, 2])
1232    self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]])
1233
1234  def testDynamicUpdateSlice(self):
1235    c = self._NewComputation()
1236    c.DynamicUpdateSlice(
1237        c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])),
1238        c.Constant(NumpyArrayS32([[1, 2], [3, 4]])),
1239        c.Constant(NumpyArrayS32([1, 1])))
1240    self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]])
1241
1242  def testTuple(self):
1243    c = self._NewComputation()
1244    c.Tuple(
1245        c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
1246        c.Constant(NumpyArrayBool([True, False, False, True])))
1247    result = xla_client.execute_with_python_values(c.Build().Compile())
1248    self.assertIsInstance(result, tuple)
1249    np.testing.assert_equal(result[0], 42)
1250    np.testing.assert_allclose(result[1], [1.0, 2.0])
1251    np.testing.assert_equal(result[2], [True, False, False, True])
1252
1253  def testGetTupleElement(self):
1254    c = self._NewComputation()
1255    c.GetTupleElement(
1256        c.Tuple(
1257            c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])),
1258            c.Constant(NumpyArrayBool([True, False, False, True]))), 1)
1259    self._ExecuteAndCompareClose(c, expected=[1.0, 2.0])
1260
1261  def testBroadcast(self):
1262    c = self._NewComputation()
1263    c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,))
1264    self._ExecuteAndCompareExact(
1265        c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]])
1266
1267  def testBroadcastInDim(self):
1268    c = self._NewComputation()
1269    c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [0])
1270    self._ExecuteAndCompareExact(c, expected=[[1, 1], [2, 2]])
1271    c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [1])
1272    self._ExecuteAndCompareExact(c, expected=[[1, 2], [1, 2]])
1273
1274  def testRngNormal(self):
1275    shape = (2, 3)
1276    c = self._NewComputation()
1277    c.RngNormal(
1278        c.Constant(NumpyArrayF32(0.)),
1279        c.Constant(NumpyArrayF32(1.)),
1280        dims=shape)
1281    result = xla_client.execute_with_python_values(c.Build().Compile())
1282    # since the result is random, we just check shape and uniqueness
1283    self.assertEqual(result.shape, shape)
1284    self.assertLen(np.unique(result), np.prod(shape))
1285
1286  def testRngUniformF32(self):
1287    lo, hi = 2., 4.
1288    shape = (2, 3)
1289    c = self._NewComputation()
1290    c.RngUniform(
1291        c.Constant(NumpyArrayF32(lo)),
1292        c.Constant(NumpyArrayF32(hi)),
1293        dims=shape)
1294    result = xla_client.execute_with_python_values(c.Build().Compile())
1295    # since the result is random, we just check shape, uniqueness, and range
1296    self.assertEqual(result.shape, shape)
1297    self.assertLen(np.unique(result), np.prod(shape))
1298    self.assertTrue(np.all(lo <= result))
1299    self.assertTrue(np.all(result < hi))
1300
1301  def testRngUniformS32(self):
1302    lo, hi = 2, 4
1303    shape = (2, 3)
1304    c = self._NewComputation()
1305    c.RngUniform(
1306        c.Constant(NumpyArrayS32(lo)),
1307        c.Constant(NumpyArrayS32(hi)),
1308        dims=shape)
1309    result = xla_client.execute_with_python_values(c.Build().Compile())
1310    # since the result is random, we just check shape, integrality, and range
1311    self.assertEqual(result.shape, shape)
1312    self.assertEqual(result.dtype, np.int32)
1313    self.assertTrue(np.all(lo <= result))
1314    self.assertTrue(np.all(result < hi))
1315
1316  def testCholesky(self):
1317    l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]],
1318                 dtype=np.float32)
1319    c = self._NewComputation()
1320    c.Cholesky(c.Constant(np.dot(l, l.T)))
1321    self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4)
1322
1323  def testSort(self):
1324    keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
1325    c = self._NewComputation()
1326    c.Sort(c.Constant(keys))
1327    self._ExecuteAndCompareClose(
1328        c, expected=np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32))
1329
1330  def testSortKeyVal(self):
1331    keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32)
1332    values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
1333    c = self._NewComputation()
1334    c.Sort((c.Constant(keys), c.Constant(values)), dimension=0)
1335    result = xla_client.execute_with_python_values(c.Build().Compile())
1336    self.assertIsInstance(result, tuple)
1337    np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]])
1338    np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]])
1339
1340  def testSortCustomComparator(self):
1341    b = self._NewComputation("comparator")
1342    p0 = b.ParameterFromNumpy(NumpyArrayF32(0))
1343    q0 = b.ParameterFromNumpy(NumpyArrayF32(0))
1344    p1 = b.ParameterFromNumpy(NumpyArrayS32(0))
1345    q1 = b.ParameterFromNumpy(NumpyArrayS32(0))
1346    b.Or(b.Lt(p0, q0), b.And(b.Eq(p0, q0), b.Gt(p1, q1)))
1347    comparator = b.Build()
1348
1349    keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32)
1350    values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32)
1351    c = self._NewComputation()
1352    c.Sort((c.Constant(keys), c.Constant(values)),
1353           dimension=1,
1354           comparator=comparator)
1355    result = xla_client.execute_with_python_values(c.Build().Compile())
1356    self.assertIsInstance(result, tuple)
1357    np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]])
1358    np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]])
1359
1360  def testQR(self):
1361    a = np.array(
1362        [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
1363        dtype=np.float32)
1364    c = self._NewComputation()
1365    c.QR(c.Constant(a), full_matrices=True)
1366    q, r = self._Execute(c, ())
1367    np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4)
1368
1369  def testEigh(self):
1370    a = np.array(
1371        [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
1372        dtype=np.float32)
1373    a = (a + a.T) / 2
1374
1375    c = self._NewComputation()
1376    c.Eigh(c.Constant(a), full_matrices=True)
1377    # TODO(b/129396575): Turn this test back on when it passes without fastmath.
1378    # v, w = self._Execute(c, ())
1379    # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3)
1380
1381  def testSVD(self):
1382    a = np.array(
1383        [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]],
1384        dtype=np.float32)
1385    c = self._NewComputation()
1386    c.SVD(c.Constant(a))
1387    u, d, v = self._Execute(c, ())
1388    self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3)
1389
1390  def testTriangularSolve(self):
1391    a_vals = np.array(
1392        [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]],
1393        dtype=np.float32)
1394    b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
1395                      dtype=np.float32)
1396
1397    c = self._NewComputation()
1398    c.TriangularSolve(
1399        c.Constant(a_vals),
1400        c.Constant(b_vals),
1401        left_side=False,
1402        lower=True,
1403        transpose_a=True)
1404    self._ExecuteAndCompareClose(
1405        c,
1406        expected=np.array([
1407            [0.5, 0.08333334, 0.04629629, 0.03367003],
1408            [2.5, -0.25, -0.1388889, -0.1010101],
1409            [4.5, -0.58333331, -0.32407406, -0.23569024],
1410        ],
1411                          dtype=np.float32),
1412        rtol=1e-4)
1413
1414  def testIsConstant(self):
1415    c = self._NewComputation()
1416    a = c.ConstantS32Scalar(3)
1417    b = c.ConstantS32Scalar(1)
1418    x = c.ParameterFromNumpy(NumpyArrayS32(0))
1419    const_expr = c.Sub(b, a)
1420    non_const_expr = c.Mul(const_expr, x)
1421    self.assertTrue(c.IsConstant(const_expr))
1422    self.assertFalse(c.IsConstant(non_const_expr))
1423    # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x)))  # TODO(b/77245564)
1424
1425  def testGather(self):
1426    a = np.arange(9).astype(np.int32).reshape((3, 3))
1427    indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32)
1428    dnums = xla_client.GatherDimensionNumbers()
1429    dnums.offset_dims.append(1)
1430    dnums.offset_dims.append(2)
1431    dnums.start_index_map.append(0)
1432    dnums.start_index_map.append(1)
1433    dnums.index_vector_dim = 2
1434    c = self._NewComputation()
1435    c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1])
1436    g = self._Execute(c, ())
1437    expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32)
1438    np.testing.assert_allclose(g, expected, rtol=1e-4)
1439
1440  def testFft(self):
1441    shape = [2, 3, 4, 5]
1442    rng = np.random.RandomState(0)
1443    a = rng.randn(*shape) + 1.0j * rng.randn(*shape)
1444    a = a.astype(np.complex64)
1445    # FFT
1446    c = self._NewComputation()
1447    c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:])
1448    self._ExecuteAndCompareClose(
1449        c, expected=np.fft.fftn(a, axes=(1, 2, 3)), rtol=1e-4)
1450    # IFFT
1451    c = self._NewComputation()
1452    c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:])
1453    self._ExecuteAndCompareClose(
1454        c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), rtol=1e-4)
1455    # RFFT
1456    b = rng.randn(*shape).astype(np.float32)
1457    c = self._NewComputation()
1458    c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:])
1459    self._ExecuteAndCompareClose(
1460        c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), rtol=1e-4)
1461    # IRFFT
1462    c = self._NewComputation()
1463    c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8])
1464    self._ExecuteAndCompareClose(
1465        c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), rtol=1e-4)
1466
1467  def testNextAfter(self):
1468    c = self._NewComputation()
1469    c.NextAfter(
1470        c.Constant(np.array([1, 2], dtype=np.float32)),
1471        c.Constant(np.array([2, 1], dtype=np.float32)))
1472    out = self._Execute(c, ())
1473    eps = np.finfo(np.float32).eps
1474    np.testing.assert_equal(np.array([eps + 1, 2 - eps], dtype=np.float32), out)
1475
1476  def testRegularizedIncompleteBeta(self):
1477    x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538])
1478    a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606])
1479    b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677])
1480    c = self._NewComputation()
1481    c.RegularizedIncompleteBeta(c.Constant(a), c.Constant(b), c.Constant(x))
1482    expected = np.array(
1483        [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155])
1484    self._ExecuteAndCompareClose(c, expected=expected, rtol=1e-4)
1485
1486
1487class EmbeddedComputationsTest(ComputationTest):
1488  """Tests for XLA graphs with embedded computations (such as maps)."""
1489
1490  def _CreateConstantS32Computation(self):
1491    """Computation (f32) -> s32 that returns a constant 1 for any input."""
1492    c = self._NewComputation("constant_s32_one")
1493    # TODO(eliben): consider adding a nicer way to create new parameters without
1494    # having to create dummy Numpy arrays or populating Shape messages. Perhaps
1495    # we need our own (Python-client-own) way to represent Shapes conveniently.
1496    c.ParameterFromNumpy(NumpyArrayF32(0))
1497    c.ConstantS32Scalar(1)
1498    return c.Build()
1499
1500  def _CreateConstantS64Computation(self):
1501    """Computation (f64) -> s64 that returns a constant 1 for any input."""
1502    c = self._NewComputation("constant_s64_one")
1503    # TODO(eliben): consider adding a nicer way to create new parameters without
1504    # having to create dummy Numpy arrays or populating Shape messages. Perhaps
1505    # we need our own (Python-client-own) way to represent Shapes conveniently.
1506    c.ParameterFromNumpy(NumpyArrayF64(0))
1507    c.ConstantS64Scalar(1)
1508    return c.Build()
1509
1510  def _CreateConstantF32Computation(self):
1511    """Computation (f32) -> f32 that returns a constant 1.0 for any input."""
1512    c = self._NewComputation("constant_f32_one")
1513    c.ParameterFromNumpy(NumpyArrayF32(0))
1514    c.ConstantF32Scalar(1.0)
1515    return c.Build()
1516
1517  def _CreateConstantF64Computation(self):
1518    """Computation (f64) -> f64 that returns a constant 1.0 for any input."""
1519    c = self._NewComputation("constant_f64_one")
1520    c.ParameterFromNumpy(NumpyArrayF64(0))
1521    c.ConstantF64Scalar(1.0)
1522    return c.Build()
1523
1524  def _CreateMulF32By2Computation(self):
1525    """Computation (f32) -> f32 that multiplies its parameter by 2."""
1526    c = self._NewComputation("mul_f32_by2")
1527    c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0))
1528    return c.Build()
1529
1530  def _CreateMulF32ByParamComputation(self):
1531    """Computation (f32) -> f32 that multiplies one parameter by the other."""
1532    c = self._NewComputation("mul_f32_by_param")
1533    c.Mul(
1534        c.ParameterFromNumpy(NumpyArrayF32(0)),
1535        c.ParameterFromNumpy(NumpyArrayF32(0)))
1536    return c.Build()
1537
1538  def _CreateMulF64By2Computation(self):
1539    """Computation (f64) -> f64 that multiplies its parameter by 2."""
1540    c = self._NewComputation("mul_f64_by2")
1541    c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0))
1542    return c.Build()
1543
1544  def _CreateBinaryAddS32Computation(self):
1545    """Computation (s32, s32) -> s32 that adds its two parameters."""
1546    c = self._NewComputation("add_param0_by_param1")
1547    c.Add(
1548        c.ParameterFromNumpy(NumpyArrayS32(0)),
1549        c.ParameterFromNumpy(NumpyArrayS32(0)))
1550    return c.Build()
1551
1552  def _CreateBinaryAddF32Computation(self):
1553    """Computation (f32, f32) -> f32 that adds its two parameters."""
1554    c = self._NewComputation("add_param0_by_param1")
1555    c.Add(
1556        c.ParameterFromNumpy(NumpyArrayF32(0)),
1557        c.ParameterFromNumpy(NumpyArrayF32(0)))
1558    return c.Build()
1559
1560  def _CreateBinaryAddF64Computation(self):
1561    """Computation (f64, f64) -> f64 that adds its two parameters."""
1562    c = self._NewComputation("add_param0_by_param1")
1563    c.Add(
1564        c.ParameterFromNumpy(NumpyArrayF64(0)),
1565        c.ParameterFromNumpy(NumpyArrayF64(0)))
1566    return c.Build()
1567
1568  def _CreateBinaryDivF32Computation(self):
1569    """Computation (f32, f32) -> f32 that divides its two parameters."""
1570    c = self._NewComputation("div_param0_by_param1")
1571    c.Div(
1572        c.ParameterFromNumpy(NumpyArrayF32(0)),
1573        c.ParameterFromNumpy(NumpyArrayF32(0)))
1574    return c.Build()
1575
1576  def _CreateBinaryDivF64Computation(self):
1577    """Computation (f64, f64) -> f64 that divides its two parameters."""
1578    c = self._NewComputation("div_param0_by_param1")
1579    c.Div(
1580        c.ParameterFromNumpy(NumpyArrayF64(0)),
1581        c.ParameterFromNumpy(NumpyArrayF64(0)))
1582    return c.Build()
1583
1584  def _CreateTestF32Lt10Computation(self):
1585    """Computation (f32) -> bool that tests if its parameter is less than 10."""
1586    c = self._NewComputation("test_f32_lt_10")
1587    c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.))
1588    return c.Build()
1589
1590  def _CreateTestF64Lt10Computation(self):
1591    """Computation (f64) -> bool that tests if its parameter is less than 10."""
1592    c = self._NewComputation("test_f64_lt_10")
1593    c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.))
1594    return c.Build()
1595
1596  def _CreateBinaryGeF32Computation(self):
1597    """Computation (f32, f32) -> bool that tests first_param >= second_param."""
1598    c = self._NewComputation("param0_lt_param1")
1599    c.Ge(
1600        c.ParameterFromNumpy(NumpyArrayF32(0)),
1601        c.ParameterFromNumpy(NumpyArrayF32(0)))
1602    return c.Build()
1603
1604  def _CreateBinaryGeF64Computation(self):
1605    """Computation (f64, f64) -> bool that tests first_param >= second_param."""
1606    c = self._NewComputation("param0_lt_param1")
1607    c.Ge(
1608        c.ParameterFromNumpy(NumpyArrayF64(0)),
1609        c.ParameterFromNumpy(NumpyArrayF64(0)))
1610    return c.Build()
1611
1612  def _MakeSample3DArrayF32(self):
1613    return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1614                          [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
1615
1616  def _MakeSample3DArrayF64(self):
1617    return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]],
1618                          [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])
1619
1620  def testCallF32(self):
1621    c = self._NewComputation()
1622    c.Call(
1623        self._CreateMulF32By2Computation(),
1624        operands=(c.ConstantF32Scalar(5.0),))
1625    self._ExecuteAndCompareClose(c, expected=10.0)
1626
1627  def testCallF64(self):
1628    c = self._NewComputation()
1629    c.Call(
1630        self._CreateMulF64By2Computation(),
1631        operands=(c.ConstantF64Scalar(5.0),))
1632    self._ExecuteAndCompareClose(c, expected=10.0)
1633
1634  def testMapEachElementToS32Constant(self):
1635    c = self._NewComputation()
1636    c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
1637          self._CreateConstantS32Computation(), [0])
1638    self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
1639
1640  def testMapEachElementToS64Constant(self):
1641    c = self._NewComputation()
1642    c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
1643          self._CreateConstantS64Computation(), [0])
1644    self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1])
1645
1646  def testMapMulBy2F32(self):
1647    c = self._NewComputation()
1648    c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
1649          self._CreateMulF32By2Computation(), [0])
1650    self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
1651
1652  def testMapMulBy2F64(self):
1653    c = self._NewComputation()
1654    c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
1655          self._CreateMulF64By2Computation(), [0])
1656    self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0])
1657
1658  def testSimpleMapChainF32(self):
1659    # Chains a map of constant-f32 with a map of mul-by-2
1660    c = self._NewComputation()
1661    const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
1662                      self._CreateConstantF32Computation(), [0])
1663    c.Map([const_f32], self._CreateMulF32By2Computation(), [0])
1664    self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
1665
1666  def testSimpleMapChainF64(self):
1667    # Chains a map of constant-f64 with a map of mul-by-2
1668    c = self._NewComputation()
1669    const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))],
1670                      self._CreateConstantF64Computation(), [0])
1671    c.Map([const_f64], self._CreateMulF64By2Computation(), [0])
1672    self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0])
1673
1674  def testDivVectorsWithMapF32(self):
1675    c = self._NewComputation()
1676    c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
1677           c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))),
1678          self._CreateBinaryDivF32Computation(), [0])
1679    self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
1680
1681  def testDivVectorsWithMapF64(self):
1682    c = self._NewComputation()
1683    c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
1684           c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))),
1685          self._CreateBinaryDivF64Computation(), [0])
1686    self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
1687
1688  def testSelectAndScatterF32(self):
1689    c = self._NewComputation()
1690    c.SelectAndScatter(
1691        c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
1692        select=self._CreateBinaryGeF32Computation(),
1693        window_dimensions=(2, 1),
1694        window_strides=(1, 2),
1695        padding=xla_client.PaddingType.VALID,
1696        source=c.Constant(NumpyArrayF32([[0.1, 0.2]])),
1697        init_value=c.Constant(NumpyArrayF32(1)),
1698        scatter=self._CreateBinaryAddF32Computation())
1699    self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
1700
1701  def testSelectAndScatterF64(self):
1702    c = self._NewComputation()
1703    c.SelectAndScatter(
1704        c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])),
1705        select=self._CreateBinaryGeF64Computation(),
1706        window_dimensions=(2, 1),
1707        window_strides=(1, 2),
1708        padding=xla_client.PaddingType.VALID,
1709        source=c.Constant(NumpyArrayF64([[0.1, 0.2]])),
1710        init_value=c.Constant(NumpyArrayF64(1)),
1711        scatter=self._CreateBinaryAddF64Computation())
1712    self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]])
1713
1714  def testReduce1DtoScalarF32(self):
1715    c = self._NewComputation()
1716    c.Reduce(
1717        operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])),
1718        init_value=c.ConstantF32Scalar(0),
1719        computation_to_apply=self._CreateBinaryAddF32Computation(),
1720        dimensions=[0])
1721    self._ExecuteAndCompareClose(c, expected=10)
1722
1723  def testReduce1DtoScalarF64(self):
1724    c = self._NewComputation()
1725    c.Reduce(
1726        operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])),
1727        init_value=c.ConstantF64Scalar(0),
1728        computation_to_apply=self._CreateBinaryAddF64Computation(),
1729        dimensions=[0])
1730    self._ExecuteAndCompareClose(c, expected=10)
1731
1732  def testReduce2DTo1DDim0F32(self):
1733    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1734    c = self._NewComputation()
1735    c.Reduce(
1736        operand=c.Constant(input_array),
1737        init_value=c.ConstantF32Scalar(0),
1738        computation_to_apply=self._CreateBinaryAddF32Computation(),
1739        dimensions=[0])
1740    self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
1741
1742  def testReduce2DTo1DDim0F64(self):
1743    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1744    c = self._NewComputation()
1745    c.Reduce(
1746        operand=c.Constant(input_array),
1747        init_value=c.ConstantF64Scalar(0),
1748        computation_to_apply=self._CreateBinaryAddF64Computation(),
1749        dimensions=[0])
1750    self._ExecuteAndCompareClose(c, expected=[5, 7, 9])
1751
1752  def testReduce2DTo1DDim1F32(self):
1753    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1754    c = self._NewComputation()
1755    c.Reduce(
1756        operand=c.Constant(input_array),
1757        init_value=c.ConstantF32Scalar(0),
1758        computation_to_apply=self._CreateBinaryAddF32Computation(),
1759        dimensions=[1])
1760    self._ExecuteAndCompareClose(c, expected=[6, 15])
1761
1762  def testReduce2DTo1DDim1F64(self):
1763    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1764    c = self._NewComputation()
1765    c.Reduce(
1766        operand=c.Constant(input_array),
1767        init_value=c.ConstantF64Scalar(0),
1768        computation_to_apply=self._CreateBinaryAddF64Computation(),
1769        dimensions=[1])
1770    self._ExecuteAndCompareClose(c, expected=[6, 15])
1771
1772  def testReduce3DAllPossibleWaysF32(self):
1773    input_array = self._MakeSample3DArrayF32()
1774
1775    def _ReduceAndTest(*dims):
1776      c = self._NewComputation()
1777      c.Reduce(
1778          operand=c.Constant(input_array),
1779          init_value=c.ConstantF32Scalar(0),
1780          computation_to_apply=self._CreateBinaryAddF32Computation(),
1781          dimensions=dims)
1782      self._ExecuteAndCompareClose(
1783          c, expected=np.sum(input_array, axis=tuple(dims)))
1784
1785    _ReduceAndTest(0)
1786    _ReduceAndTest(0, 1)
1787    _ReduceAndTest(0, 2)
1788    _ReduceAndTest(1, 2)
1789    _ReduceAndTest(0, 1, 2)
1790
1791  def testReduce3DAllPossibleWaysF64(self):
1792    input_array = self._MakeSample3DArrayF64()
1793
1794    def _ReduceAndTest(*dims):
1795      c = self._NewComputation()
1796      c.Reduce(
1797          operand=c.Constant(input_array),
1798          init_value=c.ConstantF64Scalar(0),
1799          computation_to_apply=self._CreateBinaryAddF64Computation(),
1800          dimensions=dims)
1801      self._ExecuteAndCompareClose(
1802          c, expected=np.sum(input_array, axis=tuple(dims)))
1803
1804    _ReduceAndTest(0)
1805    _ReduceAndTest(0)
1806    _ReduceAndTest(0, 1)
1807    _ReduceAndTest(0, 2)
1808    _ReduceAndTest(1, 2)
1809    _ReduceAndTest(0, 1, 2)
1810
1811  def testReduceWindowValidUnitStridesF32(self):
1812    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1813    c = self._NewComputation()
1814    c.ReduceWindow(
1815        operand=c.Constant(input_array),
1816        init_value=c.ConstantF32Scalar(0),
1817        computation_to_apply=self._CreateBinaryAddF32Computation(),
1818        window_dimensions=(2, 1),
1819        window_strides=(1, 1),
1820        padding=xla_client.PaddingType.VALID)
1821    self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
1822
1823  def testReduceWindowSameUnitStridesF32(self):
1824    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1825    c = self._NewComputation()
1826    c.ReduceWindow(
1827        operand=c.Constant(input_array),
1828        init_value=c.ConstantF32Scalar(0),
1829        computation_to_apply=self._CreateBinaryAddF32Computation(),
1830        window_dimensions=(2, 1),
1831        window_strides=(1, 1),
1832        padding=xla_client.PaddingType.SAME)
1833    self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
1834
1835  def testReduceWindowValidGeneralStridesF32(self):
1836    input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1837    c = self._NewComputation()
1838    c.ReduceWindow(
1839        operand=c.Constant(input_array),
1840        init_value=c.ConstantF32Scalar(0),
1841        computation_to_apply=self._CreateBinaryAddF32Computation(),
1842        window_dimensions=(2, 1),
1843        window_strides=(1, 2),
1844        padding=xla_client.PaddingType.VALID)
1845    self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
1846
1847  def testReduceWindowValidUnitStridesF64(self):
1848    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1849    c = self._NewComputation()
1850    c.ReduceWindow(
1851        operand=c.Constant(input_array),
1852        init_value=c.ConstantF64Scalar(0),
1853        computation_to_apply=self._CreateBinaryAddF64Computation(),
1854        window_dimensions=(2, 1),
1855        window_strides=(1, 1),
1856        padding=xla_client.PaddingType.VALID)
1857    self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]])
1858
1859  def testReduceWindowSameUnitStridesF64(self):
1860    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1861    c = self._NewComputation()
1862    c.ReduceWindow(
1863        operand=c.Constant(input_array),
1864        init_value=c.ConstantF64Scalar(0),
1865        computation_to_apply=self._CreateBinaryAddF64Computation(),
1866        window_dimensions=(2, 1),
1867        window_strides=(1, 1),
1868        padding=xla_client.PaddingType.SAME)
1869    self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]])
1870
1871  def testReduceWindowValidGeneralStridesF64(self):
1872    input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
1873    c = self._NewComputation()
1874    c.ReduceWindow(
1875        operand=c.Constant(input_array),
1876        init_value=c.ConstantF64Scalar(0),
1877        computation_to_apply=self._CreateBinaryAddF64Computation(),
1878        window_dimensions=(2, 1),
1879        window_strides=(1, 2),
1880        padding=xla_client.PaddingType.VALID)
1881    self._ExecuteAndCompareClose(c, expected=[[5., 9.]])
1882
1883  def testWhileF32(self):
1884    cond = self._CreateTestF32Lt10Computation()
1885    body = self._CreateMulF32By2Computation()
1886    c = self._NewComputation()
1887    init = c.ConstantF32Scalar(1.)
1888    c.While(cond, body, init)
1889    self._ExecuteAndCompareClose(c, expected=16.)
1890
1891  def testWhileF64(self):
1892    cond = self._CreateTestF64Lt10Computation()
1893    body = self._CreateMulF64By2Computation()
1894    c = self._NewComputation()
1895    init = c.ConstantF64Scalar(1.)
1896    c.While(cond, body, init)
1897    self._ExecuteAndCompareClose(c, expected=16.)
1898
1899  def testConditionalTrue(self):
1900    c = self._NewComputation()
1901    pred = c.ConstantPredScalar(True)
1902    true_operand = c.ConstantF32Scalar(3.)
1903    true_computation = self._CreateMulF32By2Computation()
1904    false_operand = c.ConstantF32Scalar(2.)
1905    false_computation = self._CreateConstantF32Computation()
1906    c.Conditional(pred, true_operand, true_computation, false_operand,
1907                  false_computation)
1908    self._ExecuteAndCompareClose(c, expected=6.)
1909
1910  def testConditionalFalse(self):
1911    c = self._NewComputation()
1912    pred = c.ConstantPredScalar(False)
1913    true_operand = c.ConstantF32Scalar(3.)
1914    true_computation = self._CreateMulF32By2Computation()
1915    false_operand = c.ConstantF32Scalar(2.)
1916    false_computation = self._CreateConstantF32Computation()
1917    c.Conditional(pred, true_operand, true_computation, false_operand,
1918                  false_computation)
1919    self._ExecuteAndCompareClose(c, expected=1.)
1920
1921  def testInfeedS32Values(self):
1922    to_infeed = NumpyArrayS32([1, 2, 3, 4])
1923    c = self._NewComputation()
1924    c.GetTupleElement(c.Infeed(xla_client.shape_from_pyval(to_infeed[0])), 0)
1925    compiled_c = c.Build().Compile()
1926    for item in to_infeed:
1927      xla_client.transfer_to_infeed(item)
1928
1929    for item in to_infeed:
1930      result = xla_client.execute_with_python_values(compiled_c)
1931      self.assertEqual(result, item)
1932
1933  def testInfeedTuple(self):
1934    to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]]))
1935    c = self._NewComputation()
1936    c.GetTupleElement(c.Infeed(xla_client.shape_from_pyval(to_infeed)), 0)
1937    compiled_c = c.Build().Compile()
1938    xla_client.transfer_to_infeed(to_infeed)
1939
1940    result = xla_client.execute_with_python_values(compiled_c)
1941    np.testing.assert_equal(result[0], to_infeed[0])
1942    np.testing.assert_equal(result[1], to_infeed[1])
1943
1944  def testInfeedThenOutfeedS32(self):
1945    to_round_trip = NumpyArrayS32([1, 2, 3, 4])
1946    c = self._NewComputation()
1947    x_and_token = c.Infeed(xla_client.shape_from_pyval(to_round_trip[0]))
1948    x = c.GetTupleElement(x_and_token, 0)
1949    token = c.GetTupleElement(x_and_token, 1)
1950    c.Outfeed(x, token)
1951
1952    compiled_c = c.Build().Compile()
1953
1954    for want in to_round_trip:
1955      execution = threading.Thread(target=lambda: compiled_c.Execute([]))
1956      execution.start()
1957      xla_client.transfer_to_infeed(want)
1958      got = xla_client.transfer_from_outfeed(
1959          xla_client.shape_from_pyval(to_round_trip[0]))
1960      execution.join()
1961      self.assertEqual(want, got)
1962
1963  def testScatter(self):
1964    a = np.arange(9).astype(np.int32).reshape((3, 3))
1965    scatter_indices = np.array([0, 2], dtype=np.int32)
1966    updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32)
1967
1968    dnums = xla_client.ScatterDimensionNumbers()
1969    dnums.update_window_dims.append(1)
1970    dnums.inserted_window_dims.append(0)
1971    dnums.scatter_dims_to_operand_dims.append(0)
1972    dnums.index_vector_dim = 1
1973
1974    c = self._NewComputation()
1975    c.Scatter(
1976        c.Constant(a), c.Constant(scatter_indices), c.Constant(updates),
1977        self._CreateBinaryAddS32Computation(), dnums)
1978    expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32)
1979    self._ExecuteAndCompareClose(c, expected=expected)
1980
1981
1982class ErrorTest(ComputationTest):
1983
1984  def setUp(self):
1985    self.f32_scalar_2 = NumpyArrayF32(2.0)
1986    self.s32_scalar_2 = NumpyArrayS32(2)
1987
1988  def testCompileWithWrongElementTypeInLayout(self):
1989    c = self._NewComputation()
1990    c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
1991    c.ParameterFromNumpy(self.s32_scalar_2)
1992    c.ClearOpMetadata()
1993
1994    options = xla_client.CompileOptions()
1995    options.argument_layouts = [
1996        xla_client.Shape.array_shape(np.dtype(np.float32), [])
1997    ]
1998
1999    def TestFun():
2000      return c.Build().Compile(compile_options=options)
2001
2002    self.assertRaisesRegex(
2003        RuntimeError, r".*Invalid argument shape.*"
2004        r"expected s32\[\], got f32\[\].*", TestFun)
2005
2006  def testInvokeWithWrongElementType(self):
2007    c = self._NewComputation()
2008    c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata())
2009    c.ParameterFromNumpy(self.s32_scalar_2)
2010    c.ClearOpMetadata()
2011
2012    def TestFun():
2013      return xla_client.execute_with_python_values(c.Build().Compile(),
2014                                                   [self.f32_scalar_2])
2015
2016    self.assertRaisesRegex(
2017        RuntimeError, r"Invalid argument: Argument does not match.*"
2018        r"want s32\[\], got f32\[\].*", TestFun)
2019
2020
2021class ComputationRootTest(ComputationTest):
2022  """Tests related to setting the root of the computation."""
2023
2024  def testComputationRootDifferentFromLastOp(self):
2025    c = self._NewComputation()
2026    x = c.ParameterFromNumpy(NumpyArrayF32(2.0))
2027    result = c.Add(x, c.ConstantF32Scalar(3.14))
2028    extra = c.Add(result, c.ConstantF32Scalar(1.618))  # pylint: disable=unused-variable
2029
2030    arg = NumpyArrayF32(1.0)
2031    compiled_c = c.Build(result).Compile()
2032    ans = xla_client.execute_with_python_values(compiled_c, [arg])
2033    np.testing.assert_allclose(ans, 4.14)
2034
2035
2036class SetShardingTest(ComputationTest):
2037  """Tests related to set OpSharding."""
2038
2039  def testSetSharding(self):
2040    c = self._NewComputation()
2041    sharding = xla_client.OpSharding()
2042    sharding.type = sharding.type.REPLICATED
2043    sharding.tile_assignment_dimensions.extend([1])
2044    sharding.tile_assignment_devices.extend([0])
2045    # Set Sharding.
2046    c.SetSharding(sharding)
2047    x = c.ParameterFromNumpy(NumpyArrayF32(2.0))
2048    # Clear Sharding.
2049    c.ClearSharding()
2050
2051    result = c.Add(x, c.ConstantF32Scalar(3.14))
2052    extra = c.Add(result, c.ConstantF32Scalar(1.618))  # pylint: disable=unused-variable
2053    arg = NumpyArrayF32(1.0)
2054    compiled_c = c.Build(result).Compile()
2055    ans = xla_client.execute_with_python_values(compiled_c, [arg])
2056    np.testing.assert_allclose(ans, 4.14)
2057
2058
2059int_dtypes = [
2060    np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
2061    np.uint64
2062]
2063float_dtypes = [np.float16, np.float32, np.float64]
2064complex_dtypes = [np.complex64, np.complex128]
2065dlpack_dtypes = int_dtypes + float_dtypes + [bfloat16]
2066standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_]
2067
2068testcase_shapes = [
2069    (),
2070    (1,),
2071    (2, 3),
2072    (2, 0),
2073    (0, 7),
2074    (4, 1, 2),
2075    (2, 1, 3),
2076    (2, 4, 1),
2077    (3, 1),
2078    (1, 3),
2079]
2080
2081
2082def FormatShapeAndDtype(shape, dtype):
2083  return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape)))
2084
2085
2086class DLPackTest(parameterized.TestCase):
2087
2088  # pylint: disable=g-complex-comprehension
2089  @parameterized.named_parameters({
2090      "testcase_name": FormatShapeAndDtype(shape, dtype),
2091      "dtype": dtype,
2092      "shape": shape
2093  } for dtype in dlpack_dtypes for shape in testcase_shapes)
2094  def testRoundTrip(self, dtype, shape):
2095    x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
2096    backend = xla_client.get_local_backend()
2097    buffer = xla_client.Buffer.from_pyval(x, backend=backend)
2098    dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer)
2099    del buffer  # Free "buffer" to make sure dlt retains ownership.
2100    self.assertEqual(type(dlt).__name__, "PyCapsule")
2101    y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client)
2102    np.testing.assert_array_equal(x, y.to_py())
2103
2104  def testTensorsCanBeConsumedOnceOnly(self):
2105    x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32)
2106    backend = xla_client.get_local_backend()
2107    buffer = xla_client.Buffer.from_pyval(x, backend=backend)
2108    dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer)
2109
2110    def ConsumeDLPackTensor():
2111      _ = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client)
2112
2113    ConsumeDLPackTensor()
2114    self.assertRaisesRegex(RuntimeError,
2115                           ".*a DLPack tensor may be consumed at most once.*",
2116                           ConsumeDLPackTensor)
2117
2118
2119class BufferProtocolTest(parameterized.TestCase):
2120
2121  # pylint: disable=g-complex-comprehension
2122  @parameterized.named_parameters({
2123      "testcase_name": FormatShapeAndDtype(shape, dtype),
2124      "dtype": dtype,
2125      "shape": shape
2126  } for dtype in standard_dtypes for shape in testcase_shapes)
2127  def testRoundTrip(self, dtype, shape):
2128    x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
2129    x_ptr = x.__array_interface__["data"][0]
2130    backend = xla_client.get_local_backend("cpu")
2131    buffer = xla_client.Buffer.from_pyval(x, backend=backend)
2132    y = np.array(buffer, copy=False)
2133    y_ptr = y.__array_interface__["data"][0]
2134    np.testing.assert_array_equal(x, y)
2135    # If the input was sufficiently aligned, the input and output should alias.
2136    self.assertTrue((x_ptr & 63) != 0 or x_ptr == y_ptr)
2137    self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
2138
2139    buffer2 = xla_client.Buffer.from_pyval(x, backend=backend, force_copy=True)
2140    z = np.array(buffer2, copy=False)
2141    self.assertNotEqual(x.__array_interface__["data"][0],
2142                        z.__array_interface__["data"][0])
2143
2144  def testDeleteWithActiveView(self):
2145    x = np.random.randn(20, 10)
2146    backend = xla_client.get_local_backend("cpu")
2147    buffer = xla_client.Buffer.from_pyval(x, backend=backend)
2148    buffer_ptr = buffer.unsafe_buffer_pointer()
2149    y = np.array(buffer, copy=False)
2150    buffer.delete()
2151    # It is still legal to access `y`; the array view must keep it alive.
2152    np.testing.assert_array_equal(x, y)
2153    self.assertEqual(y.__array_interface__["data"][0], buffer_ptr)
2154
2155
2156if __name__ == "__main__":
2157  absltest.main()
2158