1# Lint as: python3 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""An XLA client in Python.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import abc 23import collections 24import enum # pylint: disable=g-bad-import-order 25import inspect 26import itertools 27import os 28 29from absl import logging 30import numpy as np 31 32# Note this module does *not* depend on any Python protocol buffers. The XLA 33# Python bindings are currently packaged both as part of jaxlib and as part 34# of TensorFlow. If we use protocol buffers here, then importing both jaxlib 35# and TensorFlow may fail with duplicate protocol buffer message definitions. 36 37from tensorflow.compiler.xla.python import xla_extension as _xla 38from tensorflow.compiler.xla.python.xla_extension import ops 39 40# Most functions are snake_case for consistency with other modules, whereas 41# method names of ComputationBuilder and Computation are CamelCase for 42# consistency with XLA. 43# pylint: disable=invalid-name 44 45 46class Backend(object, metaclass=abc.ABCMeta): 47 """Abstract base class for XLA backends.""" 48 49 def __init__(self, platform): 50 """Creates a new Backend. 51 52 Args: 53 platform: A string naming the platform; for example 'gpu'. 54 """ 55 self.platform = platform 56 57 @abc.abstractmethod 58 def device_count(self): 59 """Returns the number of devices known to the backend.""" 60 61 @abc.abstractmethod 62 def local_device_count(self): 63 """Returns the number of devices local to this host.""" 64 65 @abc.abstractmethod 66 def devices(self): 67 """Returns a list of `device_count()` Device subclasses.""" 68 69 @abc.abstractmethod 70 def host_id(self): 71 """Returns the integer ID of this host.""" 72 73 @abc.abstractmethod 74 def buffer_from_pyval(self, pyval, device=None, force_copy=False): 75 """Allocates a fresh buffer and populates it with `pyval`.""" 76 77 @abc.abstractmethod 78 def make_tuple(self, c_buffers, device): 79 """Makes a tuple from a sequence of backend buffer objects.""" 80 81 @abc.abstractmethod 82 def compile(self, computation, compile_options): 83 """Compiles a computation. Returns an executable.""" 84 85 @abc.abstractmethod 86 def get_default_device_assignment(self, num_replicas, num_partitions): 87 """Returns the default device assignment that `compile` would use. 88 89 If `compile_options.device_assignment` isn't set, `compile` will pick a 90 deterministic device assignment based on the number of replicas and 91 partitions, possibly optimizing for device locality. This method returns 92 that assignment, which is useful for e.g. manually replicating a value 93 before passing it to a compiled executable. 94 95 Args: 96 num_replicas: the number of replicas needed. 97 num_partitions: the number of partitions needed. 98 99 Returns: 100 A list of list of Devices of size `(num_replicas, num_partitions)`. 101 """ 102 103 104class LocalBackend(Backend): 105 """XLA backend implemented using the in-process xla::LocalClient API.""" 106 107 def __init__(self, platform, client): 108 """Creates a new LocalBackend. 109 110 Args: 111 platform: A string; the user-visible platform name, e.g. 'gpu'. 112 client: An _xla.PyLocalClient object. 113 """ 114 super(LocalBackend, self).__init__(platform) 115 self.client = client 116 117 def device_count(self): 118 return self.client.device_count() 119 120 def local_device_count(self): 121 return self.client.local_device_count() 122 123 def devices(self): 124 return self.client.devices() 125 126 def local_devices(self): 127 return self.client.local_devices() 128 129 def host_id(self): 130 return self.client.host_id() 131 132 def buffer_from_pyval(self, pyval, device=None, force_copy=False): 133 if device is None: 134 device = self.local_devices()[0] 135 return _xla.PyLocalBuffer.from_python(pyval, self.client, device, 136 force_copy) 137 138 def make_tuple(self, c_buffers, device): 139 return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device) 140 141 def compile(self, c_computation, compile_options): 142 options = _xla.ExecutableBuildOptions() 143 options.num_replicas = compile_options.num_replicas 144 options.num_partitions = compile_options.num_partitions 145 if compile_options.result_layout: 146 options.result_layout = compile_options.result_layout 147 options.debug_options.xla_cpu_fast_math_honor_infs = True 148 options.debug_options.xla_cpu_fast_math_honor_nans = True 149 options.debug_options.xla_cpu_fast_math_honor_division = True 150 options.debug_options.xla_cpu_fast_math_honor_functions = True 151 options.debug_options.xla_gpu_enable_fast_min_max = False 152 return _xla.LocalExecutable.Compile(c_computation, 153 compile_options.argument_layouts, 154 options, self.client, 155 compile_options.device_assignment) 156 157 def get_default_device_assignment(self, num_replicas, num_partitions=None): 158 if num_partitions is not None: 159 return self.client.GetDefaultDeviceAssignment(num_replicas, 160 num_partitions) 161 else: 162 # TODO(skye): delete this case after all callers can handle 2D output 163 return self.client.GetDefaultDeviceAssignment(num_replicas) 164 165 166xla_platform_names = { 167 'cpu': 'Host', 168 'gpu': 'CUDA', 169} 170 171 172def _cpu_backend_factory(): 173 client = _xla.LocalClient.Get( 174 platform='cpu', 175 xla_platform_id=xla_platform_names['cpu'], 176 asynchronous=True) 177 return LocalBackend(platform='cpu', client=client) 178 179 180def _gpu_backend_factory(): 181 """Returns a GPU backend. BFC allocator is used by default.""" 182 allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() 183 memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION') 184 preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE') 185 if allocator not in ('default', 'platform', 'bfc'): 186 raise ValueError( 187 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", or ' 188 '"bfc", got "%s"' % allocator) 189 config = _xla.AllocatorConfig() 190 if allocator == 'default': 191 config.kind = _xla.AllocatorConfig.Kind.DEFAULT 192 if allocator == 'platform': 193 config.kind = _xla.AllocatorConfig.Kind.PLATFORM 194 if allocator == 'bfc': 195 config.kind = _xla.AllocatorConfig.Kind.BFC 196 if memory_fraction: 197 config.memory_fraction = float(memory_fraction) 198 config.preallocate = preallocate not in ('0', 'false', 'False') 199 200 client = _xla.LocalClient.Get( 201 platform='gpu', 202 xla_platform_id=xla_platform_names['gpu'], 203 asynchronous=True, 204 allocator_config=config) 205 return LocalBackend(platform='gpu', client=client) 206 207 208# Backend factories, keyed by user-visible name, in increasing priority order. 209_local_backend_factories = collections.OrderedDict([ 210 ('cpu', _cpu_backend_factory), 211 ('gpu', _gpu_backend_factory), 212]) 213 214 215def register_local_backend_factory(name, factory): 216 _local_backend_factories[name] = factory 217 218 219_local_backends = None 220 221 222def _get_local_backends(): 223 """Instantiates all known local backends.""" 224 global _local_backends 225 if _local_backends is not None: 226 return _local_backends 227 228 _local_backends = collections.OrderedDict() 229 for name, factory in _local_backend_factories.items(): 230 logging.vlog(2, "Initializing backend '%s'" % name) 231 try: 232 backend = factory() 233 except RuntimeError: 234 if name == 'cpu': 235 # We always expect CPU to initialize successfully. 236 raise 237 else: 238 # If the backend isn't built into the binary, or if it has no devices, 239 # we expect a RuntimeError. 240 continue 241 _local_backends[name] = backend 242 return _local_backends 243 244 245def get_local_backend(name=None): 246 """Returns a local backend. 247 248 Args: 249 name: the backend name. If `None`, a default local backend is returned, 250 typically `gpu` if one is present, or `cpu` if not. If a string, the named 251 backend is returned or an exception raised. 252 253 Returns: 254 A LocalBackend object. 255 """ 256 backends = _get_local_backends() 257 if name is not None: 258 try: 259 return backends[name] 260 except KeyError: 261 raise RuntimeError('Unknown backend {}'.format(name)) 262 263 return list(backends.values())[-1] 264 265 266class OpMetadata(object): 267 """Python representation of a xla.OpMetadata protobuf.""" 268 __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') 269 270 def __init__(self, op_type='', op_name='', source_file='', source_line=0): 271 self.op_type = op_type 272 self.op_name = op_name 273 self.source_file = source_file 274 self.source_line = source_line 275 276 277def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): 278 """Helper for use in source mapping that returns an OpMetadata object.""" 279 full_filename, lineno = inspect.stack()[skip_frames][1:3] 280 filename = os.path.basename(full_filename) 281 return OpMetadata( 282 op_type=op_type, 283 op_name=op_name, 284 source_file=filename, 285 source_line=lineno) 286 287 288PrimitiveType = _xla.PrimitiveType 289 290bfloat16 = _xla.bfloat16_dtype() 291 292XLA_ELEMENT_TYPE_TO_DTYPE = { 293 PrimitiveType.PRED: np.dtype('bool'), 294 PrimitiveType.S8: np.dtype('int8'), 295 PrimitiveType.S16: np.dtype('int16'), 296 PrimitiveType.S32: np.dtype('int32'), 297 PrimitiveType.S64: np.dtype('int64'), 298 PrimitiveType.U8: np.dtype('uint8'), 299 PrimitiveType.U16: np.dtype('uint16'), 300 PrimitiveType.U32: np.dtype('uint32'), 301 PrimitiveType.U64: np.dtype('uint64'), 302 PrimitiveType.BF16: np.dtype(bfloat16), 303 PrimitiveType.F16: np.dtype('float16'), 304 PrimitiveType.F32: np.dtype('float32'), 305 PrimitiveType.F64: np.dtype('float64'), 306 PrimitiveType.C64: np.dtype('complex64'), 307 PrimitiveType.C128: np.dtype('complex128'), 308 PrimitiveType.TUPLE: np.dtype(np.object), 309 PrimitiveType.TOKEN: np.dtype(np.object), 310} 311 312# Note the conversion on the key. Numpy has a known issue wherein dtype hashing 313# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, 314# when keying by dtype in this dict, we use the string form of dtypes. 315DTYPE_TO_XLA_ELEMENT_TYPE = { 316 str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() 317} 318 319 320def dtype_to_etype(dtype): 321 """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" 322 return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] 323 324 325Shape = _xla.Shape 326Shape.__doc__ = """ 327A Shape is an object defined in C++ that duck types like the following class: 328 329class Shape(object): 330 '''Represents an XLA shape. 331 332 A shape is either an array shape, having rank-many integer 333 dimensions and an element type (represented by a Numpy dtype), or it 334 is a tuple shape, having a shape for every tuple component: 335 336 type shape = 337 TupleShape of shape list 338 | ArrayShape of { dimensions: int list; element_type: dtype } 339 ''' 340 341 @staticmethod 342 def tuple_shape(tuple_shapes) -> Shape: 343 "Construct a tuple shape." 344 345 @staticmethod 346 def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: 347 348 @staticmethod 349 def from_pyval(pyval) -> Shape: 350 "Returns a Shape that describes a tuple-tree of Numpy arrays." 351 352 def __init__(self, str) -> Shape: 353 "Parses a shape string." 354 def __eq__(self, other: Shape) -> bool: 355 def __ne__(self, other: Shape) -> bool: 356 def __hash__(self): 357 def __repr__(self): 358 def is_tuple(self) -> bool: 359 def is_array(self) -> bool: 360 def tuple_shapes(self) -> [Shape]: 361 def numpy_dtype(self) -> np.dtype: 362 "Like element_type(), but returns dtype('O') for a tuple shape." 363 def xla_element_type(self) -> PrimitiveType: 364 def element_type(self) -> np.dtype: 365 def dimensions(self) -> (int, int, ...): 366 def rank(self) -> int: 367 def with_major_to_minor_layout_if_absent(self) -> Shape: 368 "Returns a copy with missing layouts set to major-to-minor." 369 370 def to_serialized_proto(self) -> bytes: 371 "Returns 'shape' as a serialized proto." 372""" 373 374ProgramShape = _xla.ProgramShape 375ProgramShape.__doc__ = """ 376A ProgramShape is a C++ object that duck types like the following class. 377 378class ProgramShape(object): 379 def __init__(self, parameter_shapes, result_shape): 380 def parameter_shapes(self) -> [Shape]: 381 def result_shape(self) -> Shape: 382 def __repr__(self): 383""" 384 385 386class Buffer(object): 387 """Represents a handle to data owned by XLA. 388 389 The referent is ready for use in executing a local, compiled 390 Computation. On XLA platforms involving a device (e.g. GPU), this 391 means the referent is in device memory. 392 """ 393 394 @staticmethod 395 def from_pyval(pyval, device=None, backend=None, force_copy=False): 396 """Copies the `pyval` to a freshly allocated on-device buffer.""" 397 backend = backend or get_local_backend() 398 return backend.buffer_from_pyval(pyval, device, force_copy=force_copy) 399 400 @staticmethod 401 def make_tuple(buffers, device, backend=None): 402 backend = backend or get_local_backend() 403 return backend.make_tuple(buffers, device) 404 405 # Buffer is not an instantiable type and exists only for its static methods. 406 # The underlying buffer objects are C++ object with the following 407 # API: 408 # def shape(self) -> Shape: 409 # def device(self) -> int: 410 # def delete(self): 411 # def destructure(self) -> [Buffer] 412 # def is_deleted(self) -> bool: 413 # def block_host_until_ready(self): 414 # """Blocks the calling thread until the buffer is ready on device.""" 415 # def copy_to_host_async(self): 416 # """Requests a copy of the buffer to the host. 417 # 418 # Does not block waiting for the copy. Values fetched are available via 419 # `to_py()`; the purpose of `copy_to_host_async` is to prefetch values 420 # for subsequent `to_py()` calls, especially when requesting many values 421 # at once. 422 # """ 423 # def to_py(self): 424 # """Returns the value of the buffer as a Python tuple tree of ndarrays.""" 425 # 426 # TODO(phawkins): remove Buffer and its static methods completely, have 427 # clients call methods on Backend to create buffers. 428 429 430# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops 431# compatibility with Jaxlib versions older than 0.1.13. 432LocalBuffer = Buffer 433 434 435def shape_from_pyval(pyval): 436 """Returns a Shape that describes a tuple-tree of Numpy arrays.""" 437 438 def convert(pyval): 439 if isinstance(pyval, tuple): 440 return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) 441 else: 442 return Shape.array_shape(pyval.dtype, np.shape(pyval)) 443 444 return convert(pyval) 445 446 447def transfer_to_infeed(value, device=None): 448 """Transfers the given value into the XLA infeed queue. 449 450 XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with 451 a totally ordered stream of values. This is dequeued from XLA computations via 452 the Infeed() operation. 453 454 Args: 455 value: the value that the caller would like to enqueue into the XLA infeed 456 queue 457 device: the device to infeed the value to. Each device has a 458 distinct infeed queue. 459 """ 460 # TODO(phawkins): support non-default backends. 461 backend = get_local_backend() 462 device = device or backend.local_devices()[0] 463 device.TransferToInfeed(value) 464 465 466def transfer_from_outfeed(shape, device=None): 467 """Transfers a literal of the given shape from `device`'s outfeed. 468 469 Args: 470 shape: The shape of the value to transfer from outfeed. 471 device: The device from which to transfer the outfeed value. Each device has 472 a distinct outfeed queue.. 473 474 Returns: 475 The literal value that is produced from the outfeed queue. 476 """ 477 # TODO(phawkins): support non-default backends. 478 backend = get_local_backend() 479 device = device or backend.local_devices()[0] 480 return device.TransferFromOutfeed( 481 shape.with_major_to_minor_layout_if_absent()) 482 483 484DeviceAssignment = _xla.DeviceAssignment 485DeviceAssignment.__doc__ = """ 486A DeviceAssignment is a C++ object with the following signature. 487 488def create(assignment): 489 '''Builds a device assignment. 490 491 Args: 492 assignment: a 2D numpy array of device ordinal integers, indexed by 493 [replica][computation_in_replica]. 494 Returns: 495 A device assignment. 496 ''' 497 498def replica_count(): 499 '''Returns the number of replicas.''' 500def computation_count(): 501 '''Returns the number of computations per replica.''' 502""" 503 504 505Device = _xla.Device 506 507 508class CompileOptions(object): 509 """Python object for XLA compile options. 510 511 These options can be passed to the 'compile' step when using a local XLA 512 client. 513 """ 514 515 def __init__(self): 516 self.xla_dump_to = None 517 self.dump_hlo_pass_re = None 518 self.dump_hlo_module_re = None 519 self.dump_hlo_as_text = None 520 self.dump_hlo_as_proto = None 521 self.hlo_profile = None 522 self.num_replicas = 1 523 self.num_partitions = 1 524 self.argument_layouts = None 525 self.result_layout = None 526 self.device_assignment = None 527 528 529class Computation(object): 530 """Python wrapper for an XLA Computation. 531 532 A Computation can be compiled to form an Executable, or used as a 533 subcomputation in ComputationBuilder methods. 534 """ 535 536 def __init__(self, c_computation, backend=None): 537 self._c_computation = c_computation 538 # The backend argument is deprecated. Pass a backend to Compile() instead. 539 self._backend = backend 540 541 @property 542 def computation(self): 543 return self._c_computation 544 545 def GetSerializedProto(self): 546 """Gets the serialized HloModuleProto proto object in this computation. 547 548 Returns: 549 A string containing a serialized HloModuleProto proto containing the 550 computation and its dependencies. 551 """ 552 return self.computation.GetSerializedProto() 553 554 def GetHloText(self): 555 """Get the textual HLO representation of this computation. 556 557 Returns: 558 A string containing the textual HLO. 559 """ 560 return self.computation.GetHloText() 561 562 def GetHloDotGraph(self): 563 """Get a Graphviz Dot representation of this computation. 564 565 Returns: 566 A string containing the graphviz dot graph. 567 """ 568 return self.computation.GetHloDotGraph() 569 570 def Compile(self, argument_shapes=None, compile_options=None, backend=None): 571 """Compiles a computation. 572 573 Computations are the result of a "ComputationBuild'ing" process. 574 575 Arguments: 576 argument_shapes: Deprecated. Use compile_options.argument_layouts instead. 577 compile_options: options to use for compilation, includes an optional laid 578 out result shape for the computation. 579 backend: a `Backend` for which an executable should be generated. 580 581 Returns: 582 A Executable instance. 583 """ 584 backend = backend or self._backend or get_local_backend() 585 586 compile_options = compile_options or CompileOptions() 587 if argument_shapes: 588 compile_options.argument_layouts = argument_shapes 589 return backend.compile(self.computation, compile_options) 590 591 def GetProgramShape(self): 592 return self._c_computation.GetProgramShape() 593 594 def GetReturnValueShape(self): 595 return self._c_computation.GetProgramShape().result_shape() 596 597 def Hash(self): 598 return self._c_computation.Hash() 599 600 601# An Executable is a C++ class that duck types with the following API: 602# class Executable(object): 603# def local_devices(self) -> [Device]: 604# def Execute(self, arguments : [Buffer]) -> Buffer: 605# """Execute on one replica with Buffer arguments and return value.""" 606# 607# def SizeOfGeneratedCodeInBytes(self) -> int: 608# """Return generated binary size, or -1 if not known.""" 609# 610# def ExecutePerReplica(self, arguments: [[Buffer]]) -> [Buffer]: 611# """Execute on many replicas with Buffer arguments and return value. 612# 613# Args: 614# arguments: A sequence of sequences of Buffers. The i'th inner sequence 615# comprises the arguments for execution on the i'th replica. 616# 617# Returns: 618# A list of the computation's outputs for each replica, as a Buffer. If 619# a shallow sequence of arguments was passed in for `arguments`, then the 620# sole, zero'th replica's output is returned instead, as a Buffer. 621# """ 622# 623# There are different implementations of Executable for different backends. 624 625 626def execute_with_python_values(executable, arguments=(), backend=None): 627 """Execute on one replica with Python values as arguments and output.""" 628 629 backend = backend or get_local_backend() 630 631 def put(arg): 632 return Buffer.from_pyval( 633 arg, device=executable.local_devices()[0], backend=backend) 634 635 arguments = [put(arg) for arg in arguments] 636 return executable.Execute(arguments).to_py() 637 638 639def execute_with_python_values_replicated(executable, arguments, backend=None): 640 """Execute on many replicas with Python values as arguments and output. 641 642 Arguments: 643 executable: the program to run. 644 arguments: a list of lists of Python values indexed by 645 `[replica][arg_num]` to pass as inputs. 646 backend: the backend we are targeting. 647 648 Returns: 649 A list of python values, one per replica. 650 """ 651 backend = backend or get_local_backend() 652 devices = executable.local_devices() 653 # pylint: disable=g-complex-comprehension 654 flat_args = [(arg, devices[replica]) 655 for replica, replica_args in enumerate(arguments) 656 for arg in replica_args] 657 flat_arg_buffers = [ 658 backend.buffer_from_pyval(pyval, device) for pyval, device in flat_args 659 ] 660 arg_buffers = [] 661 for replica_args in arguments: 662 arg_buffers.append(flat_arg_buffers[:len(replica_args)]) 663 flat_arg_buffers = flat_arg_buffers[len(replica_args):] 664 return [out.to_py() for out in executable.ExecutePerReplica(arg_buffers)] 665 666 667class PaddingType(enum.Enum): 668 VALID = 1 669 SAME = 2 670 671 672def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, 673 window_strides): 674 """Maps PaddingType or string to pad values (list of pairs of ints).""" 675 if not isinstance(padding_type, (str, PaddingType)): 676 msg = 'padding_type must be str or PaddingType, got {}.' 677 raise TypeError(msg.format(type(padding_type))) 678 679 if isinstance(padding_type, str): 680 if padding_type.upper() == 'VALID': 681 padding_type = PaddingType.VALID 682 elif padding_type.upper() == 'SAME': 683 padding_type = PaddingType.SAME 684 else: 685 msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' 686 raise ValueError(msg.format(padding_type)) 687 688 if padding_type == PaddingType.VALID: 689 return [(0, 0)] * len(window_strides) 690 elif padding_type == PaddingType.SAME: 691 out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) 692 pad_sizes = [ 693 max((out_size - 1) * stride + filter_size - in_size, 0) 694 for out_size, stride, filter_size, in_size in zip( 695 out_shape, window_strides, rhs_dims, lhs_dims) 696 ] 697 return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] 698 else: 699 msg = 'Unexpected PaddingType value: {}' 700 raise ValueError(msg.format(padding_type)) 701 702 703class ComputationBuilder(object): 704 """XLA computation builder. 705 706 Enqueues XLA ops in sequence and in order to build a 707 Computation, which in turn can be compiled into a 708 LocalExecutable, which in turn can be locally executed. 709 """ 710 711 # The methods of this class map 1-to-1 onto the XLA C++ 712 # computation builder API. Therefore, there's no need to laboriously list 713 # arguments and return values for every method, especially where it's obvious. 714 # 715 # pylint: disable=g-doc-return-or-yield 716 # pylint: disable=g-doc-args 717 718 def __init__(self, name): 719 self._builder = _xla.XlaBuilder(name) 720 self._parameter_numbering = itertools.count() 721 722 def Build(self, root=None, backend=None): 723 """Builds a `Computation` from the contents of the builder. 724 725 Args: 726 root: if not None, the operator containing the return value of the 727 computation. 728 729 Returns: 730 A `Computation`. 731 """ 732 if root is not None: 733 return Computation(self._builder.Build(root), backend=backend) 734 else: 735 return Computation(self._builder.Build(), backend=backend) 736 737 def GetShape(self, operand): 738 return self._builder.GetShape(operand) 739 740 def SetOpMetadata(self, op_metadata): 741 """Set metadata for operations that are about to be enqueued.""" 742 self._builder.SetOpMetadata(op_metadata) 743 744 def ClearOpMetadata(self): 745 """Clear metadata for operations that are about to be enqueued.""" 746 self._builder.ClearOpMetadata() 747 748 def SetSharding(self, sharding): 749 """Set sharding that will be attached to all instructions until cleared.""" 750 self._builder.SetSharding(sharding) 751 752 def ClearSharding(self): 753 """Clears the sharding. 754 755 Ops will be sharded according to the default placement policy. 756 """ 757 self._builder.ClearSharding() 758 759 def CreateToken(self): 760 """Enqueues a CreateToken op onto the computation. 761 762 Returns: 763 An XlaOp, representing a fresh token. 764 """ 765 return ops.CreateToken(self._builder) 766 767 def AfterAll(self, tokens): 768 """Enqueues a after-all op onto the computation. 769 770 `AfterAll` takes a variadic number of tokens and produces a single token. 771 772 Args: 773 tokens: a list of `XlaOp` values representing predecessor tokens. 774 775 Returns: 776 An `XlaOp`. 777 """ 778 return ops.AfterAll(self._builder, tokens) 779 780 def Infeed(self, shape, token=None): 781 """Enqueues an infeed op onto the computation. 782 783 Infeed operations dequeue data of the given shape from the device's infeed 784 queue for subsequent use in the computation. 785 786 Args: 787 shape: a `Shape` describing the shape of the infed value. 788 token: an optional `XlaOp` representing a token after which the infeed 789 effect should be sequenced. 790 Returns: 791 An XlaOp, representing a (value, token) pair. 792 """ 793 if token is None: 794 token = ops.CreateToken(self._builder) 795 return ops.InfeedWithToken(token, 796 shape.with_major_to_minor_layout_if_absent()) 797 798 def Outfeed(self, operand, token=None): 799 """Enqueues an outfeed op onto the computation. 800 801 Outfeed operations enqueue data, using the given operand, onto the XLA 802 outfeed queue for subsequent dequeue via the client API. 803 804 Args: 805 operand: an `XlaOp` representing the data to outfeed. 806 token: an `XlaOp` representing a token after which the outfeed should be 807 sequenced. 808 Returns: 809 An `XlaOp` representing a token. 810 """ 811 if token is None: 812 token = ops.CreateToken(self._builder) 813 return ops.OutfeedWithToken(operand, token, self._builder.GetShape(operand), 814 '') 815 816 def Constant(self, value): 817 """Enqueues a constant op onto the computation. 818 819 Args: 820 value: value for the constant, as a np.array with an explicit dtype set to 821 one of the supported types. 822 823 Returns: 824 An XlaOp. 825 """ 826 return ops.ConstantLiteral(self._builder, value) 827 828 def ConstantF32Scalar(self, value): 829 """Convenience method to enqueue a scalar F32 constant op. 830 831 Args: 832 value: a floating-point number. 833 834 Returns: 835 An XlaOp. 836 """ 837 return self.Constant(np.array(value, dtype=np.float32)) 838 839 def ConstantF64Scalar(self, value): 840 """Convenience method to enqueue a scalar F32 constant op. 841 842 Args: 843 value: a floating-point number. 844 845 Returns: 846 An XlaOp. 847 """ 848 return self.Constant(np.array(value, dtype=np.float64)) 849 850 def ConstantS32Scalar(self, value): 851 """Convenience method to enqueue a scalar S32 constant op. 852 853 Args: 854 value: a floating-point number. 855 856 Returns: 857 An XlaOp. 858 """ 859 return self.Constant(np.array(value, dtype=np.int32)) 860 861 def ConstantS64Scalar(self, value): 862 """Convenience method to enqueue a scalar S64 constant op. 863 864 Args: 865 value: a floating-point number. 866 867 Returns: 868 An XlaOp. 869 """ 870 return self.Constant(np.array(value, dtype=np.int64)) 871 872 def ConstantPredScalar(self, value): 873 """Convenience method to enqueue a scalar PRED constant op. 874 875 Args: 876 value: a boolean value. 877 878 Returns: 879 An XlaOp. 880 """ 881 return self.Constant(np.array(value, dtype=np.bool)) 882 883 def ParameterWithShape(self, shape, name=None, parameter_num=None, 884 replicated=False): 885 """Enqueues a Parameter op onto the computation, given a shape. 886 887 Args: 888 shape: the parameter's shape as a Shape object. 889 name: optional string name for the parameter. 890 parameter_num: parameter number in the computation function. If None, the 891 next linear parameter number is used. The default value capability can 892 be used for auto-numbering. If you're using auto-numbering for some 893 parameters, use it for *all* parameters to avoid clashes. 894 replicated: whether to mark the parameter's leaves as replicated. May be 895 a bool, in which case it applies to all leaves, or an iterable of bools. 896 897 Returns: 898 An XlaOp. 899 """ 900 if name is None: 901 name = '' 902 if parameter_num is None: 903 parameter_num = next(self._parameter_numbering) 904 if isinstance(replicated, bool): 905 replicated = [replicated] * shape.leaf_count() 906 907 return ops.Parameter(self._builder, parameter_num, 908 shape.with_major_to_minor_layout_if_absent(), 909 name.encode('utf8'), replicated) 910 911 def ParameterFromNumpy(self, value, name=None, parameter_num=None): 912 """Enqueues a Parameter op onto the computation. 913 914 Args: 915 value: a Numpy array, or a nested tuple thereof, from which the shape is 916 inferred. 917 name: as in ParameterWithShape. 918 parameter_num: as in ParameterWithShape. 919 920 Returns: 921 An XlaOp. 922 """ 923 return self.ParameterWithShape( 924 shape_from_pyval(value), name=name, parameter_num=parameter_num) 925 926 def Iota(self, dtype, size): 927 """Enqueues an iota constant onto the computation. 928 929 Args: 930 dtype: expected numpy dtype of the output. 931 size: integer, the number of elements in the array. 932 933 Returns: 934 An XlaOp representing the added iota constant. 935 """ 936 element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] 937 return ops.Iota(self._builder, element_type, size) 938 939 def BroadcastedIota(self, dtype, shape, dimension): 940 """Enqueues a broadcasted iota constant onto the computation. 941 942 Args: 943 dtype: expected numpy dtype of the output. 944 shape: tuple of integers, the expected output shape (dimensions). 945 dimension: positive integer, dimension along which to increment values. 946 947 Returns: 948 An XlaOp representing the added broadcasted iota constant. 949 """ 950 element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] 951 xla_shape = _xla.Shape.array_shape(element_type, shape, None) 952 return ops.Iota(self._builder, xla_shape, dimension) 953 954 def Concatenate(self, operands, dimension): 955 """Enqueues a concatenate operation onto the computation. 956 957 Args: 958 operands: the operands to concatenate. 959 dimension: the dimension in which to perform the concatenation. 960 961 Returns: 962 An XlaOp representing the added concatenate op. 963 """ 964 return ops.ConcatInDim(self._builder, list(operands), dimension) 965 966 def ReplicaId(self): 967 """Enqueues a ReplicaId operation onto the computation. 968 969 Returns: 970 A LocalOp representing the replica id. 971 """ 972 return _xla.ops.ReplicaId(self._builder) 973 974 def Pad(self, operand, padding_value, padding_config): 975 """Enqueues a Pad operation onto the computation. 976 977 Args: 978 operand: XlaOp representing the array to pad. 979 padding_value: XlaOp representing the scalar pad value. 980 padding_config: either a PaddingConfig or a list of integer triples 981 (edge_padding_low, edge_padding_high, interior_padding) representing the 982 configuration of the padding operation. 983 984 Returns: 985 An XlaOp representing the added Pad op. 986 """ 987 if isinstance(padding_config, tuple) or isinstance(padding_config, list): 988 padding_config = GetPaddingConfigFromTriples(padding_config) 989 return ops.Pad(operand, padding_value, padding_config) 990 991 def Reshape(self, operand, dimensions, new_sizes): 992 """Enqueues a reshape op onto the computation. 993 994 Args: 995 operand: XlaOp representing the array to be reshaped. 996 dimensions: sequence of integers encoding the order in which dimensions 997 are collapsed or None, in which case dimensions are flattened in order. 998 new_sizes: sequence of integers encoding the new dimension sizes (shape). 999 1000 Returns: 1001 An XlaOp representing the added Reshape op. 1002 """ 1003 if dimensions is None: 1004 ndim = len(self.GetShape(operand).dimensions()) 1005 dimensions = tuple(range(ndim)) 1006 return ops.Reshape(operand, dimensions, new_sizes) 1007 1008 def AllReduce(self, operand, computation, replica_groups=None): 1009 """AllReduce op. 1010 1011 Args: 1012 operand: XlaOp representing the input array 1013 computation: a Computation object - binary reduction function. 1014 replica_groups: optional, list of lists of ints encoding a partition of 1015 the set {0, 1, ..., num_replicas} into equally-sized replica groups 1016 within which the all-to-all is performed. If not supplied or None (the 1017 default), all replicas belong to the same group. 1018 1019 Returns: 1020 An XlaOp that represents the all-reduced result. 1021 """ 1022 replica_groups_protos = _get_replica_groups_protos(replica_groups) 1023 return ops.AllReduce(operand, computation.computation, 1024 replica_groups_protos, None, None) 1025 1026 def AllToAll(self, 1027 operand, 1028 split_dimension, 1029 concat_dimension, 1030 replica_groups=None): 1031 """AllToAll op. 1032 1033 Args: 1034 operand: XlaOp representing the input array 1035 split_dimension: the dimension along which the operand is split 1036 concat_dimension: the dimension along which the split blocks are 1037 concatenated 1038 replica_groups: optional, list of lists of ints encoding a partition of 1039 the set {0, 1, ..., num_replicas} into equally-sized replica groups 1040 within which the all-to-all is performed. If not supplied or None (the 1041 default), all replicas belong to the same group. 1042 1043 Returns: 1044 An XlaOp that represents the all-to-all concatenation. 1045 """ 1046 replica_groups_protos = _get_replica_groups_protos(replica_groups) 1047 if not replica_groups: 1048 split_count = 1 1049 else: 1050 split_count = len(replica_groups[0]) 1051 if not all(split_count == len(g) for g in replica_groups): 1052 raise ValueError('Replica groups must be equally sized') 1053 return ops.AllToAll(operand, split_dimension, concat_dimension, split_count, 1054 replica_groups_protos) 1055 1056 def CrossReplicaSum(self, operand, replica_groups=None): 1057 """CrossReplicaSum op. 1058 1059 Args: 1060 operand: the operand to sum across replica instances. 1061 replica_groups: optional, list of lists of ints encoding a partition of 1062 the set {0, 1, ..., num_replicas} into equally-sized replica groups 1063 within which the cross-replica sum is performed. If not supplied or None 1064 (the default), all replicas belong to the same group. 1065 1066 Returns: 1067 An XlaOp that represents on each replica the sum of its group's values. 1068 """ 1069 replica_groups_protos = _get_replica_groups_protos(replica_groups) 1070 return ops.CrossReplicaSum(operand, replica_groups_protos) 1071 1072 def Trans(self, operand): 1073 """Specialized matrix transpose op.""" 1074 return ops.Transpose(operand, [1, 0]) 1075 1076 def Transpose(self, operand, permutation): 1077 """Transpose op.""" 1078 return ops.Transpose(operand, permutation) 1079 1080 def SelectAndScatter(self, operand, select, window_dimensions, window_strides, 1081 padding, source, init_value, scatter): 1082 """Select and scatter op, used by the gradient of ReduceWindow. 1083 1084 Args: 1085 operand: XlaOp for array of dimension N and type T over which the windows 1086 slide. 1087 select: Computation of type (T, T) -> Pred to apply to the elements of 1088 each window to indicate which element is selected. 1089 window_dimensions: sequence of N integers for dimensions of the window. 1090 window_strides: sequence of N integers for the strides of the window. 1091 padding: PaddingType representing either 'SAME' or 'VALID ' padding. 1092 source: XlaOp for array of type T with values to scatter. 1093 init_value: XlaOp of scalar type T for initial out value. 1094 scatter: Computation of type (T, T) -> T to apply to each scatter source 1095 element with its destination element. 1096 1097 Returns: 1098 An XlaOp representing the added SelectAndScatter op. 1099 """ 1100 pads = _convert_padding_type_to_pad_values( 1101 padding, 1102 self.GetShape(operand).dimensions(), window_dimensions, window_strides) 1103 return ops.SelectAndScatterWithGeneralPadding(operand, select.computation, 1104 window_dimensions, 1105 window_strides, pads, source, 1106 init_value, 1107 scatter.computation) 1108 1109 def Slice(self, operand, start_indices, limit_indices, strides=None): 1110 """Enqueues a slice operation onto the computation. 1111 1112 Args: 1113 operand: XlaOp for the N dimensional array to be sliced. 1114 start_indices: iterable of N integers containing the starting indices of 1115 the slice for each dimension. 1116 limit_indices: iterable of N integers containing the ending indices 1117 (exclusive) of the slice for each dimension. 1118 strides: optional iterable of N integers containing the stride sizes for 1119 each dimension. 1120 1121 Returns: 1122 An XlaOp representing the added Slice op. 1123 """ 1124 if strides is None: 1125 start_indices = list(start_indices) 1126 strides = [1] * len(start_indices) 1127 return ops.Slice(operand, start_indices, limit_indices, strides) 1128 1129 def DynamicSlice(self, operand, start_indices, slice_sizes): 1130 """Enqueues a slice op with dynamic start indices onto the computation. 1131 1132 Args: 1133 operand: XlaOp for the N dimensional array to be sliced. 1134 start_indices: XlaOp for the 1D array of N integers containing the 1135 starting indices of the slice. 1136 slice_sizes: iterable of N integers containing the slice sizes in each 1137 dimension. 1138 1139 Returns: 1140 An XlaOp representing the added DynamicSlice op. 1141 """ 1142 slice_sizes = list(slice_sizes) 1143 if isinstance(start_indices, _xla.XlaOp): 1144 start_indices = [ 1145 ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), []) 1146 for i in range(len(slice_sizes)) 1147 ] 1148 return ops.DynamicSlice(operand, list(start_indices), slice_sizes) 1149 1150 def DynamicUpdateSlice(self, operand, update, start_indices): 1151 """Enqueues a dynamic update slice operation onto the computation. 1152 1153 Args: 1154 operand: XlaOp for the N dimensional array to be updated. 1155 update: N dimensional array comprising the slice update. 1156 start_indices: Rank-1 array of N integers comprising the starting indices 1157 of the slice along each dimension. 1158 1159 Returns: 1160 An XlaOp representing the added DynamicUpdateSlice op. 1161 """ 1162 if isinstance(start_indices, _xla.XlaOp): 1163 ndims = self._builder.GetShape(start_indices).dimensions()[0] 1164 start_indices = [ 1165 ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), []) 1166 for i in range(ndims) 1167 ] 1168 return ops.DynamicUpdateSlice(operand, update, list(start_indices)) 1169 1170 def Tuple(self, *elems): 1171 """Enqueues a tuple operation onto the computation. 1172 1173 Args: 1174 elems: a sequence of tuple operands (each a XlaOp). 1175 1176 Returns: 1177 An XlaOp representing the added Tuple op. 1178 """ 1179 return ops.Tuple(self._builder, list(elems)) 1180 1181 def Call(self, computation_to_apply, operands): 1182 """Enqueues a call operation onto the computation. 1183 1184 Args: 1185 computation_to_apply: a Computation object. 1186 operands: an iterable of XlaOp. The number and types of operands must 1187 match the arity of computation_to_apply. 1188 1189 Returns: 1190 An XlaOp representing the added call op. 1191 """ 1192 return ops.Call(self._builder, computation_to_apply.computation, 1193 list(operands)) 1194 1195 def CustomCall(self, 1196 call_target_name, 1197 operands, 1198 shape_with_layout, 1199 operand_shapes_with_layout, 1200 opaque=None): 1201 """Enqueues a custom call operation onto the computation. 1202 1203 Args: 1204 call_target_name: the name of the function to call. 1205 operands: an iterable of XlaOp. The number and types of operands must 1206 match the arity of `operand_shapes_with_layout`. 1207 shape_with_layout: the shape of the operator's output, with layout. 1208 operand_shapes_with_layout: the shapes of `operands`, including the 1209 expected layouts. 1210 opaque: an opaque string passed to the backend. 1211 1212 Returns: 1213 An XlaOp representing the added custom call op. 1214 """ 1215 opaque = opaque or b'' 1216 return ops.CustomCall(self._builder, call_target_name, 1217 list(operands), shape_with_layout, 1218 list(operand_shapes_with_layout), opaque) 1219 1220 def Map(self, operands, computation_to_apply, dimensions): 1221 """Enqueues a map operation onto the computation. 1222 1223 Args: 1224 operands: an iterable of XlaOp. 1225 computation_to_apply: a Computation object. 1226 dimensions: dimensions over which to apply map the function. 1227 1228 Returns: 1229 An XlaOp representing the added Map op. 1230 """ 1231 return ops.Map(self._builder, list(operands), 1232 computation_to_apply.computation, dimensions, []) 1233 1234 def Reduce(self, operand, init_value, computation_to_apply, dimensions): 1235 """Enqueues a reduction operation onto the computation. 1236 1237 Args: 1238 operand: reduction operand (XlaOp). 1239 init_value: reduction initial value (XlaOp). 1240 computation_to_apply: a Computation object - binary reduction function. 1241 dimensions: sequence of dimensions (integers) to reduce on. 1242 1243 Returns: 1244 An XlaOp representing the added Reduce op. 1245 """ 1246 return ops.Reduce(self._builder, [operand], [init_value], 1247 computation_to_apply.computation, dimensions) 1248 1249 def ReduceWindow(self, operand, init_value, computation_to_apply, 1250 window_dimensions, window_strides, padding): 1251 """Enqueues a windowed reduction operation onto the computation. 1252 1253 Args: 1254 operand: reduction operand (XlaOp). 1255 init_value: reduction initial value (XlaOp). 1256 computation_to_apply: a binary reduction function (Computation). 1257 window_dimensions: dimensions of window (sequence of integers). 1258 window_strides: strides for window (sequence of integers). 1259 padding: PaddingType representing either 'SAME' or 'VALID' padding. 1260 1261 Returns: 1262 An XlaOp representing the added ReduceWindow op. 1263 """ 1264 pads = _convert_padding_type_to_pad_values( 1265 padding, 1266 self.GetShape(operand).dimensions(), window_dimensions, window_strides) 1267 return ops.ReduceWindowWithGeneralPadding(operand, init_value, 1268 computation_to_apply.computation, 1269 window_dimensions, window_strides, 1270 (), (), pads) 1271 1272 def ReduceWindowWithGeneralPadding(self, operand, init_value, 1273 computation_to_apply, window_dimensions, 1274 window_strides, base_dilations, 1275 window_dilations, padding): 1276 """Enqueues a windowed reduction operation onto the computation. 1277 1278 Args: 1279 operand: reduction operand (XlaOp). 1280 init_value: reduction initial value (XlaOp). 1281 computation_to_apply: a binary reduction function (Computation). 1282 window_dimensions: dimensions of window (sequence of integers). 1283 window_strides: strides for window (sequence of integers). 1284 base_dilations: dilations for the base (sequence of integers). 1285 window_dilations: dilations for window (sequence of integers). 1286 padding: length-N array-like of pairs of integers of (low, high) padding. 1287 1288 Returns: 1289 An XlaOp representing the added ReduceWindow op. 1290 """ 1291 return ops.ReduceWindowWithGeneralPadding(operand, init_value, 1292 computation_to_apply.computation, 1293 window_dimensions, window_strides, 1294 base_dilations, window_dilations, 1295 padding) 1296 1297 def RngNormal(self, mu, sigma, dims): 1298 """Enqueues an RngNormal operation onto the computation. 1299 1300 Args: 1301 mu: An XlaOp to an F32 scalar specifying the mean. 1302 sigma: An XlaOp to an F32 scalar specifying the standard deviation. 1303 dims: A 1D array-like of nonnegative integers specifying the dimensions. 1304 Returns: a XlaOp to the generated array of F32 values. 1305 """ 1306 shape = _xla.Shape.array_shape(self.GetShape(mu).xla_element_type(), dims) 1307 return ops.RngNormal(mu, sigma, shape) 1308 1309 def RngUniform(self, a, b, dims): 1310 """Enqueues an RngUniform operation onto the computation. 1311 1312 Args: 1313 a: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of b) 1314 specifying the low end of the interval [a, b) over which values are 1315 generated. 1316 b: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of a) 1317 specifying the high end of the interval [a, b) over which values are 1318 generated. 1319 dims: A 1D array-like of nonnegative integers specifying the dimensions. 1320 Returns: a XlaOp to the generated array of values with the same numeric type 1321 (F32, S32, or U32) as the arguments a and b. 1322 """ 1323 shape = _xla.Shape.array_shape(self.GetShape(a).xla_element_type(), dims) 1324 return ops.RngUniform(a, b, shape) 1325 1326 def While(self, cond, body, init): 1327 """Enqueues a While operation onto the computation. 1328 1329 Args: 1330 cond: a Computation for the loop condition, which has type T -> PRED 1331 body: a Computation for the loop body, which has type T -> T 1332 init: a XlaOp for the initial parameter, which has type T 1333 Returns: a XlaOp representing the While operation. 1334 """ 1335 return ops.While(cond.computation, body.computation, init) 1336 1337 def Conditional(self, pred, true_operand, true_computation, false_operand, 1338 false_computation): 1339 """Enqueues a Conditional operation onto the computation. 1340 1341 Args: 1342 predicate: a XlaOp to test, which has scalar type PRED 1343 true_operand: a XlaOp of type T_0 1344 true_computation: a Computation to apply to true_operand, type T_0 -> S 1345 false_operand: a ComputationDatahandle of type T_1 1346 false_computation: a Computation to apply to false_operand, type T_1 -> S 1347 Returns: a XlaOp representing the Conditional operation. 1348 """ 1349 return ops.Conditional(pred, true_operand, true_computation.computation, 1350 false_operand, false_computation.computation) 1351 1352 def IsConstant(self, operand): 1353 """Checks whether the given operand is a compile-time constant. 1354 1355 Args: 1356 operand: a ComputationDataHandle to test. 1357 Returns: bool indicating whether `operand` is a compile-time constant, 1358 meaning its value does not depend on any parametersor, or on stateful 1359 operators such as `RngNormal` or `Infeed`. 1360 """ 1361 return self._builder.IsConstant(operand) 1362 1363 def BuildConstantSubGraph(self, operand): 1364 """Builds a constant sub graph. 1365 1366 Args: 1367 operand: a XlaOp to test. 1368 Returns: a Computation that is rooted on the given `operand` which is a 1369 compile-time constant. 1370 """ 1371 return ops.BuildConstantSubGraph(operand) 1372 1373 def DotGeneral(self, lhs, rhs, dimension_numbers, precision_config=None): 1374 """Enqueues a general dot operation onto the computation. 1375 1376 Args: 1377 lhs: XlaOp for the left-hand-side array. 1378 rhs: XlaOp for the right-hand-side array. 1379 dimension_numbers: either a DotDimensionNumbers or a nested tuple 1380 ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of 1381 integers representing the dimensions to treat as contracting dimensions 1382 and batch dimensions on each input operand. 1383 Returns: a XlaOp representing the DotGeneral operation. 1384 """ 1385 if isinstance(dimension_numbers, tuple): 1386 dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) 1387 return ops.DotGeneral( 1388 lhs, rhs, dimension_numbers, precision_config=precision_config) 1389 1390 def Conv(self, 1391 lhs, 1392 rhs, 1393 window_strides, 1394 padding, 1395 feature_group_count=1, 1396 batch_group_count=1, 1397 precision_config=None): 1398 """Enqueues a Conv operation onto the computation. 1399 1400 Args: 1401 lhs: XlaOp for the rank N+2 array of inputs. 1402 rhs: XlaOp for the rank N+2 array of kernel weights. 1403 window_strides: length-N array-like of integer kernel strides. 1404 padding: PaddingType representing either 'SAME' or 'VALID' padding. 1405 feature_group_count: number of feature groups for grouped convolution. 1406 batch_group_count: number of batch groups for grouped convolution. 1407 Returns: a XlaOp representing the Conv operation. 1408 """ 1409 pads = _convert_padding_type_to_pad_values( 1410 padding, 1411 self.GetShape(lhs).dimensions()[2:], 1412 self.GetShape(rhs).dimensions()[2:], window_strides) 1413 return self.ConvGeneralDilated( 1414 lhs, 1415 rhs, 1416 window_strides, 1417 pads, [], [], 1418 dimension_numbers=None, 1419 feature_group_count=feature_group_count, 1420 batch_group_count=batch_group_count, 1421 precision_config=precision_config) 1422 1423 def ConvWithGeneralPadding(self, 1424 lhs, 1425 rhs, 1426 window_strides, 1427 padding, 1428 lhs_dilation, 1429 rhs_dilation, 1430 feature_group_count=1, 1431 batch_group_count=1, 1432 precision_config=None): 1433 """Enqueues a ConvWithGeneralPadding operation onto the computation. 1434 1435 Args: 1436 lhs: XlaOp for the rank N+2 array of inputs. 1437 rhs: XlaOp for the rank N+2 array of kernel weights. 1438 window_strides: length-N array-like of kernel strides. 1439 padding: length-N array-like of pairs of integers of (low, high) padding. 1440 lhs_dilation: length-N array-like of dilation factors. 1441 rhs_dilation: length-N array-like of dilation factors. 1442 feature_group_count: number of feature groups for grouped convolution. 1443 batch_group_count: number of batch groups for grouped convolution. 1444 1445 Returns: 1446 A ComputationdataHandle representing the added ConvWithGeneralPadding op. 1447 """ 1448 return self.ConvGeneralDilated( 1449 lhs, 1450 rhs, 1451 list(window_strides), 1452 list(padding), 1453 list(lhs_dilation), 1454 list(rhs_dilation), 1455 dimension_numbers=None, 1456 feature_group_count=feature_group_count, 1457 batch_group_count=batch_group_count, 1458 precision_config=precision_config) 1459 1460 def _GetConvDimensionNumbers(self, num_spatial_dims): 1461 """Create ConvolutionDimensionNumbers proto for convolutions.""" 1462 nd = num_spatial_dims 1463 dimension_numbers = ConvolutionDimensionNumbers() 1464 dimension_numbers.input_batch_dimension = 0 1465 dimension_numbers.input_feature_dimension = 1 1466 dimension_numbers.output_batch_dimension = 0 1467 dimension_numbers.output_feature_dimension = 1 1468 dimension_numbers.kernel_output_feature_dimension = 0 1469 dimension_numbers.kernel_input_feature_dimension = 1 1470 dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) 1471 dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) 1472 dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) 1473 return dimension_numbers 1474 1475 def ConvGeneralDilated(self, 1476 lhs, 1477 rhs, 1478 window_strides, 1479 padding, 1480 lhs_dilation, 1481 rhs_dilation, 1482 dimension_numbers=None, 1483 feature_group_count=1, 1484 batch_group_count=1, 1485 precision_config=None): 1486 """Enqueues a ConvGeneralDilated operation onto the computation. 1487 1488 Args: 1489 lhs: XlaOp for the rank N+2 array of inputs. 1490 rhs: XlaOp for the rank N+2 array of kernel weights. 1491 window_strides: length-N array-like of integer kernel strides. 1492 padding: length-N array-like of pairs of integers of (low, high) padding. 1493 lhs_dilation: length-N array-like of integer dilation factors. 1494 rhs_dilation: length-N array-like of integer dilation factors. 1495 dimension_numbers: optional, either a ConvolutionDimensionNumbers object 1496 or a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of 1497 length N+2 identifying by position: (1) batch dimensions in lhs, rhs, 1498 and the output with the character 'N', (2) feature dimensions in lhs 1499 and the output with the character 'C', (3) input and output feature 1500 dimensions in rhs with the characters 'I' and 'O' respectively, and 1501 (4) spatial dimension correspondences between lhs, rhs, and the output 1502 using any distinct characters. For example, to indicate dimension 1503 numbers consistent with the Conv operation with two spatial 1504 dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another 1505 example, to indicate dimension numbers consistent with the TensorFlow 1506 Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using 1507 the latter form of convolution dimension specification, window strides 1508 are associated with spatial dimension character labels according to 1509 the order in which the labels appear in the rhs_spec string, so that 1510 window_strides[0] is matched with the dimension corresponding to the 1511 first character appearing in rhs_spec that is not 'I' or 'O'. By 1512 default, use the same dimension numbering as Conv and 1513 ConvWithGeneralPadding. 1514 feature_group_count: number of feature groups for grouped convolution. 1515 batch_group_count: number of batch groups for grouped convolution. 1516 Returns: a XlaOp representing the ConvGeneralDilated operation. 1517 """ 1518 if dimension_numbers is None: 1519 dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) 1520 elif isinstance(dimension_numbers, tuple): 1521 lhs_spec, rhs_spec, out_spec = dimension_numbers 1522 dimension_numbers = ConvolutionDimensionNumbers() 1523 1524 dimension_numbers.input_batch_dimension = lhs_spec.index('N') 1525 dimension_numbers.input_feature_dimension = lhs_spec.index('C') 1526 dimension_numbers.output_batch_dimension = out_spec.index('N') 1527 dimension_numbers.output_feature_dimension = out_spec.index('C') 1528 dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') 1529 dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') 1530 1531 dimension_numbers.kernel_spatial_dimensions.extend( 1532 i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) 1533 dimension_numbers.input_spatial_dimensions.extend( 1534 sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), 1535 key=lambda i: rhs_spec.index(lhs_spec[i]))) 1536 dimension_numbers.output_spatial_dimensions.extend( 1537 sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), 1538 key=lambda i: rhs_spec.index(out_spec[i]))) 1539 return ops.ConvGeneralDilated( 1540 lhs, 1541 rhs, 1542 window_strides, 1543 padding, 1544 lhs_dilation, 1545 rhs_dilation, 1546 dimension_numbers, 1547 feature_group_count, 1548 batch_group_count, 1549 precision_config=precision_config) 1550 1551 def Sort(self, operands, dimension=-1, comparator=None): 1552 """Enqueues a sort operation onto the computation. 1553 1554 Args: 1555 operands: either an XlaOp or a sequence of XlaOps to sort. All operands 1556 must be arrays with the same dimensions. 1557 dimension: the array dimension over which to sort. 1558 comparator: a comparator XlaComputation. See the XLA operation semantics 1559 for details. 1560 1561 Returns: 1562 Either an XlaOp or a tuple of XlaOps (if `operands` was an XlaOp or 1563 a tuple of XlaOps, respectively.) 1564 """ 1565 operands = ( 1566 list(operands) 1567 if isinstance(operands, collections.Sequence) else [operands]) 1568 return ops.Sort(self._builder, operands, dimension, 1569 comparator.computation if comparator else None) 1570 1571 def SortKeyVal(self, keys, values, dimension=-1): 1572 """Enqueues a key-value sort operation onto the computation. 1573 1574 Deprecated. Use `Sort` instead. 1575 """ 1576 return ops.Sort(self._builder, [keys, values], dimension) 1577 1578 def QR(self, a, full_matrices=True): 1579 """Enqueues a QR decomposition onto the computation.""" 1580 return self.Tuple(*ops.QR(a, full_matrices)) 1581 1582 def TriangularSolve(self, 1583 a, 1584 b, 1585 left_side=False, 1586 lower=False, 1587 transpose_a=False, 1588 conjugate_a=False, 1589 unit_diagonal=False): 1590 """Enqueues a triangular-solve operation onto the computation.""" 1591 if not transpose_a: 1592 transpose = _xla.TriangularSolveOptions_Transpose.NO_TRANSPOSE 1593 if conjugate_a: 1594 a = self.Conj(a) 1595 else: 1596 transpose = ( 1597 _xla.TriangularSolveOptions_Transpose.ADJOINT 1598 if conjugate_a else _xla.TriangularSolveOptions_Transpose.TRANSPOSE) 1599 return ops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose) 1600 1601 def Eigh(self, a, full_matrices=True): 1602 """Enqueues a symmetric/Hermitian eigendecomposition.""" 1603 return self.Tuple(*ops.Eigh(a, full_matrices)) 1604 1605 def SVD(self, a): 1606 """Enqueues a singular value decomposition.""" 1607 return self.Tuple(*ops.SVD(a)) 1608 1609 def Gather(self, 1610 a, 1611 start_indices, 1612 dimension_numbers, 1613 slice_sizes, 1614 indices_are_sorted=False): 1615 """Enqueues a Gather operation onto the computation.""" 1616 return ops.Gather(a, start_indices, dimension_numbers, slice_sizes, 1617 indices_are_sorted) 1618 1619 def Scatter(self, 1620 a, 1621 scatter_indices, 1622 updates, 1623 update_computation, 1624 dimension_numbers, 1625 indices_are_sorted=False, 1626 unique_indices=False): 1627 """Enqueues a Scatter operation onto the computation.""" 1628 return ops.Scatter(a, scatter_indices, updates, 1629 update_computation.computation, dimension_numbers, 1630 indices_are_sorted, unique_indices) 1631 1632 def Fft(self, operand, fft_type, fft_lengths): 1633 """Enqueues a FFT operation onto the computation.""" 1634 return ops.Fft(operand, fft_type, fft_lengths) 1635 1636 1637FftType = _xla.FftType 1638 1639_UNARY_OPS = [ 1640 'Not', 1641 'Clz', 1642 'Abs', 1643 'Exp', 1644 'Expm1', 1645 'Floor', 1646 'Round', 1647 'Ceil', 1648 'Log', 1649 'Log1p', 1650 'Sign', 1651 'Cos', 1652 'Sin', 1653 'Tanh', 1654 'IsFinite', 1655 'Sqrt', 1656 'Rsqrt', 1657 'Square', 1658 'Reciprocal', 1659 'Neg', 1660 'Erf', 1661 'Erfc', 1662 'ErfInv', 1663 'Lgamma', 1664 'Digamma', 1665 'BesselI0e', 1666 'BesselI1e', 1667 'Acos', 1668 'Asin', 1669 'Atan', 1670 'Tan', 1671 'Acosh', 1672 'Asinh', 1673 'Atanh', 1674 'Cosh', 1675 'Sinh', 1676 'Real', 1677 'Imag', 1678 'Conj', 1679] 1680 1681_BINARY_OPS = [ 1682 'Eq', 1683 'Ne', 1684 'Ge', 1685 'Gt', 1686 'Lt', 1687 'Le', 1688 'Add', 1689 'Sub', 1690 'Mul', 1691 'Div', 1692 'Rem', 1693 'Max', 1694 'Min', 1695 'And', 1696 'Or', 1697 'Xor', 1698 'Pow', 1699 'ShiftLeft', 1700 'ShiftRightArithmetic', 1701 'ShiftRightLogical', 1702 'Atan2', 1703 'Igamma', 1704 'Igammac', 1705 'Complex', 1706 'NextAfter', 1707] 1708 1709_OTHER_OPS = [ 1710 'BitcastConvertType', 1711 'Broadcast', 1712 'BroadcastInDim', 1713 'Cholesky', 1714 'Clamp', 1715 'Collapse', 1716 'CollectivePermute', 1717 'ConvertElementType', 1718 'Dot', 1719 'GetTupleElement', 1720 'ReducePrecision', 1721 'RegularizedIncompleteBeta', 1722 'Rev', 1723 'Select', 1724 'SliceInDim', 1725] 1726 1727 1728def _forward_methods_to_local_builder(): 1729 """Forward remaining ComputationBuilder methods to the C API. 1730 1731 Set up methods, corresponding to XLA operations, 1732 whose calls are forwarded in a boilerplate manner to the underlying 1733 _xla.ops API. 1734 """ 1735 1736 def forward_op(target_method): 1737 1738 def forward(builder, *args, **kwargs): 1739 del builder 1740 return target_method(*args, **kwargs) 1741 1742 return forward 1743 1744 for method_name in itertools.chain(_UNARY_OPS, _BINARY_OPS, _OTHER_OPS): 1745 forward = forward_op(getattr(ops, method_name)) 1746 forward.__name__ = method_name 1747 setattr(ComputationBuilder, method_name, forward) 1748 1749 1750_forward_methods_to_local_builder() 1751 1752 1753def register_custom_call_target(name, fn, platform='cpu'): 1754 """Registers a custom call target. 1755 1756 Args: 1757 name: bytes containing the name of the function. 1758 fn: a PyCapsule object containing the function pointer. 1759 platform: the target platform. 1760 """ 1761 _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform]) 1762 1763# Deprecated. Use register_custom_call_target instead. 1764register_cpu_custom_call_target = register_custom_call_target 1765 1766 1767class PaddingConfigDimension(object): 1768 """Python representation of a xla.PaddingConfigDimension protobuf.""" 1769 __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') 1770 1771 def __init__(self): 1772 self.edge_padding_low = 0 1773 self.edge_padding_high = 0 1774 self.interior_padding = 0 1775 1776 1777class PaddingConfig(object): 1778 """Python representation of a xla.PaddingConfig protobuf.""" 1779 __slots__ = ('dimensions',) 1780 1781 def __init__(self): 1782 self.dimensions = [] 1783 1784 1785def GetPaddingConfigFromTriples(triples): 1786 """Create PaddingConfig proto from list of triples of integers.""" 1787 padding_config = PaddingConfig() 1788 for lo, hi, interior in triples: 1789 dimension = PaddingConfigDimension() 1790 dimension.edge_padding_low = lo 1791 dimension.edge_padding_high = hi 1792 dimension.interior_padding = interior 1793 padding_config.dimensions.append(dimension) 1794 return padding_config 1795 1796 1797class DotDimensionNumbers(object): 1798 """Python representation of a xla.DotDimensionNumbers protobuf.""" 1799 __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', 1800 'lhs_batch_dimensions', 'rhs_batch_dimensions') 1801 1802 def __init__(self): 1803 self.lhs_contracting_dimensions = [] 1804 self.rhs_contracting_dimensions = [] 1805 self.lhs_batch_dimensions = [] 1806 self.rhs_batch_dimensions = [] 1807 1808 1809def GetDotDimensionsFromLists(dimension_numbers): 1810 (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 1811 dot_dims_proto = DotDimensionNumbers() 1812 dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) 1813 dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) 1814 dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) 1815 dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) 1816 return dot_dims_proto 1817 1818 1819class ConvolutionDimensionNumbers(object): 1820 """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" 1821 __slots__ = ('input_batch_dimension', 'input_feature_dimension', 1822 'input_spatial_dimensions', 'kernel_input_feature_dimension', 1823 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', 1824 'output_batch_dimension', 'output_feature_dimension', 1825 'output_spatial_dimensions') 1826 1827 def __init__(self): 1828 self.input_batch_dimension = 0 1829 self.input_feature_dimension = 0 1830 self.input_spatial_dimensions = [] 1831 self.kernel_input_feature_dimension = 0 1832 self.kernel_output_feature_dimension = 0 1833 self.kernel_spatial_dimensions = [] 1834 self.output_batch_dimension = 0 1835 self.output_feature_dimension = 0 1836 self.output_spatial_dimensions = [] 1837 1838 1839class OpSharding(object): 1840 """Python representation of a xla.OpSharding protobuf.""" 1841 __slots__ = ('type', 'tile_assignment_dimensions', 'tile_assignment_devices', 1842 'tuple_shardings') 1843 1844 Type = _xla.OpSharding_Type 1845 1846 def __init__(self): 1847 self.type = self.Type.REPLICATED 1848 self.tile_assignment_dimensions = [] 1849 self.tile_assignment_devices = [] 1850 self.tuple_shardings = [] 1851 1852 1853class PrecisionConfig(object): 1854 """Python representation of a xla.PrecisionConfig protobuf.""" 1855 __slots__ = ('operand_precision',) 1856 1857 Precision = _xla.PrecisionConfig_Precision 1858 1859 def __init__(self): 1860 self.operand_precision = [] 1861 1862 1863class GatherDimensionNumbers(object): 1864 """Python representation of a xla.GatherDimensionNumbers protobuf.""" 1865 __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', 1866 'index_vector_dim') 1867 1868 def __init__(self): 1869 self.offset_dims = [] 1870 self.collapsed_slice_dims = [] 1871 self.start_index_map = [] 1872 self.index_vector_dim = 0 1873 1874 1875class ScatterDimensionNumbers(object): 1876 """Python representation of a xla.ScatterDimensionNumbers protobuf.""" 1877 __slots__ = ('update_window_dims', 'inserted_window_dims', 1878 'scatter_dims_to_operand_dims', 'index_vector_dim') 1879 1880 def __init__(self): 1881 self.update_window_dims = [] 1882 self.inserted_window_dims = [] 1883 self.scatter_dims_to_operand_dims = [] 1884 self.index_vector_dim = 0 1885 1886 1887class ReplicaGroup(object): 1888 """Python representation of a xla.ReplicaGroup protobuf.""" 1889 __slots__ = ('replica_ids',) 1890 1891 def __init__(self): 1892 self.replica_ids = [] 1893 1894 1895def _make_replica_group_proto(replica_group): 1896 replica_group_proto = ReplicaGroup() 1897 replica_group_proto.replica_ids.extend(replica_group) 1898 return replica_group_proto 1899 1900 1901def _get_replica_groups_protos(replica_groups): 1902 if replica_groups is None: 1903 replica_groups_protos = [] # special value for XLA API 1904 else: 1905 replica_groups = list(replica_groups) 1906 replica_groups_protos = [ 1907 _make_replica_group_proto(group) for group in replica_groups 1908 ] 1909 return replica_groups_protos 1910