1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Backend-dependent tests for the Python XLA client.""" 16 17import functools 18import itertools 19import re 20import threading 21import unittest 22 23from absl import flags 24from absl import logging 25from absl.testing import absltest 26from absl.testing import parameterized 27import numpy as np 28 29from tensorflow.compiler.xla.python import xla_client 30 31# pylint: disable=g-import-not-at-top 32try: 33 # This import is only used for GPU; the dependency is incompatible with TPU 34 # so it results in an import error. 35 from tensorflow.python.framework import test_util 36except ImportError: 37 test_util = None 38 39# pylint: disable=g-import-not-at-top 40try: 41 from tensorflow.compiler.xla.python import custom_call_for_test 42except ImportError: 43 custom_call_for_test = None 44 45bfloat16 = xla_client.bfloat16 46ops = xla_client.ops 47 48FLAGS = flags.FLAGS 49 50# We choose to ignore pylint's complaints about complex comprehensions, which we 51# use widely for parameterizing tests. 52# pylint: disable=g-complex-comprehension 53 54 55def TestFactory(xla_backend, 56 cloud_tpu=False, 57 tfrt_tpu=False, 58 external_tpu=False): 59 tests = [] 60 61 if not cloud_tpu: 62 int_dtypes = [np.int32, np.int64, np.uint32, np.uint64] 63 # TODO(phawkins): test np.float16, where supported. 64 float_dtypes = [bfloat16, np.float32, np.float64] 65 complex_dtypes = [np.complex64, np.complex128] 66 standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] 67 else: 68 int_dtypes = [np.int32, np.uint32] 69 float_dtypes = [np.float32] 70 complex_dtypes = [np.complex64] 71 standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] 72 dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] 73 74 class ComputationTest(parameterized.TestCase): 75 """Base class for running an XLA Computation through the local client.""" 76 77 def setUp(self): 78 super(ComputationTest, self).setUp() 79 self.backend = xla_backend() 80 81 def _NewComputation(self, name=None): 82 if name is None: 83 name = self.id() 84 return xla_client.XlaBuilder(name) 85 86 def _Execute(self, c, arguments): 87 compiled_c = self.backend.compile(c.build()) 88 return xla_client.execute_with_python_values( 89 compiled_c, arguments, backend=self.backend) 90 91 def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): 92 assert expected is not None 93 results = self._Execute(c, arguments) 94 self.assertLen(results, len(expected)) 95 for result, e in zip(results, expected): 96 # Numpy's comparison methods are a bit too lenient by treating inputs as 97 # "array-like", meaning that scalar 4 will be happily compared equal to 98 # [[4]]. We'd like to be more strict so assert shapes as well. 99 self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) 100 assert_func(result, e) 101 102 def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): 103 self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, 104 expected) 105 106 def _ExecuteAndCompareClose(self, 107 c, 108 arguments=(), 109 expected=None, 110 rtol=1e-4, 111 atol=0): 112 self._ExecuteAndAssertWith( 113 functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), 114 c, arguments, expected) 115 116 def NumpyArrayF32(*args, **kwargs): 117 """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" 118 return np.array(*args, dtype=np.float32, **kwargs) 119 120 def NumpyArrayF64(*args, **kwargs): 121 """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" 122 return np.array(*args, dtype=np.float64, **kwargs) 123 124 def NumpyArrayS32(*args, **kwargs): 125 """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" 126 return np.array(*args, dtype=np.int32, **kwargs) 127 128 def NumpyArrayBool(*args, **kwargs): 129 """Convenience wrapper to create Numpy arrays with a np.bool_ dtype.""" 130 return np.array(*args, dtype=np.bool_, **kwargs) 131 132 class ComputationPrinting(absltest.TestCase): 133 134 def setUp(self): 135 super(ComputationPrinting, self).setUp() 136 self.backend = xla_backend() 137 138 def ExampleComputation(self): 139 builder = xla_client.XlaBuilder("acomputation") 140 p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) 141 p1 = ops.Parameter( 142 builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) 143 x = ops.Mul(p0, p1) 144 ops.Add(x, x) 145 return builder.build() 146 147 @unittest.skipIf(cloud_tpu, "not implemented") 148 def testCompiledHloModuleToHloText(self): 149 computation = self.ExampleComputation() 150 executable = self.backend.compile(computation) 151 hlo_modules = executable.hlo_modules() 152 self.assertLen(hlo_modules, 1) 153 hlo_text = hlo_modules[0].to_string() 154 self.assertTrue(hlo_text.startswith("HloModule acomputation")) 155 self.assertIn("fusion", hlo_text) 156 157 @unittest.skipIf(cloud_tpu, "not implemented") 158 def testCompiledHloModuleAsSerializedProto(self): 159 computation = self.ExampleComputation() 160 executable = self.backend.compile(computation) 161 hlo_modules = executable.hlo_modules() 162 self.assertLen(hlo_modules, 1) 163 hlo_text = hlo_modules[0].to_string() 164 proto = hlo_modules[0].as_serialized_hlo_module_proto() 165 hlo_module_roundtrip = xla_client.XlaComputation(proto).get_hlo_module() 166 hlo_text_roundtrip = hlo_module_roundtrip.to_string() 167 self.assertEqual(hlo_text, hlo_text_roundtrip) 168 169 @unittest.skipIf(cloud_tpu, "not implemented") 170 def testStableComputationSerialization(self): 171 # Ideally we would test identical computations produced in different 172 # processes. For now we have this limited smoke test. 173 computation = self.ExampleComputation() 174 ref = computation.as_serialized_hlo_module_proto() 175 for _ in range(10): 176 self.assertEqual(computation.as_serialized_hlo_module_proto(), ref) 177 178 @unittest.skipIf(cloud_tpu, "not implemented") 179 def testFlopEstimate(self): 180 computation = self.ExampleComputation() 181 properties = xla_client._xla.hlo_module_cost_analysis( 182 self.backend, computation.as_hlo_module()) 183 self.assertEqual(properties["flops"], 8.0) 184 185 def testFingerprint(self): 186 computation = self.ExampleComputation() 187 executable = self.backend.compile(computation) 188 fingerprint = executable.fingerprint 189 if self.backend.platform == "tpu" and not cloud_tpu: 190 logging.info("fingerprint: %s", fingerprint) 191 self.assertNotEmpty(fingerprint) 192 else: 193 self.assertIsNone(fingerprint) 194 195 tests.append(ComputationPrinting) 196 197 class ComputationsWithConstantsTest(ComputationTest): 198 """Tests focusing on Constant ops.""" 199 200 @parameterized.named_parameters({ 201 "testcase_name": "_{}".format(dtype.__name__), 202 "dtype": dtype, 203 } for dtype in int_dtypes + float_dtypes) 204 def testConstantScalarSum(self, dtype): 205 if dtype == np.int8 and self.backend.platform == "tpu": 206 self.skipTest("TPU doesn't support int8") 207 c = self._NewComputation() 208 ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14))) 209 self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)]) 210 211 @parameterized.named_parameters({ 212 "testcase_name": "_{}".format(dtype.__name__), 213 "dtype": dtype, 214 } for dtype in float_dtypes) 215 def testConstantVectorMul(self, dtype): 216 c = self._NewComputation() 217 ops.Mul( 218 ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)), 219 ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype))) 220 self._ExecuteAndCompareClose( 221 c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3) 222 223 @parameterized.named_parameters({ 224 "testcase_name": "_{}".format(dtype.__name__), 225 "dtype": dtype, 226 } for dtype in float_dtypes) 227 def testConstantVectorScalarDiv(self, dtype): 228 c = self._NewComputation() 229 ops.Div( 230 ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)), 231 ops.Constant(c, dtype(2.0))) 232 self._ExecuteAndCompareClose( 233 c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3) 234 235 @parameterized.named_parameters({ 236 "testcase_name": "_{}".format(dtype.__name__), 237 "dtype": dtype, 238 } for dtype in float_dtypes) 239 def testConstantVectorScalarPow(self, dtype): 240 c = self._NewComputation() 241 ops.Pow( 242 ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)), 243 ops.Constant(c, dtype(2.))) 244 self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) 245 246 def testIota(self): 247 c = self._NewComputation() 248 ops.Iota(c, xla_client.PrimitiveType.F32, 10) 249 self._ExecuteAndCompareExact( 250 c, expected=[np.arange(10, dtype=np.float32)]) 251 252 @parameterized.named_parameters({ 253 "testcase_name": "_{}".format(dtype.__name__), 254 "dtype": dtype, 255 } for dtype in int_dtypes) 256 def testBroadcastedIota(self, dtype): 257 c = self._NewComputation() 258 shape = xla_client.Shape.array_shape( 259 xla_client.dtype_to_etype(dtype), (2, 3)) 260 ops.Iota(c, shape, 1) 261 expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype) 262 self._ExecuteAndCompareExact(c, expected=[expected]) 263 264 def testBooleanAnd(self): 265 c = self._NewComputation() 266 ops.And( 267 ops.Constant(c, NumpyArrayBool([True, False, True, False])), 268 ops.Constant(c, NumpyArrayBool([True, True, False, False]))) 269 self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) 270 271 def testBooleanOr(self): 272 c = self._NewComputation() 273 ops.Or( 274 ops.Constant(c, NumpyArrayBool([True, False, True, False])), 275 ops.Constant(c, NumpyArrayBool([True, True, False, False]))) 276 self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) 277 278 def testBooleanXor(self): 279 c = self._NewComputation() 280 ops.Xor( 281 ops.Constant(c, NumpyArrayBool([True, False, True, False])), 282 ops.Constant(c, NumpyArrayBool([True, True, False, False]))) 283 self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) 284 285 @parameterized.named_parameters({ 286 "testcase_name": "_{}".format(dtype.__name__), 287 "dtype": dtype, 288 } for dtype in float_dtypes) 289 def testSum2D(self, dtype): 290 c = self._NewComputation() 291 ops.Add( 292 ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)), 293 ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype))) 294 self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) 295 296 def testShiftLeft(self): 297 c = self._NewComputation() 298 ops.ShiftLeft( 299 ops.Constant(c, NumpyArrayS32([3])), 300 ops.Constant(c, NumpyArrayS32([2]))) 301 self._ExecuteAndCompareClose(c, expected=[[12]]) 302 303 def testShiftRightArithmetic(self): 304 c = self._NewComputation() 305 ops.ShiftRightArithmetic( 306 ops.Constant(c, NumpyArrayS32([-2])), 307 ops.Constant(c, NumpyArrayS32([1]))) 308 self._ExecuteAndCompareClose(c, expected=[[-1]]) 309 310 def testShiftRightLogical(self): 311 c = self._NewComputation() 312 ops.ShiftRightLogical( 313 ops.Constant(c, NumpyArrayS32([-1])), 314 ops.Constant(c, NumpyArrayS32([1]))) 315 self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) 316 317 @parameterized.named_parameters({ 318 "testcase_name": "_{}".format(dtype.__name__), 319 "dtype": dtype, 320 } for dtype in float_dtypes) 321 def testSum2DWith1DBroadcastDim0(self, dtype): 322 # sum of a 2D array with a 1D array where the latter is replicated across 323 # dimension 0 to match the former's shape. 324 c = self._NewComputation() 325 ops.Add( 326 ops.Constant(c, 327 np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 328 dtype=dtype)), 329 ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), 330 broadcast_dimensions=(0,)) 331 self._ExecuteAndCompareClose( 332 c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) 333 334 @parameterized.named_parameters({ 335 "testcase_name": "_{}".format(dtype.__name__), 336 "dtype": dtype, 337 } for dtype in float_dtypes) 338 def testSum2DWith1DBroadcastDim1(self, dtype): 339 # sum of a 2D array with a 1D array where the latter is replicated across 340 # dimension 1 to match the former's shape. 341 c = self._NewComputation() 342 ops.Add( 343 ops.Constant(c, 344 np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 345 dtype=dtype)), 346 ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), 347 broadcast_dimensions=(1,)) 348 self._ExecuteAndCompareClose( 349 c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) 350 351 @parameterized.named_parameters({ 352 "testcase_name": "_{}".format(dtype.__name__), 353 "dtype": dtype, 354 } for dtype in float_dtypes) 355 def testConstantAxpy(self, dtype): 356 c = self._NewComputation() 357 ops.Add( 358 ops.Mul( 359 ops.Constant(c, dtype(2)), 360 ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))), 361 ops.Constant(c, np.array([100, -100, 200, -200], dtype))) 362 self._ExecuteAndCompareClose( 363 c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3) 364 365 def testCustomCall(self): 366 if self.backend.platform != "cpu": 367 self.skipTest("Test requires cpu platform") 368 c = self._NewComputation() 369 for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): 370 xla_client.register_custom_call_target(name, fn, platform="cpu") 371 ops.CustomCallWithLayout( 372 c, 373 b"test_subtract_f32", 374 operands=[ 375 ops.Constant(c, np.float32(1.25)), 376 ops.Constant(c, np.float32(0.5)) 377 ], 378 shape_with_layout=xla_client.Shape.array_shape( 379 np.dtype(np.float32), (), ()), 380 operand_shapes_with_layout=[ 381 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 382 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 383 ]) 384 self._ExecuteAndCompareClose(c, expected=[0.75]) 385 386 tests.append(ComputationsWithConstantsTest) 387 388 class PythonCallbackTest(ComputationTest): 389 390 def testPythonCallback(self): 391 if self.backend.platform != "cpu": 392 self.skipTest("Test requires cpu platform") 393 c = self._NewComputation() 394 395 f = lambda x, y: (x + y, x - y) 396 397 arg0 = np.array([9, 43, -101, 22], dtype=np.int32) 398 arg1 = np.array([10, 15, -2, 7], dtype=np.int32) 399 shape = xla_client.shape_from_pyval(arg0) 400 shape = shape.with_major_to_minor_layout_if_absent() 401 p0 = ops.Parameter(c, 0, shape) 402 p1 = ops.Parameter(c, 1, shape) 403 out, keepalive = self.backend.emit_python_callback( 404 f, c, [p0, p1], [shape, shape]) 405 self._ExecuteAndCompareExact( 406 c, arguments=[arg0, arg1], expected=[arg0 + arg1, arg0 - arg1]) 407 del out, keepalive 408 409 def testTokens(self): 410 if self.backend.platform != "cpu": 411 self.skipTest("Test requires cpu platform") 412 c = self._NewComputation() 413 414 def _Callback(x, y): 415 assert y is None, y 416 return None, x + 1 417 418 arg0 = np.array([9, 43, -101, 22], dtype=np.int32) 419 shape = xla_client.shape_from_pyval(arg0) 420 token_shape = xla_client.Shape.token_shape() 421 p0 = ops.Parameter(c, 0, shape) 422 token = ops.CreateToken(c) 423 out, keepalive = self.backend.emit_python_callback( 424 _Callback, c, [p0, token], [token_shape, shape]) 425 out = ops.GetTupleElement(out, 1) 426 self._ExecuteAndCompareExact(c, arguments=[arg0], expected=[arg0 + 1]) 427 del out, keepalive 428 429 def testStriding(self): 430 if self.backend.platform != "cpu": 431 self.skipTest("Test requires cpu platform") 432 c = self._NewComputation() 433 434 def _Callback(x): 435 assert x.flags.f_contiguous, x.strides 436 # Force the output array to have C layout, which will require a 437 # transpose back to the expected Fortran layout. 438 return np.ascontiguousarray(x * 2), 439 440 arg0 = np.arange(12, dtype=np.int16).reshape(3, 4) 441 shape_f_layout = xla_client.Shape.array_shape( 442 arg0.dtype, arg0.shape, layout=(0, 1)) 443 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 444 out, keepalive = self.backend.emit_python_callback( 445 _Callback, c, [p0], [shape_f_layout], [shape_f_layout]) 446 self._ExecuteAndCompareExact(c, arguments=[arg0], expected=[arg0 * 2]) 447 del out, keepalive 448 449 tests.append(PythonCallbackTest) 450 451 class ComputationFromProtoTest(absltest.TestCase): 452 """Test computation execution from HLO proto.""" 453 454 def setUp(self): 455 super(ComputationFromProtoTest, self).setUp() 456 self.backend = xla_backend() 457 458 def testExecuteFromProto(self): 459 # Build the HLO proto 460 b = xla_client.XlaBuilder("computation") 461 ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) 462 serialized_proto = b.build().as_serialized_hlo_module_proto() 463 464 # Load and execute the proto 465 c = xla_client.XlaComputation(serialized_proto) 466 ans, = xla_client.execute_with_python_values( 467 self.backend.compile(c), (), backend=self.backend) 468 np.testing.assert_equal(ans, np.int32(3)) 469 470 tests.append(ComputationFromProtoTest) 471 472 class ParametersTest(ComputationTest): 473 """Tests focusing on Parameter ops and argument-passing.""" 474 475 @parameterized.named_parameters({ 476 "testcase_name": "_{}".format(dtype.__name__), 477 "dtype": dtype, 478 } for dtype in int_dtypes) 479 def testScalarTimesVector(self, dtype): 480 c = self._NewComputation() 481 arg0 = np.array(3, dtype=dtype) 482 arg1 = np.array([10, 15, -2, 7], dtype=dtype) 483 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 484 p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) 485 ops.Mul(p0, p1) 486 self._ExecuteAndCompareExact( 487 c, arguments=[arg0, arg1], expected=[arg0 * arg1]) 488 489 # TODO(phawkins): test comparison harness doesn't support bfloat16 490 @parameterized.named_parameters({ 491 "testcase_name": "_{}".format(dtype.__name__), 492 "dtype": dtype, 493 } for dtype in float_dtypes if dtype != bfloat16) 494 def testScalarMinusVectorExplicitNumbering(self, dtype): 495 # Use explicit numbering and pass parameter_num first. Sub is used since 496 # it's not commutative and can help catch parameter reversal within the 497 # computation. 498 c = self._NewComputation() 499 arg0 = np.array(2.0, dtype=dtype) 500 arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype) 501 p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) 502 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 503 ops.Sub(p1, p0) 504 self._ExecuteAndCompareClose( 505 c, arguments=[arg0, arg1], expected=[arg1 - arg0]) 506 507 tests.append(ParametersTest) 508 509 class BufferTest(ComputationTest): 510 """Tests focusing on execution with Buffers.""" 511 512 def testConstantSum(self): 513 c = self._NewComputation() 514 ops.Add( 515 ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14))) 516 self._ExecuteAndCompareClose(c, expected=[4.25]) 517 518 def testOneParameterSum(self): 519 c = self._NewComputation() 520 ops.Add( 521 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), 522 ops.Constant(c, np.float32(3.14))) 523 self._ExecuteAndCompareClose( 524 c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) 525 526 def testTwoParameterSum(self): 527 c = self._NewComputation() 528 ops.Add( 529 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), 530 ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.)))) 531 self._ExecuteAndCompareClose( 532 c, 533 arguments=[NumpyArrayF32(1.11), 534 NumpyArrayF32(3.14)], 535 expected=[4.25]) 536 537 @unittest.skipIf(cloud_tpu, "not implemented") 538 def testCannotCallWithDeletedBuffers(self): 539 c = self._NewComputation() 540 ops.Add( 541 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), 542 ops.Constant(c, np.float32(3.14))) 543 arg = NumpyArrayF32(1.11) 544 compiled_c = self.backend.compile(c.build()) 545 arg_buffer = self.backend.buffer_from_pyval(arg) 546 arg_buffer.delete() 547 with self.assertRaises(RuntimeError): 548 compiled_c.execute([arg_buffer]) 549 550 def testXlaShape(self): 551 pyval = np.array([[1., 2.]], np.float32) 552 local_buffer = self.backend.buffer_from_pyval(pyval) 553 xla_shape = local_buffer.xla_shape() 554 self.assertEqual(xla_shape.dimensions(), (1, 2)) 555 self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) 556 557 def testXlaShapeIndex(self): 558 a = xla_client.ShapeIndex((1, 2)) 559 b = xla_client.ShapeIndex((1, 2)) 560 c = xla_client.ShapeIndex((2, 3)) 561 self.assertEqual(a, b) 562 self.assertNotEqual(b, c) 563 564 def testBlockHostUntilReadyWorks(self): 565 arg = np.array([[1., 2.]], np.float32) 566 arg_buffer = self.backend.buffer_from_pyval(arg) 567 arg_buffer.block_host_until_ready() 568 # This test merely checks that nothing goes awry when we call 569 # block_host_until_ready(); it's difficult to test anything else. 570 571 def testBlockHostUntilReadyRaisesOnDeletedBuffer(self): 572 arg = np.array([[1., 2.]], np.float32) 573 buffer = self.backend.buffer_from_pyval(arg) 574 buffer.delete() 575 with self.assertRaisesRegex( 576 RuntimeError, 577 re.escape( 578 "BlockHostUntilReady() called on deleted or donated buffer")): 579 buffer.block_host_until_ready() 580 581 def testDeviceArrayBaseSignatures(self): 582 # When extending `DeviceArrayBase`, the object behaves as a `DeviceArray` 583 # and thus needs to correctly implement the following methods. 584 arg = np.array([[1., 2., 3.]], np.float32) 585 buffer = self.backend.buffer_from_pyval(arg) 586 if not isinstance(buffer, xla_client.DeviceArrayBase): 587 raise unittest.SkipTest( 588 "The objectof type {} do not extend DeviceArrayBase".format( 589 type(buffer))) 590 591 self.assertEqual(buffer.__array_priority__, 100) 592 self.assertEqual(buffer.shape, (1, 3)) 593 self.assertEqual(buffer.dtype, np.float32) 594 self.assertEqual(buffer.size, 3) 595 self.assertEqual(buffer.ndim, 2) 596 597 self.assertIs(buffer, buffer.block_until_ready()) 598 buffer.delete() 599 with self.assertRaises(RuntimeError): 600 buffer.block_until_ready() 601 602 def testOnDeviceSizeInBytes(self): 603 if not isinstance(self.backend, xla_client.Client): 604 self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.") 605 arg0 = np.array([]) 606 arg1 = np.array([[0., 1., 2.]], np.float32) 607 arg2 = np.array([[3., 4., 5.]], bfloat16) 608 arg0_buffer = self.backend.buffer_from_pyval(arg0) 609 arg1_buffer = self.backend.buffer_from_pyval(arg1) 610 arg2_buffer = self.backend.buffer_from_pyval(arg2) 611 self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0) 612 # OnDeviceSizeInBytes varies depending on the platform. Confirm there's 613 # a reasonable value. 614 self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0) 615 self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0) 616 617 def testLiveBuffers(self): 618 if not isinstance(self.backend, xla_client.Client): 619 self.skipTest("TPU Driver doesn't support LiveBuffers().") 620 self.assertEmpty(self.backend.live_buffers()) 621 arg0 = np.array([]) 622 arg1 = np.array([[0., 1., 2.]], np.float32) 623 arg2 = np.array([[3., 4., 5.]], bfloat16) 624 arg0_buffer = self.backend.buffer_from_pyval(arg0) 625 arg1_buffer = self.backend.buffer_from_pyval(arg1) 626 arg2_buffer = self.backend.buffer_from_pyval(arg2) 627 self.assertLen(self.backend.live_buffers(), 3) 628 self.assertIs(self.backend.live_buffers()[0], arg2_buffer) 629 self.assertIs(self.backend.live_buffers()[1], arg1_buffer) 630 self.assertIs(self.backend.live_buffers()[2], arg0_buffer) 631 self.assertEqual(self.backend.devices()[0].live_buffers(), 632 self.backend.live_buffers()) 633 634 arg1_buffer.delete() 635 self.assertLen(self.backend.live_buffers(), 2) 636 self.assertIs(self.backend.live_buffers()[0], arg2_buffer) 637 self.assertIs(self.backend.live_buffers()[1], arg0_buffer) 638 639 arg0_buffer.delete() 640 arg2_buffer.delete() 641 self.assertEmpty(self.backend.live_buffers()) 642 643 def testCopyToHost(self): 644 arg0 = np.array([[1., 2.]], np.float32) 645 arg1 = np.array([[3., 4.]], np.float32) 646 arg0_buffer = self.backend.buffer_from_pyval(arg0) 647 arg1_buffer = self.backend.buffer_from_pyval(arg1) 648 # Prefetch two buffers using copy_to_host_async, and then retrieve their 649 # values using to_py. 650 arg0_buffer.copy_to_host_async() 651 arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. 652 arg1_buffer.copy_to_host_async() 653 np.testing.assert_equal(arg0, arg0_buffer.to_py()) 654 np.testing.assert_equal(arg1, arg1_buffer.to_py()) 655 # copy_to_host_async does nothing after to_py is called. 656 arg0_buffer.copy_to_host_async() 657 np.testing.assert_equal(arg0, arg0_buffer.to_py()) 658 659 def testDevice(self): 660 x = np.arange(8, dtype=np.int32) 661 for device in self.backend.local_devices(): 662 buf = self.backend.buffer_from_pyval(x, device=device) 663 self.assertEqual(buf.device(), device) 664 np.testing.assert_equal(x, buf.to_py()) 665 666 def testStandardTypes(self): 667 for dtype in standard_dtypes: 668 if dtype == bfloat16 or dtype == np.complex128: 669 continue 670 arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype)) 671 arr = arr.to_py() 672 self.assertEqual(dtype, type(arr[0])) 673 674 def testUnsafeBufferPointer(self): 675 if not isinstance(self.backend, xla_client.Client): 676 self.skipTest("TPU Driver doesn't support UnsafeBufferPointer().") 677 arg0 = np.array([]) 678 arg1 = np.array([[0., 1., 2.]], np.float32) 679 arg2 = np.array([[3., 4., 5.]], bfloat16) 680 arg0_buffer = self.backend.buffer_from_pyval(arg0) 681 arg1_buffer = self.backend.buffer_from_pyval(arg1) 682 arg2_buffer = self.backend.buffer_from_pyval(arg2) 683 self.assertGreaterEqual(arg0_buffer.unsafe_buffer_pointer(), 0) 684 self.assertGreaterEqual(arg1_buffer.unsafe_buffer_pointer(), 0) 685 self.assertGreaterEqual(arg2_buffer.unsafe_buffer_pointer(), 0) 686 687 @unittest.skipIf(cloud_tpu, "not implemented") 688 def testClone(self): 689 x = np.array([[3., 4., 5.]], np.float32) 690 y = self.backend.buffer_from_pyval(x) 691 z = y.clone() 692 self.assertNotEqual(id(x), id(y)) 693 np.testing.assert_array_equal(y.to_py(), z.to_py()) 694 self.assertEqual(y.unsafe_buffer_pointer(), z.unsafe_buffer_pointer()) 695 696 @unittest.skipIf(cloud_tpu, "not implemented") 697 def testJaxAttributesHaveCorrectDefaults(self): 698 x = np.array([[3., 4., 5.]], np.float32) 699 y = self.backend.buffer_from_pyval(x) 700 self.assertIsNone(y.aval) 701 self.assertIsNone(y._device) 702 703 tests.append(BufferTest) 704 705 class SingleOpTest(ComputationTest): 706 """Tests for single ops. 707 708 The goal here is smoke testing - to exercise the most basic functionality of 709 single XLA ops. As minimal as possible number of additional ops are added 710 around the op being tested. 711 """ 712 713 @parameterized.named_parameters({ 714 "testcase_name": "_{}".format(dtype.__name__), 715 "dtype": dtype, 716 } for dtype in float_dtypes) 717 def testConcatenate(self, dtype): 718 c = self._NewComputation() 719 args = ( 720 ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)), 721 ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)), 722 ) 723 ops.ConcatInDim(c, args, dimension=0) 724 self._ExecuteAndCompareExact( 725 c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)]) 726 727 # pyformat: disable 728 @parameterized.named_parameters({ 729 "testcase_name": "_{}_{}".format(src_dtype.__name__, 730 dst_dtype.__name__), 731 "src_dtype": src_dtype, 732 "dst_dtype": dst_dtype, 733 } for src_dtype, dst_dtype in itertools.permutations( 734 [np.bool_, np.int32, np.int64, np.float32, np.float64], 2)) 735 # pyformat: enable 736 def testConvertElementType(self, src_dtype, dst_dtype): 737 if ((src_dtype in [np.int64, np.float64] or 738 dst_dtype in [np.int64, np.float64]) and 739 self.backend.platform == "tpu"): 740 self.skipTest("TPU doesn't support float64") 741 c = self._NewComputation() 742 x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) 743 ops.ConvertElementType( 744 ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) 745 746 result = xla_client.execute_with_python_values( 747 self.backend.compile(c.build()), (), backend=self.backend) 748 self.assertLen(result, 1) 749 expected = np.array(x, dtype=dst_dtype) 750 751 self.assertEqual(result[0].shape, expected.shape) 752 self.assertEqual(result[0].dtype, expected.dtype) 753 np.testing.assert_equal(result[0], expected) 754 755 # pyformat: disable 756 @parameterized.named_parameters( 757 { 758 "testcase_name": "_{}_{}".format(src_dtype.__name__, 759 dst_dtype.__name__), 760 "src_dtype": src_dtype, 761 "dst_dtype": dst_dtype, 762 } 763 for dtypes in [[np.int32, np.float32], [np.int64, np.float64]] 764 for src_dtype, dst_dtype in itertools.permutations(dtypes, 2)) 765 # pyformat: enable 766 def testBitcastConvertType(self, src_dtype, dst_dtype): 767 if (np.float64 in (src_dtype, dst_dtype) and 768 self.backend.platform == "tpu"): 769 self.skipTest("TPU doesn't support float64") 770 c = self._NewComputation() 771 x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) 772 ops.BitcastConvertType( 773 ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) 774 775 result = xla_client.execute_with_python_values( 776 self.backend.compile(c.build()), (), backend=self.backend) 777 self.assertLen(result, 1) 778 expected = x.view(dst_dtype) 779 780 self.assertEqual(result[0].shape, expected.shape) 781 self.assertEqual(result[0].dtype, expected.dtype) 782 np.testing.assert_equal(result[0], expected) 783 784 # TODO(b/123523486) implement AllToAll on CPU 785 def DISABLED_testAllToAllOneReplica(self): 786 samples = [ 787 NumpyArrayF32([97.0]), 788 NumpyArrayF32([64.0, 117.0]), 789 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 790 ] 791 for lhs in samples[:1]: 792 c = self._NewComputation() 793 ops.AllToAll(ops.Constant(c, lhs), 0, 0) 794 self._ExecuteAndCompareExact(c, expected=[lhs]) 795 796 def testCrossReplicaSumOneReplica(self): 797 samples = [ 798 NumpyArrayF32(42.0), 799 NumpyArrayF32([97.0]), 800 NumpyArrayF32([64.0, 117.0]), 801 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 802 ] 803 for lhs in samples: 804 c = self._NewComputation() 805 ops.CrossReplicaSum(ops.Constant(c, lhs)) 806 self._ExecuteAndCompareExact(c, expected=[lhs]) 807 808 def testReplicaId(self): 809 c = self._NewComputation() 810 _ = ops.ReplicaId(c) 811 self._ExecuteAndCompareExact(c, expected=[0]) 812 813 def testCrossReplicaSumOneReplicaWithSingletonGroup(self): 814 samples = [ 815 NumpyArrayF32(42.0), 816 NumpyArrayF32([97.0]), 817 NumpyArrayF32([64.0, 117.0]), 818 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 819 ] 820 for lhs in samples: 821 c = self._NewComputation() 822 ops.CrossReplicaSum( 823 ops.Constant(c, lhs), xla_client.make_replica_groups([[0]])) 824 self._ExecuteAndCompareExact(c, expected=[lhs]) 825 826 # TODO(phawkins): np.dot implementation doesn't support bfloat16 827 @parameterized.named_parameters({ 828 "testcase_name": "_{}".format(dtype.__name__), 829 "dtype": dtype, 830 } for dtype in float_dtypes if dtype != bfloat16) 831 def testDotMatrixVector(self, dtype): 832 c = self._NewComputation() 833 lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) 834 rhs = np.array([[10.0], [20.0]], dtype=dtype) 835 ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) 836 self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) 837 838 # TODO(phawkins): np.dot implementation doesn't support bfloat16 839 @parameterized.named_parameters({ 840 "testcase_name": "_{}".format(dtype.__name__), 841 "dtype": dtype, 842 } for dtype in float_dtypes if dtype != bfloat16) 843 def testDotMatrixMatrix(self, dtype): 844 c = self._NewComputation() 845 lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) 846 rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype) 847 ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) 848 self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) 849 850 def testDotGeneral(self): 851 c = self._NewComputation() 852 rng = np.random.RandomState(0) 853 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 854 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 855 dimension_numbers = xla_client.make_dot_dimension_numbers( 856 (([2], [1]), ([0], [0]))) 857 ops.DotGeneral( 858 ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) 859 self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) 860 861 def testDotGeneralWithDotDimensionNumbersProto(self): 862 c = self._NewComputation() 863 rng = np.random.RandomState(0) 864 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 865 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 866 867 dimension_numbers = xla_client.DotDimensionNumbers() 868 dimension_numbers.lhs_contracting_dimensions.append(2) 869 dimension_numbers.rhs_contracting_dimensions.append(1) 870 dimension_numbers.lhs_batch_dimensions.append(0) 871 dimension_numbers.rhs_batch_dimensions.append(0) 872 873 ops.DotGeneral( 874 ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) 875 self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) 876 877 def testDotGeneralWithPrecisionConfig(self): 878 c = self._NewComputation() 879 rng = np.random.RandomState(0) 880 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 881 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 882 dimension_numbers = xla_client.make_dot_dimension_numbers( 883 (([2], [1]), ([0], [0]))) 884 config = xla_client.PrecisionConfig() 885 config.operand_precision.append(config.Precision.HIGH) 886 config.operand_precision.append(config.Precision.HIGHEST) 887 ops.DotGeneral( 888 ops.Constant(c, lhs), 889 ops.Constant(c, rhs), 890 dimension_numbers, 891 precision_config=config) 892 self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) 893 894 def testConvGeneralDilatedF32(self): 895 c = self._NewComputation() 896 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 897 lhs = a(1, 1, 2, 3) 898 rhs = a(1, 1, 1, 2) * 10 899 strides = [1, 1] 900 pads = [(1, 0), (0, 1)] 901 lhs_dilation = (2, 1) 902 rhs_dilation = (1, 1) 903 dimension_numbers = xla_client.make_convolution_dimension_numbers( 904 ("NCHW", "OIHW", "NCHW"), 2) 905 ops.ConvGeneralDilated( 906 ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, 907 lhs_dilation, rhs_dilation, dimension_numbers) 908 result = np.array([[[ 909 [0., 0., 0.], 910 [10., 20., 0.], 911 [0., 0., 0.], 912 [40., 50., 0.], 913 ]]]) 914 self._ExecuteAndCompareClose(c, expected=[result]) 915 916 def testConvGeneralDilatedF32WithPrecisionConfig(self): 917 c = self._NewComputation() 918 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 919 lhs = a(1, 1, 2, 3) 920 rhs = a(1, 1, 1, 2) * 10 921 strides = [1, 1] 922 pads = [(1, 0), (0, 1)] 923 lhs_dilation = (2, 1) 924 rhs_dilation = (1, 1) 925 dimension_numbers = xla_client.make_convolution_dimension_numbers( 926 ("NCHW", "OIHW", "NCHW"), 2) 927 config = xla_client.PrecisionConfig() 928 config.operand_precision.append(config.Precision.HIGHEST) 929 config.operand_precision.append(config.Precision.DEFAULT) 930 ops.ConvGeneralDilated( 931 ops.Constant(c, lhs), 932 ops.Constant(c, rhs), 933 strides, 934 pads, 935 lhs_dilation, 936 rhs_dilation, 937 dimension_numbers, 938 precision_config=config) 939 result = np.array([[[ 940 [0., 0., 0.], 941 [10., 20., 0.], 942 [0., 0., 0.], 943 [40., 50., 0.], 944 ]]]) 945 self._ExecuteAndCompareClose(c, expected=[result]) 946 947 def testConvGeneralDilatedPermutedF32(self): 948 c = self._NewComputation() 949 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 950 lhs = a(1, 1, 2, 3) 951 rhs = a(1, 1, 1, 2) * 10 952 strides = [1, 1] 953 pads = [(1, 0), (0, 1)] 954 lhs_dilation = (2, 1) 955 rhs_dilation = (1, 1) 956 957 dimension_numbers = xla_client.make_convolution_dimension_numbers( 958 ("NHWC", "OIHW", "CWNH"), 2) 959 ops.ConvGeneralDilated( 960 ops.Constant(c, np.transpose(lhs, 961 (0, 2, 3, 1))), ops.Constant(c, rhs), 962 strides, pads, lhs_dilation, rhs_dilation, dimension_numbers) 963 result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], 964 [40., 50., 0.]]]]) 965 self._ExecuteAndCompareClose( 966 c, expected=[np.transpose(result, (1, 3, 0, 2))]) 967 968 def testConvGeneralDilatedGroupedConvolutionF32(self): 969 c = self._NewComputation() 970 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 971 lhs = a(1, 2, 2, 3) 972 rhs = a(2, 1, 1, 2) * 10 973 strides = [1, 1] 974 pads = [(1, 0), (0, 1)] 975 lhs_dilation = (2, 1) 976 rhs_dilation = (1, 1) 977 dimension_numbers = xla_client.make_convolution_dimension_numbers( 978 ("NCHW", "OIHW", "NCHW"), 2) 979 feature_group_count = 2 980 ops.ConvGeneralDilated( 981 ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, 982 lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count) 983 result = np.array([[[ 984 [0., 0., 0.], 985 [10., 20., 0.], 986 [0., 0., 0.], 987 [40., 50., 0.], 988 ], [ 989 [0., 0., 0.], 990 [330., 380., 160.], 991 [0., 0., 0.], 992 [480., 530., 220.], 993 ]]]) 994 self._ExecuteAndCompareClose(c, expected=[result]) 995 996 def testBooleanNot(self): 997 c = self._NewComputation() 998 arr = NumpyArrayBool([True, False, True]) 999 ops.Not(ops.Constant(c, arr)) 1000 self._ExecuteAndCompareClose(c, expected=[~arr]) 1001 1002 def testPopulationCount(self): 1003 c = self._NewComputation() 1004 arr = NumpyArrayS32([3, 0, 1]) 1005 ops.PopulationCount(ops.Constant(c, arr)) 1006 self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) 1007 1008 def testCountLeadingZeros(self): 1009 c = self._NewComputation() 1010 arr = NumpyArrayS32([0x7FFF, 0x12345678]) 1011 ops.Clz(ops.Constant(c, arr)) 1012 self._ExecuteAndCompareClose(c, expected=[[17, 3]]) 1013 1014 def testExp(self): 1015 c = self._NewComputation() 1016 arr = NumpyArrayF32([3.3, 12.1]) 1017 ops.Exp(ops.Constant(c, arr)) 1018 self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) 1019 1020 def testExpm1(self): 1021 c = self._NewComputation() 1022 arr = NumpyArrayF32([3.3, 12.1]) 1023 ops.Expm1(ops.Constant(c, arr)) 1024 self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) 1025 1026 def testRound(self): 1027 c = self._NewComputation() 1028 arr = NumpyArrayF32([3.3, 12.1]) 1029 ops.Round(ops.Constant(c, arr)) 1030 self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) 1031 1032 def testLog(self): 1033 c = self._NewComputation() 1034 arr = NumpyArrayF32([3.3, 12.1]) 1035 ops.Log(ops.Constant(c, arr)) 1036 self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) 1037 1038 def testLog1p(self): 1039 c = self._NewComputation() 1040 arr = NumpyArrayF32([3.3, 12.1]) 1041 ops.Log1p(ops.Constant(c, arr)) 1042 self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) 1043 1044 def testNeg(self): 1045 c = self._NewComputation() 1046 arr = NumpyArrayF32([3.3, 12.1]) 1047 ops.Neg(ops.Constant(c, arr)) 1048 self._ExecuteAndCompareClose(c, expected=[-arr]) 1049 1050 def testFloor(self): 1051 c = self._NewComputation() 1052 arr = NumpyArrayF32([3.3, 12.1]) 1053 ops.Floor(ops.Constant(c, arr)) 1054 self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) 1055 1056 def testCeil(self): 1057 c = self._NewComputation() 1058 arr = NumpyArrayF32([3.3, 12.1]) 1059 ops.Ceil(ops.Constant(c, arr)) 1060 self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) 1061 1062 def testAbs(self): 1063 c = self._NewComputation() 1064 arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) 1065 ops.Abs(ops.Constant(c, arr)) 1066 self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) 1067 1068 def testTanhF32(self): 1069 c = self._NewComputation() 1070 arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) 1071 ops.Tanh(ops.Constant(c, arr)) 1072 self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) 1073 1074 def testTanhF64(self): 1075 if self.backend.platform == "tpu": 1076 self.skipTest("TPU doesn't support 64bit tanh") 1077 c = self._NewComputation() 1078 arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001]) 1079 ops.Tanh(ops.Constant(c, arr)) 1080 self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12) 1081 1082 def testTranspose(self): 1083 1084 def _TransposeAndTest(array, permutation): 1085 c = self._NewComputation() 1086 ops.Transpose(ops.Constant(c, array), permutation) 1087 expected = np.transpose(array, permutation) 1088 self._ExecuteAndCompareClose(c, expected=[expected]) 1089 1090 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) 1091 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) 1092 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) 1093 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) 1094 1095 arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) 1096 for permutation in itertools.permutations(range(arr.ndim)): 1097 _TransposeAndTest(arr, permutation) 1098 _TransposeAndTest(np.asfortranarray(arr), permutation) 1099 1100 def testEq(self): 1101 c = self._NewComputation() 1102 ops.Eq( 1103 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), 1104 ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) 1105 self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) 1106 1107 def testNe(self): 1108 c = self._NewComputation() 1109 ops.Ne( 1110 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), 1111 ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) 1112 self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) 1113 1114 ops.Ne( 1115 ops.Constant(c, NumpyArrayF32([-2.0, 0.0, 1116 float("nan"), 1117 float("nan")])), 1118 ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0, 1119 float("nan")]))) 1120 self._ExecuteAndAssertWith( 1121 np.testing.assert_allclose, 1122 c, (), 1123 expected=[[True, False, True, True]]) 1124 1125 def testGt(self): 1126 c = self._NewComputation() 1127 ops.Gt( 1128 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1129 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1130 self._ExecuteAndCompareExact( 1131 c, expected=[[False, True, True, False, False]]) 1132 1133 def testGe(self): 1134 c = self._NewComputation() 1135 ops.Ge( 1136 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1137 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1138 self._ExecuteAndCompareExact( 1139 c, expected=[[True, True, True, False, False]]) 1140 1141 def testLt(self): 1142 c = self._NewComputation() 1143 ops.Lt( 1144 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1145 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1146 self._ExecuteAndCompareExact( 1147 c, expected=[[False, False, False, True, True]]) 1148 1149 def testLe(self): 1150 c = self._NewComputation() 1151 ops.Le( 1152 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1153 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1154 self._ExecuteAndCompareExact( 1155 c, expected=[[True, False, False, True, True]]) 1156 1157 def testMax(self): 1158 c = self._NewComputation() 1159 ops.Max( 1160 ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 1161 ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 1162 self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) 1163 1164 def testMaxExplicitBroadcastDim0(self): 1165 c = self._NewComputation() 1166 ops.Max( 1167 ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1168 ops.Constant(c, NumpyArrayF32([3, 4, 5])), 1169 broadcast_dimensions=(0,)) 1170 self._ExecuteAndCompareExact( 1171 c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) 1172 1173 def testMaxExplicitBroadcastDim1(self): 1174 c = self._NewComputation() 1175 ops.Max( 1176 ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1177 ops.Constant(c, NumpyArrayF32([3, 4, 5])), 1178 broadcast_dimensions=(1,)) 1179 self._ExecuteAndCompareExact( 1180 c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) 1181 1182 def testMin(self): 1183 c = self._NewComputation() 1184 ops.Min( 1185 ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 1186 ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 1187 self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) 1188 1189 def testPad(self): 1190 c = self._NewComputation() 1191 ops.Pad( 1192 ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 1193 ops.Constant(c, NumpyArrayF32(0.0)), 1194 xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)])) 1195 self._ExecuteAndCompareClose( 1196 c, 1197 expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], 1198 [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) 1199 1200 def testPadWithPaddingConfig(self): 1201 c = self._NewComputation() 1202 padding_config = xla_client.PaddingConfig() 1203 for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: 1204 dimension = xla_client.PaddingConfigDimension() 1205 dimension.edge_padding_low = lo 1206 dimension.edge_padding_high = hi 1207 dimension.interior_padding = interior 1208 padding_config.dimensions.append(dimension) 1209 ops.Pad( 1210 ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 1211 ops.Constant(c, NumpyArrayF32(0.0)), padding_config) 1212 self._ExecuteAndCompareClose( 1213 c, 1214 expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], 1215 [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) 1216 1217 def testReshape(self): 1218 c = self._NewComputation() 1219 ops.Reshape( 1220 ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), 1221 dimensions=[0, 1], 1222 new_sizes=[2, 3]) 1223 self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) 1224 1225 def testCollapse(self): 1226 c = self._NewComputation() 1227 ops.Collapse( 1228 ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1229 dimensions=[1, 2]) 1230 self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) 1231 1232 def testRev(self): 1233 c = self._NewComputation() 1234 ops.Rev( 1235 ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1236 dimensions=[0, 2]) 1237 self._ExecuteAndCompareExact( 1238 c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) 1239 1240 def testReducePrecision(self): 1241 c = self._NewComputation() 1242 ops.ReducePrecision( 1243 ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), 1244 exponent_bits=8, 1245 mantissa_bits=7) 1246 self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) 1247 1248 def testClampF32(self): 1249 c = self._NewComputation() 1250 ops.Clamp( 1251 ops.Constant(c, NumpyArrayF32(-1)), 1252 ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])), 1253 ops.Constant(c, NumpyArrayF32(2))) 1254 self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) 1255 1256 def testClampS32(self): 1257 c = self._NewComputation() 1258 ops.Clamp( 1259 ops.Constant(c, NumpyArrayS32(-1)), 1260 ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])), 1261 ops.Constant(c, NumpyArrayS32(2))) 1262 self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) 1263 1264 def testSelect(self): 1265 c = self._NewComputation() 1266 ops.Select( 1267 ops.Constant(c, NumpyArrayBool([True, False, False, True, False])), 1268 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])), 1269 ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5]))) 1270 self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) 1271 1272 def testSlice(self): 1273 c = self._NewComputation() 1274 ops.Slice( 1275 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1276 [1, 0], [3, 2], [1, 1]) 1277 self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) 1278 1279 def testSliceInDim(self): 1280 c = self._NewComputation() 1281 ops.SliceInDim( 1282 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1283 start_index=1, 1284 limit_index=2, 1285 stride=1, 1286 dimno=1) 1287 self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) 1288 ops.SliceInDim( 1289 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1290 start_index=0, 1291 limit_index=3, 1292 stride=2, 1293 dimno=0) 1294 self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) 1295 1296 def testDynamicSlice(self): 1297 c = self._NewComputation() 1298 ops.DynamicSlice( 1299 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1300 [ops.Constant(c, NumpyArrayS32([1, 0]))], [2, 2]) 1301 self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) 1302 1303 def testDynamicUpdateSlice(self): 1304 c = self._NewComputation() 1305 ops.DynamicUpdateSlice( 1306 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1307 ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])), 1308 [ops.Constant(c, NumpyArrayS32([1, 1]))]) 1309 self._ExecuteAndCompareExact( 1310 c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) 1311 1312 def testTuple(self): 1313 c = self._NewComputation() 1314 ops.Tuple(c, [ 1315 ops.Constant(c, np.int32(42)), 1316 ops.Constant(c, NumpyArrayF32([1.0, 2.0])), 1317 ops.Constant(c, NumpyArrayBool([True, False, False, True])) 1318 ]) 1319 result = xla_client.execute_with_python_values( 1320 self.backend.compile(c.build()), (), backend=self.backend) 1321 self.assertLen(result, 3) 1322 np.testing.assert_equal(result[0], 42) 1323 np.testing.assert_allclose(result[1], [1.0, 2.0]) 1324 np.testing.assert_equal(result[2], [True, False, False, True]) 1325 1326 def testGetTupleElement(self): 1327 c = self._NewComputation() 1328 ops.GetTupleElement( 1329 ops.Tuple(c, [ 1330 ops.Constant(c, np.int32(42)), 1331 ops.Constant(c, NumpyArrayF32([1.0, 2.0])), 1332 ops.Constant(c, NumpyArrayBool([True, False, False, True])) 1333 ]), 1) 1334 self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) 1335 1336 def testBroadcast(self): 1337 c = self._NewComputation() 1338 ops.Broadcast( 1339 ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) 1340 self._ExecuteAndCompareExact( 1341 c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) 1342 1343 def testBroadcastInDim(self): 1344 c = self._NewComputation() 1345 ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0]) 1346 self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) 1347 ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1]) 1348 self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) 1349 1350 def testRngNormal(self): 1351 shape = (2, 3) 1352 c = self._NewComputation() 1353 ops.RngNormal( 1354 ops.Constant(c, NumpyArrayF32(0.)), 1355 ops.Constant(c, NumpyArrayF32(1.)), 1356 shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, 1357 shape)) 1358 result = xla_client.execute_with_python_values( 1359 self.backend.compile(c.build()), (), backend=self.backend) 1360 # since the result is random, we just check shape and uniqueness 1361 self.assertLen(result, 1) 1362 self.assertEqual(result[0].shape, shape) 1363 self.assertLen(np.unique(result[0]), np.prod(shape)) 1364 1365 def testRngUniformF32(self): 1366 lo, hi = 2., 4. 1367 shape = (2, 3) 1368 c = self._NewComputation() 1369 ops.RngUniform( 1370 ops.Constant(c, NumpyArrayF32(lo)), 1371 ops.Constant(c, NumpyArrayF32(hi)), 1372 shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, 1373 shape)) 1374 result = xla_client.execute_with_python_values( 1375 self.backend.compile(c.build()), (), backend=self.backend) 1376 # since the result is random, we just check shape, uniqueness, and range 1377 self.assertLen(result, 1) 1378 self.assertEqual(result[0].shape, shape) 1379 self.assertLen(np.unique(result[0]), np.prod(shape)) 1380 self.assertTrue(np.all(lo <= result[0])) 1381 self.assertTrue(np.all(result[0] < hi)) 1382 1383 def testRngUniformS32(self): 1384 lo, hi = 2, 4 1385 shape = (2, 3) 1386 c = self._NewComputation() 1387 ops.RngUniform( 1388 ops.Constant(c, NumpyArrayS32(lo)), 1389 ops.Constant(c, NumpyArrayS32(hi)), 1390 shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, 1391 shape)) 1392 result = xla_client.execute_with_python_values( 1393 self.backend.compile(c.build()), (), backend=self.backend) 1394 # since the result is random, we just check shape, integrality, and range 1395 self.assertLen(result, 1) 1396 self.assertEqual(result[0].shape, shape) 1397 self.assertEqual(result[0].dtype, np.int32) 1398 self.assertTrue(np.all(lo <= result[0])) 1399 self.assertTrue(np.all(result[0] < hi)) 1400 1401 def testCholesky(self): 1402 l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], 1403 dtype=np.float32) 1404 c = self._NewComputation() 1405 ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T)))) 1406 self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) 1407 1408 def testSort(self): 1409 keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) 1410 c = self._NewComputation() 1411 ops.Sort(c, [ops.Constant(c, keys)], is_stable=True) 1412 self._ExecuteAndCompareClose( 1413 c, 1414 expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) 1415 1416 def testSortKeyVal(self): 1417 keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) 1418 values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) 1419 c = self._NewComputation() 1420 ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) 1421 result = xla_client.execute_with_python_values( 1422 self.backend.compile(c.build()), (), backend=self.backend) 1423 self.assertLen(result, 2) 1424 np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) 1425 np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) 1426 1427 def testSortCustomComparator(self): 1428 b = self._NewComputation("comparator") 1429 p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) 1430 q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) 1431 p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1432 q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1433 ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) 1434 comparator = b.build() 1435 1436 keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) 1437 values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) 1438 c = self._NewComputation() 1439 ops.Sort( 1440 c, (ops.Constant(c, keys), ops.Constant(c, values)), 1441 dimension=1, 1442 comparator=comparator) 1443 result = xla_client.execute_with_python_values( 1444 self.backend.compile(c.build()), (), backend=self.backend) 1445 self.assertLen(result, 2) 1446 np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) 1447 np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) 1448 1449 def testQR(self): 1450 a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], 1451 [10, 63, 166, 310]], 1452 dtype=np.float32) 1453 c = self._NewComputation() 1454 ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True)) 1455 q, r = self._Execute(c, ()) 1456 np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) 1457 1458 def testEigh(self): 1459 a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], 1460 [10, 63, 166, 310]], 1461 dtype=np.float32) 1462 a = (a + a.T) / 2 1463 1464 c = self._NewComputation() 1465 ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True)) 1466 # TODO(b/129396575): Turn this test back on when it passes without 1467 # fastmath. 1468 # v, w = self._Execute(c, ()) 1469 # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) 1470 1471 def testSVD(self): 1472 a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], 1473 [10, 63, 166, 310]], 1474 dtype=np.float32) 1475 c = self._NewComputation() 1476 ops.Tuple(c, ops.SVD(ops.Constant(c, a))) 1477 u, d, v = self._Execute(c, ()) 1478 self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) 1479 1480 def testTriangularSolve(self): 1481 a_vals = np.array( 1482 [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], 1483 dtype=np.float32) 1484 b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 1485 dtype=np.float32) 1486 1487 c = self._NewComputation() 1488 ops.TriangularSolve( 1489 ops.Constant(c, a_vals), 1490 ops.Constant(c, b_vals), 1491 left_side=False, 1492 lower=True, 1493 transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE, 1494 unit_diagonal=False) 1495 self._ExecuteAndCompareClose( 1496 c, 1497 expected=[ 1498 np.array([ 1499 [0.5, 0.08333334, 0.04629629, 0.03367003], 1500 [2.5, -0.25, -0.1388889, -0.1010101], 1501 [4.5, -0.58333331, -0.32407406, -0.23569024], 1502 ], 1503 dtype=np.float32) 1504 ], 1505 rtol=1e-4) 1506 1507 def testIsConstant(self): 1508 c = self._NewComputation() 1509 a = ops.Constant(c, np.int32(3)) 1510 b = ops.Constant(c, np.int32(1)) 1511 x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1512 const_expr = ops.Sub(b, a) 1513 non_const_expr = ops.Mul(const_expr, x) 1514 self.assertTrue(c.is_constant(const_expr)) 1515 self.assertFalse(c.is_constant(non_const_expr)) 1516 1517 def testGather(self): 1518 a = np.arange(9).astype(np.int32).reshape((3, 3)) 1519 indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) 1520 dnums = xla_client.GatherDimensionNumbers() 1521 dnums.offset_dims.append(1) 1522 dnums.offset_dims.append(2) 1523 dnums.start_index_map.append(0) 1524 dnums.start_index_map.append(1) 1525 dnums.index_vector_dim = 2 1526 c = self._NewComputation() 1527 ops.Gather( 1528 ops.Constant(c, a), 1529 ops.Constant(c, indices), 1530 dnums, 1531 slice_sizes=[1, 1]) 1532 g, = self._Execute(c, ()) 1533 expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) 1534 np.testing.assert_allclose(g, expected, rtol=1e-4) 1535 1536 def testFft(self): 1537 if self.backend.platform == "tpu": 1538 self.skipTest("TPU only supports 1D FFT") 1539 shape = [2, 3, 4, 5] 1540 rng = np.random.RandomState(0) 1541 a = rng.randn(*shape) + 1.0j * rng.randn(*shape) 1542 a = a.astype(np.complex64) 1543 # FFT 1544 c = self._NewComputation() 1545 ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:]) 1546 self._ExecuteAndCompareClose( 1547 c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) 1548 # IFFT 1549 c = self._NewComputation() 1550 ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:]) 1551 self._ExecuteAndCompareClose( 1552 c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) 1553 # RFFT 1554 b = rng.randn(*shape).astype(np.float32) 1555 c = self._NewComputation() 1556 ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:]) 1557 self._ExecuteAndCompareClose( 1558 c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) 1559 # IRFFT 1560 c = self._NewComputation() 1561 ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8]) 1562 self._ExecuteAndCompareClose( 1563 c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4) 1564 1565 def testNextAfter(self): 1566 c = self._NewComputation() 1567 ops.NextAfter( 1568 ops.Constant(c, np.array([1, 2], dtype=np.float32)), 1569 ops.Constant(c, np.array([2, 1], dtype=np.float32))) 1570 out, = self._Execute(c, ()) 1571 eps = np.finfo(np.float32).eps 1572 np.testing.assert_equal( 1573 np.array([eps + 1, 2 - eps], dtype=np.float32), out) 1574 1575 @parameterized.named_parameters({ 1576 "testcase_name": "_{}".format(dtype.__name__), 1577 "dtype": dtype, 1578 } for dtype in float_dtypes) 1579 def testRegularizedIncompleteBeta(self, dtype): 1580 x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538], 1581 dtype=dtype) 1582 a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606], 1583 dtype=dtype) 1584 b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677], 1585 dtype=dtype) 1586 c = self._NewComputation() 1587 ops.RegularizedIncompleteBeta( 1588 ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x)) 1589 expected = np.array( 1590 [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) 1591 self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2) 1592 1593 tests.append(SingleOpTest) 1594 1595 class EmbeddedComputationsTest(ComputationTest): 1596 """Tests for XLA graphs with embedded computations (such as maps).""" 1597 1598 def _CreateConstantComputation(self, in_dtype, out_dtype): 1599 """Computation (A) -> B that returns a constant 1 for any input.""" 1600 c = self._NewComputation("constant_{}_{}_one".format( 1601 in_dtype.__name__, out_dtype.__name__)) 1602 ops.Parameter( 1603 c, 0, 1604 xla_client.shape_from_pyval(np.array( 1605 0, dtype=in_dtype)).with_major_to_minor_layout_if_absent()) 1606 ops.Constant(c, out_dtype(1)) 1607 return c.build() 1608 1609 def _CreateMulBy2Computation(self, dtype): 1610 """Computation (dtype) -> dtype that multiplies its parameter by 2.""" 1611 c = self._NewComputation("mul_f32_by2") 1612 ops.Mul( 1613 ops.Parameter( 1614 c, 0, 1615 xla_client.shape_from_pyval(np.array( 1616 0, dtype=dtype)).with_major_to_minor_layout_if_absent()), 1617 ops.Constant(c, dtype(2.0))) 1618 return c.build() 1619 1620 def _CreateMulF32ByParamComputation(self): 1621 """Computation (f32) -> f32 that multiplies one parameter by the other.""" 1622 c = self._NewComputation("mul_f32_by_param") 1623 ops.Mul( 1624 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), 1625 ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) 1626 return c.build() 1627 1628 def _CreateBinaryAddComputation(self, dtype): 1629 """Computation (dtype, dtype) -> dtype that adds its two parameters.""" 1630 c = self._NewComputation("add_param0_by_param1") 1631 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1632 shape = shape.with_major_to_minor_layout_if_absent() 1633 ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) 1634 return c.build() 1635 1636 def _CreateBinaryGeComputation(self, dtype): 1637 """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" 1638 c = self._NewComputation("param0_lt_param1") 1639 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1640 shape = shape.with_major_to_minor_layout_if_absent() 1641 ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) 1642 return c.build() 1643 1644 def _MakeSample3DArray(self, dtype): 1645 return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], 1646 [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], 1647 dtype=dtype) 1648 1649 @parameterized.named_parameters({ 1650 "testcase_name": "_{}".format(dtype.__name__), 1651 "dtype": dtype, 1652 } for dtype in float_dtypes) 1653 def testCall(self, dtype): 1654 c = self._NewComputation() 1655 ops.Call( 1656 c, 1657 self._CreateMulBy2Computation(dtype), 1658 operands=(ops.Constant(c, dtype(5.0)),)) 1659 self._ExecuteAndCompareClose(c, expected=[10.0]) 1660 1661 @parameterized.named_parameters({ 1662 "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__), 1663 "in_dtype": in_dtype, 1664 "out_dtype": out_dtype, 1665 } for in_dtype, out_dtype in [[np.float32, np.int32]]) 1666 def testMapEachElementToConstant(self, in_dtype, out_dtype): 1667 c = self._NewComputation() 1668 ops.Map(c, 1669 [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))], 1670 self._CreateConstantComputation(in_dtype, out_dtype), [0]) 1671 self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) 1672 1673 @parameterized.named_parameters({ 1674 "testcase_name": "_{}".format(dtype.__name__), 1675 "dtype": dtype, 1676 } for dtype in float_dtypes) 1677 def testMapMulBy2(self, dtype): 1678 if dtype == np.float64 and self.backend.platform == "tpu": 1679 self.skipTest("TPU doesn't support float64") 1680 c = self._NewComputation() 1681 ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], 1682 self._CreateMulBy2Computation(dtype), [0]) 1683 self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) 1684 1685 @parameterized.named_parameters({ 1686 "testcase_name": "_{}".format(dtype.__name__), 1687 "dtype": dtype, 1688 } for dtype in float_dtypes) 1689 def testSimpleMapChain(self, dtype): 1690 if dtype == np.float64 and self.backend.platform == "tpu": 1691 self.skipTest("TPU doesn't support float64") 1692 # Chains a map of constant-out with a map of mul-by-2 1693 c = self._NewComputation() 1694 const = ops.Map( 1695 c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], 1696 self._CreateConstantComputation(dtype, dtype), [0]) 1697 ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0]) 1698 self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) 1699 1700 # TODO(b/154752816): bfloat16 crashes in evaluator. 1701 @parameterized.named_parameters({ 1702 "testcase_name": "_{}".format(dtype.__name__), 1703 "dtype": dtype, 1704 } for dtype in float_dtypes if dtype != bfloat16) 1705 def testDivVectorsWithMap(self, dtype): 1706 1707 def DivComputation(): 1708 c = self._NewComputation("div_param0_by_param1") 1709 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1710 ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) 1711 return c.build() 1712 1713 c = self._NewComputation() 1714 ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), 1715 ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))), 1716 DivComputation(), [0]) 1717 self._ExecuteAndCompareClose( 1718 c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3) 1719 1720 @parameterized.named_parameters({ 1721 "testcase_name": "_{}".format(dtype.__name__), 1722 "dtype": dtype, 1723 } for dtype in float_dtypes) 1724 def testSelectAndScatter(self, dtype): 1725 if dtype == np.float64 and self.backend.platform == "tpu": 1726 self.skipTest("TPU doesn't support float64") 1727 c = self._NewComputation() 1728 operand = ops.Constant( 1729 c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype)) 1730 window_dimensions = (2, 1) 1731 window_strides = (1, 2) 1732 padding = xla_client.window_padding_type_to_pad_values( 1733 xla_client.PaddingType.VALID, 1734 c.get_shape(operand).dimensions(), window_dimensions, window_strides) 1735 ops.SelectAndScatterWithGeneralPadding( 1736 operand, 1737 select=self._CreateBinaryGeComputation(dtype), 1738 window_dimensions=window_dimensions, 1739 window_strides=window_strides, 1740 padding=padding, 1741 source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)), 1742 init_value=ops.Constant(c, np.array(1, dtype=dtype)), 1743 scatter=self._CreateBinaryAddComputation(dtype)) 1744 self._ExecuteAndCompareClose( 1745 c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3) 1746 1747 @parameterized.named_parameters({ 1748 "testcase_name": "_{}".format(dtype.__name__), 1749 "dtype": dtype, 1750 } for dtype in float_dtypes) 1751 def testReduce1DtoScalar(self, dtype): 1752 c = self._NewComputation() 1753 ops.Reduce( 1754 c, 1755 operands=[ 1756 ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)) 1757 ], 1758 init_values=[ops.Constant(c, dtype(0))], 1759 computation=self._CreateBinaryAddComputation(dtype), 1760 dimensions_to_reduce=[0]) 1761 self._ExecuteAndCompareClose(c, expected=[10]) 1762 1763 # TODO(phawkins): test comparison harness doesn't support bfloat16 1764 @parameterized.named_parameters({ 1765 "testcase_name": "_{}_dim{}".format(dtype.__name__, dim), 1766 "dtype": dtype, 1767 "dim": dim, 1768 } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2)) 1769 def testReduce2DTo1D(self, dtype, dim): 1770 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1771 c = self._NewComputation() 1772 ops.Reduce( 1773 c, 1774 operands=[ops.Constant(c, input_array)], 1775 init_values=[ops.Constant(c, dtype(0))], 1776 computation=self._CreateBinaryAddComputation(dtype), 1777 dimensions_to_reduce=[dim]) 1778 self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)]) 1779 1780 @parameterized.named_parameters({ 1781 "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims), 1782 "dtype": dtype, 1783 "dims": tuple(dims) 1784 } for dtype in float_dtypes for dims in itertools.permutations(range(3))) 1785 def testReduce3DAllPossibleWaysF32(self, dtype, dims): 1786 input_array = self._MakeSample3DArray(dtype) 1787 c = self._NewComputation() 1788 ops.Reduce( 1789 c, 1790 operands=[ops.Constant(c, input_array)], 1791 init_values=[ops.Constant(c, dtype(0))], 1792 computation=self._CreateBinaryAddComputation(dtype), 1793 dimensions_to_reduce=dims) 1794 self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)]) 1795 1796 @parameterized.named_parameters({ 1797 "testcase_name": "_{}".format(dtype.__name__), 1798 "dtype": dtype, 1799 } for dtype in float_dtypes) 1800 def testReduceWindowValidUnitStrides(self, dtype): 1801 if dtype == np.float64 and self.backend.platform == "tpu": 1802 self.skipTest("TPU doesn't support float64") 1803 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1804 c = self._NewComputation() 1805 window_dimensions = (2, 1) 1806 window_strides = (1, 1) 1807 padding = xla_client.window_padding_type_to_pad_values( 1808 xla_client.PaddingType.VALID, input_array.shape, window_dimensions, 1809 window_strides) 1810 ops.ReduceWindowWithGeneralPadding( 1811 operand=ops.Constant(c, input_array), 1812 init_value=ops.Constant(c, dtype(0)), 1813 computation=self._CreateBinaryAddComputation(dtype), 1814 window_dimensions=window_dimensions, 1815 window_strides=window_strides, 1816 base_dilations=[], 1817 window_dilations=[], 1818 padding=padding) 1819 self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) 1820 1821 @parameterized.named_parameters({ 1822 "testcase_name": "_{}".format(dtype.__name__), 1823 "dtype": dtype, 1824 } for dtype in float_dtypes) 1825 def testReduceWindowSameUnitStrides(self, dtype): 1826 if dtype == np.float64 and self.backend.platform == "tpu": 1827 self.skipTest("TPU doesn't support float64") 1828 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1829 c = self._NewComputation() 1830 window_dimensions = (2, 1) 1831 window_strides = (1, 1) 1832 padding = xla_client.window_padding_type_to_pad_values( 1833 xla_client.PaddingType.SAME, input_array.shape, window_dimensions, 1834 window_strides) 1835 ops.ReduceWindowWithGeneralPadding( 1836 operand=ops.Constant(c, input_array), 1837 init_value=ops.Constant(c, dtype(0)), 1838 computation=self._CreateBinaryAddComputation(dtype), 1839 window_dimensions=window_dimensions, 1840 window_strides=window_strides, 1841 base_dilations=[], 1842 window_dilations=[], 1843 padding=padding) 1844 self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) 1845 1846 @parameterized.named_parameters({ 1847 "testcase_name": "_{}".format(dtype.__name__), 1848 "dtype": dtype, 1849 } for dtype in float_dtypes) 1850 def testReduceWindowValidGeneralStrides(self, dtype): 1851 if dtype == np.float64 and self.backend.platform == "tpu": 1852 self.skipTest("TPU doesn't support float64") 1853 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1854 c = self._NewComputation() 1855 window_dimensions = (2, 1) 1856 window_strides = (1, 2) 1857 padding = xla_client.window_padding_type_to_pad_values( 1858 xla_client.PaddingType.VALID, input_array.shape, window_dimensions, 1859 window_strides) 1860 ops.ReduceWindowWithGeneralPadding( 1861 operand=ops.Constant(c, input_array), 1862 init_value=ops.Constant(c, dtype(0)), 1863 computation=self._CreateBinaryAddComputation(dtype), 1864 window_dimensions=window_dimensions, 1865 window_strides=window_strides, 1866 base_dilations=[], 1867 window_dilations=[], 1868 padding=padding) 1869 self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) 1870 1871 def testReduceWindowVariadic(self): 1872 c = self._NewComputation("reducer") 1873 shape = xla_client.shape_from_pyval(np.array(0, dtype=np.int32)) 1874 shape = shape.with_major_to_minor_layout_if_absent() 1875 ps = [ops.Parameter(c, i, shape) for i in range(4)] 1876 which = ops.Ge(ps[0], ps[2]) 1877 ops.Tuple( 1878 c, [ops.Select(which, ps[0], ps[2]), 1879 ops.Select(which, ps[1], ps[3])]) 1880 reducer = c.build() 1881 1882 key_array = np.array([[1, 5, 6], [4, 2, 3]], dtype=np.int32) 1883 val_array = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.int32) 1884 c = self._NewComputation() 1885 window_dimensions = (2, 1) 1886 window_strides = (1, 1) 1887 padding = xla_client.window_padding_type_to_pad_values( 1888 xla_client.PaddingType.VALID, key_array.shape, window_dimensions, 1889 window_strides) 1890 ops.ReduceWindowWithGeneralPadding( 1891 operands=[ops.Constant(c, key_array), 1892 ops.Constant(c, val_array)], 1893 init_values=[ 1894 ops.Constant(c, np.int32(0)), 1895 ops.Constant(c, np.int32(0)) 1896 ], 1897 computation=reducer, 1898 window_dimensions=window_dimensions, 1899 window_strides=window_strides, 1900 base_dilations=[], 1901 window_dilations=[], 1902 padding=padding) 1903 self._ExecuteAndCompareClose(c, expected=[[[4, 5, 6]], [[10, 8, 9]]]) 1904 1905 @parameterized.named_parameters({ 1906 "testcase_name": "_{}".format(dtype.__name__), 1907 "dtype": dtype, 1908 } for dtype in float_dtypes) 1909 def testWhile(self, dtype): 1910 1911 def LessThan10Cond(): 1912 c = self._NewComputation("test_lt_10") 1913 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1914 ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) 1915 return c.build() 1916 1917 cond = LessThan10Cond() 1918 body = self._CreateMulBy2Computation(dtype) 1919 c = self._NewComputation() 1920 init = ops.Constant(c, dtype(1.)) 1921 ops.While(cond, body, init) 1922 self._ExecuteAndCompareClose(c, expected=[16.]) 1923 1924 def testConditionalTrue(self): 1925 c = self._NewComputation() 1926 pred = ops.Constant(c, np.bool_(True)) 1927 true_operand = ops.Constant(c, np.float32(3.)) 1928 true_computation = self._CreateMulBy2Computation(np.float32) 1929 false_operand = ops.Constant(c, np.float32(2.)) 1930 false_computation = self._CreateConstantComputation( 1931 np.float32, np.float32) 1932 ops.Conditional(pred, true_operand, true_computation, false_operand, 1933 false_computation) 1934 self._ExecuteAndCompareClose(c, expected=[6.]) 1935 1936 def testConditionalFalse(self): 1937 c = self._NewComputation() 1938 pred = ops.Constant(c, np.bool_(False)) 1939 true_operand = ops.Constant(c, np.float32(3.)) 1940 true_computation = self._CreateMulBy2Computation(np.float32) 1941 false_operand = ops.Constant(c, np.float32(2.)) 1942 false_computation = self._CreateConstantComputation( 1943 np.float32, np.float32) 1944 ops.Conditional(pred, true_operand, true_computation, false_operand, 1945 false_computation) 1946 self._ExecuteAndCompareClose(c, expected=[1.]) 1947 1948 @unittest.skipIf(cloud_tpu, "not implemented") 1949 def testInfeedS32Values(self): 1950 to_infeed = NumpyArrayS32([1, 2, 3, 4]) 1951 c = self._NewComputation() 1952 ops.GetTupleElement( 1953 ops.InfeedWithToken( 1954 ops.CreateToken(c), 1955 xla_client.shape_from_pyval( 1956 to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) 1957 compiled_c = self.backend.compile(c.build()) 1958 device = self.backend.local_devices()[0] 1959 for item in to_infeed: 1960 device.transfer_to_infeed(item) 1961 1962 for item in to_infeed: 1963 result, = xla_client.execute_with_python_values( 1964 compiled_c, (), backend=self.backend) 1965 self.assertEqual(result, item) 1966 1967 @unittest.skipIf(cloud_tpu, "not implemented") 1968 def testInfeedTuple(self): 1969 to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) 1970 c = self._NewComputation() 1971 ops.GetTupleElement( 1972 ops.InfeedWithToken( 1973 ops.CreateToken(c), 1974 xla_client.shape_from_pyval( 1975 to_infeed).with_major_to_minor_layout_if_absent()), 0) 1976 compiled_c = self.backend.compile(c.build()) 1977 device = self.backend.local_devices()[0] 1978 device.transfer_to_infeed(to_infeed) 1979 1980 result = xla_client.execute_with_python_values( 1981 compiled_c, (), backend=self.backend) 1982 self.assertLen(result, 2) 1983 np.testing.assert_equal(result[0], to_infeed[0]) 1984 np.testing.assert_equal(result[1], to_infeed[1]) 1985 1986 @unittest.skipIf(cloud_tpu, "not implemented") 1987 def testInfeedThenOutfeedS32(self): 1988 to_round_trip = NumpyArrayS32([1, 2, 3, 4]) 1989 c = self._NewComputation() 1990 x_and_token = ops.InfeedWithToken( 1991 ops.CreateToken(c), 1992 xla_client.shape_from_pyval( 1993 to_round_trip[0]).with_major_to_minor_layout_if_absent()) 1994 x = ops.GetTupleElement(x_and_token, 0) 1995 token = ops.GetTupleElement(x_and_token, 1) 1996 outfeed_shape = xla_client.shape_from_pyval( 1997 to_round_trip[0]).with_major_to_minor_layout_if_absent() 1998 ops.OutfeedWithToken(x, token, outfeed_shape) 1999 2000 compiled_c = self.backend.compile(c.build()) 2001 device = self.backend.local_devices()[0] 2002 2003 for want in to_round_trip: 2004 execution = threading.Thread(target=lambda: compiled_c.execute([])) 2005 execution.start() 2006 device.transfer_to_infeed(want) 2007 got = device.transfer_from_outfeed(outfeed_shape) 2008 execution.join() 2009 self.assertEqual(want, got) 2010 2011 def testScatter(self): 2012 a = np.arange(9).astype(np.int32).reshape((3, 3)) 2013 scatter_indices = np.array([0, 2], dtype=np.int32) 2014 updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) 2015 2016 dnums = xla_client.ScatterDimensionNumbers() 2017 dnums.update_window_dims.append(1) 2018 dnums.inserted_window_dims.append(0) 2019 dnums.scatter_dims_to_operand_dims.append(0) 2020 dnums.index_vector_dim = 1 2021 2022 c = self._NewComputation() 2023 ops.Scatter( 2024 ops.Constant(c, a), ops.Constant(c, scatter_indices), 2025 ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32), 2026 dnums) 2027 expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], 2028 dtype=np.int32) 2029 self._ExecuteAndCompareClose(c, expected=[expected]) 2030 2031 class DeviceTest(ComputationTest): 2032 2033 def testPlatform(self): 2034 for device in self.backend.local_devices(): 2035 self.assertEqual(device.platform, self.backend.platform) 2036 2037 tests.append(DeviceTest) 2038 2039 class ErrorTest(ComputationTest): 2040 2041 def setUp(self): 2042 super(ErrorTest, self).setUp() 2043 self.f32_scalar_2 = NumpyArrayF32(2.0) 2044 self.s32_scalar_2 = NumpyArrayS32(2) 2045 2046 def testCompileWithWrongElementTypeInLayout(self): 2047 c = self._NewComputation() 2048 c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) 2049 ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) 2050 c.clear_op_metadata() 2051 2052 options = xla_client.CompileOptions() 2053 options.argument_layouts = [ 2054 xla_client.Shape.array_shape(np.dtype(np.float32), []) 2055 ] 2056 2057 def TestFun(): 2058 return self.backend.compile(c.build(), compile_options=options) 2059 2060 self.assertRaisesRegex( 2061 RuntimeError, r".*Invalid argument shape.*" 2062 r"expected s32\[\], got f32\[\].*", TestFun) 2063 2064 def testInvokeWithWrongElementType(self): 2065 c = self._NewComputation() 2066 c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) 2067 ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) 2068 c.clear_op_metadata() 2069 2070 def TestFun(): 2071 return xla_client.execute_with_python_values( 2072 self.backend.compile(c.build()), [self.f32_scalar_2], self.backend) 2073 2074 self.assertRaisesRegex( 2075 RuntimeError, r"Invalid argument: Argument does not match.*" 2076 r"want s32\[\], got f32\[\].*", TestFun) 2077 2078 tests.append(EmbeddedComputationsTest) 2079 2080 class ComputationRootTest(ComputationTest): 2081 """Tests related to setting the root of the computation.""" 2082 2083 def testComputationRootDifferentFromLastOp(self): 2084 c = self._NewComputation() 2085 x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) 2086 result = ops.Add(x, ops.Constant(c, np.float32(3.14))) 2087 ops.Add(result, ops.Constant(c, np.float32(1.618))) 2088 2089 arg = NumpyArrayF32(1.0) 2090 compiled_c = self.backend.compile(c.build(result)) 2091 ans, = xla_client.execute_with_python_values( 2092 compiled_c, [arg], backend=self.backend) 2093 np.testing.assert_allclose(ans, 4.14) 2094 2095 tests.append(ComputationRootTest) 2096 2097 class SetShardingTest(ComputationTest): 2098 """Tests related to set OpSharding.""" 2099 2100 def testSetSharding(self): 2101 c = self._NewComputation() 2102 sharding = xla_client.OpSharding() 2103 sharding.type = sharding.type.REPLICATED 2104 sharding.tile_assignment_dimensions.extend([1]) 2105 sharding.tile_assignment_devices.extend([0]) 2106 c.set_sharding(sharding) 2107 x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) 2108 c.clear_sharding() 2109 2110 result = ops.Add(x, ops.Constant(c, np.float32(3.14))) 2111 ops.Add(result, ops.Constant(c, np.float32(1.618))) 2112 arg = NumpyArrayF32(1.0) 2113 compiled_c = self.backend.compile(c.build(result)) 2114 ans, = xla_client.execute_with_python_values( 2115 compiled_c, [arg], backend=self.backend) 2116 np.testing.assert_allclose(ans, 4.14) 2117 2118 tests.append(SetShardingTest) 2119 2120 testcase_shapes = [ 2121 (), 2122 (1,), 2123 (2, 3), 2124 (2, 0), 2125 (0, 7), 2126 (4, 1, 2), 2127 (2, 1, 3), 2128 (2, 4, 1), 2129 (3, 1), 2130 (1, 3), 2131 ] 2132 2133 def FormatShapeAndDtype(shape, dtype): 2134 return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) 2135 2136 class DLPackTest(parameterized.TestCase): 2137 2138 def setUp(self): 2139 super(DLPackTest, self).setUp() 2140 self.backend = xla_backend() 2141 if self.backend.platform not in ("cpu", "gpu"): 2142 self.skipTest("DLPack requires CPU or GPU") 2143 self.cpu_backend = ( 2144 self.backend 2145 if self.backend.platform == "cpu" else xla_client.make_cpu_client()) 2146 self.gpu_backend = ( 2147 self.backend if self.backend.platform == "gpu" else None) 2148 2149 # pylint: disable=g-complex-comprehension 2150 # pyformat: disable 2151 @parameterized.named_parameters({ 2152 "testcase_name": "{}_own={}_gpu={}".format( 2153 FormatShapeAndDtype(shape, dtype), take_ownership, gpu), 2154 "dtype": dtype, 2155 "shape": shape, 2156 "take_ownership": take_ownership, 2157 "gpu": gpu 2158 } for dtype in dlpack_dtypes for shape in testcase_shapes 2159 for take_ownership in [False, True] 2160 for gpu in [False, True]) 2161 # pyformat: enable 2162 def testRoundTrip(self, dtype, shape, take_ownership, gpu): 2163 if gpu and self.gpu_backend is None: 2164 raise unittest.SkipTest("Test not running with GPU support") 2165 backend = self.gpu_backend if gpu else self.cpu_backend 2166 if dtype == np.bool_: 2167 x = np.random.randint(0, 2, size=shape).astype(np.bool_) 2168 else: 2169 x = np.array(np.random.rand(*shape) * 100, dtype=dtype) 2170 buffer = backend.buffer_from_pyval(x) 2171 dlt = xla_client._xla.buffer_to_dlpack_managed_tensor( 2172 buffer, take_ownership=take_ownership) 2173 del buffer # Free "buffer" to make sure dlt retains ownership. 2174 self.assertEqual(type(dlt).__name__, "PyCapsule") 2175 y = xla_client._xla.dlpack_managed_tensor_to_buffer( 2176 dlt, self.cpu_backend, self.gpu_backend) 2177 np.testing.assert_array_equal( 2178 x.astype(np.uint8) if dtype == np.bool_ else x, y.to_py()) 2179 2180 def testTensorsCanBeConsumedOnceOnly(self): 2181 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 2182 buffer = self.backend.buffer_from_pyval(x) 2183 dlt = xla_client._xla.buffer_to_dlpack_managed_tensor( 2184 buffer, take_ownership=True) 2185 2186 def ConsumeDLPackTensor(): 2187 _ = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend) 2188 2189 ConsumeDLPackTensor() 2190 self.assertRaisesRegex( 2191 RuntimeError, ".*a DLPack tensor may be consumed at most once.*", 2192 ConsumeDLPackTensor) 2193 2194 def testTensorsCanBeOwnedOnceOnly(self): 2195 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 2196 buffer = self.backend.buffer_from_pyval(x) 2197 _ = xla_client._xla.buffer_to_dlpack_managed_tensor( 2198 buffer, take_ownership=True) 2199 self.assertTrue(buffer.is_deleted()) 2200 with self.assertRaisesRegex( 2201 RuntimeError, 2202 "Cannot convert deleted/invalid buffer to DLPack tensor.*"): 2203 _ = xla_client._xla.buffer_to_dlpack_managed_tensor( 2204 buffer, take_ownership=True) 2205 2206 def testNonOwnedDlpackCanBeViewedTwice(self): 2207 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 2208 buffer = self.backend.buffer_from_pyval(x) 2209 d1 = xla_client._xla.buffer_to_dlpack_managed_tensor( 2210 buffer, take_ownership=False) 2211 d2 = xla_client._xla.buffer_to_dlpack_managed_tensor( 2212 buffer, take_ownership=False) 2213 2214 y = xla_client._xla.dlpack_managed_tensor_to_buffer(d1, self.backend) 2215 z = xla_client._xla.dlpack_managed_tensor_to_buffer(d2, self.backend) 2216 del d1, d2 2217 np.testing.assert_array_equal(x, buffer.to_py()) 2218 np.testing.assert_array_equal(x, y.to_py()) 2219 np.testing.assert_array_equal(x, z.to_py()) 2220 2221 tests.append(DLPackTest) 2222 2223 class BufferProtocolTest(parameterized.TestCase): 2224 2225 def setUp(self): 2226 super(BufferProtocolTest, self).setUp() 2227 self.backend = xla_backend() 2228 if self.backend.platform != "cpu": 2229 self.skipTest("Test requires CPU") 2230 2231 # pylint: disable=g-complex-comprehension 2232 @parameterized.named_parameters({ 2233 "testcase_name": FormatShapeAndDtype(shape, dtype), 2234 "dtype": dtype, 2235 "shape": shape 2236 } for dtype in standard_dtypes if dtype != bfloat16 2237 for shape in testcase_shapes) 2238 def testRoundTrip(self, dtype, shape): 2239 x = np.array(np.random.rand(*shape) * 100, dtype=dtype) 2240 x_ptr = x.__array_interface__["data"][0] 2241 buffer = self.backend.buffer_from_pyval( 2242 x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY) 2243 y = np.array(buffer, copy=False) 2244 y_ptr = y.__array_interface__["data"][0] 2245 np.testing.assert_array_equal(x, y) 2246 # If the input was sufficiently aligned, the input and output should 2247 # alias. 2248 self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr) 2249 self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) 2250 2251 during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL 2252 buffer2 = self.backend.buffer_from_pyval( 2253 x, host_buffer_semantics=during_call) 2254 z = np.array(buffer2, copy=False) 2255 self.assertNotEqual(x.__array_interface__["data"][0], 2256 z.__array_interface__["data"][0]) 2257 2258 def testDeleteWithActiveView(self): 2259 x = np.random.randn(20, 10) 2260 buffer = self.backend.buffer_from_pyval(x) 2261 buffer_ptr = buffer.unsafe_buffer_pointer() 2262 y = np.array(buffer, copy=False) 2263 buffer.delete() 2264 # It is still legal to access `y`; the array view must keep it alive. 2265 np.testing.assert_array_equal(x, y) 2266 self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) 2267 2268 tests.append(BufferProtocolTest) 2269 2270 class TracebackTest(absltest.TestCase): 2271 2272 def setUp(self): 2273 super(TracebackTest, self).setUp() 2274 self.backend = xla_backend() 2275 2276 def testNoTracebacksIfDisabled(self): 2277 with xla_client.tracebacks(enabled=False): 2278 self.assertEqual(None, xla_client.Traceback.get_traceback()) 2279 buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) 2280 self.assertEqual(None, buffer.traceback) 2281 2282 b = xla_client.XlaBuilder("computation") 2283 ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) 2284 e = self.backend.compile(b.build()) 2285 self.assertEqual(None, e.traceback) 2286 2287 def assertIsTracebackContaining(self, tb, function): 2288 self.assertIsInstance(tb, xla_client.Traceback) 2289 self.assertIn(function, str(tb)) 2290 self.assertTrue(any(f.function_name == function for f in tb.frames)) 2291 2292 def testTracebacks(self): 2293 with xla_client.tracebacks(enabled=True): 2294 tb = xla_client.Traceback.get_traceback() 2295 self.assertIsTracebackContaining(tb, "testTracebacks") 2296 2297 # Tracebacks are not implemented on the TPU driver extension's variant 2298 # of buffers and executables. 2299 if not isinstance(self.backend, xla_client.Client): 2300 return 2301 2302 buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) 2303 self.assertIsTracebackContaining(buffer.traceback, "testTracebacks") 2304 2305 b = xla_client.XlaBuilder("computation") 2306 ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) 2307 e = self.backend.compile(b.build()) 2308 self.assertIsTracebackContaining(e.traceback, "testTracebacks") 2309 2310 def testNestedFunction(self): 2311 2312 def AFunction(): 2313 2314 def AnotherFunction(): 2315 return xla_client.Traceback.get_traceback() 2316 2317 return AnotherFunction() 2318 2319 with xla_client.tracebacks(enabled=True): 2320 tb = AFunction() 2321 self.assertIsInstance(tb, xla_client.Traceback) 2322 frames = tb.frames 2323 i = next( 2324 i for (i, f) in enumerate(frames) if f.function_name == "AFunction") 2325 self.assertEqual(frames[i - 1].function_name, "AnotherFunction") 2326 self.assertEqual(frames[i + 1].function_name, "testNestedFunction") 2327 2328 tests.append(TracebackTest) 2329 2330 class ClientTest(ComputationTest): 2331 2332 def setUp(self): 2333 super(ClientTest, self).setUp() 2334 self.backend = xla_backend() 2335 2336 def testPlatformVersion(self): 2337 version = self.backend.platform_version 2338 if self.backend.platform == "cpu": 2339 self.assertEqual(version, "<unknown>") 2340 elif self.backend.platform == "gpu": 2341 # Following is false if not built with --config=cuda 2342 if test_util.is_gpu_available(cuda_only=True): 2343 self.assertTrue( 2344 re.match(r"^cuda \d{4,}$", version), 2345 msg=f"Expected CUDA version string; got {repr(version)}") 2346 else: 2347 self.assertEqual(version, "<unknown>") 2348 2349 @unittest.skipIf(cloud_tpu or tfrt_tpu, "not implemented") 2350 def testExecutableSerialization(self): 2351 if self.backend.platform != "tpu": 2352 self.skipTest("Test requires tpu platform") 2353 2354 c = self._NewComputation() 2355 ops.Add( 2356 ops.Constant(c, NumpyArrayS32([1, 2])), 2357 ops.Constant(c, NumpyArrayS32([3, 4]))) 2358 2359 options = xla_client.CompileOptions() 2360 executable = self.backend.compile(c.build(), options) 2361 self.assertLen(executable.hlo_modules(), 1) 2362 2363 serialized = self.backend.serialize_executable(executable) 2364 deserialized = self.backend.deserialize_executable( 2365 serialized, 2366 executable.hlo_modules()[0], options) 2367 2368 expected, = xla_client.execute_with_python_values(executable, (), 2369 self.backend) 2370 actual, = xla_client.execute_with_python_values(deserialized, (), 2371 self.backend) 2372 self.assertTrue(np.all(actual == expected)) 2373 2374 tests.append(ClientTest) 2375 2376 # TODO(b/182461453): Add TFRT and cloud TPU implementation of 2377 # ReadDynamicShapes 2378 class DynamicReshapeTest(ComputationTest): 2379 """Tests related to DynamicReshape.""" 2380 2381 def _CompareToPyAndBufferProtocol(self, builder, args, expected_results, 2382 test_fn): 2383 compiled = self.backend.compile(builder.build()) 2384 output_buffers = compiled.execute([ 2385 self.backend.buffer_from_pyval( 2386 arg, device=compiled.local_devices()[0]) for arg in args 2387 ]) 2388 self.assertLen(output_buffers, len(expected_results)) 2389 for buf, expected in zip(output_buffers, expected_results): 2390 to_py_result = buf.to_py() 2391 self.assertEqual(expected.shape, to_py_result.shape) 2392 test_fn(expected, to_py_result) 2393 if self.backend.platform == "cpu" and buf.dtype != bfloat16: 2394 mview = memoryview(buf) 2395 self.assertEqual(expected.shape, mview.shape) 2396 test_fn(expected, np.asarray(mview)) 2397 else: 2398 # Buffer protocol expected to fail on non-cpu platforms and bfloat16 2399 # Note that np.asarray(buf) doesn't throw an exception. To test if the 2400 # error was thrown properly we must use memoryview(buf). 2401 with self.assertRaises(BufferError): 2402 memoryview(buf) 2403 2404 # 1D reshape of full size, half size, and size of 0. 2405 @unittest.skipIf(cloud_tpu or tfrt_tpu or external_tpu, "not implemented") 2406 @parameterized.parameters((5), (3), (0)) 2407 def testReshape1D(self, reshape_size): 2408 full_size = 5 2409 c = self._NewComputation() 2410 arg = np.array(reshape_size, dtype=np.int32) 2411 expected = np.array(range(reshape_size), dtype=np.int32) 2412 p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) 2413 ops.DynamicReshape( 2414 ops.Constant(c, NumpyArrayS32(range(full_size))), [p], [full_size], 2415 [True]) 2416 self._CompareToPyAndBufferProtocol(c, [arg], [expected], 2417 np.testing.assert_equal) 2418 2419 # 2D reshape with an slice on the minor dimension. We test different types 2420 # where the strides may differ between the host and devices. The reshaped 2421 # physical memory layout is not consecutive, and we test if the program can 2422 # return the correct logical view of the data. 2423 @unittest.skipIf(cloud_tpu or tfrt_tpu or external_tpu, "not implemented") 2424 @parameterized.named_parameters({ 2425 "testcase_name": "_{}".format(dtype.__name__), 2426 "dtype": dtype, 2427 } for dtype in int_dtypes + float_dtypes) 2428 def testReshape2D(self, dtype): 2429 arg0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype) 2430 arg1 = np.array(2, dtype=np.int32) 2431 expected = np.array([[1, 2], [4, 5]], dtype=np.int32) 2432 c = self._NewComputation() 2433 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 2434 p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) 2435 ops.DynamicReshape(p0, [p1, p1], [2, 3], [False, True]) 2436 self._CompareToPyAndBufferProtocol(c, [arg0, arg1], [expected], 2437 np.testing.assert_equal) 2438 2439 @unittest.skipIf(cloud_tpu or tfrt_tpu, "not implemented") 2440 @parameterized.named_parameters({ 2441 "testcase_name": "_{}".format(dtype.__name__), 2442 "dtype": dtype, 2443 } for dtype in int_dtypes + float_dtypes) 2444 def testDynamicShapeArgs(self, dtype): 2445 full_size = 10 2446 dynamic_shape_size = 4 2447 # subcomputation 1 2448 binary_add_builder = self._NewComputation() 2449 scalar_shape = xla_client.Shape.scalar_shape(np.dtype(dtype)) 2450 ops.Add( 2451 ops.Parameter(binary_add_builder, 0, scalar_shape), 2452 ops.Parameter(binary_add_builder, 1, scalar_shape)) 2453 # subcomputation 2 2454 reshape_reduce_builder = self._NewComputation() 2455 dshape = xla_client.Shape.array_shape( 2456 np.dtype(dtype), dims=[full_size], dynamic_dimensions=[True]) 2457 reshape_reduce_p = ops.Parameter(reshape_reduce_builder, 0, dshape) 2458 ops.Reduce( 2459 reshape_reduce_builder, 2460 operands=[reshape_reduce_p], 2461 init_values=[ops.Constant(reshape_reduce_builder, dtype(0))], 2462 computation=binary_add_builder.build(), 2463 dimensions_to_reduce=[0]) 2464 # main computation: sum(range(full_size)[:dynamic_shape_size]) 2465 c = self._NewComputation() 2466 arg = np.array(dynamic_shape_size, dtype=np.int32) 2467 p = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg)) 2468 reshaped = ops.DynamicReshape( 2469 ops.Constant(c, np.array(range(full_size), dtype=dtype)), [p], 2470 [full_size], [True]) 2471 ops.Call(c, reshape_reduce_builder.build(), operands=(reshaped,)) 2472 self._ExecuteAndCompareClose(c, [arg], [dtype(6)]) 2473 2474 tests.append(DynamicReshapeTest) 2475 2476 class DeviceAssignmentTest(ComputationTest): 2477 2478 def testSerialize(self): 2479 shape = (3, 4) 2480 device_assignment = xla_client.DeviceAssignment.create( 2481 np.arange(np.prod(shape)).reshape(*shape)) 2482 self.assertEqual(device_assignment.replica_count(), shape[0]) 2483 self.assertEqual(device_assignment.computation_count(), shape[1]) 2484 serialized = device_assignment.serialize() 2485 self.assertIsInstance(serialized, bytes) 2486 self.assertNotEmpty(serialized) 2487 2488 tests.append(DeviceAssignmentTest) 2489 2490 return tests 2491 2492 2493def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): 2494 # Avoid creating a new backend per test (this causes GPU OOM, and is probably 2495 # inefficient). 2496 backend_fn = functools.lru_cache(maxsize=None)(backend_fn) 2497 for klass in TestFactory(backend_fn, **kw): 2498 test = type(test_prefix + klass.__name__, (klass,), {}) 2499 # Clean up the qualified names of the tests to not include the test factory. 2500 test.__qualname__ = test.__name__ 2501 globals_dict[test.__name__] = test 2502 2503 2504backends = { 2505 "cpu": xla_client.make_cpu_client, 2506 "gpu": xla_client.make_gpu_client, 2507} 2508 2509if __name__ == "__main__": 2510 flags.DEFINE_string("backend", "cpu", "Target platform.") 2511 # pylint: disable=unnecessary-lambda 2512 InstantiateTests(globals(), lambda: backends[FLAGS.backend]()) 2513 # pylint: enable=unnecessary-lambda 2514 absltest.main() 2515