1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Backend-dependent tests for the Python XLA client.""" 16 17import functools 18import itertools 19import re 20import threading 21import unittest 22 23from absl import flags 24from absl.testing import absltest 25from absl.testing import parameterized 26import numpy as np 27 28from tensorflow.compiler.xla.python import xla_client 29 30# pylint: disable=g-import-not-at-top 31try: 32 from tensorflow.compiler.xla.python import custom_call_for_test 33except ImportError: 34 custom_call_for_test = None 35 36bfloat16 = xla_client.bfloat16 37ops = xla_client.ops 38 39FLAGS = flags.FLAGS 40 41# We choose to ignore pylint's complaints about complex comprehensions, which we 42# use widely for parameterizing tests. 43# pylint: disable=g-complex-comprehension 44 45 46def TestFactory(xla_backend, cloud_tpu=False): 47 tests = [] 48 49 if not cloud_tpu: 50 int_dtypes = [np.int32, np.int64, np.uint32, np.uint64] 51 # TODO(phawkins): test np.float16, where supported. 52 float_dtypes = [bfloat16, np.float32, np.float64] 53 complex_dtypes = [np.complex64, np.complex128] 54 standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] 55 else: 56 int_dtypes = [np.int32, np.uint32] 57 float_dtypes = [np.float32] 58 complex_dtypes = [np.complex64] 59 standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] 60 dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] 61 62 class ComputationTest(parameterized.TestCase): 63 """Base class for running an XLA Computation through the local client.""" 64 65 def setUp(self): 66 super(ComputationTest, self).setUp() 67 self.backend = xla_backend() 68 69 def _NewComputation(self, name=None): 70 if name is None: 71 name = self.id() 72 return xla_client.XlaBuilder(name) 73 74 def _Execute(self, c, arguments): 75 compiled_c = self.backend.compile(c.build()) 76 return xla_client.execute_with_python_values( 77 compiled_c, arguments, backend=self.backend) 78 79 def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): 80 assert expected is not None 81 results = self._Execute(c, arguments) 82 self.assertLen(results, len(expected)) 83 for result, e in zip(results, expected): 84 # Numpy's comparison methods are a bit too lenient by treating inputs as 85 # "array-like", meaning that scalar 4 will be happily compared equal to 86 # [[4]]. We'd like to be more strict so assert shapes as well. 87 self.assertEqual(np.asanyarray(result).shape, np.asanyarray(e).shape) 88 assert_func(result, e) 89 90 def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): 91 self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, 92 expected) 93 94 def _ExecuteAndCompareClose(self, 95 c, 96 arguments=(), 97 expected=None, 98 rtol=1e-4, 99 atol=0): 100 self._ExecuteAndAssertWith( 101 functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), 102 c, arguments, expected) 103 104 def NumpyArrayF32(*args, **kwargs): 105 """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" 106 return np.array(*args, dtype=np.float32, **kwargs) 107 108 def NumpyArrayF64(*args, **kwargs): 109 """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" 110 return np.array(*args, dtype=np.float64, **kwargs) 111 112 def NumpyArrayS32(*args, **kwargs): 113 """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" 114 return np.array(*args, dtype=np.int32, **kwargs) 115 116 def NumpyArrayBool(*args, **kwargs): 117 """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" 118 return np.array(*args, dtype=np.bool, **kwargs) 119 120 class ComputationPrinting(absltest.TestCase): 121 122 def setUp(self): 123 super(ComputationPrinting, self).setUp() 124 self.backend = xla_backend() 125 126 def ExampleComputation(self): 127 builder = xla_client.XlaBuilder("acomputation") 128 p0 = ops.Parameter(builder, 0, xla_client.shape_from_pyval(np.float32(0))) 129 p1 = ops.Parameter( 130 builder, 1, xla_client.shape_from_pyval(np.zeros((4,), np.float32))) 131 x = ops.Mul(p0, p1) 132 ops.Add(x, x) 133 return builder.build() 134 135 @unittest.skipIf(cloud_tpu, "not implemented") 136 def testCompiledHloModuleToHloText(self): 137 computation = self.ExampleComputation() 138 executable = self.backend.compile(computation) 139 hlo_modules = executable.hlo_modules() 140 self.assertLen(hlo_modules, 1) 141 hlo_text = hlo_modules[0].to_string() 142 self.assertTrue(hlo_text.startswith("HloModule acomputation")) 143 self.assertIn("fusion", hlo_text) 144 145 @unittest.skipIf(cloud_tpu, "not implemented") 146 def testFlopEstimate(self): 147 computation = self.ExampleComputation() 148 properties = xla_client._xla.hlo_module_cost_analysis( 149 self.backend, computation.as_hlo_module()) 150 self.assertEqual(properties["flops"], 8.0) 151 152 tests.append(ComputationPrinting) 153 154 class ComputationsWithConstantsTest(ComputationTest): 155 """Tests focusing on Constant ops.""" 156 157 @parameterized.named_parameters({ 158 "testcase_name": "_{}".format(dtype.__name__), 159 "dtype": dtype, 160 } for dtype in int_dtypes + float_dtypes) 161 def testConstantScalarSum(self, dtype): 162 if dtype == np.int8 and self.backend.platform == "tpu": 163 self.skipTest("TPU doesn't support int8") 164 c = self._NewComputation() 165 ops.Add(ops.Constant(c, dtype(1.11)), ops.Constant(c, dtype(3.14))) 166 self._ExecuteAndCompareClose(c, expected=[dtype(1.11) + dtype(3.14)]) 167 168 @parameterized.named_parameters({ 169 "testcase_name": "_{}".format(dtype.__name__), 170 "dtype": dtype, 171 } for dtype in float_dtypes) 172 def testConstantVectorMul(self, dtype): 173 c = self._NewComputation() 174 ops.Mul( 175 ops.Constant(c, np.array([2.5, 3.3, -1.2, 0.7], dtype)), 176 ops.Constant(c, np.array([-1.2, 2, -2, -3], dtype))) 177 self._ExecuteAndCompareClose( 178 c, expected=[[-3, 6.6, 2.4, -2.1]], rtol=3e-3) 179 180 @parameterized.named_parameters({ 181 "testcase_name": "_{}".format(dtype.__name__), 182 "dtype": dtype, 183 } for dtype in float_dtypes) 184 def testConstantVectorScalarDiv(self, dtype): 185 c = self._NewComputation() 186 ops.Div( 187 ops.Constant(c, np.array([1.5, 2.5, 3.0, -10.8], dtype=dtype)), 188 ops.Constant(c, dtype(2.0))) 189 self._ExecuteAndCompareClose( 190 c, expected=[[0.75, 1.25, 1.5, -5.4]], rtol=2e-3) 191 192 @parameterized.named_parameters({ 193 "testcase_name": "_{}".format(dtype.__name__), 194 "dtype": dtype, 195 } for dtype in float_dtypes) 196 def testConstantVectorScalarPow(self, dtype): 197 c = self._NewComputation() 198 ops.Pow( 199 ops.Constant(c, np.array([1.5, 2.5, 3.0], dtype=dtype)), 200 ops.Constant(c, dtype(2.))) 201 self._ExecuteAndCompareClose(c, expected=[[2.25, 6.25, 9.]]) 202 203 def testIota(self): 204 c = self._NewComputation() 205 ops.Iota(c, xla_client.PrimitiveType.F32, 10) 206 self._ExecuteAndCompareExact( 207 c, expected=[np.arange(10, dtype=np.float32)]) 208 209 @parameterized.named_parameters({ 210 "testcase_name": "_{}".format(dtype.__name__), 211 "dtype": dtype, 212 } for dtype in int_dtypes) 213 def testBroadcastedIota(self, dtype): 214 c = self._NewComputation() 215 shape = xla_client.Shape.array_shape( 216 xla_client.dtype_to_etype(dtype), (2, 3)) 217 ops.Iota(c, shape, 1) 218 expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=dtype) 219 self._ExecuteAndCompareExact(c, expected=[expected]) 220 221 def testBooleanAnd(self): 222 c = self._NewComputation() 223 ops.And( 224 ops.Constant(c, NumpyArrayBool([True, False, True, False])), 225 ops.Constant(c, NumpyArrayBool([True, True, False, False]))) 226 self._ExecuteAndCompareExact(c, expected=[[True, False, False, False]]) 227 228 def testBooleanOr(self): 229 c = self._NewComputation() 230 ops.Or( 231 ops.Constant(c, NumpyArrayBool([True, False, True, False])), 232 ops.Constant(c, NumpyArrayBool([True, True, False, False]))) 233 self._ExecuteAndCompareExact(c, expected=[[True, True, True, False]]) 234 235 def testBooleanXor(self): 236 c = self._NewComputation() 237 ops.Xor( 238 ops.Constant(c, NumpyArrayBool([True, False, True, False])), 239 ops.Constant(c, NumpyArrayBool([True, True, False, False]))) 240 self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) 241 242 @parameterized.named_parameters({ 243 "testcase_name": "_{}".format(dtype.__name__), 244 "dtype": dtype, 245 } for dtype in float_dtypes) 246 def testSum2D(self, dtype): 247 c = self._NewComputation() 248 ops.Add( 249 ops.Constant(c, np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)), 250 ops.Constant(c, np.array([[1, -1, 1], [-1, 1, -1]], dtype=dtype))) 251 self._ExecuteAndCompareClose(c, expected=[[[2, 1, 4], [3, 6, 5]]]) 252 253 def testShiftLeft(self): 254 c = self._NewComputation() 255 ops.ShiftLeft( 256 ops.Constant(c, NumpyArrayS32([3])), 257 ops.Constant(c, NumpyArrayS32([2]))) 258 self._ExecuteAndCompareClose(c, expected=[[12]]) 259 260 def testShiftRightArithmetic(self): 261 c = self._NewComputation() 262 ops.ShiftRightArithmetic( 263 ops.Constant(c, NumpyArrayS32([-2])), 264 ops.Constant(c, NumpyArrayS32([1]))) 265 self._ExecuteAndCompareClose(c, expected=[[-1]]) 266 267 def testShiftRightLogical(self): 268 c = self._NewComputation() 269 ops.ShiftRightLogical( 270 ops.Constant(c, NumpyArrayS32([-1])), 271 ops.Constant(c, NumpyArrayS32([1]))) 272 self._ExecuteAndCompareClose(c, expected=[[2**31 - 1]]) 273 274 @parameterized.named_parameters({ 275 "testcase_name": "_{}".format(dtype.__name__), 276 "dtype": dtype, 277 } for dtype in float_dtypes) 278 def testSum2DWith1DBroadcastDim0(self, dtype): 279 # sum of a 2D array with a 1D array where the latter is replicated across 280 # dimension 0 to match the former's shape. 281 c = self._NewComputation() 282 ops.Add( 283 ops.Constant(c, 284 np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 285 dtype=dtype)), 286 ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), 287 broadcast_dimensions=(0,)) 288 self._ExecuteAndCompareClose( 289 c, expected=[[[11, 12, 13], [24, 25, 26], [37, 38, 39]]]) 290 291 @parameterized.named_parameters({ 292 "testcase_name": "_{}".format(dtype.__name__), 293 "dtype": dtype, 294 } for dtype in float_dtypes) 295 def testSum2DWith1DBroadcastDim1(self, dtype): 296 # sum of a 2D array with a 1D array where the latter is replicated across 297 # dimension 1 to match the former's shape. 298 c = self._NewComputation() 299 ops.Add( 300 ops.Constant(c, 301 np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 302 dtype=dtype)), 303 ops.Constant(c, np.array([10, 20, 30], dtype=dtype)), 304 broadcast_dimensions=(1,)) 305 self._ExecuteAndCompareClose( 306 c, expected=[[[11, 22, 33], [14, 25, 36], [17, 28, 39]]]) 307 308 @parameterized.named_parameters({ 309 "testcase_name": "_{}".format(dtype.__name__), 310 "dtype": dtype, 311 } for dtype in float_dtypes) 312 def testConstantAxpy(self, dtype): 313 c = self._NewComputation() 314 ops.Add( 315 ops.Mul( 316 ops.Constant(c, dtype(2)), 317 ops.Constant(c, np.array([2.2, 3.3, 4.4, 5.5], dtype=dtype))), 318 ops.Constant(c, np.array([100, -100, 200, -200], dtype))) 319 self._ExecuteAndCompareClose( 320 c, expected=[[104.4, -93.4, 208.8, -189]], rtol=2e-3) 321 322 def testCustomCall(self): 323 if self.backend.platform != "cpu": 324 self.skipTest("Test requires cpu platform") 325 c = self._NewComputation() 326 for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): 327 xla_client.register_custom_call_target(name, fn, platform="cpu") 328 ops.CustomCallWithLayout( 329 c, 330 b"test_subtract_f32", 331 operands=[ 332 ops.Constant(c, np.float32(1.25)), 333 ops.Constant(c, np.float32(0.5)) 334 ], 335 shape_with_layout=xla_client.Shape.array_shape( 336 np.dtype(np.float32), (), ()), 337 operand_shapes_with_layout=[ 338 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 339 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 340 ]) 341 self._ExecuteAndCompareClose(c, expected=[0.75]) 342 343 tests.append(ComputationsWithConstantsTest) 344 345 class ComputationFromProtoTest(absltest.TestCase): 346 """Test computation execution from HLO proto.""" 347 348 def setUp(self): 349 super(ComputationFromProtoTest, self).setUp() 350 self.backend = xla_backend() 351 352 def testExecuteFromProto(self): 353 # Build the HLO proto 354 b = xla_client.XlaBuilder("computation") 355 ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) 356 serialized_proto = b.build().as_serialized_hlo_module_proto() 357 358 # Load and execute the proto 359 c = xla_client.XlaComputation(serialized_proto) 360 ans, = xla_client.execute_with_python_values( 361 self.backend.compile(c), (), backend=self.backend) 362 np.testing.assert_equal(ans, np.int32(3)) 363 364 tests.append(ComputationFromProtoTest) 365 366 class ParametersTest(ComputationTest): 367 """Tests focusing on Parameter ops and argument-passing.""" 368 369 @parameterized.named_parameters({ 370 "testcase_name": "_{}".format(dtype.__name__), 371 "dtype": dtype, 372 } for dtype in int_dtypes) 373 def testScalarTimesVector(self, dtype): 374 c = self._NewComputation() 375 arg0 = np.array(3, dtype=dtype) 376 arg1 = np.array([10, 15, -2, 7], dtype=dtype) 377 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 378 p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) 379 ops.Mul(p0, p1) 380 self._ExecuteAndCompareExact( 381 c, arguments=[arg0, arg1], expected=[arg0 * arg1]) 382 383 # TODO(phawkins): test comparison harness doesn't support bfloat16 384 @parameterized.named_parameters({ 385 "testcase_name": "_{}".format(dtype.__name__), 386 "dtype": dtype, 387 } for dtype in float_dtypes if dtype != bfloat16) 388 def testScalarMinusVectorExplicitNumbering(self, dtype): 389 # Use explicit numbering and pass parameter_num first. Sub is used since 390 # it's not commutative and can help catch parameter reversal within the 391 # computation. 392 c = self._NewComputation() 393 arg0 = np.array(2.0, dtype=dtype) 394 arg1 = np.array([-2.3, 3.3, -4.3, 5.3], dtype=dtype) 395 p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) 396 p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) 397 ops.Sub(p1, p0) 398 self._ExecuteAndCompareClose( 399 c, arguments=[arg0, arg1], expected=[arg1 - arg0]) 400 401 tests.append(ParametersTest) 402 403 class BufferTest(ComputationTest): 404 """Tests focusing on execution with Buffers.""" 405 406 def testConstantSum(self): 407 c = self._NewComputation() 408 ops.Add( 409 ops.Constant(c, np.float32(1.11)), ops.Constant(c, np.float32(3.14))) 410 self._ExecuteAndCompareClose(c, expected=[4.25]) 411 412 def testOneParameterSum(self): 413 c = self._NewComputation() 414 ops.Add( 415 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), 416 ops.Constant(c, np.float32(3.14))) 417 self._ExecuteAndCompareClose( 418 c, arguments=[NumpyArrayF32(1.11)], expected=[4.25]) 419 420 def testTwoParameterSum(self): 421 c = self._NewComputation() 422 ops.Add( 423 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), 424 ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0.)))) 425 self._ExecuteAndCompareClose( 426 c, 427 arguments=[NumpyArrayF32(1.11), 428 NumpyArrayF32(3.14)], 429 expected=[4.25]) 430 431 @unittest.skipIf(cloud_tpu, "not implemented") 432 def testCannotCallWithDeletedBuffers(self): 433 c = self._NewComputation() 434 ops.Add( 435 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0.))), 436 ops.Constant(c, np.float32(3.14))) 437 arg = NumpyArrayF32(1.11) 438 compiled_c = self.backend.compile(c.build()) 439 arg_buffer = self.backend.buffer_from_pyval(arg) 440 arg_buffer.delete() 441 with self.assertRaises(RuntimeError): 442 compiled_c.execute([arg_buffer]) 443 444 def testXlaShape(self): 445 pyval = np.array([[1., 2.]], np.float32) 446 local_buffer = self.backend.buffer_from_pyval(pyval) 447 xla_shape = local_buffer.xla_shape() 448 self.assertEqual(xla_shape.dimensions(), (1, 2)) 449 self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) 450 451 def testBlockHostUntilReadyWorks(self): 452 arg = np.array([[1., 2.]], np.float32) 453 arg_buffer = self.backend.buffer_from_pyval(arg) 454 arg_buffer.block_host_until_ready() 455 # This test merely checks that nothing goes awry when we call 456 # block_host_until_ready(); it's difficult to test anything else. 457 458 def testBlockHostUntilReadyRaisesOnDeletedBuffer(self): 459 arg = np.array([[1., 2.]], np.float32) 460 buffer = self.backend.buffer_from_pyval(arg) 461 buffer.delete() 462 with self.assertRaisesRegex( 463 RuntimeError, 464 re.escape( 465 "BlockHostUntilReady() called on deleted or donated buffer")): 466 buffer.block_host_until_ready() 467 468 def testDeviceArrayBaseSignatures(self): 469 # When extending `DeviceArrayBase`, the object behaves as a `DeviceArray` 470 # and thus needs to correctly implement the following methods. 471 arg = np.array([[1., 2., 3.]], np.float32) 472 buffer = self.backend.buffer_from_pyval(arg) 473 if not isinstance(buffer, xla_client.DeviceArrayBase): 474 raise unittest.SkipTest( 475 "The objectof type {} do not extend DeviceArrayBase".format( 476 type(buffer))) 477 478 self.assertEqual(buffer.__array_priority__, 100) 479 self.assertEqual(buffer.shape, (1, 3)) 480 self.assertEqual(buffer.dtype, np.float32) 481 self.assertEqual(buffer.size, 3) 482 self.assertEqual(buffer.ndim, 2) 483 484 self.assertIs(buffer, buffer.block_until_ready()) 485 buffer.delete() 486 with self.assertRaises(RuntimeError): 487 buffer.block_until_ready() 488 489 def testOnDeviceSizeInBytes(self): 490 if not isinstance(self.backend, xla_client.Client): 491 self.skipTest("TPU Driver doesn't support OnDeviceSizeInBytes.") 492 arg0 = np.array([]) 493 arg1 = np.array([[0., 1., 2.]], np.float32) 494 arg2 = np.array([[3., 4., 5.]], bfloat16) 495 arg0_buffer = self.backend.buffer_from_pyval(arg0) 496 arg1_buffer = self.backend.buffer_from_pyval(arg1) 497 arg2_buffer = self.backend.buffer_from_pyval(arg2) 498 self.assertEqual(arg0_buffer.on_device_size_in_bytes(), 0) 499 # OnDeviceSizeInBytes varies depending on the platform. Confirm there's 500 # a reasonable value. 501 self.assertGreater(arg1_buffer.on_device_size_in_bytes(), 0) 502 self.assertGreater(arg2_buffer.on_device_size_in_bytes(), 0) 503 504 def testLiveBuffers(self): 505 if not isinstance(self.backend, xla_client.Client): 506 self.skipTest("TPU Driver doesn't support LiveBuffers().") 507 self.assertEmpty(self.backend.live_buffers()) 508 arg0 = np.array([]) 509 arg1 = np.array([[0., 1., 2.]], np.float32) 510 arg2 = np.array([[3., 4., 5.]], bfloat16) 511 arg0_buffer = self.backend.buffer_from_pyval(arg0) 512 arg1_buffer = self.backend.buffer_from_pyval(arg1) 513 arg2_buffer = self.backend.buffer_from_pyval(arg2) 514 self.assertLen(self.backend.live_buffers(), 3) 515 self.assertIs(self.backend.live_buffers()[0], arg2_buffer) 516 self.assertIs(self.backend.live_buffers()[1], arg1_buffer) 517 self.assertIs(self.backend.live_buffers()[2], arg0_buffer) 518 519 arg1_buffer.delete() 520 self.assertLen(self.backend.live_buffers(), 2) 521 self.assertIs(self.backend.live_buffers()[0], arg2_buffer) 522 self.assertIs(self.backend.live_buffers()[1], arg0_buffer) 523 524 arg0_buffer.delete() 525 arg2_buffer.delete() 526 self.assertEmpty(self.backend.live_buffers()) 527 528 def testCopyToHost(self): 529 arg0 = np.array([[1., 2.]], np.float32) 530 arg1 = np.array([[3., 4.]], np.float32) 531 arg0_buffer = self.backend.buffer_from_pyval(arg0) 532 arg1_buffer = self.backend.buffer_from_pyval(arg1) 533 # Prefetch two buffers using copy_to_host_async, and then retrieve their 534 # values using to_py. 535 arg0_buffer.copy_to_host_async() 536 arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. 537 arg1_buffer.copy_to_host_async() 538 np.testing.assert_equal(arg0, arg0_buffer.to_py()) 539 np.testing.assert_equal(arg1, arg1_buffer.to_py()) 540 # copy_to_host_async does nothing after to_py is called. 541 arg0_buffer.copy_to_host_async() 542 np.testing.assert_equal(arg0, arg0_buffer.to_py()) 543 544 def testDevice(self): 545 x = np.arange(8, dtype=np.int32) 546 for device in self.backend.local_devices(): 547 buf = self.backend.buffer_from_pyval(x, device=device) 548 self.assertEqual(buf.device(), device) 549 np.testing.assert_equal(x, buf.to_py()) 550 551 def testStandardTypes(self): 552 for dtype in standard_dtypes: 553 if dtype == bfloat16 or dtype == np.complex128: 554 continue 555 arr = self.backend.buffer_from_pyval(np.array([0, 1], dtype)) 556 arr = arr.to_py() 557 self.assertEqual(dtype, type(arr[0])) 558 559 tests.append(BufferTest) 560 561 class SingleOpTest(ComputationTest): 562 """Tests for single ops. 563 564 The goal here is smoke testing - to exercise the most basic functionality of 565 single XLA ops. As minimal as possible number of additional ops are added 566 around the op being tested. 567 """ 568 569 @parameterized.named_parameters({ 570 "testcase_name": "_{}".format(dtype.__name__), 571 "dtype": dtype, 572 } for dtype in float_dtypes) 573 def testConcatenate(self, dtype): 574 c = self._NewComputation() 575 args = ( 576 ops.Constant(c, np.array([1.0, 2.0, 3.0], dtype=dtype)), 577 ops.Constant(c, np.array([4.0, 5.0, 6.0], dtype=dtype)), 578 ) 579 ops.ConcatInDim(c, args, dimension=0) 580 self._ExecuteAndCompareExact( 581 c, expected=[np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype)]) 582 583 # pyformat: disable 584 @parameterized.named_parameters({ 585 "testcase_name": "_{}_{}".format(src_dtype.__name__, 586 dst_dtype.__name__), 587 "src_dtype": src_dtype, 588 "dst_dtype": dst_dtype, 589 } for src_dtype, dst_dtype in itertools.permutations( 590 [np.bool, np.int32, np.int64, np.float32, np.float64], 2)) 591 # pyformat: enable 592 def testConvertElementType(self, src_dtype, dst_dtype): 593 if ((src_dtype in [np.int64, np.float64] or 594 dst_dtype in [np.int64, np.float64]) and 595 self.backend.platform == "tpu"): 596 self.skipTest("TPU doesn't support float64") 597 c = self._NewComputation() 598 x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) 599 ops.ConvertElementType( 600 ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) 601 602 result = xla_client.execute_with_python_values( 603 self.backend.compile(c.build()), (), backend=self.backend) 604 self.assertLen(result, 1) 605 expected = np.array(x, dtype=dst_dtype) 606 607 self.assertEqual(result[0].shape, expected.shape) 608 self.assertEqual(result[0].dtype, expected.dtype) 609 np.testing.assert_equal(result[0], expected) 610 611 # pyformat: disable 612 @parameterized.named_parameters( 613 { 614 "testcase_name": "_{}_{}".format(src_dtype.__name__, 615 dst_dtype.__name__), 616 "src_dtype": src_dtype, 617 "dst_dtype": dst_dtype, 618 } 619 for dtypes in [[np.int32, np.float32], [np.int64, np.float64]] 620 for src_dtype, dst_dtype in itertools.permutations(dtypes, 2)) 621 # pyformat: enable 622 def testBitcastConvertType(self, src_dtype, dst_dtype): 623 if (np.float64 in (src_dtype, dst_dtype) and 624 self.backend.platform == "tpu"): 625 self.skipTest("TPU doesn't support float64") 626 c = self._NewComputation() 627 x = np.array([0, 1, 0, 0, 1], dtype=src_dtype) 628 ops.BitcastConvertType( 629 ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) 630 631 result = xla_client.execute_with_python_values( 632 self.backend.compile(c.build()), (), backend=self.backend) 633 self.assertLen(result, 1) 634 expected = x.view(dst_dtype) 635 636 self.assertEqual(result[0].shape, expected.shape) 637 self.assertEqual(result[0].dtype, expected.dtype) 638 np.testing.assert_equal(result[0], expected) 639 640 # TODO(b/123523486) implement AllToAll on CPU 641 def DISABLED_testAllToAllOneReplica(self): 642 samples = [ 643 NumpyArrayF32([97.0]), 644 NumpyArrayF32([64.0, 117.0]), 645 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 646 ] 647 for lhs in samples[:1]: 648 c = self._NewComputation() 649 ops.AllToAll(ops.Constant(c, lhs), 0, 0) 650 self._ExecuteAndCompareExact(c, expected=[lhs]) 651 652 def testCrossReplicaSumOneReplica(self): 653 samples = [ 654 NumpyArrayF32(42.0), 655 NumpyArrayF32([97.0]), 656 NumpyArrayF32([64.0, 117.0]), 657 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 658 ] 659 for lhs in samples: 660 c = self._NewComputation() 661 ops.CrossReplicaSum(ops.Constant(c, lhs)) 662 self._ExecuteAndCompareExact(c, expected=[lhs]) 663 664 def testReplicaId(self): 665 c = self._NewComputation() 666 _ = ops.ReplicaId(c) 667 self._ExecuteAndCompareExact(c, expected=[0]) 668 669 def testCrossReplicaSumOneReplicaWithSingletonGroup(self): 670 samples = [ 671 NumpyArrayF32(42.0), 672 NumpyArrayF32([97.0]), 673 NumpyArrayF32([64.0, 117.0]), 674 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 675 ] 676 for lhs in samples: 677 c = self._NewComputation() 678 ops.CrossReplicaSum( 679 ops.Constant(c, lhs), xla_client.make_replica_groups([[0]])) 680 self._ExecuteAndCompareExact(c, expected=[lhs]) 681 682 # TODO(phawkins): np.dot implementation doesn't support bfloat16 683 @parameterized.named_parameters({ 684 "testcase_name": "_{}".format(dtype.__name__), 685 "dtype": dtype, 686 } for dtype in float_dtypes if dtype != bfloat16) 687 def testDotMatrixVector(self, dtype): 688 c = self._NewComputation() 689 lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) 690 rhs = np.array([[10.0], [20.0]], dtype=dtype) 691 ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) 692 self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) 693 694 # TODO(phawkins): np.dot implementation doesn't support bfloat16 695 @parameterized.named_parameters({ 696 "testcase_name": "_{}".format(dtype.__name__), 697 "dtype": dtype, 698 } for dtype in float_dtypes if dtype != bfloat16) 699 def testDotMatrixMatrix(self, dtype): 700 c = self._NewComputation() 701 lhs = np.array([[2.0, 3.0], [4.0, 5.0]], dtype=dtype) 702 rhs = np.array([[10.0, 20.0], [100.0, 200.0]], dtype=dtype) 703 ops.Dot(ops.Constant(c, lhs), ops.Constant(c, rhs)) 704 self._ExecuteAndCompareClose(c, expected=[np.dot(lhs, rhs)]) 705 706 def testDotGeneral(self): 707 c = self._NewComputation() 708 rng = np.random.RandomState(0) 709 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 710 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 711 dimension_numbers = xla_client.make_dot_dimension_numbers( 712 (([2], [1]), ([0], [0]))) 713 ops.DotGeneral( 714 ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) 715 self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) 716 717 def testDotGeneralWithDotDimensionNumbersProto(self): 718 c = self._NewComputation() 719 rng = np.random.RandomState(0) 720 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 721 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 722 723 dimension_numbers = xla_client.DotDimensionNumbers() 724 dimension_numbers.lhs_contracting_dimensions.append(2) 725 dimension_numbers.rhs_contracting_dimensions.append(1) 726 dimension_numbers.lhs_batch_dimensions.append(0) 727 dimension_numbers.rhs_batch_dimensions.append(0) 728 729 ops.DotGeneral( 730 ops.Constant(c, lhs), ops.Constant(c, rhs), dimension_numbers) 731 self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) 732 733 def testDotGeneralWithPrecisionConfig(self): 734 c = self._NewComputation() 735 rng = np.random.RandomState(0) 736 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 737 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 738 dimension_numbers = xla_client.make_dot_dimension_numbers( 739 (([2], [1]), ([0], [0]))) 740 config = xla_client.PrecisionConfig() 741 config.operand_precision.append(config.Precision.HIGH) 742 config.operand_precision.append(config.Precision.HIGHEST) 743 ops.DotGeneral( 744 ops.Constant(c, lhs), 745 ops.Constant(c, rhs), 746 dimension_numbers, 747 precision_config=config) 748 self._ExecuteAndCompareClose(c, expected=[np.matmul(lhs, rhs)], rtol=1e-6) 749 750 def testConvGeneralDilatedF32(self): 751 c = self._NewComputation() 752 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 753 lhs = a(1, 1, 2, 3) 754 rhs = a(1, 1, 1, 2) * 10 755 strides = [1, 1] 756 pads = [(1, 0), (0, 1)] 757 lhs_dilation = (2, 1) 758 rhs_dilation = (1, 1) 759 dimension_numbers = xla_client.make_convolution_dimension_numbers( 760 ("NCHW", "OIHW", "NCHW"), 2) 761 ops.ConvGeneralDilated( 762 ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, 763 lhs_dilation, rhs_dilation, dimension_numbers) 764 result = np.array([[[ 765 [0., 0., 0.], 766 [10., 20., 0.], 767 [0., 0., 0.], 768 [40., 50., 0.], 769 ]]]) 770 self._ExecuteAndCompareClose(c, expected=[result]) 771 772 def testConvGeneralDilatedF32WithPrecisionConfig(self): 773 c = self._NewComputation() 774 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 775 lhs = a(1, 1, 2, 3) 776 rhs = a(1, 1, 1, 2) * 10 777 strides = [1, 1] 778 pads = [(1, 0), (0, 1)] 779 lhs_dilation = (2, 1) 780 rhs_dilation = (1, 1) 781 dimension_numbers = xla_client.make_convolution_dimension_numbers( 782 ("NCHW", "OIHW", "NCHW"), 2) 783 config = xla_client.PrecisionConfig() 784 config.operand_precision.append(config.Precision.HIGHEST) 785 config.operand_precision.append(config.Precision.DEFAULT) 786 ops.ConvGeneralDilated( 787 ops.Constant(c, lhs), 788 ops.Constant(c, rhs), 789 strides, 790 pads, 791 lhs_dilation, 792 rhs_dilation, 793 dimension_numbers, 794 precision_config=config) 795 result = np.array([[[ 796 [0., 0., 0.], 797 [10., 20., 0.], 798 [0., 0., 0.], 799 [40., 50., 0.], 800 ]]]) 801 self._ExecuteAndCompareClose(c, expected=[result]) 802 803 def testConvGeneralDilatedPermutedF32(self): 804 c = self._NewComputation() 805 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 806 lhs = a(1, 1, 2, 3) 807 rhs = a(1, 1, 1, 2) * 10 808 strides = [1, 1] 809 pads = [(1, 0), (0, 1)] 810 lhs_dilation = (2, 1) 811 rhs_dilation = (1, 1) 812 813 dimension_numbers = xla_client.make_convolution_dimension_numbers( 814 ("NHWC", "OIHW", "CWNH"), 2) 815 ops.ConvGeneralDilated( 816 ops.Constant(c, np.transpose(lhs, 817 (0, 2, 3, 1))), ops.Constant(c, rhs), 818 strides, pads, lhs_dilation, rhs_dilation, dimension_numbers) 819 result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], 820 [40., 50., 0.]]]]) 821 self._ExecuteAndCompareClose( 822 c, expected=[np.transpose(result, (1, 3, 0, 2))]) 823 824 def testConvGeneralDilatedGroupedConvolutionF32(self): 825 c = self._NewComputation() 826 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 827 lhs = a(1, 2, 2, 3) 828 rhs = a(2, 1, 1, 2) * 10 829 strides = [1, 1] 830 pads = [(1, 0), (0, 1)] 831 lhs_dilation = (2, 1) 832 rhs_dilation = (1, 1) 833 dimension_numbers = xla_client.make_convolution_dimension_numbers( 834 ("NCHW", "OIHW", "NCHW"), 2) 835 feature_group_count = 2 836 ops.ConvGeneralDilated( 837 ops.Constant(c, lhs), ops.Constant(c, rhs), strides, pads, 838 lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count) 839 result = np.array([[[ 840 [0., 0., 0.], 841 [10., 20., 0.], 842 [0., 0., 0.], 843 [40., 50., 0.], 844 ], [ 845 [0., 0., 0.], 846 [330., 380., 160.], 847 [0., 0., 0.], 848 [480., 530., 220.], 849 ]]]) 850 self._ExecuteAndCompareClose(c, expected=[result]) 851 852 def testBooleanNot(self): 853 c = self._NewComputation() 854 arr = NumpyArrayBool([True, False, True]) 855 ops.Not(ops.Constant(c, arr)) 856 self._ExecuteAndCompareClose(c, expected=[~arr]) 857 858 def testPopulationCount(self): 859 c = self._NewComputation() 860 arr = NumpyArrayS32([3, 0, 1]) 861 ops.PopulationCount(ops.Constant(c, arr)) 862 self._ExecuteAndCompareClose(c, expected=[np.array([2, 0, 1])]) 863 864 def testCountLeadingZeros(self): 865 c = self._NewComputation() 866 arr = NumpyArrayS32([0x7FFF, 0x12345678]) 867 ops.Clz(ops.Constant(c, arr)) 868 self._ExecuteAndCompareClose(c, expected=[[17, 3]]) 869 870 def testExp(self): 871 c = self._NewComputation() 872 arr = NumpyArrayF32([3.3, 12.1]) 873 ops.Exp(ops.Constant(c, arr)) 874 self._ExecuteAndCompareClose(c, expected=[np.exp(arr)]) 875 876 def testExpm1(self): 877 c = self._NewComputation() 878 arr = NumpyArrayF32([3.3, 12.1]) 879 ops.Expm1(ops.Constant(c, arr)) 880 self._ExecuteAndCompareClose(c, expected=[np.expm1(arr)]) 881 882 def testRound(self): 883 c = self._NewComputation() 884 arr = NumpyArrayF32([3.3, 12.1]) 885 ops.Round(ops.Constant(c, arr)) 886 self._ExecuteAndCompareClose(c, expected=[np.round(arr)]) 887 888 def testLog(self): 889 c = self._NewComputation() 890 arr = NumpyArrayF32([3.3, 12.1]) 891 ops.Log(ops.Constant(c, arr)) 892 self._ExecuteAndCompareClose(c, expected=[np.log(arr)]) 893 894 def testLog1p(self): 895 c = self._NewComputation() 896 arr = NumpyArrayF32([3.3, 12.1]) 897 ops.Log1p(ops.Constant(c, arr)) 898 self._ExecuteAndCompareClose(c, expected=[np.log1p(arr)]) 899 900 def testNeg(self): 901 c = self._NewComputation() 902 arr = NumpyArrayF32([3.3, 12.1]) 903 ops.Neg(ops.Constant(c, arr)) 904 self._ExecuteAndCompareClose(c, expected=[-arr]) 905 906 def testFloor(self): 907 c = self._NewComputation() 908 arr = NumpyArrayF32([3.3, 12.1]) 909 ops.Floor(ops.Constant(c, arr)) 910 self._ExecuteAndCompareClose(c, expected=[np.floor(arr)]) 911 912 def testCeil(self): 913 c = self._NewComputation() 914 arr = NumpyArrayF32([3.3, 12.1]) 915 ops.Ceil(ops.Constant(c, arr)) 916 self._ExecuteAndCompareClose(c, expected=[np.ceil(arr)]) 917 918 def testAbs(self): 919 c = self._NewComputation() 920 arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) 921 ops.Abs(ops.Constant(c, arr)) 922 self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) 923 924 def testTanhF32(self): 925 c = self._NewComputation() 926 arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) 927 ops.Tanh(ops.Constant(c, arr)) 928 self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) 929 930 def testTanhF64(self): 931 if self.backend.platform == "tpu": 932 self.skipTest("TPU doesn't support 64bit tanh") 933 c = self._NewComputation() 934 arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001]) 935 ops.Tanh(ops.Constant(c, arr)) 936 self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12) 937 938 def testTranspose(self): 939 940 def _TransposeAndTest(array, permutation): 941 c = self._NewComputation() 942 ops.Transpose(ops.Constant(c, array), permutation) 943 expected = np.transpose(array, permutation) 944 self._ExecuteAndCompareClose(c, expected=[expected]) 945 946 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) 947 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) 948 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) 949 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) 950 951 arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) 952 for permutation in itertools.permutations(range(arr.ndim)): 953 _TransposeAndTest(arr, permutation) 954 _TransposeAndTest(np.asfortranarray(arr), permutation) 955 956 def testEq(self): 957 c = self._NewComputation() 958 ops.Eq( 959 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), 960 ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) 961 self._ExecuteAndCompareExact(c, expected=[[False, True, True, False]]) 962 963 def testNe(self): 964 c = self._NewComputation() 965 ops.Ne( 966 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4])), 967 ops.Constant(c, NumpyArrayS32([4, 2, 3, 1]))) 968 self._ExecuteAndCompareExact(c, expected=[[True, False, False, True]]) 969 970 ops.Ne( 971 ops.Constant(c, NumpyArrayF32([-2.0, 0.0, 972 float("nan"), 973 float("nan")])), 974 ops.Constant(c, NumpyArrayF32([2.0, -0.0, 1.0, 975 float("nan")]))) 976 self._ExecuteAndAssertWith( 977 np.testing.assert_allclose, 978 c, (), 979 expected=[[True, False, True, True]]) 980 981 def testGt(self): 982 c = self._NewComputation() 983 ops.Gt( 984 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 985 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 986 self._ExecuteAndCompareExact( 987 c, expected=[[False, True, True, False, False]]) 988 989 def testGe(self): 990 c = self._NewComputation() 991 ops.Ge( 992 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 993 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 994 self._ExecuteAndCompareExact( 995 c, expected=[[True, True, True, False, False]]) 996 997 def testLt(self): 998 c = self._NewComputation() 999 ops.Lt( 1000 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1001 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1002 self._ExecuteAndCompareExact( 1003 c, expected=[[False, False, False, True, True]]) 1004 1005 def testLe(self): 1006 c = self._NewComputation() 1007 ops.Le( 1008 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 9])), 1009 ops.Constant(c, NumpyArrayS32([1, 0, 2, 7, 12]))) 1010 self._ExecuteAndCompareExact( 1011 c, expected=[[True, False, False, True, True]]) 1012 1013 def testMax(self): 1014 c = self._NewComputation() 1015 ops.Max( 1016 ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 1017 ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 1018 self._ExecuteAndCompareExact(c, expected=[[1.0, 2.0, 3.0, 7.0, 12.0]]) 1019 1020 def testMaxExplicitBroadcastDim0(self): 1021 c = self._NewComputation() 1022 ops.Max( 1023 ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1024 ops.Constant(c, NumpyArrayF32([3, 4, 5])), 1025 broadcast_dimensions=(0,)) 1026 self._ExecuteAndCompareExact( 1027 c, expected=[[[3, 3, 3], [4, 5, 6], [7, 8, 9]]]) 1028 1029 def testMaxExplicitBroadcastDim1(self): 1030 c = self._NewComputation() 1031 ops.Max( 1032 ops.Constant(c, NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1033 ops.Constant(c, NumpyArrayF32([3, 4, 5])), 1034 broadcast_dimensions=(1,)) 1035 self._ExecuteAndCompareExact( 1036 c, expected=[[[3, 4, 5], [4, 5, 6], [7, 8, 9]]]) 1037 1038 def testMin(self): 1039 c = self._NewComputation() 1040 ops.Min( 1041 ops.Constant(c, NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 1042 ops.Constant(c, NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 1043 self._ExecuteAndCompareExact(c, expected=[[1.0, 0.0, 2.0, 4.0, 9.0]]) 1044 1045 def testPad(self): 1046 c = self._NewComputation() 1047 ops.Pad( 1048 ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 1049 ops.Constant(c, NumpyArrayF32(0.0)), 1050 xla_client.make_padding_config([(1, 2, 1), (0, 1, 0)])) 1051 self._ExecuteAndCompareClose( 1052 c, 1053 expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], 1054 [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) 1055 1056 def testPadWithPaddingConfig(self): 1057 c = self._NewComputation() 1058 padding_config = xla_client.PaddingConfig() 1059 for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: 1060 dimension = xla_client.PaddingConfigDimension() 1061 dimension.edge_padding_low = lo 1062 dimension.edge_padding_high = hi 1063 dimension.interior_padding = interior 1064 padding_config.dimensions.append(dimension) 1065 ops.Pad( 1066 ops.Constant(c, NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 1067 ops.Constant(c, NumpyArrayF32(0.0)), padding_config) 1068 self._ExecuteAndCompareClose( 1069 c, 1070 expected=[[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], 1071 [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]) 1072 1073 def testReshape(self): 1074 c = self._NewComputation() 1075 ops.Reshape( 1076 ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), 1077 dimensions=[0, 1], 1078 new_sizes=[2, 3]) 1079 self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [4, 5, 6]]]) 1080 1081 def testCollapse(self): 1082 c = self._NewComputation() 1083 ops.Collapse( 1084 ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1085 dimensions=[1, 2]) 1086 self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3, 4], [5, 6, 7, 8]]]) 1087 1088 def testRev(self): 1089 c = self._NewComputation() 1090 ops.Rev( 1091 ops.Constant(c, NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1092 dimensions=[0, 2]) 1093 self._ExecuteAndCompareExact( 1094 c, expected=[[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]]) 1095 1096 def testReducePrecision(self): 1097 c = self._NewComputation() 1098 ops.ReducePrecision( 1099 ops.Constant(c, NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), 1100 exponent_bits=8, 1101 mantissa_bits=7) 1102 self._ExecuteAndCompareClose(c, expected=[[float.fromhex("0x1.32p-3")]]) 1103 1104 def testClampF32(self): 1105 c = self._NewComputation() 1106 ops.Clamp( 1107 ops.Constant(c, NumpyArrayF32(-1)), 1108 ops.Constant(c, NumpyArrayF32([-2, -1, 0, 1, 2, 3])), 1109 ops.Constant(c, NumpyArrayF32(2))) 1110 self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) 1111 1112 def testClampS32(self): 1113 c = self._NewComputation() 1114 ops.Clamp( 1115 ops.Constant(c, NumpyArrayS32(-1)), 1116 ops.Constant(c, NumpyArrayS32([-2, -1, 0, 1, 2, 3])), 1117 ops.Constant(c, NumpyArrayS32(2))) 1118 self._ExecuteAndCompareExact(c, expected=[[-1, -1, 0, 1, 2, 2]]) 1119 1120 def testSelect(self): 1121 c = self._NewComputation() 1122 ops.Select( 1123 ops.Constant(c, NumpyArrayBool([True, False, False, True, False])), 1124 ops.Constant(c, NumpyArrayS32([1, 2, 3, 4, 5])), 1125 ops.Constant(c, NumpyArrayS32([-1, -2, -3, -4, -5]))) 1126 self._ExecuteAndCompareExact(c, expected=[[1, -2, -3, 4, -5]]) 1127 1128 def testSlice(self): 1129 c = self._NewComputation() 1130 ops.Slice( 1131 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1132 [1, 0], [3, 2], [1, 1]) 1133 self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) 1134 1135 def testSliceInDim(self): 1136 c = self._NewComputation() 1137 ops.SliceInDim( 1138 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1139 start_index=1, 1140 limit_index=2, 1141 stride=1, 1142 dimno=1) 1143 self._ExecuteAndCompareExact(c, expected=[[[2], [5], [8]]]) 1144 ops.SliceInDim( 1145 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1146 start_index=0, 1147 limit_index=3, 1148 stride=2, 1149 dimno=0) 1150 self._ExecuteAndCompareExact(c, expected=[[[1, 2, 3], [7, 8, 9]]]) 1151 1152 def testDynamicSlice(self): 1153 c = self._NewComputation() 1154 ops.DynamicSlice( 1155 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1156 [ops.Constant(c, NumpyArrayS32([1, 0]))], [2, 2]) 1157 self._ExecuteAndCompareExact(c, expected=[[[4, 5], [7, 8]]]) 1158 1159 def testDynamicUpdateSlice(self): 1160 c = self._NewComputation() 1161 ops.DynamicUpdateSlice( 1162 ops.Constant(c, NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1163 ops.Constant(c, NumpyArrayS32([[1, 2], [3, 4]])), 1164 [ops.Constant(c, NumpyArrayS32([1, 1]))]) 1165 self._ExecuteAndCompareExact( 1166 c, expected=[[[1, 2, 3], [4, 1, 2], [7, 3, 4]]]) 1167 1168 def testTuple(self): 1169 c = self._NewComputation() 1170 ops.Tuple(c, [ 1171 ops.Constant(c, np.int32(42)), 1172 ops.Constant(c, NumpyArrayF32([1.0, 2.0])), 1173 ops.Constant(c, NumpyArrayBool([True, False, False, True])) 1174 ]) 1175 result = xla_client.execute_with_python_values( 1176 self.backend.compile(c.build()), (), backend=self.backend) 1177 self.assertLen(result, 3) 1178 np.testing.assert_equal(result[0], 42) 1179 np.testing.assert_allclose(result[1], [1.0, 2.0]) 1180 np.testing.assert_equal(result[2], [True, False, False, True]) 1181 1182 def testGetTupleElement(self): 1183 c = self._NewComputation() 1184 ops.GetTupleElement( 1185 ops.Tuple(c, [ 1186 ops.Constant(c, np.int32(42)), 1187 ops.Constant(c, NumpyArrayF32([1.0, 2.0])), 1188 ops.Constant(c, NumpyArrayBool([True, False, False, True])) 1189 ]), 1) 1190 self._ExecuteAndCompareClose(c, expected=[[1.0, 2.0]]) 1191 1192 def testBroadcast(self): 1193 c = self._NewComputation() 1194 ops.Broadcast( 1195 ops.Constant(c, NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) 1196 self._ExecuteAndCompareExact( 1197 c, expected=[[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]]) 1198 1199 def testBroadcastInDim(self): 1200 c = self._NewComputation() 1201 ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [0]) 1202 self._ExecuteAndCompareExact(c, expected=[[[1, 1], [2, 2]]]) 1203 ops.BroadcastInDim(ops.Constant(c, NumpyArrayS32([1, 2])), [2, 2], [1]) 1204 self._ExecuteAndCompareExact(c, expected=[[[1, 2], [1, 2]]]) 1205 1206 def testRngNormal(self): 1207 shape = (2, 3) 1208 c = self._NewComputation() 1209 ops.RngNormal( 1210 ops.Constant(c, NumpyArrayF32(0.)), 1211 ops.Constant(c, NumpyArrayF32(1.)), 1212 shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, 1213 shape)) 1214 result = xla_client.execute_with_python_values( 1215 self.backend.compile(c.build()), (), backend=self.backend) 1216 # since the result is random, we just check shape and uniqueness 1217 self.assertLen(result, 1) 1218 self.assertEqual(result[0].shape, shape) 1219 self.assertLen(np.unique(result[0]), np.prod(shape)) 1220 1221 def testRngUniformF32(self): 1222 lo, hi = 2., 4. 1223 shape = (2, 3) 1224 c = self._NewComputation() 1225 ops.RngUniform( 1226 ops.Constant(c, NumpyArrayF32(lo)), 1227 ops.Constant(c, NumpyArrayF32(hi)), 1228 shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, 1229 shape)) 1230 result = xla_client.execute_with_python_values( 1231 self.backend.compile(c.build()), (), backend=self.backend) 1232 # since the result is random, we just check shape, uniqueness, and range 1233 self.assertLen(result, 1) 1234 self.assertEqual(result[0].shape, shape) 1235 self.assertLen(np.unique(result[0]), np.prod(shape)) 1236 self.assertTrue(np.all(lo <= result[0])) 1237 self.assertTrue(np.all(result[0] < hi)) 1238 1239 def testRngUniformS32(self): 1240 lo, hi = 2, 4 1241 shape = (2, 3) 1242 c = self._NewComputation() 1243 ops.RngUniform( 1244 ops.Constant(c, NumpyArrayS32(lo)), 1245 ops.Constant(c, NumpyArrayS32(hi)), 1246 shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, 1247 shape)) 1248 result = xla_client.execute_with_python_values( 1249 self.backend.compile(c.build()), (), backend=self.backend) 1250 # since the result is random, we just check shape, integrality, and range 1251 self.assertLen(result, 1) 1252 self.assertEqual(result[0].shape, shape) 1253 self.assertEqual(result[0].dtype, np.int32) 1254 self.assertTrue(np.all(lo <= result[0])) 1255 self.assertTrue(np.all(result[0] < hi)) 1256 1257 def testCholesky(self): 1258 l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], 1259 dtype=np.float32) 1260 c = self._NewComputation() 1261 ops.Cholesky(ops.Constant(c, np.tril(np.dot(l, l.T)))) 1262 self._ExecuteAndCompareClose(c, expected=[l], rtol=1e-4) 1263 1264 def testSort(self): 1265 keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) 1266 c = self._NewComputation() 1267 ops.Sort(c, [ops.Constant(c, keys)], is_stable=True) 1268 self._ExecuteAndCompareClose( 1269 c, 1270 expected=[np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)]) 1271 1272 def testSortKeyVal(self): 1273 keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) 1274 values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) 1275 c = self._NewComputation() 1276 ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) 1277 result = xla_client.execute_with_python_values( 1278 self.backend.compile(c.build()), (), backend=self.backend) 1279 self.assertLen(result, 2) 1280 np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) 1281 np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) 1282 1283 def testSortCustomComparator(self): 1284 b = self._NewComputation("comparator") 1285 p0 = ops.Parameter(b, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))) 1286 q0 = ops.Parameter(b, 1, xla_client.shape_from_pyval(NumpyArrayF32(0))) 1287 p1 = ops.Parameter(b, 2, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1288 q1 = ops.Parameter(b, 3, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1289 ops.Or(ops.Lt(p0, q0), ops.And(ops.Eq(p0, q0), ops.Gt(p1, q1))) 1290 comparator = b.build() 1291 1292 keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) 1293 values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) 1294 c = self._NewComputation() 1295 ops.Sort( 1296 c, (ops.Constant(c, keys), ops.Constant(c, values)), 1297 dimension=1, 1298 comparator=comparator) 1299 result = xla_client.execute_with_python_values( 1300 self.backend.compile(c.build()), (), backend=self.backend) 1301 self.assertLen(result, 2) 1302 np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) 1303 np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) 1304 1305 def testQR(self): 1306 a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], 1307 [10, 63, 166, 310]], 1308 dtype=np.float32) 1309 c = self._NewComputation() 1310 ops.Tuple(c, ops.QR(ops.Constant(c, a), full_matrices=True)) 1311 q, r = self._Execute(c, ()) 1312 np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) 1313 1314 def testEigh(self): 1315 a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], 1316 [10, 63, 166, 310]], 1317 dtype=np.float32) 1318 a = (a + a.T) / 2 1319 1320 c = self._NewComputation() 1321 ops.Tuple(c, ops.Eigh(ops.Constant(c, a), lower=True)) 1322 # TODO(b/129396575): Turn this test back on when it passes without 1323 # fastmath. 1324 # v, w = self._Execute(c, ()) 1325 # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) 1326 1327 def testSVD(self): 1328 a = np.array([[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], 1329 [10, 63, 166, 310]], 1330 dtype=np.float32) 1331 c = self._NewComputation() 1332 ops.Tuple(c, ops.SVD(ops.Constant(c, a))) 1333 u, d, v = self._Execute(c, ()) 1334 self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) 1335 1336 def testTriangularSolve(self): 1337 a_vals = np.array( 1338 [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], 1339 dtype=np.float32) 1340 b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 1341 dtype=np.float32) 1342 1343 c = self._NewComputation() 1344 ops.TriangularSolve( 1345 ops.Constant(c, a_vals), 1346 ops.Constant(c, b_vals), 1347 left_side=False, 1348 lower=True, 1349 transpose_a=ops.TriangularSolveOptions_Transpose.TRANSPOSE, 1350 unit_diagonal=False) 1351 self._ExecuteAndCompareClose( 1352 c, 1353 expected=[ 1354 np.array([ 1355 [0.5, 0.08333334, 0.04629629, 0.03367003], 1356 [2.5, -0.25, -0.1388889, -0.1010101], 1357 [4.5, -0.58333331, -0.32407406, -0.23569024], 1358 ], 1359 dtype=np.float32) 1360 ], 1361 rtol=1e-4) 1362 1363 def testIsConstant(self): 1364 c = self._NewComputation() 1365 a = ops.Constant(c, np.int32(3)) 1366 b = ops.Constant(c, np.int32(1)) 1367 x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayS32(0))) 1368 const_expr = ops.Sub(b, a) 1369 non_const_expr = ops.Mul(const_expr, x) 1370 self.assertTrue(c.is_constant(const_expr)) 1371 self.assertFalse(c.is_constant(non_const_expr)) 1372 1373 def testGather(self): 1374 a = np.arange(9).astype(np.int32).reshape((3, 3)) 1375 indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) 1376 dnums = xla_client.GatherDimensionNumbers() 1377 dnums.offset_dims.append(1) 1378 dnums.offset_dims.append(2) 1379 dnums.start_index_map.append(0) 1380 dnums.start_index_map.append(1) 1381 dnums.index_vector_dim = 2 1382 c = self._NewComputation() 1383 ops.Gather( 1384 ops.Constant(c, a), 1385 ops.Constant(c, indices), 1386 dnums, 1387 slice_sizes=[1, 1]) 1388 g, = self._Execute(c, ()) 1389 expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) 1390 np.testing.assert_allclose(g, expected, rtol=1e-4) 1391 1392 def testFft(self): 1393 if self.backend.platform == "tpu": 1394 self.skipTest("TPU only supports 1D FFT") 1395 shape = [2, 3, 4, 5] 1396 rng = np.random.RandomState(0) 1397 a = rng.randn(*shape) + 1.0j * rng.randn(*shape) 1398 a = a.astype(np.complex64) 1399 # FFT 1400 c = self._NewComputation() 1401 ops.Fft(ops.Constant(c, a), xla_client.FftType.FFT, shape[-3:]) 1402 self._ExecuteAndCompareClose( 1403 c, expected=[np.fft.fftn(a, axes=(1, 2, 3))], rtol=1e-4) 1404 # IFFT 1405 c = self._NewComputation() 1406 ops.Fft(ops.Constant(c, a), xla_client.FftType.IFFT, shape[-3:]) 1407 self._ExecuteAndCompareClose( 1408 c, expected=[np.fft.ifftn(a, axes=(1, 2, 3))], rtol=1e-4) 1409 # RFFT 1410 b = rng.randn(*shape).astype(np.float32) 1411 c = self._NewComputation() 1412 ops.Fft(ops.Constant(c, b), xla_client.FftType.RFFT, shape[-3:]) 1413 self._ExecuteAndCompareClose( 1414 c, expected=[np.fft.rfftn(b, axes=(1, 2, 3))], rtol=1e-4) 1415 # IRFFT 1416 c = self._NewComputation() 1417 ops.Fft(ops.Constant(c, a), xla_client.FftType.IRFFT, [3, 4, 8]) 1418 self._ExecuteAndCompareClose( 1419 c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=1e-4) 1420 1421 def testNextAfter(self): 1422 c = self._NewComputation() 1423 ops.NextAfter( 1424 ops.Constant(c, np.array([1, 2], dtype=np.float32)), 1425 ops.Constant(c, np.array([2, 1], dtype=np.float32))) 1426 out, = self._Execute(c, ()) 1427 eps = np.finfo(np.float32).eps 1428 np.testing.assert_equal( 1429 np.array([eps + 1, 2 - eps], dtype=np.float32), out) 1430 1431 @parameterized.named_parameters({ 1432 "testcase_name": "_{}".format(dtype.__name__), 1433 "dtype": dtype, 1434 } for dtype in float_dtypes) 1435 def testRegularizedIncompleteBeta(self, dtype): 1436 x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538], 1437 dtype=dtype) 1438 a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606], 1439 dtype=dtype) 1440 b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677], 1441 dtype=dtype) 1442 c = self._NewComputation() 1443 ops.RegularizedIncompleteBeta( 1444 ops.Constant(c, a), ops.Constant(c, b), ops.Constant(c, x)) 1445 expected = np.array( 1446 [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) 1447 self._ExecuteAndCompareClose(c, expected=[expected], rtol=2e-2) 1448 1449 tests.append(SingleOpTest) 1450 1451 class EmbeddedComputationsTest(ComputationTest): 1452 """Tests for XLA graphs with embedded computations (such as maps).""" 1453 1454 def _CreateConstantComputation(self, in_dtype, out_dtype): 1455 """Computation (A) -> B that returns a constant 1 for any input.""" 1456 c = self._NewComputation("constant_{}_{}_one".format( 1457 in_dtype.__name__, out_dtype.__name__)) 1458 ops.Parameter(c, 0, 1459 xla_client.shape_from_pyval(np.array(0, dtype=in_dtype))) 1460 ops.Constant(c, out_dtype(1)) 1461 return c.build() 1462 1463 def _CreateMulBy2Computation(self, dtype): 1464 """Computation (dtype) -> dtype that multiplies its parameter by 2.""" 1465 c = self._NewComputation("mul_f32_by2") 1466 ops.Mul( 1467 ops.Parameter( 1468 c, 0, 1469 xla_client.shape_from_pyval(np.array( 1470 0, dtype=dtype)).with_major_to_minor_layout_if_absent()), 1471 ops.Constant(c, dtype(2.0))) 1472 return c.build() 1473 1474 def _CreateMulF32ByParamComputation(self): 1475 """Computation (f32) -> f32 that multiplies one parameter by the other.""" 1476 c = self._NewComputation("mul_f32_by_param") 1477 ops.Mul( 1478 ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(0))), 1479 ops.Parameter(c, 1, xla_client.shape_from_pyval(NumpyArrayF32(0)))) 1480 return c.build() 1481 1482 def _CreateBinaryAddComputation(self, dtype): 1483 """Computation (dtype, dtype) -> dtype that adds its two parameters.""" 1484 c = self._NewComputation("add_param0_by_param1") 1485 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1486 shape = shape.with_major_to_minor_layout_if_absent() 1487 ops.Add(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) 1488 return c.build() 1489 1490 def _CreateBinaryGeComputation(self, dtype): 1491 """Computation (dtype, dtype) -> bool that tests param0 >= param1.""" 1492 c = self._NewComputation("param0_lt_param1") 1493 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1494 shape = shape.with_major_to_minor_layout_if_absent() 1495 ops.Ge(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) 1496 return c.build() 1497 1498 def _MakeSample3DArray(self, dtype): 1499 return np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], 1500 [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], 1501 dtype=dtype) 1502 1503 @parameterized.named_parameters({ 1504 "testcase_name": "_{}".format(dtype.__name__), 1505 "dtype": dtype, 1506 } for dtype in float_dtypes) 1507 def testCall(self, dtype): 1508 c = self._NewComputation() 1509 ops.Call( 1510 c, 1511 self._CreateMulBy2Computation(dtype), 1512 operands=(ops.Constant(c, dtype(5.0)),)) 1513 self._ExecuteAndCompareClose(c, expected=[10.0]) 1514 1515 @parameterized.named_parameters({ 1516 "testcase_name": "_{}_{}".format(in_dtype.__name__, out_dtype.__name__), 1517 "in_dtype": in_dtype, 1518 "out_dtype": out_dtype, 1519 } for in_dtype, out_dtype in [[np.float32, np.int32]]) 1520 def testMapEachElementToConstant(self, in_dtype, out_dtype): 1521 c = self._NewComputation() 1522 ops.Map(c, 1523 [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=in_dtype))], 1524 self._CreateConstantComputation(in_dtype, out_dtype), [0]) 1525 self._ExecuteAndCompareExact(c, expected=[[1, 1, 1, 1]]) 1526 1527 @parameterized.named_parameters({ 1528 "testcase_name": "_{}".format(dtype.__name__), 1529 "dtype": dtype, 1530 } for dtype in float_dtypes) 1531 def testMapMulBy2(self, dtype): 1532 if dtype == np.float64 and self.backend.platform == "tpu": 1533 self.skipTest("TPU doesn't support float64") 1534 c = self._NewComputation() 1535 ops.Map(c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], 1536 self._CreateMulBy2Computation(dtype), [0]) 1537 self._ExecuteAndCompareClose(c, expected=[[2.0, 4.0, 6.0, 8.0]]) 1538 1539 @parameterized.named_parameters({ 1540 "testcase_name": "_{}".format(dtype.__name__), 1541 "dtype": dtype, 1542 } for dtype in float_dtypes) 1543 def testSimpleMapChain(self, dtype): 1544 if dtype == np.float64 and self.backend.platform == "tpu": 1545 self.skipTest("TPU doesn't support float64") 1546 # Chains a map of constant-out with a map of mul-by-2 1547 c = self._NewComputation() 1548 const = ops.Map( 1549 c, [ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype))], 1550 self._CreateConstantComputation(dtype, dtype), [0]) 1551 ops.Map(c, [const], self._CreateMulBy2Computation(dtype), [0]) 1552 self._ExecuteAndCompareClose(c, expected=[[2.0, 2.0, 2.0, 2.0]]) 1553 1554 # TODO(b/154752816): bfloat16 crashes in evaluator. 1555 @parameterized.named_parameters({ 1556 "testcase_name": "_{}".format(dtype.__name__), 1557 "dtype": dtype, 1558 } for dtype in float_dtypes if dtype != bfloat16) 1559 def testDivVectorsWithMap(self, dtype): 1560 1561 def DivComputation(): 1562 c = self._NewComputation("div_param0_by_param1") 1563 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1564 ops.Div(ops.Parameter(c, 0, shape), ops.Parameter(c, 1, shape)) 1565 return c.build() 1566 1567 c = self._NewComputation() 1568 ops.Map(c, (ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)), 1569 ops.Constant(c, np.array([5.0, 5.0, 4.0, 4.0], dtype=dtype))), 1570 DivComputation(), [0]) 1571 self._ExecuteAndCompareClose( 1572 c, expected=[[0.2, 0.4, 0.75, 1.0]], rtol=1e-3) 1573 1574 @parameterized.named_parameters({ 1575 "testcase_name": "_{}".format(dtype.__name__), 1576 "dtype": dtype, 1577 } for dtype in float_dtypes) 1578 def testSelectAndScatter(self, dtype): 1579 if dtype == np.float64 and self.backend.platform == "tpu": 1580 self.skipTest("TPU doesn't support float64") 1581 c = self._NewComputation() 1582 operand = ops.Constant( 1583 c, np.array([[1., 2., 6.], [4., 5., 3.]], dtype=dtype)) 1584 window_dimensions = (2, 1) 1585 window_strides = (1, 2) 1586 padding = xla_client.window_padding_type_to_pad_values( 1587 xla_client.PaddingType.VALID, 1588 c.get_shape(operand).dimensions(), window_dimensions, window_strides) 1589 ops.SelectAndScatterWithGeneralPadding( 1590 operand, 1591 select=self._CreateBinaryGeComputation(dtype), 1592 window_dimensions=window_dimensions, 1593 window_strides=window_strides, 1594 padding=padding, 1595 source=ops.Constant(c, np.array([[0.1, 0.2]], dtype=dtype)), 1596 init_value=ops.Constant(c, np.array(1, dtype=dtype)), 1597 scatter=self._CreateBinaryAddComputation(dtype)) 1598 self._ExecuteAndCompareClose( 1599 c, expected=[[[1., 1., 1.2], [1.1, 1., 1.]]], rtol=5e-3) 1600 1601 @parameterized.named_parameters({ 1602 "testcase_name": "_{}".format(dtype.__name__), 1603 "dtype": dtype, 1604 } for dtype in float_dtypes) 1605 def testReduce1DtoScalar(self, dtype): 1606 c = self._NewComputation() 1607 ops.Reduce( 1608 c, 1609 operands=[ 1610 ops.Constant(c, np.array([1.0, 2.0, 3.0, 4.0], dtype=dtype)) 1611 ], 1612 init_values=[ops.Constant(c, dtype(0))], 1613 computation=self._CreateBinaryAddComputation(dtype), 1614 dimensions_to_reduce=[0]) 1615 self._ExecuteAndCompareClose(c, expected=[10]) 1616 1617 # TODO(phawkins): test comparison harness doesn't support bfloat16 1618 @parameterized.named_parameters({ 1619 "testcase_name": "_{}_dim{}".format(dtype.__name__, dim), 1620 "dtype": dtype, 1621 "dim": dim, 1622 } for dtype in float_dtypes if dtype != bfloat16 for dim in range(2)) 1623 def testReduce2DTo1D(self, dtype, dim): 1624 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1625 c = self._NewComputation() 1626 ops.Reduce( 1627 c, 1628 operands=[ops.Constant(c, input_array)], 1629 init_values=[ops.Constant(c, dtype(0))], 1630 computation=self._CreateBinaryAddComputation(dtype), 1631 dimensions_to_reduce=[dim]) 1632 self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dim)]) 1633 1634 @parameterized.named_parameters({ 1635 "testcase_name": "_{}_dims[{}]".format(dtype.__name__, dims), 1636 "dtype": dtype, 1637 "dims": tuple(dims) 1638 } for dtype in float_dtypes for dims in itertools.permutations(range(3))) 1639 def testReduce3DAllPossibleWaysF32(self, dtype, dims): 1640 input_array = self._MakeSample3DArray(dtype) 1641 c = self._NewComputation() 1642 ops.Reduce( 1643 c, 1644 operands=[ops.Constant(c, input_array)], 1645 init_values=[ops.Constant(c, dtype(0))], 1646 computation=self._CreateBinaryAddComputation(dtype), 1647 dimensions_to_reduce=dims) 1648 self._ExecuteAndCompareClose(c, expected=[np.sum(input_array, axis=dims)]) 1649 1650 @parameterized.named_parameters({ 1651 "testcase_name": "_{}".format(dtype.__name__), 1652 "dtype": dtype, 1653 } for dtype in float_dtypes) 1654 def testReduceWindowValidUnitStrides(self, dtype): 1655 if dtype == np.float64 and self.backend.platform == "tpu": 1656 self.skipTest("TPU doesn't support float64") 1657 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1658 c = self._NewComputation() 1659 window_dimensions = (2, 1) 1660 window_strides = (1, 1) 1661 padding = xla_client.window_padding_type_to_pad_values( 1662 xla_client.PaddingType.VALID, input_array.shape, window_dimensions, 1663 window_strides) 1664 ops.ReduceWindowWithGeneralPadding( 1665 operand=ops.Constant(c, input_array), 1666 init_value=ops.Constant(c, dtype(0)), 1667 computation=self._CreateBinaryAddComputation(dtype), 1668 window_dimensions=window_dimensions, 1669 window_strides=window_strides, 1670 base_dilations=[], 1671 window_dilations=[], 1672 padding=padding) 1673 self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.]]]) 1674 1675 @parameterized.named_parameters({ 1676 "testcase_name": "_{}".format(dtype.__name__), 1677 "dtype": dtype, 1678 } for dtype in float_dtypes) 1679 def testReduceWindowSameUnitStrides(self, dtype): 1680 if dtype == np.float64 and self.backend.platform == "tpu": 1681 self.skipTest("TPU doesn't support float64") 1682 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1683 c = self._NewComputation() 1684 window_dimensions = (2, 1) 1685 window_strides = (1, 1) 1686 padding = xla_client.window_padding_type_to_pad_values( 1687 xla_client.PaddingType.SAME, input_array.shape, window_dimensions, 1688 window_strides) 1689 ops.ReduceWindowWithGeneralPadding( 1690 operand=ops.Constant(c, input_array), 1691 init_value=ops.Constant(c, dtype(0)), 1692 computation=self._CreateBinaryAddComputation(dtype), 1693 window_dimensions=window_dimensions, 1694 window_strides=window_strides, 1695 base_dilations=[], 1696 window_dilations=[], 1697 padding=padding) 1698 self._ExecuteAndCompareClose(c, expected=[[[5., 7., 9.], [4., 5., 6.]]]) 1699 1700 @parameterized.named_parameters({ 1701 "testcase_name": "_{}".format(dtype.__name__), 1702 "dtype": dtype, 1703 } for dtype in float_dtypes) 1704 def testReduceWindowValidGeneralStrides(self, dtype): 1705 if dtype == np.float64 and self.backend.platform == "tpu": 1706 self.skipTest("TPU doesn't support float64") 1707 input_array = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype) 1708 c = self._NewComputation() 1709 window_dimensions = (2, 1) 1710 window_strides = (1, 2) 1711 padding = xla_client.window_padding_type_to_pad_values( 1712 xla_client.PaddingType.VALID, input_array.shape, window_dimensions, 1713 window_strides) 1714 ops.ReduceWindowWithGeneralPadding( 1715 operand=ops.Constant(c, input_array), 1716 init_value=ops.Constant(c, dtype(0)), 1717 computation=self._CreateBinaryAddComputation(dtype), 1718 window_dimensions=window_dimensions, 1719 window_strides=window_strides, 1720 base_dilations=[], 1721 window_dilations=[], 1722 padding=padding) 1723 self._ExecuteAndCompareClose(c, expected=[[[5., 9.]]]) 1724 1725 @parameterized.named_parameters({ 1726 "testcase_name": "_{}".format(dtype.__name__), 1727 "dtype": dtype, 1728 } for dtype in float_dtypes) 1729 def testWhile(self, dtype): 1730 1731 def LessThan10Cond(): 1732 c = self._NewComputation("test_lt_10") 1733 shape = xla_client.shape_from_pyval(np.array(0, dtype=dtype)) 1734 ops.Lt(ops.Parameter(c, 0, shape), ops.Constant(c, dtype(10.))) 1735 return c.build() 1736 1737 cond = LessThan10Cond() 1738 body = self._CreateMulBy2Computation(dtype) 1739 c = self._NewComputation() 1740 init = ops.Constant(c, dtype(1.)) 1741 ops.While(cond, body, init) 1742 self._ExecuteAndCompareClose(c, expected=[16.]) 1743 1744 def testConditionalTrue(self): 1745 c = self._NewComputation() 1746 pred = ops.Constant(c, np.bool_(True)) 1747 true_operand = ops.Constant(c, np.float32(3.)) 1748 true_computation = self._CreateMulBy2Computation(np.float32) 1749 false_operand = ops.Constant(c, np.float32(2.)) 1750 false_computation = self._CreateConstantComputation( 1751 np.float32, np.float32) 1752 ops.Conditional(pred, true_operand, true_computation, false_operand, 1753 false_computation) 1754 self._ExecuteAndCompareClose(c, expected=[6.]) 1755 1756 def testConditionalFalse(self): 1757 c = self._NewComputation() 1758 pred = ops.Constant(c, np.bool_(False)) 1759 true_operand = ops.Constant(c, np.float32(3.)) 1760 true_computation = self._CreateMulBy2Computation(np.float32) 1761 false_operand = ops.Constant(c, np.float32(2.)) 1762 false_computation = self._CreateConstantComputation( 1763 np.float32, np.float32) 1764 ops.Conditional(pred, true_operand, true_computation, false_operand, 1765 false_computation) 1766 self._ExecuteAndCompareClose(c, expected=[1.]) 1767 1768 @unittest.skipIf(cloud_tpu, "not implemented") 1769 def testInfeedS32Values(self): 1770 to_infeed = NumpyArrayS32([1, 2, 3, 4]) 1771 c = self._NewComputation() 1772 ops.GetTupleElement( 1773 ops.InfeedWithToken( 1774 ops.CreateToken(c), 1775 xla_client.shape_from_pyval( 1776 to_infeed[0]).with_major_to_minor_layout_if_absent()), 0) 1777 compiled_c = self.backend.compile(c.build()) 1778 device = self.backend.local_devices()[0] 1779 for item in to_infeed: 1780 device.transfer_to_infeed(item) 1781 1782 for item in to_infeed: 1783 result, = xla_client.execute_with_python_values( 1784 compiled_c, (), backend=self.backend) 1785 self.assertEqual(result, item) 1786 1787 @unittest.skipIf(cloud_tpu, "not implemented") 1788 def testInfeedTuple(self): 1789 to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) 1790 c = self._NewComputation() 1791 ops.GetTupleElement( 1792 ops.InfeedWithToken( 1793 ops.CreateToken(c), 1794 xla_client.shape_from_pyval( 1795 to_infeed).with_major_to_minor_layout_if_absent()), 0) 1796 compiled_c = self.backend.compile(c.build()) 1797 device = self.backend.local_devices()[0] 1798 device.transfer_to_infeed(to_infeed) 1799 1800 result = xla_client.execute_with_python_values( 1801 compiled_c, (), backend=self.backend) 1802 self.assertLen(result, 2) 1803 np.testing.assert_equal(result[0], to_infeed[0]) 1804 np.testing.assert_equal(result[1], to_infeed[1]) 1805 1806 @unittest.skipIf(cloud_tpu, "not implemented") 1807 def testInfeedThenOutfeedS32(self): 1808 to_round_trip = NumpyArrayS32([1, 2, 3, 4]) 1809 c = self._NewComputation() 1810 x_and_token = ops.InfeedWithToken( 1811 ops.CreateToken(c), 1812 xla_client.shape_from_pyval( 1813 to_round_trip[0]).with_major_to_minor_layout_if_absent()) 1814 x = ops.GetTupleElement(x_and_token, 0) 1815 token = ops.GetTupleElement(x_and_token, 1) 1816 outfeed_shape = xla_client.shape_from_pyval( 1817 to_round_trip[0]).with_major_to_minor_layout_if_absent() 1818 ops.OutfeedWithToken(x, token, outfeed_shape) 1819 1820 compiled_c = self.backend.compile(c.build()) 1821 device = self.backend.local_devices()[0] 1822 1823 for want in to_round_trip: 1824 execution = threading.Thread(target=lambda: compiled_c.execute([])) 1825 execution.start() 1826 device.transfer_to_infeed(want) 1827 got = device.transfer_from_outfeed(outfeed_shape) 1828 execution.join() 1829 self.assertEqual(want, got) 1830 1831 def testScatter(self): 1832 a = np.arange(9).astype(np.int32).reshape((3, 3)) 1833 scatter_indices = np.array([0, 2], dtype=np.int32) 1834 updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) 1835 1836 dnums = xla_client.ScatterDimensionNumbers() 1837 dnums.update_window_dims.append(1) 1838 dnums.inserted_window_dims.append(0) 1839 dnums.scatter_dims_to_operand_dims.append(0) 1840 dnums.index_vector_dim = 1 1841 1842 c = self._NewComputation() 1843 ops.Scatter( 1844 ops.Constant(c, a), ops.Constant(c, scatter_indices), 1845 ops.Constant(c, updates), self._CreateBinaryAddComputation(np.int32), 1846 dnums) 1847 expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], 1848 dtype=np.int32) 1849 self._ExecuteAndCompareClose(c, expected=[expected]) 1850 1851 class DeviceTest(ComputationTest): 1852 1853 def testPlatform(self): 1854 for device in self.backend.local_devices(): 1855 self.assertEqual(device.platform, self.backend.platform) 1856 1857 tests.append(DeviceTest) 1858 1859 class ErrorTest(ComputationTest): 1860 1861 def setUp(self): 1862 super(ErrorTest, self).setUp() 1863 self.f32_scalar_2 = NumpyArrayF32(2.0) 1864 self.s32_scalar_2 = NumpyArrayS32(2) 1865 1866 def testCompileWithWrongElementTypeInLayout(self): 1867 c = self._NewComputation() 1868 c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) 1869 ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) 1870 c.clear_op_metadata() 1871 1872 options = xla_client.CompileOptions() 1873 options.argument_layouts = [ 1874 xla_client.Shape.array_shape(np.dtype(np.float32), []) 1875 ] 1876 1877 def TestFun(): 1878 return self.backend.compile(c.build(), compile_options=options) 1879 1880 self.assertRaisesRegex( 1881 RuntimeError, r".*Invalid argument shape.*" 1882 r"expected s32\[\], got f32\[\].*", TestFun) 1883 1884 def testInvokeWithWrongElementType(self): 1885 c = self._NewComputation() 1886 c.set_op_metadata(xla_client.CurrentSourceInfoMetadata()) 1887 ops.Parameter(c, 0, xla_client.shape_from_pyval(self.s32_scalar_2)) 1888 c.clear_op_metadata() 1889 1890 def TestFun(): 1891 return xla_client.execute_with_python_values( 1892 self.backend.compile(c.build()), [self.f32_scalar_2], self.backend) 1893 1894 self.assertRaisesRegex( 1895 RuntimeError, r"Invalid argument: Argument does not match.*" 1896 r"want s32\[\], got f32\[\].*", TestFun) 1897 1898 tests.append(EmbeddedComputationsTest) 1899 1900 class ComputationRootTest(ComputationTest): 1901 """Tests related to setting the root of the computation.""" 1902 1903 def testComputationRootDifferentFromLastOp(self): 1904 c = self._NewComputation() 1905 x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) 1906 result = ops.Add(x, ops.Constant(c, np.float32(3.14))) 1907 ops.Add(result, ops.Constant(c, np.float32(1.618))) 1908 1909 arg = NumpyArrayF32(1.0) 1910 compiled_c = self.backend.compile(c.build(result)) 1911 ans, = xla_client.execute_with_python_values( 1912 compiled_c, [arg], backend=self.backend) 1913 np.testing.assert_allclose(ans, 4.14) 1914 1915 tests.append(ComputationRootTest) 1916 1917 class SetShardingTest(ComputationTest): 1918 """Tests related to set OpSharding.""" 1919 1920 def testSetSharding(self): 1921 c = self._NewComputation() 1922 sharding = xla_client.OpSharding() 1923 sharding.type = sharding.type.REPLICATED 1924 sharding.tile_assignment_dimensions.extend([1]) 1925 sharding.tile_assignment_devices.extend([0]) 1926 c.set_sharding(sharding) 1927 x = ops.Parameter(c, 0, xla_client.shape_from_pyval(NumpyArrayF32(2.0))) 1928 c.clear_sharding() 1929 1930 result = ops.Add(x, ops.Constant(c, np.float32(3.14))) 1931 ops.Add(result, ops.Constant(c, np.float32(1.618))) 1932 arg = NumpyArrayF32(1.0) 1933 compiled_c = self.backend.compile(c.build(result)) 1934 ans, = xla_client.execute_with_python_values( 1935 compiled_c, [arg], backend=self.backend) 1936 np.testing.assert_allclose(ans, 4.14) 1937 1938 tests.append(SetShardingTest) 1939 1940 testcase_shapes = [ 1941 (), 1942 (1,), 1943 (2, 3), 1944 (2, 0), 1945 (0, 7), 1946 (4, 1, 2), 1947 (2, 1, 3), 1948 (2, 4, 1), 1949 (3, 1), 1950 (1, 3), 1951 ] 1952 1953 def FormatShapeAndDtype(shape, dtype): 1954 return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) 1955 1956 class DLPackTest(parameterized.TestCase): 1957 1958 def setUp(self): 1959 super(DLPackTest, self).setUp() 1960 self.backend = xla_backend() 1961 if self.backend.platform not in ("cpu", "gpu"): 1962 self.skipTest("DLPack requires CPU or GPU") 1963 1964 # pylint: disable=g-complex-comprehension 1965 # pyformat: disable 1966 @parameterized.named_parameters({ 1967 "testcase_name": "{}_own={}".format(FormatShapeAndDtype(shape, dtype), 1968 take_ownership), 1969 "dtype": dtype, 1970 "shape": shape, 1971 "take_ownership": take_ownership 1972 } for dtype in dlpack_dtypes for shape in testcase_shapes 1973 for take_ownership in [False, True]) 1974 # pyformat: enable 1975 def testRoundTrip(self, dtype, shape, take_ownership): 1976 if dtype == np.bool_: 1977 x = np.random.randint(0, 2, size=shape).astype(np.bool_) 1978 else: 1979 x = np.array(np.random.rand(*shape) * 100, dtype=dtype) 1980 buffer = self.backend.buffer_from_pyval(x) 1981 dlt = xla_client._xla.buffer_to_dlpack_managed_tensor( 1982 buffer, take_ownership=take_ownership) 1983 del buffer # Free "buffer" to make sure dlt retains ownership. 1984 self.assertEqual(type(dlt).__name__, "PyCapsule") 1985 y = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend) 1986 np.testing.assert_array_equal( 1987 x.astype(np.uint8) if dtype == np.bool_ else x, y.to_py()) 1988 1989 def testTensorsCanBeConsumedOnceOnly(self): 1990 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 1991 buffer = self.backend.buffer_from_pyval(x) 1992 dlt = xla_client._xla.buffer_to_dlpack_managed_tensor( 1993 buffer, take_ownership=True) 1994 1995 def ConsumeDLPackTensor(): 1996 _ = xla_client._xla.dlpack_managed_tensor_to_buffer(dlt, self.backend) 1997 1998 ConsumeDLPackTensor() 1999 self.assertRaisesRegex( 2000 RuntimeError, ".*a DLPack tensor may be consumed at most once.*", 2001 ConsumeDLPackTensor) 2002 2003 def testTensorsCanBeOwnedOnceOnly(self): 2004 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 2005 buffer = self.backend.buffer_from_pyval(x) 2006 _ = xla_client._xla.buffer_to_dlpack_managed_tensor( 2007 buffer, take_ownership=True) 2008 self.assertTrue(buffer.is_deleted()) 2009 with self.assertRaisesRegex( 2010 RuntimeError, 2011 "Cannot convert deleted/invalid buffer to DLPack tensor.*"): 2012 _ = xla_client._xla.buffer_to_dlpack_managed_tensor( 2013 buffer, take_ownership=True) 2014 2015 def testNonOwnedDlpackCanBeViewedTwice(self): 2016 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 2017 buffer = self.backend.buffer_from_pyval(x) 2018 d1 = xla_client._xla.buffer_to_dlpack_managed_tensor( 2019 buffer, take_ownership=False) 2020 d2 = xla_client._xla.buffer_to_dlpack_managed_tensor( 2021 buffer, take_ownership=False) 2022 2023 y = xla_client._xla.dlpack_managed_tensor_to_buffer(d1, self.backend) 2024 z = xla_client._xla.dlpack_managed_tensor_to_buffer(d2, self.backend) 2025 del d1, d2 2026 np.testing.assert_array_equal(x, buffer.to_py()) 2027 np.testing.assert_array_equal(x, y.to_py()) 2028 np.testing.assert_array_equal(x, z.to_py()) 2029 2030 tests.append(DLPackTest) 2031 2032 class BufferProtocolTest(parameterized.TestCase): 2033 2034 def setUp(self): 2035 super(BufferProtocolTest, self).setUp() 2036 self.backend = xla_backend() 2037 if self.backend.platform != "cpu": 2038 self.skipTest("Test requires CPU") 2039 2040 # pylint: disable=g-complex-comprehension 2041 @parameterized.named_parameters({ 2042 "testcase_name": FormatShapeAndDtype(shape, dtype), 2043 "dtype": dtype, 2044 "shape": shape 2045 } for dtype in standard_dtypes if dtype != bfloat16 2046 for shape in testcase_shapes) 2047 def testRoundTrip(self, dtype, shape): 2048 x = np.array(np.random.rand(*shape) * 100, dtype=dtype) 2049 x_ptr = x.__array_interface__["data"][0] 2050 buffer = self.backend.buffer_from_pyval( 2051 x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY) 2052 y = np.array(buffer, copy=False) 2053 y_ptr = y.__array_interface__["data"][0] 2054 np.testing.assert_array_equal(x, y) 2055 # If the input was sufficiently aligned, the input and output should 2056 # alias. 2057 self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr) 2058 self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) 2059 2060 during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL 2061 buffer2 = self.backend.buffer_from_pyval( 2062 x, host_buffer_semantics=during_call) 2063 z = np.array(buffer2, copy=False) 2064 self.assertNotEqual(x.__array_interface__["data"][0], 2065 z.__array_interface__["data"][0]) 2066 2067 def testDeleteWithActiveView(self): 2068 x = np.random.randn(20, 10) 2069 buffer = self.backend.buffer_from_pyval(x) 2070 buffer_ptr = buffer.unsafe_buffer_pointer() 2071 y = np.array(buffer, copy=False) 2072 buffer.delete() 2073 # It is still legal to access `y`; the array view must keep it alive. 2074 np.testing.assert_array_equal(x, y) 2075 self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) 2076 2077 tests.append(BufferProtocolTest) 2078 2079 class TracebackTest(absltest.TestCase): 2080 2081 def setUp(self): 2082 super(TracebackTest, self).setUp() 2083 self.backend = xla_backend() 2084 2085 def testNoTracebacksIfDisabled(self): 2086 with xla_client.tracebacks(enabled=False): 2087 self.assertEqual(None, xla_client.Traceback.get_traceback()) 2088 buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) 2089 self.assertEqual(None, buffer.traceback) 2090 2091 b = xla_client.XlaBuilder("computation") 2092 ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) 2093 e = self.backend.compile(b.build()) 2094 self.assertEqual(None, e.traceback) 2095 2096 def assertIsTracebackContaining(self, tb, function): 2097 self.assertIsInstance(tb, xla_client.Traceback) 2098 self.assertIn(function, str(tb)) 2099 self.assertTrue(any(f.function_name == function for f in tb.frames)) 2100 2101 def testTracebacks(self): 2102 with xla_client.tracebacks(enabled=True): 2103 tb = xla_client.Traceback.get_traceback() 2104 self.assertIsTracebackContaining(tb, "testTracebacks") 2105 2106 # Tracebacks are not implemented on the TPU driver extension's variant 2107 # of buffers and executables. 2108 if not isinstance(self.backend, xla_client.Client): 2109 return 2110 2111 buffer = self.backend.buffer_from_pyval(np.array(7, np.int32)) 2112 self.assertIsTracebackContaining(buffer.traceback, "testTracebacks") 2113 2114 b = xla_client.XlaBuilder("computation") 2115 ops.Add(ops.Constant(b, np.int32(1)), ops.Constant(b, np.int32(2))) 2116 e = self.backend.compile(b.build()) 2117 self.assertIsTracebackContaining(e.traceback, "testTracebacks") 2118 2119 def testNestedFunction(self): 2120 2121 def AFunction(): 2122 2123 def AnotherFunction(): 2124 return xla_client.Traceback.get_traceback() 2125 2126 return AnotherFunction() 2127 2128 with xla_client.tracebacks(enabled=True): 2129 tb = AFunction() 2130 self.assertIsInstance(tb, xla_client.Traceback) 2131 frames = tb.frames 2132 i = next( 2133 i for (i, f) in enumerate(frames) if f.function_name == "AFunction") 2134 self.assertEqual(frames[i - 1].function_name, "AnotherFunction") 2135 self.assertEqual(frames[i + 1].function_name, "testNestedFunction") 2136 2137 tests.append(TracebackTest) 2138 2139 return tests 2140 2141 2142def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): 2143 # Avoid creating a new backend per test (this causes GPU OOM, and is probably 2144 # inefficient). 2145 backend_fn = functools.lru_cache(maxsize=None)(backend_fn) 2146 for klass in TestFactory(backend_fn, **kw): 2147 test = type(test_prefix + klass.__name__, (klass,), {}) 2148 # Clean up the qualified names of the tests to not include the test factory. 2149 test.__qualname__ = test.__name__ 2150 globals_dict[test.__name__] = test 2151 2152 2153if __name__ == "__main__": 2154 flags.DEFINE_string("backend", "cpu", "Target backend.") 2155 InstantiateTests(globals(), 2156 lambda: xla_client.get_local_backend(FLAGS.backend)) 2157 absltest.main() 2158