1# Lint as: python3 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for the Python extension-based XLA client.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import itertools 24import threading 25 26from absl.testing import absltest 27from absl.testing import parameterized 28import numpy as np 29 30from tensorflow.compiler.xla.python import custom_call_for_test 31from tensorflow.compiler.xla.python import xla_client 32 33bfloat16 = xla_client.bfloat16 34 35 36class ComputationTest(absltest.TestCase): 37 """Base class for running an XLA Computation through the local client.""" 38 39 def _NewComputation(self, name=None): 40 if name is None: 41 name = self.id() 42 return xla_client.ComputationBuilder(name) 43 44 def _Execute(self, c, arguments): 45 compiled_c = c.Build().Compile() 46 return xla_client.execute_with_python_values(compiled_c, arguments) 47 48 def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): 49 assert expected is not None 50 result = self._Execute(c, arguments) 51 # Numpy's comparison methods are a bit too lenient by treating inputs as 52 # "array-like", meaning that scalar 4 will be happily compared equal to 53 # [[4]]. We'd like to be more strict so assert shapes as well. 54 self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape) 55 assert_func(result, expected) 56 57 def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): 58 self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) 59 60 def _ExecuteAndCompareClose(self, 61 c, 62 arguments=(), 63 expected=None, 64 rtol=1e-7, 65 atol=0): 66 self._ExecuteAndAssertWith( 67 functools.partial(np.testing.assert_allclose, rtol=rtol, atol=atol), c, 68 arguments, expected) 69 70 71def NumpyArrayF32(*args, **kwargs): 72 """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" 73 return np.array(*args, dtype=np.float32, **kwargs) 74 75 76def NumpyArrayF64(*args, **kwargs): 77 """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" 78 return np.array(*args, dtype=np.float64, **kwargs) 79 80 81def NumpyArrayS32(*args, **kwargs): 82 """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" 83 return np.array(*args, dtype=np.int32, **kwargs) 84 85 86def NumpyArrayS64(*args, **kwargs): 87 """Convenience wrapper to create Numpy arrays with a np.int64 dtype.""" 88 return np.array(*args, dtype=np.int64, **kwargs) 89 90 91def NumpyArrayBool(*args, **kwargs): 92 """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" 93 return np.array(*args, dtype=np.bool, **kwargs) 94 95 96class ComputationPrinting(absltest.TestCase): 97 98 def ExampleComputation(self): 99 builder = xla_client.ComputationBuilder("acomputation") 100 p0 = builder.ParameterFromNumpy(np.float32(0)) 101 p1 = builder.ParameterFromNumpy(np.zeros((4,), np.float32)) 102 builder.Mul(p0, p1) 103 return builder.Build() 104 105 def testComputationToHloText(self): 106 computation = self.ExampleComputation() 107 hlo_text = computation.GetHloText() 108 self.assertTrue(hlo_text.startswith("HloModule acomputation")) 109 110 def testComputationToHloGraph(self): 111 computation = self.ExampleComputation() 112 hlo_dot_graph = computation.GetHloDotGraph() 113 self.assertTrue(hlo_dot_graph.startswith("digraph ")) 114 115 116class ComputationHashTest(absltest.TestCase): 117 118 def testHash(self): 119 builder0 = xla_client.ComputationBuilder("computation0") 120 p0 = builder0.ParameterFromNumpy(np.float32(0)) 121 p1 = builder0.ParameterFromNumpy(np.zeros((4,), np.float32)) 122 builder0.Mul(p0, p1) 123 computation0 = builder0.Build() 124 125 builder1 = xla_client.ComputationBuilder("computation1") 126 p0 = builder1.ParameterFromNumpy(np.float32(0)) 127 p1 = builder1.ParameterFromNumpy(np.zeros((4,), np.float32)) 128 builder1.Mul(p0, p1) 129 computation1 = builder1.Build() 130 131 self.assertEqual(computation0.Hash(), computation1.Hash()) 132 133 134class ComputationsWithConstantsTest(ComputationTest): 135 """Tests focusing on Constant ops.""" 136 137 def testConstantScalarSumS8(self): 138 c = self._NewComputation() 139 c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) 140 self._ExecuteAndCompareExact(c, expected=np.int8(3)) 141 142 def testConstantScalarSumBF16(self): 143 c = self._NewComputation() 144 c.Add(c.Constant(bfloat16(1.11)), c.Constant(bfloat16(3.14))) 145 self._ExecuteAndCompareClose(c, expected=bfloat16(4.25)) 146 147 def testConstantScalarSumF32(self): 148 c = self._NewComputation() 149 c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) 150 self._ExecuteAndCompareClose(c, expected=4.25) 151 152 def testConstantScalarSumF64(self): 153 c = self._NewComputation() 154 c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) 155 self._ExecuteAndCompareClose(c, expected=4.25) 156 157 def testConstantScalarSumS32(self): 158 c = self._NewComputation() 159 c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) 160 self._ExecuteAndCompareClose(c, expected=3) 161 162 def testConstantScalarSumS64(self): 163 c = self._NewComputation() 164 c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) 165 self._ExecuteAndCompareClose(c, expected=3) 166 167 def testConstantVectorMulF16(self): 168 c = self._NewComputation() 169 c.Mul( 170 c.Constant(np.array([2.5, 3.3, -1.2, 0.7], np.float16)), 171 c.Constant(np.array([-1.2, 2, -2, -3], np.float16))) 172 self._ExecuteAndCompareClose( 173 c, expected=np.array([-3, 6.6, 2.4, -2.1], np.float16), rtol=2e-3) 174 175 def testConstantVectorMulF32(self): 176 c = self._NewComputation() 177 c.Mul( 178 c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), 179 c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) 180 self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) 181 182 def testConstantVectorMulF64(self): 183 c = self._NewComputation() 184 c.Mul( 185 c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), 186 c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) 187 self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) 188 189 def testConstantVectorScalarDivF32(self): 190 c = self._NewComputation() 191 c.Div( 192 c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), 193 c.ConstantF32Scalar(2.0)) 194 self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) 195 196 def testConstantVectorScalarDivF64(self): 197 c = self._NewComputation() 198 c.Div( 199 c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), 200 c.ConstantF64Scalar(2.0)) 201 self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) 202 203 def testConstantVectorScalarPowF32(self): 204 c = self._NewComputation() 205 c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) 206 self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) 207 208 def testConstantVectorScalarPowF64(self): 209 c = self._NewComputation() 210 c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) 211 self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) 212 213 def testIota(self): 214 c = self._NewComputation() 215 c.Iota(np.float32, 10) 216 self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32)) 217 218 def testBroadcastedIota(self): 219 c = self._NewComputation() 220 c.BroadcastedIota(np.int64, (2, 3), 1) 221 expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64) 222 self._ExecuteAndCompareExact(c, expected=expected) 223 224 def testBooleanAnd(self): 225 c = self._NewComputation() 226 c.And( 227 c.Constant(NumpyArrayBool([True, False, True, False])), 228 c.Constant(NumpyArrayBool([True, True, False, False]))) 229 self._ExecuteAndCompareExact(c, expected=[True, False, False, False]) 230 231 def testBooleanOr(self): 232 c = self._NewComputation() 233 c.Or( 234 c.Constant(NumpyArrayBool([True, False, True, False])), 235 c.Constant(NumpyArrayBool([True, True, False, False]))) 236 self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) 237 238 def testBooleanXor(self): 239 c = self._NewComputation() 240 c.Xor( 241 c.Constant(NumpyArrayBool([True, False, True, False])), 242 c.Constant(NumpyArrayBool([True, True, False, False]))) 243 self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) 244 245 def testSum2DF32(self): 246 c = self._NewComputation() 247 c.Add( 248 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), 249 c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) 250 self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) 251 252 def testShiftLeft(self): 253 c = self._NewComputation() 254 c.ShiftLeft(c.Constant(NumpyArrayS32([3])), c.Constant(NumpyArrayS32([2]))) 255 self._ExecuteAndCompareClose(c, expected=[12]) 256 257 def testShiftRightArithmetic(self): 258 c = self._NewComputation() 259 c.ShiftRightArithmetic( 260 c.Constant(NumpyArrayS32([-2])), c.Constant(NumpyArrayS32([1]))) 261 self._ExecuteAndCompareClose(c, expected=[-1]) 262 263 def testShiftRightLogical(self): 264 c = self._NewComputation() 265 c.ShiftRightLogical( 266 c.Constant(NumpyArrayS32([-1])), c.Constant(NumpyArrayS32([1]))) 267 self._ExecuteAndCompareClose(c, expected=[2**31 - 1]) 268 269 def testSum2DF64(self): 270 c = self._NewComputation() 271 c.Add( 272 c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), 273 c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) 274 self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) 275 276 def testSum2DWith1DBroadcastDim0F32(self): 277 # sum of a 2D array with a 1D array where the latter is replicated across 278 # dimension 0 to match the former's shape. 279 c = self._NewComputation() 280 c.Add( 281 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 282 c.Constant(NumpyArrayF32([10, 20, 30])), 283 broadcast_dimensions=(0,)) 284 self._ExecuteAndCompareClose( 285 c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) 286 287 def testSum2DWith1DBroadcastDim0F64(self): 288 # sum of a 2D array with a 1D array where the latter is replicated across 289 # dimension 0 to match the former's shape. 290 c = self._NewComputation() 291 c.Add( 292 c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 293 c.Constant(NumpyArrayF64([10, 20, 30])), 294 broadcast_dimensions=(0,)) 295 self._ExecuteAndCompareClose( 296 c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) 297 298 def testSum2DWith1DBroadcastDim1F32(self): 299 # sum of a 2D array with a 1D array where the latter is replicated across 300 # dimension 1 to match the former's shape. 301 c = self._NewComputation() 302 c.Add( 303 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 304 c.Constant(NumpyArrayF32([10, 20, 30])), 305 broadcast_dimensions=(1,)) 306 self._ExecuteAndCompareClose( 307 c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) 308 309 def testSum2DWith1DBroadcastDim1F64(self): 310 # sum of a 2D array with a 1D array where the latter is replicated across 311 # dimension 1 to match the former's shape. 312 c = self._NewComputation() 313 c.Add( 314 c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 315 c.Constant(NumpyArrayF64([10, 20, 30])), 316 broadcast_dimensions=(1,)) 317 self._ExecuteAndCompareClose( 318 c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) 319 320 def testConstantAxpyF32(self): 321 c = self._NewComputation() 322 c.Add( 323 c.Mul( 324 c.ConstantF32Scalar(2), 325 c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), 326 c.Constant(NumpyArrayF32([100, -100, 200, -200]))) 327 self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) 328 329 def testConstantAxpyF64(self): 330 c = self._NewComputation() 331 c.Add( 332 c.Mul( 333 c.ConstantF64Scalar(2), 334 c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), 335 c.Constant(NumpyArrayF64([100, -100, 200, -200]))) 336 self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) 337 338 def testCustomCall(self): 339 c = self._NewComputation() 340 for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): 341 xla_client.register_custom_call_target(name, fn, platform="cpu") 342 c.CustomCall( 343 b"test_subtract_f32", 344 operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), 345 shape_with_layout=xla_client.Shape.array_shape( 346 np.dtype(np.float32), (), ()), 347 operand_shapes_with_layout=( 348 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 349 xla_client.Shape.array_shape(np.dtype(np.float32), (), ()), 350 )) 351 self._ExecuteAndCompareClose(c, expected=0.75) 352 353 354class ParametersTest(ComputationTest): 355 """Tests focusing on Parameter ops and argument-passing.""" 356 357 def setUp(self): 358 self.f32_scalar_2 = NumpyArrayF32(2.0) 359 self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3]) 360 self.f64_scalar_2 = NumpyArrayF64(2.0) 361 self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3]) 362 self.s32_scalar_3 = NumpyArrayS32(3) 363 self.s32_4vector = NumpyArrayS32([10, 15, -2, 7]) 364 self.s64_scalar_3 = NumpyArrayS64(3) 365 self.s64_4vector = NumpyArrayS64([10, 15, -2, 7]) 366 367 def testScalarTimesVectorAutonumberF32(self): 368 c = self._NewComputation() 369 p0 = c.ParameterFromNumpy(self.f32_scalar_2) 370 p1 = c.ParameterFromNumpy(self.f32_4vector) 371 c.Mul(p0, p1) 372 self._ExecuteAndCompareClose( 373 c, 374 arguments=[self.f32_scalar_2, self.f32_4vector], 375 expected=[-4.6, 6.6, -8.6, 10.6]) 376 377 def testScalarTimesVectorAutonumberF64(self): 378 c = self._NewComputation() 379 p0 = c.ParameterFromNumpy(self.f64_scalar_2) 380 p1 = c.ParameterFromNumpy(self.f64_4vector) 381 c.Mul(p0, p1) 382 self._ExecuteAndCompareClose( 383 c, 384 arguments=[self.f64_scalar_2, self.f64_4vector], 385 expected=[-4.6, 6.6, -8.6, 10.6]) 386 387 def testScalarTimesVectorS32(self): 388 c = self._NewComputation() 389 p0 = c.ParameterFromNumpy(self.s32_scalar_3) 390 p1 = c.ParameterFromNumpy(self.s32_4vector) 391 c.Mul(p0, p1) 392 self._ExecuteAndCompareExact( 393 c, 394 arguments=[self.s32_scalar_3, self.s32_4vector], 395 expected=[30, 45, -6, 21]) 396 397 def testScalarTimesVectorS64(self): 398 c = self._NewComputation() 399 p0 = c.ParameterFromNumpy(self.s64_scalar_3) 400 p1 = c.ParameterFromNumpy(self.s64_4vector) 401 c.Mul(p0, p1) 402 self._ExecuteAndCompareExact( 403 c, 404 arguments=[self.s64_scalar_3, self.s64_4vector], 405 expected=[30, 45, -6, 21]) 406 407 def testScalarMinusVectorExplicitNumberingF32(self): 408 # Use explicit numbering and pass parameter_num first. Sub is used since 409 # it's not commutative and can help catch parameter reversal within the 410 # computation. 411 c = self._NewComputation() 412 p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1) 413 p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0) 414 c.Sub(p1, p0) 415 self._ExecuteAndCompareClose( 416 c, 417 arguments=[self.f32_scalar_2, self.f32_4vector], 418 expected=[-4.3, 1.3, -6.3, 3.3]) 419 420 def testScalarMinusVectorExplicitNumberingF64(self): 421 # Use explicit numbering and pass parameter_num first. Sub is used since 422 # it's not commutative and can help catch parameter reversal within the 423 # computation. 424 c = self._NewComputation() 425 p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1) 426 p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0) 427 c.Sub(p1, p0) 428 self._ExecuteAndCompareClose( 429 c, 430 arguments=[self.f64_scalar_2, self.f64_4vector], 431 expected=[-4.3, 1.3, -6.3, 3.3]) 432 433 434class BufferTest(ComputationTest): 435 """Tests focusing on execution with Buffers.""" 436 437 def _Execute(self, c, arguments): 438 compiled_c = c.Build().Compile() 439 arg_buffers = [xla_client.Buffer.from_pyval(arg) for arg in arguments] 440 result_buffer = compiled_c.Execute(arg_buffers) 441 return result_buffer.to_py() 442 443 def testConstantSum(self): 444 c = self._NewComputation() 445 c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) 446 self._ExecuteAndCompareClose(c, expected=4.25) 447 448 def testOneParameterSum(self): 449 c = self._NewComputation() 450 c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) 451 self._ExecuteAndCompareClose( 452 c, arguments=[NumpyArrayF32(1.11)], expected=4.25) 453 454 def testTwoParameterSum(self): 455 c = self._NewComputation() 456 c.Add( 457 c.ParameterFromNumpy(NumpyArrayF32(0.)), 458 c.ParameterFromNumpy(NumpyArrayF32(0.))) 459 self._ExecuteAndCompareClose( 460 c, arguments=[NumpyArrayF32(1.11), 461 NumpyArrayF32(3.14)], expected=4.25) 462 463 def testCannotCallWithDeletedBuffers(self): 464 c = self._NewComputation() 465 c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) 466 arg = NumpyArrayF32(1.11) 467 compiled_c = c.Build().Compile() 468 arg_buffer = xla_client.Buffer.from_pyval(arg) 469 arg_buffer.delete() 470 with self.assertRaises(RuntimeError): 471 compiled_c.Execute([arg_buffer]) 472 473 def testDestructureTupleEmpty(self): 474 device = xla_client.get_local_backend().devices()[0] 475 local_buffer = xla_client.Buffer.make_tuple((), device=device) 476 pieces = local_buffer.destructure() 477 self.assertFalse(local_buffer.is_deleted()) 478 self.assertEmpty(pieces) 479 480 def testDestructureTupleOneArrayElement(self): 481 device = xla_client.get_local_backend().devices()[0] 482 t = xla_client.Buffer.from_pyval(np.array([1, 2, 3, 4], dtype=np.int32)) 483 local_buffer = xla_client.Buffer.make_tuple((t,), device) 484 pieces = local_buffer.destructure() 485 self.assertFalse(local_buffer.is_deleted()) 486 self.assertLen(pieces, 1) 487 array = pieces[0] 488 got = array.to_py() 489 want = NumpyArrayS32([1, 2, 3, 4]) 490 np.testing.assert_equal(want, got) 491 492 def testDestructureTupleTwoArrayElementDifferentType(self): 493 device = xla_client.get_local_backend().devices()[0] 494 t = ( 495 xla_client.Buffer.from_pyval( 496 np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)), 497 xla_client.Buffer.from_pyval(np.array([2, 3, 4, 5], dtype=np.int32)), 498 ) 499 local_buffer = xla_client.Buffer.make_tuple(t, device) 500 # Run the test twice to verify that the original tuple buffer remains valid 501 # even after destructuring. 502 for _ in range(2): 503 pieces = local_buffer.destructure() 504 self.assertFalse(local_buffer.is_deleted()) 505 self.assertLen(pieces, 2) 506 array0, array1 = pieces 507 got = array0.to_py() 508 want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) 509 np.testing.assert_equal(want, got) 510 got = array1.to_py() 511 want = NumpyArrayS32([2, 3, 4, 5]) 512 np.testing.assert_equal(want, got) 513 514 def testDestructureTupleNested(self): 515 device = xla_client.get_local_backend().devices()[0] 516 t = xla_client.Buffer.make_tuple( 517 (xla_client.Buffer.from_pyval(NumpyArrayF32([1.0, 2.0])), 518 xla_client.Buffer.from_pyval(NumpyArrayS32([3, 4]))), device) 519 local_buffer = xla_client.Buffer.make_tuple( 520 (t, xla_client.Buffer.from_pyval(NumpyArrayS32([5]))), device) 521 pieces = local_buffer.destructure() 522 self.assertFalse(local_buffer.is_deleted()) 523 self.assertLen(pieces, 2) 524 tuple0, array1 = pieces 525 got = array1.to_py() 526 want = NumpyArrayS32([5]) 527 np.testing.assert_equal(want, got) 528 got = tuple0.to_py() 529 self.assertEqual(type(got), tuple) 530 self.assertLen(got, 2) 531 np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) 532 np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) 533 534 def testMakeTuple(self): 535 t = ( 536 np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), 537 np.array([2, 3, 4, 5], dtype=np.int32), 538 ) 539 b0 = xla_client.Buffer.from_pyval(t[0]) 540 b1 = xla_client.Buffer.from_pyval(t[1]) 541 device = xla_client.get_local_backend().local_devices()[0] 542 btup = xla_client.Buffer.make_tuple([b0, b1], device=device) 543 pieces = btup.destructure() 544 self.assertLen(pieces, 2) 545 array0, array1 = pieces 546 np.testing.assert_equal( 547 np.array([1, 2, 3, 4], dtype=np.float32), array0.to_py()) 548 np.testing.assert_equal( 549 np.array([2, 3, 4, 5], dtype=np.int32), array1.to_py()) 550 551 def testShape(self): 552 pyval = np.array([[1., 2.]], np.float32) 553 local_buffer = xla_client.Buffer.from_pyval(pyval) 554 xla_shape = local_buffer.shape() 555 self.assertEqual(xla_shape.dimensions(), (1, 2)) 556 self.assertEqual(np.dtype(xla_shape.element_type()), np.dtype(np.float32)) 557 558 def testTupleShape(self): 559 t = ( 560 np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32), 561 np.array([2, 3, 4, 5], dtype=np.int32), 562 ) 563 b0 = xla_client.Buffer.from_pyval(t[0]) 564 b1 = xla_client.Buffer.from_pyval(t[1]) 565 device = xla_client.get_local_backend().local_devices()[0] 566 tuple_buffer = xla_client.Buffer.make_tuple([b0, b1], device=device) 567 tuple_shape = tuple_buffer.shape() 568 self.assertEqual(tuple_shape.leaf_count(), 2) 569 shapes = tuple_shape.tuple_shapes() 570 self.assertLen(shapes, 2) 571 shape1, shape2 = shapes 572 self.assertEqual(shape1.dimensions(), (1, 4)) 573 self.assertEqual(shape2.dimensions(), (4,)) 574 575 def testBlockHostUntilReadyWorks(self): 576 arg = np.array([[1., 2.]], np.float32) 577 arg_buffer = xla_client.Buffer.from_pyval(arg) 578 arg_buffer.block_host_until_ready() 579 # This test merely checks that nothing goes awry when we call 580 # block_host_until_ready(); it's difficult to test anything else. 581 582 def testCopyToHost(self): 583 arg0 = np.array([[1., 2.]], np.float32) 584 arg1 = np.array([[3., 4.]], np.float32) 585 arg0_buffer = xla_client.Buffer.from_pyval(arg0) 586 arg1_buffer = xla_client.Buffer.from_pyval(arg1) 587 # Prefetch two buffers using copy_to_host_async, and then retrieve their 588 # values using to_py. 589 arg0_buffer.copy_to_host_async() 590 arg0_buffer.copy_to_host_async() # Duplicate calls don't do anything. 591 arg1_buffer.copy_to_host_async() 592 np.testing.assert_equal(arg0, arg0_buffer.to_py()) 593 np.testing.assert_equal(arg1, arg1_buffer.to_py()) 594 # copy_to_host_async does nothing after to_py is called. 595 arg0_buffer.copy_to_host_async() 596 np.testing.assert_equal(arg0, arg0_buffer.to_py()) 597 598 def testDevice(self): 599 x = np.arange(8) 600 for device in xla_client.get_local_backend().local_devices(): 601 buf = xla_client.Buffer.from_pyval(x, device=device) 602 self.assertEqual(buf.device(), device) 603 np.testing.assert_equal(x, buf.to_py()) 604 605 606class SingleOpTest(ComputationTest): 607 """Tests for single ops. 608 609 The goal here is smoke testing - to exercise the most basic functionality of 610 single XLA ops. As minimal as possible number of additional ops are added 611 around the op being tested. 612 """ 613 614 def testConcatenateF32(self): 615 c = self._NewComputation() 616 args = ( 617 c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), 618 c.Constant(NumpyArrayF32([4.0, 5.0, 6.0])), 619 ) 620 c.Concatenate(args, dimension=0) 621 self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 622 623 def testConcatenateF64(self): 624 c = self._NewComputation() 625 args = ( 626 c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), 627 c.Constant(NumpyArrayF64([4.0, 5.0, 6.0])), 628 ) 629 c.Concatenate(args, dimension=0) 630 self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 631 632 def testConvertElementType(self): 633 xla_types = { 634 np.bool: xla_client.PrimitiveType.PRED, 635 np.int32: xla_client.PrimitiveType.S32, 636 np.int64: xla_client.PrimitiveType.S64, 637 np.float32: xla_client.PrimitiveType.F32, 638 np.float64: xla_client.PrimitiveType.F64, 639 } 640 641 def _ConvertAndTest(template, src_dtype, dst_dtype): 642 c = self._NewComputation() 643 x = c.Constant(np.array(template, dtype=src_dtype)) 644 c.ConvertElementType(x, xla_types[dst_dtype]) 645 646 result = xla_client.execute_with_python_values(c.Build().Compile()) 647 expected = np.array(template, dtype=dst_dtype) 648 649 self.assertEqual(result.shape, expected.shape) 650 self.assertEqual(result.dtype, expected.dtype) 651 np.testing.assert_equal(result, expected) 652 653 x = [0, 1, 0, 0, 1] 654 for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): 655 _ConvertAndTest(x, src_dtype, dst_dtype) 656 657 def testBitcastConvertType(self): 658 xla_x32_types = { 659 np.int32: xla_client.PrimitiveType.S32, 660 np.float32: xla_client.PrimitiveType.F32, 661 } 662 663 xla_x64_types = { 664 np.int64: xla_client.PrimitiveType.S64, 665 np.float64: xla_client.PrimitiveType.F64, 666 } 667 668 def _ConvertAndTest(template, src_dtype, dst_dtype, dst_etype): 669 c = self._NewComputation() 670 x = c.Constant(np.array(template, dtype=src_dtype)) 671 c.BitcastConvertType(x, dst_etype) 672 673 result = xla_client.execute_with_python_values(c.Build().Compile()) 674 expected = np.array(template, src_dtype).view(dst_dtype) 675 676 self.assertEqual(result.shape, expected.shape) 677 self.assertEqual(result.dtype, expected.dtype) 678 np.testing.assert_equal(result, expected) 679 680 x = [0, 1, 0, 0, 1] 681 for xla_types in [xla_x32_types, xla_x64_types]: 682 for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): 683 _ConvertAndTest(x, src_dtype, dst_dtype, xla_types[dst_dtype]) 684 685 # TODO(b/123523486) implement AllToAll on CPU 686 def DISABLED_testAllToAllOneReplica(self): 687 samples = [ 688 NumpyArrayF32([97.0]), 689 NumpyArrayF32([64.0, 117.0]), 690 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 691 ] 692 for lhs in samples[:1]: 693 c = self._NewComputation() 694 c.AllToAll(c.Constant(lhs), 0, 0) 695 self._ExecuteAndCompareExact(c, expected=lhs) 696 697 def testCrossReplicaSumOneReplica(self): 698 samples = [ 699 NumpyArrayF32(42.0), 700 NumpyArrayF32([97.0]), 701 NumpyArrayF32([64.0, 117.0]), 702 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 703 ] 704 for lhs in samples: 705 c = self._NewComputation() 706 c.CrossReplicaSum(c.Constant(lhs)) 707 self._ExecuteAndCompareExact(c, expected=lhs) 708 709 def testReplicaId(self): 710 c = self._NewComputation() 711 _ = c.ReplicaId() 712 self._ExecuteAndCompareExact(c, expected=0) 713 714 def testCrossReplicaSumOneReplicaWithSingletonGroup(self): 715 samples = [ 716 NumpyArrayF32(42.0), 717 NumpyArrayF32([97.0]), 718 NumpyArrayF32([64.0, 117.0]), 719 NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), 720 ] 721 for lhs in samples: 722 c = self._NewComputation() 723 c.CrossReplicaSum(c.Constant(lhs), [[0]]) 724 self._ExecuteAndCompareExact(c, expected=lhs) 725 726 def testDotMatrixVectorF32(self): 727 c = self._NewComputation() 728 lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) 729 rhs = NumpyArrayF32([[10.0], [20.0]]) 730 c.Dot(c.Constant(lhs), c.Constant(rhs)) 731 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 732 733 def testDotMatrixVectorF64(self): 734 c = self._NewComputation() 735 lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) 736 rhs = NumpyArrayF64([[10.0], [20.0]]) 737 c.Dot(c.Constant(lhs), c.Constant(rhs)) 738 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 739 740 def testDotMatrixMatrixF32(self): 741 c = self._NewComputation() 742 lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) 743 rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) 744 c.Dot(c.Constant(lhs), c.Constant(rhs)) 745 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 746 747 def testDotMatrixMatrixF64(self): 748 c = self._NewComputation() 749 lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) 750 rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) 751 c.Dot(c.Constant(lhs), c.Constant(rhs)) 752 self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) 753 754 def testDotGeneral(self): 755 c = self._NewComputation() 756 rng = np.random.RandomState(0) 757 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 758 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 759 dimension_numbers = (([2], [1]), ([0], [0])) 760 c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) 761 self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) 762 763 def testDotGeneralWithDotDimensionNumbersProto(self): 764 c = self._NewComputation() 765 rng = np.random.RandomState(0) 766 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 767 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 768 769 dimension_numbers = xla_client.DotDimensionNumbers() 770 dimension_numbers.lhs_contracting_dimensions.append(2) 771 dimension_numbers.rhs_contracting_dimensions.append(1) 772 dimension_numbers.lhs_batch_dimensions.append(0) 773 dimension_numbers.rhs_batch_dimensions.append(0) 774 775 c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) 776 self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) 777 778 def testDotGeneralWithPrecisionConfig(self): 779 c = self._NewComputation() 780 rng = np.random.RandomState(0) 781 lhs = NumpyArrayF32(rng.randn(10, 3, 4)) 782 rhs = NumpyArrayF32(rng.randn(10, 4, 5)) 783 dimension_numbers = (([2], [1]), ([0], [0])) 784 config = xla_client.PrecisionConfig() 785 config.operand_precision.append(config.Precision.HIGH) 786 config.operand_precision.append(config.Precision.HIGHEST) 787 c.DotGeneral( 788 c.Constant(lhs), 789 c.Constant(rhs), 790 dimension_numbers, 791 precision_config=config) 792 self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) 793 794 def testConvF32Same(self): 795 c = self._NewComputation() 796 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 797 lhs = a(1, 2, 3, 4) 798 rhs = a(1, 2, 1, 2) * 10 799 c.Conv( 800 c.Constant(lhs), c.Constant(rhs), [1, 1], xla_client.PaddingType.SAME) 801 result = np.array([[[ 802 [640., 700., 760., 300.], 803 [880., 940., 1000., 380.], 804 [1120., 1180., 1240., 460.], 805 ]]]) 806 self._ExecuteAndCompareClose(c, expected=result) 807 808 def testConvF32Valid(self): 809 c = self._NewComputation() 810 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 811 lhs = a(1, 2, 3, 4) 812 rhs = a(1, 2, 1, 2) * 10 813 c.Conv( 814 c.Constant(lhs), c.Constant(rhs), [2, 1], xla_client.PaddingType.VALID) 815 result = np.array([[[ 816 [640., 700., 760.], 817 [1120., 1180., 1240.], 818 ]]]) 819 self._ExecuteAndCompareClose(c, expected=result) 820 821 def testConvWithGeneralPaddingF32(self): 822 c = self._NewComputation() 823 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 824 lhs = a(1, 1, 2, 3) 825 rhs = a(1, 1, 1, 2) * 10 826 strides = [1, 1] 827 pads = [(1, 0), (0, 1)] 828 lhs_dilation = (2, 1) 829 rhs_dilation = (1, 1) 830 c.ConvWithGeneralPadding( 831 c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation, 832 rhs_dilation) 833 result = np.array([[[ 834 [0., 0., 0.], 835 [10., 20., 0.], 836 [0., 0., 0.], 837 [40., 50., 0.], 838 ]]]) 839 self._ExecuteAndCompareClose(c, expected=result) 840 841 def testConvGeneralDilatedF32(self): 842 c = self._NewComputation() 843 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 844 lhs = a(1, 1, 2, 3) 845 rhs = a(1, 1, 1, 2) * 10 846 strides = [1, 1] 847 pads = [(1, 0), (0, 1)] 848 lhs_dilation = (2, 1) 849 rhs_dilation = (1, 1) 850 dimension_numbers = ("NCHW", "OIHW", "NCHW") 851 c.ConvGeneralDilated( 852 c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation, 853 rhs_dilation, dimension_numbers) 854 result = np.array([[[ 855 [0., 0., 0.], 856 [10., 20., 0.], 857 [0., 0., 0.], 858 [40., 50., 0.], 859 ]]]) 860 self._ExecuteAndCompareClose(c, expected=result) 861 862 def testConvGeneralDilatedF32WithPrecisionConfig(self): 863 c = self._NewComputation() 864 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 865 lhs = a(1, 1, 2, 3) 866 rhs = a(1, 1, 1, 2) * 10 867 strides = [1, 1] 868 pads = [(1, 0), (0, 1)] 869 lhs_dilation = (2, 1) 870 rhs_dilation = (1, 1) 871 dimension_numbers = ("NCHW", "OIHW", "NCHW") 872 config = xla_client.PrecisionConfig() 873 config.operand_precision.append(config.Precision.HIGHEST) 874 config.operand_precision.append(config.Precision.DEFAULT) 875 c.ConvGeneralDilated( 876 c.Constant(lhs), 877 c.Constant(rhs), 878 strides, 879 pads, 880 lhs_dilation, 881 rhs_dilation, 882 dimension_numbers, 883 precision_config=config) 884 result = np.array([[[ 885 [0., 0., 0.], 886 [10., 20., 0.], 887 [0., 0., 0.], 888 [40., 50., 0.], 889 ]]]) 890 self._ExecuteAndCompareClose(c, expected=result) 891 892 def testConvGeneralDilatedPermutedF32(self): 893 c = self._NewComputation() 894 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 895 lhs = a(1, 1, 2, 3) 896 rhs = a(1, 1, 1, 2) * 10 897 strides = [1, 1] 898 pads = [(1, 0), (0, 1)] 899 lhs_dilation = (2, 1) 900 rhs_dilation = (1, 1) 901 902 dimension_numbers = ("NHWC", "OIHW", "CWNH") 903 c.ConvGeneralDilated( 904 c.Constant(np.transpose(lhs, (0, 2, 3, 1))), c.Constant(rhs), strides, 905 pads, lhs_dilation, rhs_dilation, dimension_numbers) 906 result = np.array([[[[0., 0., 0.], [10., 20., 0.], [0., 0., 0.], 907 [40., 50., 0.]]]]) 908 self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) 909 910 def testConvGeneralDilatedGroupedConvolutionF32(self): 911 c = self._NewComputation() 912 a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") 913 lhs = a(1, 2, 2, 3) 914 rhs = a(2, 1, 1, 2) * 10 915 strides = [1, 1] 916 pads = [(1, 0), (0, 1)] 917 lhs_dilation = (2, 1) 918 rhs_dilation = (1, 1) 919 dimension_numbers = ("NCHW", "OIHW", "NCHW") 920 feature_group_count = 2 921 c.ConvGeneralDilated( 922 c.Constant(lhs), c.Constant(rhs), strides, pads, lhs_dilation, 923 rhs_dilation, dimension_numbers, feature_group_count) 924 result = np.array([[[ 925 [0., 0., 0.], 926 [10., 20., 0.], 927 [0., 0., 0.], 928 [40., 50., 0.], 929 ], [ 930 [0., 0., 0.], 931 [330., 380., 160.], 932 [0., 0., 0.], 933 [480., 530., 220.], 934 ]]]) 935 self._ExecuteAndCompareClose(c, expected=result) 936 937 def testBooleanNot(self): 938 c = self._NewComputation() 939 arr = NumpyArrayBool([True, False, True]) 940 c.Not(c.Constant(arr)) 941 self._ExecuteAndCompareClose(c, expected=~arr) 942 943 def testCountLeadingZeros(self): 944 c = self._NewComputation() 945 arr = NumpyArrayS32([0x7FFF, 0x12345678]) 946 c.Clz(c.Constant(arr)) 947 self._ExecuteAndCompareClose(c, expected=[17, 3]) 948 949 def testExp(self): 950 c = self._NewComputation() 951 arr = NumpyArrayF32([3.3, 12.1]) 952 c.Exp(c.Constant(arr)) 953 self._ExecuteAndCompareClose(c, expected=np.exp(arr)) 954 955 def testExpm1(self): 956 c = self._NewComputation() 957 arr = NumpyArrayF32([3.3, 12.1]) 958 c.Expm1(c.Constant(arr)) 959 self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) 960 961 def testRound(self): 962 c = self._NewComputation() 963 arr = NumpyArrayF32([3.3, 12.1]) 964 c.Round(c.Constant(arr)) 965 self._ExecuteAndCompareClose(c, expected=np.round(arr)) 966 967 def testLog(self): 968 c = self._NewComputation() 969 arr = NumpyArrayF32([3.3, 12.1]) 970 c.Log(c.Constant(arr)) 971 self._ExecuteAndCompareClose(c, expected=np.log(arr)) 972 973 def testLog1p(self): 974 c = self._NewComputation() 975 arr = NumpyArrayF32([3.3, 12.1]) 976 c.Log1p(c.Constant(arr)) 977 self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) 978 979 def testNeg(self): 980 c = self._NewComputation() 981 arr = NumpyArrayF32([3.3, 12.1]) 982 c.Neg(c.Constant(arr)) 983 self._ExecuteAndCompareClose(c, expected=-arr) 984 985 def testFloor(self): 986 c = self._NewComputation() 987 arr = NumpyArrayF32([3.3, 12.1]) 988 c.Floor(c.Constant(arr)) 989 self._ExecuteAndCompareClose(c, expected=np.floor(arr)) 990 991 def testCeil(self): 992 c = self._NewComputation() 993 arr = NumpyArrayF32([3.3, 12.1]) 994 c.Ceil(c.Constant(arr)) 995 self._ExecuteAndCompareClose(c, expected=np.ceil(arr)) 996 997 def testAbs(self): 998 c = self._NewComputation() 999 arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) 1000 c.Abs(c.Constant(arr)) 1001 self._ExecuteAndCompareClose(c, expected=np.abs(arr)) 1002 1003 def testTanh(self): 1004 c = self._NewComputation() 1005 arr = NumpyArrayF32([3.3, 12.1]) 1006 c.Tanh(c.Constant(arr)) 1007 self._ExecuteAndCompareClose(c, expected=np.tanh(arr)) 1008 1009 def testTrans(self): 1010 1011 def _TransposeAndTest(array): 1012 c = self._NewComputation() 1013 c.Trans(c.Constant(array)) 1014 self._ExecuteAndCompareClose(c, expected=array.T) 1015 1016 # Test square and non-square matrices in both default (C) and F orders. 1017 for array_fun in [NumpyArrayF32, NumpyArrayF64]: 1018 _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]])) 1019 _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F")) 1020 _TransposeAndTest(array_fun([[1, 2], [4, 5]])) 1021 _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F")) 1022 1023 def testTranspose(self): 1024 1025 def _TransposeAndTest(array, permutation): 1026 c = self._NewComputation() 1027 c.Transpose(c.Constant(array), permutation) 1028 expected = np.transpose(array, permutation) 1029 self._ExecuteAndCompareClose(c, expected=expected) 1030 1031 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) 1032 _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) 1033 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) 1034 _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) 1035 1036 arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) 1037 for permutation in itertools.permutations(range(arr.ndim)): 1038 _TransposeAndTest(arr, permutation) 1039 _TransposeAndTest(np.asfortranarray(arr), permutation) 1040 1041 def testEq(self): 1042 c = self._NewComputation() 1043 c.Eq( 1044 c.Constant(NumpyArrayS32([1, 2, 3, 4])), 1045 c.Constant(NumpyArrayS32([4, 2, 3, 1]))) 1046 self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) 1047 1048 def testNe(self): 1049 c = self._NewComputation() 1050 c.Ne( 1051 c.Constant(NumpyArrayS32([1, 2, 3, 4])), 1052 c.Constant(NumpyArrayS32([4, 2, 3, 1]))) 1053 self._ExecuteAndCompareExact(c, expected=[True, False, False, True]) 1054 1055 c.Ne( 1056 c.Constant(NumpyArrayF32([-2.0, 0.0, 1057 float("nan"), 1058 float("nan")])), 1059 c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) 1060 self._ExecuteAndAssertWith( 1061 np.testing.assert_allclose, c, (), expected=[True, False, True, True]) 1062 1063 def testGt(self): 1064 c = self._NewComputation() 1065 c.Gt( 1066 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 1067 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 1068 self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False]) 1069 1070 def testGe(self): 1071 c = self._NewComputation() 1072 c.Ge( 1073 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 1074 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 1075 self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False]) 1076 1077 def testLt(self): 1078 c = self._NewComputation() 1079 c.Lt( 1080 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 1081 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 1082 self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True]) 1083 1084 def testLe(self): 1085 c = self._NewComputation() 1086 c.Le( 1087 c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), 1088 c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) 1089 self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True]) 1090 1091 def testMax(self): 1092 c = self._NewComputation() 1093 c.Max( 1094 c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 1095 c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 1096 self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0]) 1097 1098 def testMaxExplicitBroadcastDim0(self): 1099 c = self._NewComputation() 1100 c.Max( 1101 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1102 c.Constant(NumpyArrayF32([3, 4, 5])), 1103 broadcast_dimensions=(0,)) 1104 self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]]) 1105 1106 def testMaxExplicitBroadcastDim1(self): 1107 c = self._NewComputation() 1108 c.Max( 1109 c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1110 c.Constant(NumpyArrayF32([3, 4, 5])), 1111 broadcast_dimensions=(1,)) 1112 self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]]) 1113 1114 def testMin(self): 1115 c = self._NewComputation() 1116 c.Min( 1117 c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), 1118 c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) 1119 self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) 1120 1121 def testPad(self): 1122 c = self._NewComputation() 1123 c.Pad( 1124 c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 1125 c.Constant(NumpyArrayF32(0.0)), [(1, 2, 1), (0, 1, 0)]) 1126 self._ExecuteAndCompareClose( 1127 c, 1128 expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], 1129 [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 1130 1131 def testPadWithPaddingConfig(self): 1132 c = self._NewComputation() 1133 padding_config = xla_client.PaddingConfig() 1134 for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: 1135 dimension = xla_client.PaddingConfigDimension() 1136 dimension.edge_padding_low = lo 1137 dimension.edge_padding_high = hi 1138 dimension.interior_padding = interior 1139 padding_config.dimensions.append(dimension) 1140 c.Pad( 1141 c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), 1142 c.Constant(NumpyArrayF32(0.0)), padding_config) 1143 self._ExecuteAndCompareClose( 1144 c, 1145 expected=[[0.0, 0.0, 0.0], [1.0, 2.0, 0.0], [0.0, 0.0, 0.0], 1146 [3.0, 4.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) 1147 1148 def testReshape(self): 1149 c = self._NewComputation() 1150 c.Reshape( 1151 c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), 1152 dimensions=[0, 1], 1153 new_sizes=[2, 3]) 1154 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]]) 1155 1156 def testCollapse(self): 1157 c = self._NewComputation() 1158 c.Collapse( 1159 c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1160 dimensions=[1, 2]) 1161 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]]) 1162 1163 def testRev(self): 1164 c = self._NewComputation() 1165 c.Rev( 1166 c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), 1167 dimensions=[0, 2]) 1168 self._ExecuteAndCompareExact( 1169 c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]) 1170 1171 def testReducePrecision(self): 1172 c = self._NewComputation() 1173 c.ReducePrecision( 1174 c.Constant(NumpyArrayF32([float.fromhex("0x1.32fffep-3")])), 1175 exponent_bits=8, 1176 mantissa_bits=7) 1177 self._ExecuteAndCompareClose(c, expected=[float.fromhex("0x1.32p-3")]) 1178 1179 def testClampF32(self): 1180 c = self._NewComputation() 1181 c.Clamp( 1182 c.Constant(NumpyArrayF32(-1)), 1183 c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])), 1184 c.Constant(NumpyArrayF32(2))) 1185 self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) 1186 1187 def testClampS32(self): 1188 c = self._NewComputation() 1189 c.Clamp( 1190 c.Constant(NumpyArrayS32(-1)), 1191 c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), 1192 c.Constant(NumpyArrayS32(2))) 1193 self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) 1194 1195 def testSelect(self): 1196 c = self._NewComputation() 1197 c.Select( 1198 c.Constant(NumpyArrayBool([True, False, False, True, False])), 1199 c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), 1200 c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) 1201 self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5]) 1202 1203 def testSlice(self): 1204 c = self._NewComputation() 1205 c.Slice( 1206 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], 1207 [3, 2]) 1208 self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) 1209 1210 def testSliceInDim(self): 1211 c = self._NewComputation() 1212 c.SliceInDim( 1213 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1214 start_index=1, 1215 limit_index=2, 1216 stride=1, 1217 dimno=1) 1218 self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]]) 1219 c.SliceInDim( 1220 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1221 start_index=0, 1222 limit_index=3, 1223 stride=2, 1224 dimno=0) 1225 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]]) 1226 1227 def testDynamicSlice(self): 1228 c = self._NewComputation() 1229 c.DynamicSlice( 1230 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1231 c.Constant(NumpyArrayS32([1, 0])), [2, 2]) 1232 self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) 1233 1234 def testDynamicUpdateSlice(self): 1235 c = self._NewComputation() 1236 c.DynamicUpdateSlice( 1237 c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), 1238 c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), 1239 c.Constant(NumpyArrayS32([1, 1]))) 1240 self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]]) 1241 1242 def testTuple(self): 1243 c = self._NewComputation() 1244 c.Tuple( 1245 c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), 1246 c.Constant(NumpyArrayBool([True, False, False, True]))) 1247 result = xla_client.execute_with_python_values(c.Build().Compile()) 1248 self.assertIsInstance(result, tuple) 1249 np.testing.assert_equal(result[0], 42) 1250 np.testing.assert_allclose(result[1], [1.0, 2.0]) 1251 np.testing.assert_equal(result[2], [True, False, False, True]) 1252 1253 def testGetTupleElement(self): 1254 c = self._NewComputation() 1255 c.GetTupleElement( 1256 c.Tuple( 1257 c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), 1258 c.Constant(NumpyArrayBool([True, False, False, True]))), 1) 1259 self._ExecuteAndCompareClose(c, expected=[1.0, 2.0]) 1260 1261 def testBroadcast(self): 1262 c = self._NewComputation() 1263 c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) 1264 self._ExecuteAndCompareExact( 1265 c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) 1266 1267 def testBroadcastInDim(self): 1268 c = self._NewComputation() 1269 c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [0]) 1270 self._ExecuteAndCompareExact(c, expected=[[1, 1], [2, 2]]) 1271 c.BroadcastInDim(c.Constant(NumpyArrayS32([1, 2])), [2, 2], [1]) 1272 self._ExecuteAndCompareExact(c, expected=[[1, 2], [1, 2]]) 1273 1274 def testRngNormal(self): 1275 shape = (2, 3) 1276 c = self._NewComputation() 1277 c.RngNormal( 1278 c.Constant(NumpyArrayF32(0.)), 1279 c.Constant(NumpyArrayF32(1.)), 1280 dims=shape) 1281 result = xla_client.execute_with_python_values(c.Build().Compile()) 1282 # since the result is random, we just check shape and uniqueness 1283 self.assertEqual(result.shape, shape) 1284 self.assertLen(np.unique(result), np.prod(shape)) 1285 1286 def testRngUniformF32(self): 1287 lo, hi = 2., 4. 1288 shape = (2, 3) 1289 c = self._NewComputation() 1290 c.RngUniform( 1291 c.Constant(NumpyArrayF32(lo)), 1292 c.Constant(NumpyArrayF32(hi)), 1293 dims=shape) 1294 result = xla_client.execute_with_python_values(c.Build().Compile()) 1295 # since the result is random, we just check shape, uniqueness, and range 1296 self.assertEqual(result.shape, shape) 1297 self.assertLen(np.unique(result), np.prod(shape)) 1298 self.assertTrue(np.all(lo <= result)) 1299 self.assertTrue(np.all(result < hi)) 1300 1301 def testRngUniformS32(self): 1302 lo, hi = 2, 4 1303 shape = (2, 3) 1304 c = self._NewComputation() 1305 c.RngUniform( 1306 c.Constant(NumpyArrayS32(lo)), 1307 c.Constant(NumpyArrayS32(hi)), 1308 dims=shape) 1309 result = xla_client.execute_with_python_values(c.Build().Compile()) 1310 # since the result is random, we just check shape, integrality, and range 1311 self.assertEqual(result.shape, shape) 1312 self.assertEqual(result.dtype, np.int32) 1313 self.assertTrue(np.all(lo <= result)) 1314 self.assertTrue(np.all(result < hi)) 1315 1316 def testCholesky(self): 1317 l = np.array([[4, 0, 0, 0], [6, 5, 0, 0], [2, 14, 16, 0], [3, 6, 1, 4]], 1318 dtype=np.float32) 1319 c = self._NewComputation() 1320 c.Cholesky(c.Constant(np.dot(l, l.T))) 1321 self._ExecuteAndCompareClose(c, expected=l, rtol=1e-4) 1322 1323 def testSort(self): 1324 keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) 1325 c = self._NewComputation() 1326 c.Sort(c.Constant(keys)) 1327 self._ExecuteAndCompareClose( 1328 c, expected=np.array([[1, 2, 3, 4], [1, 2, 3, 4]], dtype=np.float32)) 1329 1330 def testSortKeyVal(self): 1331 keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) 1332 values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) 1333 c = self._NewComputation() 1334 c.Sort((c.Constant(keys), c.Constant(values)), dimension=0) 1335 result = xla_client.execute_with_python_values(c.Build().Compile()) 1336 self.assertIsInstance(result, tuple) 1337 np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) 1338 np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) 1339 1340 def testSortCustomComparator(self): 1341 b = self._NewComputation("comparator") 1342 p0 = b.ParameterFromNumpy(NumpyArrayF32(0)) 1343 q0 = b.ParameterFromNumpy(NumpyArrayF32(0)) 1344 p1 = b.ParameterFromNumpy(NumpyArrayS32(0)) 1345 q1 = b.ParameterFromNumpy(NumpyArrayS32(0)) 1346 b.Or(b.Lt(p0, q0), b.And(b.Eq(p0, q0), b.Gt(p1, q1))) 1347 comparator = b.Build() 1348 1349 keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) 1350 values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) 1351 c = self._NewComputation() 1352 c.Sort((c.Constant(keys), c.Constant(values)), 1353 dimension=1, 1354 comparator=comparator) 1355 result = xla_client.execute_with_python_values(c.Build().Compile()) 1356 self.assertIsInstance(result, tuple) 1357 np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) 1358 np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) 1359 1360 def testQR(self): 1361 a = np.array( 1362 [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], 1363 dtype=np.float32) 1364 c = self._NewComputation() 1365 c.QR(c.Constant(a), full_matrices=True) 1366 q, r = self._Execute(c, ()) 1367 np.testing.assert_allclose(np.dot(q, r), a, rtol=1e-4) 1368 1369 def testEigh(self): 1370 a = np.array( 1371 [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], 1372 dtype=np.float32) 1373 a = (a + a.T) / 2 1374 1375 c = self._NewComputation() 1376 c.Eigh(c.Constant(a), full_matrices=True) 1377 # TODO(b/129396575): Turn this test back on when it passes without fastmath. 1378 # v, w = self._Execute(c, ()) 1379 # self.assertLess(np.linalg.norm(np.dot(a, v) - w * v), 1e-3) 1380 1381 def testSVD(self): 1382 a = np.array( 1383 [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], 1384 dtype=np.float32) 1385 c = self._NewComputation() 1386 c.SVD(c.Constant(a)) 1387 u, d, v = self._Execute(c, ()) 1388 self.assertLess(np.linalg.norm(a - np.matmul(u * d, v.T)), 1e-3) 1389 1390 def testTriangularSolve(self): 1391 a_vals = np.array( 1392 [[2, 0, 0, 0], [3, 6, 0, 0], [4, 7, 9, 0], [5, 8, 10, 11]], 1393 dtype=np.float32) 1394 b_vals = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], 1395 dtype=np.float32) 1396 1397 c = self._NewComputation() 1398 c.TriangularSolve( 1399 c.Constant(a_vals), 1400 c.Constant(b_vals), 1401 left_side=False, 1402 lower=True, 1403 transpose_a=True) 1404 self._ExecuteAndCompareClose( 1405 c, 1406 expected=np.array([ 1407 [0.5, 0.08333334, 0.04629629, 0.03367003], 1408 [2.5, -0.25, -0.1388889, -0.1010101], 1409 [4.5, -0.58333331, -0.32407406, -0.23569024], 1410 ], 1411 dtype=np.float32), 1412 rtol=1e-4) 1413 1414 def testIsConstant(self): 1415 c = self._NewComputation() 1416 a = c.ConstantS32Scalar(3) 1417 b = c.ConstantS32Scalar(1) 1418 x = c.ParameterFromNumpy(NumpyArrayS32(0)) 1419 const_expr = c.Sub(b, a) 1420 non_const_expr = c.Mul(const_expr, x) 1421 self.assertTrue(c.IsConstant(const_expr)) 1422 self.assertFalse(c.IsConstant(non_const_expr)) 1423 # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) 1424 1425 def testGather(self): 1426 a = np.arange(9).astype(np.int32).reshape((3, 3)) 1427 indices = np.array([[[0, 2], [2, 1]], [[1, 2], [2, 0]]], dtype=np.int32) 1428 dnums = xla_client.GatherDimensionNumbers() 1429 dnums.offset_dims.append(1) 1430 dnums.offset_dims.append(2) 1431 dnums.start_index_map.append(0) 1432 dnums.start_index_map.append(1) 1433 dnums.index_vector_dim = 2 1434 c = self._NewComputation() 1435 c.Gather(c.Constant(a), c.Constant(indices), dnums, slice_sizes=[1, 1]) 1436 g = self._Execute(c, ()) 1437 expected = np.array([[[[2, 7]]], [[[5, 6]]]], dtype=np.int32) 1438 np.testing.assert_allclose(g, expected, rtol=1e-4) 1439 1440 def testFft(self): 1441 shape = [2, 3, 4, 5] 1442 rng = np.random.RandomState(0) 1443 a = rng.randn(*shape) + 1.0j * rng.randn(*shape) 1444 a = a.astype(np.complex64) 1445 # FFT 1446 c = self._NewComputation() 1447 c.Fft(c.Constant(a), xla_client.FftType.FFT, shape[-3:]) 1448 self._ExecuteAndCompareClose( 1449 c, expected=np.fft.fftn(a, axes=(1, 2, 3)), rtol=1e-4) 1450 # IFFT 1451 c = self._NewComputation() 1452 c.Fft(c.Constant(a), xla_client.FftType.IFFT, shape[-3:]) 1453 self._ExecuteAndCompareClose( 1454 c, expected=np.fft.ifftn(a, axes=(1, 2, 3)), rtol=1e-4) 1455 # RFFT 1456 b = rng.randn(*shape).astype(np.float32) 1457 c = self._NewComputation() 1458 c.Fft(c.Constant(b), xla_client.FftType.RFFT, shape[-3:]) 1459 self._ExecuteAndCompareClose( 1460 c, expected=np.fft.rfftn(b, axes=(1, 2, 3)), rtol=1e-4) 1461 # IRFFT 1462 c = self._NewComputation() 1463 c.Fft(c.Constant(a), xla_client.FftType.IRFFT, [3, 4, 8]) 1464 self._ExecuteAndCompareClose( 1465 c, expected=np.fft.irfftn(a, axes=(1, 2, 3)), rtol=1e-4) 1466 1467 def testNextAfter(self): 1468 c = self._NewComputation() 1469 c.NextAfter( 1470 c.Constant(np.array([1, 2], dtype=np.float32)), 1471 c.Constant(np.array([2, 1], dtype=np.float32))) 1472 out = self._Execute(c, ()) 1473 eps = np.finfo(np.float32).eps 1474 np.testing.assert_equal(np.array([eps + 1, 2 - eps], dtype=np.float32), out) 1475 1476 def testRegularizedIncompleteBeta(self): 1477 x = np.array([0.53787335, 0.24015466, 0.47494545, 0.13567594, 0.95114538]) 1478 a = np.array([0.00753073, 0.34813385, 0.30485708, 1.29298632, 0.51472606]) 1479 b = np.array([0.55688389, 0.59794214, 0.42661022, 1.59748339, 0.95047677]) 1480 c = self._NewComputation() 1481 c.RegularizedIncompleteBeta(c.Constant(a), c.Constant(b), c.Constant(x)) 1482 expected = np.array( 1483 [0.98923271, 0.48575411, 0.57952568, 0.12579775, 0.96989155]) 1484 self._ExecuteAndCompareClose(c, expected=expected, rtol=1e-4) 1485 1486 1487class EmbeddedComputationsTest(ComputationTest): 1488 """Tests for XLA graphs with embedded computations (such as maps).""" 1489 1490 def _CreateConstantS32Computation(self): 1491 """Computation (f32) -> s32 that returns a constant 1 for any input.""" 1492 c = self._NewComputation("constant_s32_one") 1493 # TODO(eliben): consider adding a nicer way to create new parameters without 1494 # having to create dummy Numpy arrays or populating Shape messages. Perhaps 1495 # we need our own (Python-client-own) way to represent Shapes conveniently. 1496 c.ParameterFromNumpy(NumpyArrayF32(0)) 1497 c.ConstantS32Scalar(1) 1498 return c.Build() 1499 1500 def _CreateConstantS64Computation(self): 1501 """Computation (f64) -> s64 that returns a constant 1 for any input.""" 1502 c = self._NewComputation("constant_s64_one") 1503 # TODO(eliben): consider adding a nicer way to create new parameters without 1504 # having to create dummy Numpy arrays or populating Shape messages. Perhaps 1505 # we need our own (Python-client-own) way to represent Shapes conveniently. 1506 c.ParameterFromNumpy(NumpyArrayF64(0)) 1507 c.ConstantS64Scalar(1) 1508 return c.Build() 1509 1510 def _CreateConstantF32Computation(self): 1511 """Computation (f32) -> f32 that returns a constant 1.0 for any input.""" 1512 c = self._NewComputation("constant_f32_one") 1513 c.ParameterFromNumpy(NumpyArrayF32(0)) 1514 c.ConstantF32Scalar(1.0) 1515 return c.Build() 1516 1517 def _CreateConstantF64Computation(self): 1518 """Computation (f64) -> f64 that returns a constant 1.0 for any input.""" 1519 c = self._NewComputation("constant_f64_one") 1520 c.ParameterFromNumpy(NumpyArrayF64(0)) 1521 c.ConstantF64Scalar(1.0) 1522 return c.Build() 1523 1524 def _CreateMulF32By2Computation(self): 1525 """Computation (f32) -> f32 that multiplies its parameter by 2.""" 1526 c = self._NewComputation("mul_f32_by2") 1527 c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) 1528 return c.Build() 1529 1530 def _CreateMulF32ByParamComputation(self): 1531 """Computation (f32) -> f32 that multiplies one parameter by the other.""" 1532 c = self._NewComputation("mul_f32_by_param") 1533 c.Mul( 1534 c.ParameterFromNumpy(NumpyArrayF32(0)), 1535 c.ParameterFromNumpy(NumpyArrayF32(0))) 1536 return c.Build() 1537 1538 def _CreateMulF64By2Computation(self): 1539 """Computation (f64) -> f64 that multiplies its parameter by 2.""" 1540 c = self._NewComputation("mul_f64_by2") 1541 c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) 1542 return c.Build() 1543 1544 def _CreateBinaryAddS32Computation(self): 1545 """Computation (s32, s32) -> s32 that adds its two parameters.""" 1546 c = self._NewComputation("add_param0_by_param1") 1547 c.Add( 1548 c.ParameterFromNumpy(NumpyArrayS32(0)), 1549 c.ParameterFromNumpy(NumpyArrayS32(0))) 1550 return c.Build() 1551 1552 def _CreateBinaryAddF32Computation(self): 1553 """Computation (f32, f32) -> f32 that adds its two parameters.""" 1554 c = self._NewComputation("add_param0_by_param1") 1555 c.Add( 1556 c.ParameterFromNumpy(NumpyArrayF32(0)), 1557 c.ParameterFromNumpy(NumpyArrayF32(0))) 1558 return c.Build() 1559 1560 def _CreateBinaryAddF64Computation(self): 1561 """Computation (f64, f64) -> f64 that adds its two parameters.""" 1562 c = self._NewComputation("add_param0_by_param1") 1563 c.Add( 1564 c.ParameterFromNumpy(NumpyArrayF64(0)), 1565 c.ParameterFromNumpy(NumpyArrayF64(0))) 1566 return c.Build() 1567 1568 def _CreateBinaryDivF32Computation(self): 1569 """Computation (f32, f32) -> f32 that divides its two parameters.""" 1570 c = self._NewComputation("div_param0_by_param1") 1571 c.Div( 1572 c.ParameterFromNumpy(NumpyArrayF32(0)), 1573 c.ParameterFromNumpy(NumpyArrayF32(0))) 1574 return c.Build() 1575 1576 def _CreateBinaryDivF64Computation(self): 1577 """Computation (f64, f64) -> f64 that divides its two parameters.""" 1578 c = self._NewComputation("div_param0_by_param1") 1579 c.Div( 1580 c.ParameterFromNumpy(NumpyArrayF64(0)), 1581 c.ParameterFromNumpy(NumpyArrayF64(0))) 1582 return c.Build() 1583 1584 def _CreateTestF32Lt10Computation(self): 1585 """Computation (f32) -> bool that tests if its parameter is less than 10.""" 1586 c = self._NewComputation("test_f32_lt_10") 1587 c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.)) 1588 return c.Build() 1589 1590 def _CreateTestF64Lt10Computation(self): 1591 """Computation (f64) -> bool that tests if its parameter is less than 10.""" 1592 c = self._NewComputation("test_f64_lt_10") 1593 c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.)) 1594 return c.Build() 1595 1596 def _CreateBinaryGeF32Computation(self): 1597 """Computation (f32, f32) -> bool that tests first_param >= second_param.""" 1598 c = self._NewComputation("param0_lt_param1") 1599 c.Ge( 1600 c.ParameterFromNumpy(NumpyArrayF32(0)), 1601 c.ParameterFromNumpy(NumpyArrayF32(0))) 1602 return c.Build() 1603 1604 def _CreateBinaryGeF64Computation(self): 1605 """Computation (f64, f64) -> bool that tests first_param >= second_param.""" 1606 c = self._NewComputation("param0_lt_param1") 1607 c.Ge( 1608 c.ParameterFromNumpy(NumpyArrayF64(0)), 1609 c.ParameterFromNumpy(NumpyArrayF64(0))) 1610 return c.Build() 1611 1612 def _MakeSample3DArrayF32(self): 1613 return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], 1614 [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) 1615 1616 def _MakeSample3DArrayF64(self): 1617 return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], 1618 [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) 1619 1620 def testCallF32(self): 1621 c = self._NewComputation() 1622 c.Call( 1623 self._CreateMulF32By2Computation(), 1624 operands=(c.ConstantF32Scalar(5.0),)) 1625 self._ExecuteAndCompareClose(c, expected=10.0) 1626 1627 def testCallF64(self): 1628 c = self._NewComputation() 1629 c.Call( 1630 self._CreateMulF64By2Computation(), 1631 operands=(c.ConstantF64Scalar(5.0),)) 1632 self._ExecuteAndCompareClose(c, expected=10.0) 1633 1634 def testMapEachElementToS32Constant(self): 1635 c = self._NewComputation() 1636 c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 1637 self._CreateConstantS32Computation(), [0]) 1638 self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) 1639 1640 def testMapEachElementToS64Constant(self): 1641 c = self._NewComputation() 1642 c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], 1643 self._CreateConstantS64Computation(), [0]) 1644 self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) 1645 1646 def testMapMulBy2F32(self): 1647 c = self._NewComputation() 1648 c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 1649 self._CreateMulF32By2Computation(), [0]) 1650 self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) 1651 1652 def testMapMulBy2F64(self): 1653 c = self._NewComputation() 1654 c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], 1655 self._CreateMulF64By2Computation(), [0]) 1656 self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) 1657 1658 def testSimpleMapChainF32(self): 1659 # Chains a map of constant-f32 with a map of mul-by-2 1660 c = self._NewComputation() 1661 const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], 1662 self._CreateConstantF32Computation(), [0]) 1663 c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) 1664 self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) 1665 1666 def testSimpleMapChainF64(self): 1667 # Chains a map of constant-f64 with a map of mul-by-2 1668 c = self._NewComputation() 1669 const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], 1670 self._CreateConstantF64Computation(), [0]) 1671 c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) 1672 self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) 1673 1674 def testDivVectorsWithMapF32(self): 1675 c = self._NewComputation() 1676 c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), 1677 c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), 1678 self._CreateBinaryDivF32Computation(), [0]) 1679 self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) 1680 1681 def testDivVectorsWithMapF64(self): 1682 c = self._NewComputation() 1683 c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), 1684 c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), 1685 self._CreateBinaryDivF64Computation(), [0]) 1686 self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) 1687 1688 def testSelectAndScatterF32(self): 1689 c = self._NewComputation() 1690 c.SelectAndScatter( 1691 c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), 1692 select=self._CreateBinaryGeF32Computation(), 1693 window_dimensions=(2, 1), 1694 window_strides=(1, 2), 1695 padding=xla_client.PaddingType.VALID, 1696 source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), 1697 init_value=c.Constant(NumpyArrayF32(1)), 1698 scatter=self._CreateBinaryAddF32Computation()) 1699 self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) 1700 1701 def testSelectAndScatterF64(self): 1702 c = self._NewComputation() 1703 c.SelectAndScatter( 1704 c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])), 1705 select=self._CreateBinaryGeF64Computation(), 1706 window_dimensions=(2, 1), 1707 window_strides=(1, 2), 1708 padding=xla_client.PaddingType.VALID, 1709 source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), 1710 init_value=c.Constant(NumpyArrayF64(1)), 1711 scatter=self._CreateBinaryAddF64Computation()) 1712 self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) 1713 1714 def testReduce1DtoScalarF32(self): 1715 c = self._NewComputation() 1716 c.Reduce( 1717 operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), 1718 init_value=c.ConstantF32Scalar(0), 1719 computation_to_apply=self._CreateBinaryAddF32Computation(), 1720 dimensions=[0]) 1721 self._ExecuteAndCompareClose(c, expected=10) 1722 1723 def testReduce1DtoScalarF64(self): 1724 c = self._NewComputation() 1725 c.Reduce( 1726 operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), 1727 init_value=c.ConstantF64Scalar(0), 1728 computation_to_apply=self._CreateBinaryAddF64Computation(), 1729 dimensions=[0]) 1730 self._ExecuteAndCompareClose(c, expected=10) 1731 1732 def testReduce2DTo1DDim0F32(self): 1733 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1734 c = self._NewComputation() 1735 c.Reduce( 1736 operand=c.Constant(input_array), 1737 init_value=c.ConstantF32Scalar(0), 1738 computation_to_apply=self._CreateBinaryAddF32Computation(), 1739 dimensions=[0]) 1740 self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) 1741 1742 def testReduce2DTo1DDim0F64(self): 1743 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1744 c = self._NewComputation() 1745 c.Reduce( 1746 operand=c.Constant(input_array), 1747 init_value=c.ConstantF64Scalar(0), 1748 computation_to_apply=self._CreateBinaryAddF64Computation(), 1749 dimensions=[0]) 1750 self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) 1751 1752 def testReduce2DTo1DDim1F32(self): 1753 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1754 c = self._NewComputation() 1755 c.Reduce( 1756 operand=c.Constant(input_array), 1757 init_value=c.ConstantF32Scalar(0), 1758 computation_to_apply=self._CreateBinaryAddF32Computation(), 1759 dimensions=[1]) 1760 self._ExecuteAndCompareClose(c, expected=[6, 15]) 1761 1762 def testReduce2DTo1DDim1F64(self): 1763 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1764 c = self._NewComputation() 1765 c.Reduce( 1766 operand=c.Constant(input_array), 1767 init_value=c.ConstantF64Scalar(0), 1768 computation_to_apply=self._CreateBinaryAddF64Computation(), 1769 dimensions=[1]) 1770 self._ExecuteAndCompareClose(c, expected=[6, 15]) 1771 1772 def testReduce3DAllPossibleWaysF32(self): 1773 input_array = self._MakeSample3DArrayF32() 1774 1775 def _ReduceAndTest(*dims): 1776 c = self._NewComputation() 1777 c.Reduce( 1778 operand=c.Constant(input_array), 1779 init_value=c.ConstantF32Scalar(0), 1780 computation_to_apply=self._CreateBinaryAddF32Computation(), 1781 dimensions=dims) 1782 self._ExecuteAndCompareClose( 1783 c, expected=np.sum(input_array, axis=tuple(dims))) 1784 1785 _ReduceAndTest(0) 1786 _ReduceAndTest(0, 1) 1787 _ReduceAndTest(0, 2) 1788 _ReduceAndTest(1, 2) 1789 _ReduceAndTest(0, 1, 2) 1790 1791 def testReduce3DAllPossibleWaysF64(self): 1792 input_array = self._MakeSample3DArrayF64() 1793 1794 def _ReduceAndTest(*dims): 1795 c = self._NewComputation() 1796 c.Reduce( 1797 operand=c.Constant(input_array), 1798 init_value=c.ConstantF64Scalar(0), 1799 computation_to_apply=self._CreateBinaryAddF64Computation(), 1800 dimensions=dims) 1801 self._ExecuteAndCompareClose( 1802 c, expected=np.sum(input_array, axis=tuple(dims))) 1803 1804 _ReduceAndTest(0) 1805 _ReduceAndTest(0) 1806 _ReduceAndTest(0, 1) 1807 _ReduceAndTest(0, 2) 1808 _ReduceAndTest(1, 2) 1809 _ReduceAndTest(0, 1, 2) 1810 1811 def testReduceWindowValidUnitStridesF32(self): 1812 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1813 c = self._NewComputation() 1814 c.ReduceWindow( 1815 operand=c.Constant(input_array), 1816 init_value=c.ConstantF32Scalar(0), 1817 computation_to_apply=self._CreateBinaryAddF32Computation(), 1818 window_dimensions=(2, 1), 1819 window_strides=(1, 1), 1820 padding=xla_client.PaddingType.VALID) 1821 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) 1822 1823 def testReduceWindowSameUnitStridesF32(self): 1824 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1825 c = self._NewComputation() 1826 c.ReduceWindow( 1827 operand=c.Constant(input_array), 1828 init_value=c.ConstantF32Scalar(0), 1829 computation_to_apply=self._CreateBinaryAddF32Computation(), 1830 window_dimensions=(2, 1), 1831 window_strides=(1, 1), 1832 padding=xla_client.PaddingType.SAME) 1833 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) 1834 1835 def testReduceWindowValidGeneralStridesF32(self): 1836 input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1837 c = self._NewComputation() 1838 c.ReduceWindow( 1839 operand=c.Constant(input_array), 1840 init_value=c.ConstantF32Scalar(0), 1841 computation_to_apply=self._CreateBinaryAddF32Computation(), 1842 window_dimensions=(2, 1), 1843 window_strides=(1, 2), 1844 padding=xla_client.PaddingType.VALID) 1845 self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) 1846 1847 def testReduceWindowValidUnitStridesF64(self): 1848 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1849 c = self._NewComputation() 1850 c.ReduceWindow( 1851 operand=c.Constant(input_array), 1852 init_value=c.ConstantF64Scalar(0), 1853 computation_to_apply=self._CreateBinaryAddF64Computation(), 1854 window_dimensions=(2, 1), 1855 window_strides=(1, 1), 1856 padding=xla_client.PaddingType.VALID) 1857 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) 1858 1859 def testReduceWindowSameUnitStridesF64(self): 1860 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1861 c = self._NewComputation() 1862 c.ReduceWindow( 1863 operand=c.Constant(input_array), 1864 init_value=c.ConstantF64Scalar(0), 1865 computation_to_apply=self._CreateBinaryAddF64Computation(), 1866 window_dimensions=(2, 1), 1867 window_strides=(1, 1), 1868 padding=xla_client.PaddingType.SAME) 1869 self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) 1870 1871 def testReduceWindowValidGeneralStridesF64(self): 1872 input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 1873 c = self._NewComputation() 1874 c.ReduceWindow( 1875 operand=c.Constant(input_array), 1876 init_value=c.ConstantF64Scalar(0), 1877 computation_to_apply=self._CreateBinaryAddF64Computation(), 1878 window_dimensions=(2, 1), 1879 window_strides=(1, 2), 1880 padding=xla_client.PaddingType.VALID) 1881 self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) 1882 1883 def testWhileF32(self): 1884 cond = self._CreateTestF32Lt10Computation() 1885 body = self._CreateMulF32By2Computation() 1886 c = self._NewComputation() 1887 init = c.ConstantF32Scalar(1.) 1888 c.While(cond, body, init) 1889 self._ExecuteAndCompareClose(c, expected=16.) 1890 1891 def testWhileF64(self): 1892 cond = self._CreateTestF64Lt10Computation() 1893 body = self._CreateMulF64By2Computation() 1894 c = self._NewComputation() 1895 init = c.ConstantF64Scalar(1.) 1896 c.While(cond, body, init) 1897 self._ExecuteAndCompareClose(c, expected=16.) 1898 1899 def testConditionalTrue(self): 1900 c = self._NewComputation() 1901 pred = c.ConstantPredScalar(True) 1902 true_operand = c.ConstantF32Scalar(3.) 1903 true_computation = self._CreateMulF32By2Computation() 1904 false_operand = c.ConstantF32Scalar(2.) 1905 false_computation = self._CreateConstantF32Computation() 1906 c.Conditional(pred, true_operand, true_computation, false_operand, 1907 false_computation) 1908 self._ExecuteAndCompareClose(c, expected=6.) 1909 1910 def testConditionalFalse(self): 1911 c = self._NewComputation() 1912 pred = c.ConstantPredScalar(False) 1913 true_operand = c.ConstantF32Scalar(3.) 1914 true_computation = self._CreateMulF32By2Computation() 1915 false_operand = c.ConstantF32Scalar(2.) 1916 false_computation = self._CreateConstantF32Computation() 1917 c.Conditional(pred, true_operand, true_computation, false_operand, 1918 false_computation) 1919 self._ExecuteAndCompareClose(c, expected=1.) 1920 1921 def testInfeedS32Values(self): 1922 to_infeed = NumpyArrayS32([1, 2, 3, 4]) 1923 c = self._NewComputation() 1924 c.GetTupleElement(c.Infeed(xla_client.shape_from_pyval(to_infeed[0])), 0) 1925 compiled_c = c.Build().Compile() 1926 for item in to_infeed: 1927 xla_client.transfer_to_infeed(item) 1928 1929 for item in to_infeed: 1930 result = xla_client.execute_with_python_values(compiled_c) 1931 self.assertEqual(result, item) 1932 1933 def testInfeedTuple(self): 1934 to_infeed = (NumpyArrayS32([1, 2, 3, 4]), NumpyArrayS32([[7], [8]])) 1935 c = self._NewComputation() 1936 c.GetTupleElement(c.Infeed(xla_client.shape_from_pyval(to_infeed)), 0) 1937 compiled_c = c.Build().Compile() 1938 xla_client.transfer_to_infeed(to_infeed) 1939 1940 result = xla_client.execute_with_python_values(compiled_c) 1941 np.testing.assert_equal(result[0], to_infeed[0]) 1942 np.testing.assert_equal(result[1], to_infeed[1]) 1943 1944 def testInfeedThenOutfeedS32(self): 1945 to_round_trip = NumpyArrayS32([1, 2, 3, 4]) 1946 c = self._NewComputation() 1947 x_and_token = c.Infeed(xla_client.shape_from_pyval(to_round_trip[0])) 1948 x = c.GetTupleElement(x_and_token, 0) 1949 token = c.GetTupleElement(x_and_token, 1) 1950 c.Outfeed(x, token) 1951 1952 compiled_c = c.Build().Compile() 1953 1954 for want in to_round_trip: 1955 execution = threading.Thread(target=lambda: compiled_c.Execute([])) 1956 execution.start() 1957 xla_client.transfer_to_infeed(want) 1958 got = xla_client.transfer_from_outfeed( 1959 xla_client.shape_from_pyval(to_round_trip[0])) 1960 execution.join() 1961 self.assertEqual(want, got) 1962 1963 def testScatter(self): 1964 a = np.arange(9).astype(np.int32).reshape((3, 3)) 1965 scatter_indices = np.array([0, 2], dtype=np.int32) 1966 updates = np.array([[10, 20, 30], [70, 80, 90]], dtype=np.int32) 1967 1968 dnums = xla_client.ScatterDimensionNumbers() 1969 dnums.update_window_dims.append(1) 1970 dnums.inserted_window_dims.append(0) 1971 dnums.scatter_dims_to_operand_dims.append(0) 1972 dnums.index_vector_dim = 1 1973 1974 c = self._NewComputation() 1975 c.Scatter( 1976 c.Constant(a), c.Constant(scatter_indices), c.Constant(updates), 1977 self._CreateBinaryAddS32Computation(), dnums) 1978 expected = np.array([[10, 21, 32], [3, 4, 5], [76, 87, 98]], dtype=np.int32) 1979 self._ExecuteAndCompareClose(c, expected=expected) 1980 1981 1982class ErrorTest(ComputationTest): 1983 1984 def setUp(self): 1985 self.f32_scalar_2 = NumpyArrayF32(2.0) 1986 self.s32_scalar_2 = NumpyArrayS32(2) 1987 1988 def testCompileWithWrongElementTypeInLayout(self): 1989 c = self._NewComputation() 1990 c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) 1991 c.ParameterFromNumpy(self.s32_scalar_2) 1992 c.ClearOpMetadata() 1993 1994 options = xla_client.CompileOptions() 1995 options.argument_layouts = [ 1996 xla_client.Shape.array_shape(np.dtype(np.float32), []) 1997 ] 1998 1999 def TestFun(): 2000 return c.Build().Compile(compile_options=options) 2001 2002 self.assertRaisesRegex( 2003 RuntimeError, r".*Invalid argument shape.*" 2004 r"expected s32\[\], got f32\[\].*", TestFun) 2005 2006 def testInvokeWithWrongElementType(self): 2007 c = self._NewComputation() 2008 c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) 2009 c.ParameterFromNumpy(self.s32_scalar_2) 2010 c.ClearOpMetadata() 2011 2012 def TestFun(): 2013 return xla_client.execute_with_python_values(c.Build().Compile(), 2014 [self.f32_scalar_2]) 2015 2016 self.assertRaisesRegex( 2017 RuntimeError, r"Invalid argument: Argument does not match.*" 2018 r"want s32\[\], got f32\[\].*", TestFun) 2019 2020 2021class ComputationRootTest(ComputationTest): 2022 """Tests related to setting the root of the computation.""" 2023 2024 def testComputationRootDifferentFromLastOp(self): 2025 c = self._NewComputation() 2026 x = c.ParameterFromNumpy(NumpyArrayF32(2.0)) 2027 result = c.Add(x, c.ConstantF32Scalar(3.14)) 2028 extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable 2029 2030 arg = NumpyArrayF32(1.0) 2031 compiled_c = c.Build(result).Compile() 2032 ans = xla_client.execute_with_python_values(compiled_c, [arg]) 2033 np.testing.assert_allclose(ans, 4.14) 2034 2035 2036class SetShardingTest(ComputationTest): 2037 """Tests related to set OpSharding.""" 2038 2039 def testSetSharding(self): 2040 c = self._NewComputation() 2041 sharding = xla_client.OpSharding() 2042 sharding.type = sharding.type.REPLICATED 2043 sharding.tile_assignment_dimensions.extend([1]) 2044 sharding.tile_assignment_devices.extend([0]) 2045 # Set Sharding. 2046 c.SetSharding(sharding) 2047 x = c.ParameterFromNumpy(NumpyArrayF32(2.0)) 2048 # Clear Sharding. 2049 c.ClearSharding() 2050 2051 result = c.Add(x, c.ConstantF32Scalar(3.14)) 2052 extra = c.Add(result, c.ConstantF32Scalar(1.618)) # pylint: disable=unused-variable 2053 arg = NumpyArrayF32(1.0) 2054 compiled_c = c.Build(result).Compile() 2055 ans = xla_client.execute_with_python_values(compiled_c, [arg]) 2056 np.testing.assert_allclose(ans, 4.14) 2057 2058 2059int_dtypes = [ 2060 np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, 2061 np.uint64 2062] 2063float_dtypes = [np.float16, np.float32, np.float64] 2064complex_dtypes = [np.complex64, np.complex128] 2065dlpack_dtypes = int_dtypes + float_dtypes + [bfloat16] 2066standard_dtypes = int_dtypes + float_dtypes + complex_dtypes + [np.bool_] 2067 2068testcase_shapes = [ 2069 (), 2070 (1,), 2071 (2, 3), 2072 (2, 0), 2073 (0, 7), 2074 (4, 1, 2), 2075 (2, 1, 3), 2076 (2, 4, 1), 2077 (3, 1), 2078 (1, 3), 2079] 2080 2081 2082def FormatShapeAndDtype(shape, dtype): 2083 return "_{}[{}]".format(np.dtype(dtype).name, ",".join(map(str, shape))) 2084 2085 2086class DLPackTest(parameterized.TestCase): 2087 2088 # pylint: disable=g-complex-comprehension 2089 @parameterized.named_parameters({ 2090 "testcase_name": FormatShapeAndDtype(shape, dtype), 2091 "dtype": dtype, 2092 "shape": shape 2093 } for dtype in dlpack_dtypes for shape in testcase_shapes) 2094 def testRoundTrip(self, dtype, shape): 2095 x = np.array(np.random.rand(*shape) * 100, dtype=dtype) 2096 backend = xla_client.get_local_backend() 2097 buffer = xla_client.Buffer.from_pyval(x, backend=backend) 2098 dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) 2099 del buffer # Free "buffer" to make sure dlt retains ownership. 2100 self.assertEqual(type(dlt).__name__, "PyCapsule") 2101 y = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client) 2102 np.testing.assert_array_equal(x, y.to_py()) 2103 2104 def testTensorsCanBeConsumedOnceOnly(self): 2105 x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) 2106 backend = xla_client.get_local_backend() 2107 buffer = xla_client.Buffer.from_pyval(x, backend=backend) 2108 dlt = xla_client._xla.BufferToDLPackManagedTensor(buffer) 2109 2110 def ConsumeDLPackTensor(): 2111 _ = xla_client._xla.DLPackManagedTensorToBuffer(dlt, backend.client) 2112 2113 ConsumeDLPackTensor() 2114 self.assertRaisesRegex(RuntimeError, 2115 ".*a DLPack tensor may be consumed at most once.*", 2116 ConsumeDLPackTensor) 2117 2118 2119class BufferProtocolTest(parameterized.TestCase): 2120 2121 # pylint: disable=g-complex-comprehension 2122 @parameterized.named_parameters({ 2123 "testcase_name": FormatShapeAndDtype(shape, dtype), 2124 "dtype": dtype, 2125 "shape": shape 2126 } for dtype in standard_dtypes for shape in testcase_shapes) 2127 def testRoundTrip(self, dtype, shape): 2128 x = np.array(np.random.rand(*shape) * 100, dtype=dtype) 2129 x_ptr = x.__array_interface__["data"][0] 2130 backend = xla_client.get_local_backend("cpu") 2131 buffer = xla_client.Buffer.from_pyval(x, backend=backend) 2132 y = np.array(buffer, copy=False) 2133 y_ptr = y.__array_interface__["data"][0] 2134 np.testing.assert_array_equal(x, y) 2135 # If the input was sufficiently aligned, the input and output should alias. 2136 self.assertTrue((x_ptr & 63) != 0 or x_ptr == y_ptr) 2137 self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) 2138 2139 buffer2 = xla_client.Buffer.from_pyval(x, backend=backend, force_copy=True) 2140 z = np.array(buffer2, copy=False) 2141 self.assertNotEqual(x.__array_interface__["data"][0], 2142 z.__array_interface__["data"][0]) 2143 2144 def testDeleteWithActiveView(self): 2145 x = np.random.randn(20, 10) 2146 backend = xla_client.get_local_backend("cpu") 2147 buffer = xla_client.Buffer.from_pyval(x, backend=backend) 2148 buffer_ptr = buffer.unsafe_buffer_pointer() 2149 y = np.array(buffer, copy=False) 2150 buffer.delete() 2151 # It is still legal to access `y`; the array view must keep it alive. 2152 np.testing.assert_array_equal(x, y) 2153 self.assertEqual(y.__array_interface__["data"][0], buffer_ptr) 2154 2155 2156if __name__ == "__main__": 2157 absltest.main() 2158