• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""An XLA client in Python."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import abc
23import collections
24import enum  # pylint: disable=g-bad-import-order
25import inspect
26import itertools
27import os
28
29from absl import logging
30import numpy as np
31
32# Note this module does *not* depend on any Python protocol buffers. The XLA
33# Python bindings are currently packaged both as part of jaxlib and as part
34# of TensorFlow. If we use protocol buffers here, then importing both jaxlib
35# and TensorFlow may fail with duplicate protocol buffer message definitions.
36
37from tensorflow.compiler.xla.python import xla_extension as _xla
38from tensorflow.compiler.xla.python.xla_extension import ops
39
40# Most functions are snake_case for consistency with other modules, whereas
41# method names of ComputationBuilder and Computation are CamelCase for
42# consistency with XLA.
43# pylint: disable=invalid-name
44
45
46class Backend(object, metaclass=abc.ABCMeta):
47  """Abstract base class for XLA backends."""
48
49  def __init__(self, platform):
50    """Creates a new Backend.
51
52    Args:
53      platform: A string naming the platform; for example 'gpu'.
54    """
55    self.platform = platform
56
57  @abc.abstractmethod
58  def device_count(self):
59    """Returns the number of devices known to the backend."""
60
61  @abc.abstractmethod
62  def local_device_count(self):
63    """Returns the number of devices local to this host."""
64
65  @abc.abstractmethod
66  def devices(self):
67    """Returns a list of `device_count()` Device subclasses."""
68
69  @abc.abstractmethod
70  def host_id(self):
71    """Returns the integer ID of this host."""
72
73  @abc.abstractmethod
74  def buffer_from_pyval(self, pyval, device=None, force_copy=False):
75    """Allocates a fresh buffer and populates it with `pyval`."""
76
77  @abc.abstractmethod
78  def make_tuple(self, c_buffers, device):
79    """Makes a tuple from a sequence of backend buffer objects."""
80
81  @abc.abstractmethod
82  def compile(self, computation, compile_options):
83    """Compiles a computation. Returns an executable."""
84
85  @abc.abstractmethod
86  def get_default_device_assignment(self, num_replicas, num_partitions):
87    """Returns the default device assignment that `compile` would use.
88
89    If `compile_options.device_assignment` isn't set, `compile` will pick a
90    deterministic device assignment based on the number of replicas and
91    partitions, possibly optimizing for device locality. This method returns
92    that assignment, which is useful for e.g. manually replicating a value
93    before passing it to a compiled executable.
94
95    Args:
96      num_replicas: the number of replicas needed.
97      num_partitions: the number of partitions needed.
98
99    Returns:
100      A list of list of Devices of size `(num_replicas, num_partitions)`.
101    """
102
103
104class LocalBackend(Backend):
105  """XLA backend implemented using the in-process xla::LocalClient API."""
106
107  def __init__(self, platform, client):
108    """Creates a new LocalBackend.
109
110    Args:
111      platform: A string; the user-visible platform name, e.g. 'gpu'.
112      client: An _xla.PyLocalClient object.
113    """
114    super(LocalBackend, self).__init__(platform)
115    self.client = client
116
117  def device_count(self):
118    return self.client.device_count()
119
120  def local_device_count(self):
121    return self.client.local_device_count()
122
123  def devices(self):
124    return self.client.devices()
125
126  def local_devices(self):
127    return self.client.local_devices()
128
129  def host_id(self):
130    return self.client.host_id()
131
132  def buffer_from_pyval(self, pyval, device=None, force_copy=False):
133    if device is None:
134      device = self.local_devices()[0]
135    return _xla.PyLocalBuffer.from_python(pyval, self.client, device,
136                                          force_copy)
137
138  def make_tuple(self, c_buffers, device):
139    return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device)
140
141  def compile(self, c_computation, compile_options):
142    options = _xla.ExecutableBuildOptions()
143    options.num_replicas = compile_options.num_replicas
144    options.num_partitions = compile_options.num_partitions
145    if compile_options.result_layout:
146      options.result_layout = compile_options.result_layout
147    options.debug_options.xla_cpu_fast_math_honor_infs = True
148    options.debug_options.xla_cpu_fast_math_honor_nans = True
149    options.debug_options.xla_cpu_fast_math_honor_division = True
150    options.debug_options.xla_cpu_fast_math_honor_functions = True
151    options.debug_options.xla_gpu_enable_fast_min_max = False
152    return _xla.LocalExecutable.Compile(c_computation,
153                                        compile_options.argument_layouts,
154                                        options, self.client,
155                                        compile_options.device_assignment)
156
157  def get_default_device_assignment(self, num_replicas, num_partitions=None):
158    if num_partitions is not None:
159      return self.client.GetDefaultDeviceAssignment(num_replicas,
160                                                    num_partitions)
161    else:
162      # TODO(skye): delete this case after all callers can handle 2D output
163      return self.client.GetDefaultDeviceAssignment(num_replicas)
164
165
166xla_platform_names = {
167    'cpu': 'Host',
168    'gpu': 'CUDA',
169}
170
171
172def _cpu_backend_factory():
173  client = _xla.LocalClient.Get(
174      platform='cpu',
175      xla_platform_id=xla_platform_names['cpu'],
176      asynchronous=True)
177  return LocalBackend(platform='cpu', client=client)
178
179
180def _gpu_backend_factory():
181  """Returns a GPU backend. BFC allocator is used by default."""
182  allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
183  memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION')
184  preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE')
185  if allocator not in ('default', 'platform', 'bfc'):
186    raise ValueError(
187        'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", or '
188        '"bfc", got "%s"' % allocator)
189  config = _xla.AllocatorConfig()
190  if allocator == 'default':
191    config.kind = _xla.AllocatorConfig.Kind.DEFAULT
192  if allocator == 'platform':
193    config.kind = _xla.AllocatorConfig.Kind.PLATFORM
194  if allocator == 'bfc':
195    config.kind = _xla.AllocatorConfig.Kind.BFC
196  if memory_fraction:
197    config.memory_fraction = float(memory_fraction)
198  config.preallocate = preallocate not in ('0', 'false', 'False')
199
200  client = _xla.LocalClient.Get(
201      platform='gpu',
202      xla_platform_id=xla_platform_names['gpu'],
203      asynchronous=True,
204      allocator_config=config)
205  return LocalBackend(platform='gpu', client=client)
206
207
208# Backend factories, keyed by user-visible name, in increasing priority order.
209_local_backend_factories = collections.OrderedDict([
210    ('cpu', _cpu_backend_factory),
211    ('gpu', _gpu_backend_factory),
212])
213
214
215def register_local_backend_factory(name, factory):
216  _local_backend_factories[name] = factory
217
218
219_local_backends = None
220
221
222def _get_local_backends():
223  """Instantiates all known local backends."""
224  global _local_backends
225  if _local_backends is not None:
226    return _local_backends
227
228  _local_backends = collections.OrderedDict()
229  for name, factory in _local_backend_factories.items():
230    logging.vlog(2, "Initializing backend '%s'" % name)
231    try:
232      backend = factory()
233    except RuntimeError:
234      if name == 'cpu':
235        # We always expect CPU to initialize successfully.
236        raise
237      else:
238        # If the backend isn't built into the binary, or if it has no devices,
239        # we expect a RuntimeError.
240        continue
241    _local_backends[name] = backend
242  return _local_backends
243
244
245def get_local_backend(name=None):
246  """Returns a local backend.
247
248  Args:
249    name: the backend name. If `None`, a default local backend is returned,
250      typically `gpu` if one is present, or `cpu` if not. If a string, the named
251      backend is returned or an exception raised.
252
253  Returns:
254    A LocalBackend object.
255  """
256  backends = _get_local_backends()
257  if name is not None:
258    try:
259      return backends[name]
260    except KeyError:
261      raise RuntimeError('Unknown backend {}'.format(name))
262
263  return list(backends.values())[-1]
264
265
266class OpMetadata(object):
267  """Python representation of a xla.OpMetadata protobuf."""
268  __slots__ = ('op_type', 'op_name', 'source_file', 'source_line')
269
270  def __init__(self, op_type='', op_name='', source_file='', source_line=0):
271    self.op_type = op_type
272    self.op_name = op_name
273    self.source_file = source_file
274    self.source_line = source_line
275
276
277def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
278  """Helper for use in source mapping that returns an OpMetadata object."""
279  full_filename, lineno = inspect.stack()[skip_frames][1:3]
280  filename = os.path.basename(full_filename)
281  return OpMetadata(
282      op_type=op_type,
283      op_name=op_name,
284      source_file=filename,
285      source_line=lineno)
286
287
288PrimitiveType = _xla.PrimitiveType
289
290bfloat16 = _xla.bfloat16_dtype()
291
292XLA_ELEMENT_TYPE_TO_DTYPE = {
293    PrimitiveType.PRED: np.dtype('bool'),
294    PrimitiveType.S8: np.dtype('int8'),
295    PrimitiveType.S16: np.dtype('int16'),
296    PrimitiveType.S32: np.dtype('int32'),
297    PrimitiveType.S64: np.dtype('int64'),
298    PrimitiveType.U8: np.dtype('uint8'),
299    PrimitiveType.U16: np.dtype('uint16'),
300    PrimitiveType.U32: np.dtype('uint32'),
301    PrimitiveType.U64: np.dtype('uint64'),
302    PrimitiveType.BF16: np.dtype(bfloat16),
303    PrimitiveType.F16: np.dtype('float16'),
304    PrimitiveType.F32: np.dtype('float32'),
305    PrimitiveType.F64: np.dtype('float64'),
306    PrimitiveType.C64: np.dtype('complex64'),
307    PrimitiveType.C128: np.dtype('complex128'),
308    PrimitiveType.TUPLE: np.dtype(np.object),
309    PrimitiveType.TOKEN: np.dtype(np.object),
310}
311
312# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
313# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
314# when keying by dtype in this dict, we use the string form of dtypes.
315DTYPE_TO_XLA_ELEMENT_TYPE = {
316    str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()
317}
318
319
320def dtype_to_etype(dtype):
321  """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE."""
322  return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
323
324
325Shape = _xla.Shape
326Shape.__doc__ = """
327A Shape is an object defined in C++ that duck types like the following class:
328
329class Shape(object):
330  '''Represents an XLA shape.
331
332  A shape is either an array shape, having rank-many integer
333  dimensions and an element type (represented by a Numpy dtype), or it
334  is a tuple shape, having a shape for every tuple component:
335
336    type shape =
337        TupleShape of shape list
338      | ArrayShape of { dimensions: int list; element_type: dtype }
339  '''
340
341  @staticmethod
342  def tuple_shape(tuple_shapes) -> Shape:
343    "Construct a tuple shape."
344
345  @staticmethod
346  def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:
347
348  @staticmethod
349  def from_pyval(pyval) -> Shape:
350    "Returns a Shape that describes a tuple-tree of Numpy arrays."
351
352  def __init__(self, str) -> Shape:
353    "Parses a shape string."
354  def __eq__(self, other: Shape) -> bool:
355  def __ne__(self, other: Shape) -> bool:
356  def __hash__(self):
357  def __repr__(self):
358  def is_tuple(self) -> bool:
359  def is_array(self) -> bool:
360  def tuple_shapes(self) -> [Shape]:
361  def numpy_dtype(self) -> np.dtype:
362    "Like element_type(), but returns dtype('O') for a tuple shape."
363  def xla_element_type(self) -> PrimitiveType:
364  def element_type(self) -> np.dtype:
365  def dimensions(self) -> (int, int, ...):
366  def rank(self) -> int:
367  def with_major_to_minor_layout_if_absent(self) -> Shape:
368    "Returns a copy with missing layouts set to major-to-minor."
369
370  def to_serialized_proto(self) -> bytes:
371    "Returns 'shape' as a serialized proto."
372"""
373
374ProgramShape = _xla.ProgramShape
375ProgramShape.__doc__ = """
376A ProgramShape is a C++ object that duck types like the following class.
377
378class ProgramShape(object):
379  def __init__(self, parameter_shapes, result_shape):
380  def parameter_shapes(self) -> [Shape]:
381  def result_shape(self) -> Shape:
382  def __repr__(self):
383"""
384
385
386class Buffer(object):
387  """Represents a handle to data owned by XLA.
388
389  The referent is ready for use in executing a local, compiled
390  Computation. On XLA platforms involving a device (e.g. GPU), this
391  means the referent is in device memory.
392  """
393
394  @staticmethod
395  def from_pyval(pyval, device=None, backend=None, force_copy=False):
396    """Copies the `pyval` to a freshly allocated on-device buffer."""
397    backend = backend or get_local_backend()
398    return backend.buffer_from_pyval(pyval, device, force_copy=force_copy)
399
400  @staticmethod
401  def make_tuple(buffers, device, backend=None):
402    backend = backend or get_local_backend()
403    return backend.make_tuple(buffers, device)
404
405  # Buffer is not an instantiable type and exists only for its static methods.
406  # The underlying buffer objects are C++ object with the following
407  # API:
408  # def shape(self) -> Shape:
409  # def device(self) -> int:
410  # def delete(self):
411  # def destructure(self) -> [Buffer]
412  # def is_deleted(self) -> bool:
413  # def block_host_until_ready(self):
414  #    """Blocks the calling thread until the buffer is ready on device."""
415  # def copy_to_host_async(self):
416  #    """Requests a copy of the buffer to the host.
417  #
418  #       Does not block waiting for the copy. Values fetched are available via
419  #       `to_py()`; the purpose of `copy_to_host_async` is to prefetch values
420  #       for subsequent `to_py()` calls, especially when requesting many values
421  #       at once.
422  #    """
423  # def to_py(self):
424  #    """Returns the value of the buffer as a Python tuple tree of ndarrays."""
425  #
426  # TODO(phawkins): remove Buffer and its static methods completely, have
427  # clients call methods on Backend to create buffers.
428
429
430# TODO(phawkins): Alias for backward compatibility. Remove after JAX drops
431# compatibility with Jaxlib versions older than 0.1.13.
432LocalBuffer = Buffer
433
434
435def shape_from_pyval(pyval):
436  """Returns a Shape that describes a tuple-tree of Numpy arrays."""
437
438  def convert(pyval):
439    if isinstance(pyval, tuple):
440      return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
441    else:
442      return Shape.array_shape(pyval.dtype, np.shape(pyval))
443
444  return convert(pyval)
445
446
447def transfer_to_infeed(value, device=None):
448  """Transfers the given value into the XLA infeed queue.
449
450  XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
451  a totally ordered stream of values. This is dequeued from XLA computations via
452  the Infeed() operation.
453
454  Args:
455    value: the value that the caller would like to enqueue into the XLA infeed
456      queue
457    device: the device to infeed the value to. Each device has a
458      distinct infeed queue.
459  """
460  # TODO(phawkins): support non-default backends.
461  backend = get_local_backend()
462  device = device or backend.local_devices()[0]
463  device.TransferToInfeed(value)
464
465
466def transfer_from_outfeed(shape, device=None):
467  """Transfers a literal of the given shape from `device`'s outfeed.
468
469  Args:
470    shape: The shape of the value to transfer from outfeed.
471    device: The device from which to transfer the outfeed value. Each device has
472      a distinct outfeed queue..
473
474  Returns:
475    The literal value that is produced from the outfeed queue.
476  """
477  # TODO(phawkins): support non-default backends.
478  backend = get_local_backend()
479  device = device or backend.local_devices()[0]
480  return device.TransferFromOutfeed(
481      shape.with_major_to_minor_layout_if_absent())
482
483
484DeviceAssignment = _xla.DeviceAssignment
485DeviceAssignment.__doc__ = """
486A DeviceAssignment is a C++ object with the following signature.
487
488def create(assignment):
489  '''Builds a device assignment.
490
491   Args:
492     assignment: a 2D numpy array of device ordinal integers, indexed by
493       [replica][computation_in_replica].
494   Returns:
495     A device assignment.
496  '''
497
498def replica_count():
499  '''Returns the number of replicas.'''
500def computation_count():
501  '''Returns the number of computations per replica.'''
502"""
503
504
505Device = _xla.Device
506
507
508class CompileOptions(object):
509  """Python object for XLA compile options.
510
511  These options can be passed to the 'compile' step when using a local XLA
512  client.
513  """
514
515  def __init__(self):
516    self.xla_dump_to = None
517    self.dump_hlo_pass_re = None
518    self.dump_hlo_module_re = None
519    self.dump_hlo_as_text = None
520    self.dump_hlo_as_proto = None
521    self.hlo_profile = None
522    self.num_replicas = 1
523    self.num_partitions = 1
524    self.argument_layouts = None
525    self.result_layout = None
526    self.device_assignment = None
527
528
529class Computation(object):
530  """Python wrapper for an XLA Computation.
531
532  A Computation can be compiled to form an Executable, or used as a
533  subcomputation in ComputationBuilder methods.
534  """
535
536  def __init__(self, c_computation, backend=None):
537    self._c_computation = c_computation
538    # The backend argument is deprecated. Pass a backend to Compile() instead.
539    self._backend = backend
540
541  @property
542  def computation(self):
543    return self._c_computation
544
545  def GetSerializedProto(self):
546    """Gets the serialized HloModuleProto proto object in this computation.
547
548    Returns:
549       A string containing a serialized HloModuleProto proto containing the
550       computation and its dependencies.
551    """
552    return self.computation.GetSerializedProto()
553
554  def GetHloText(self):
555    """Get the textual HLO representation of this computation.
556
557    Returns:
558       A string containing the textual HLO.
559    """
560    return self.computation.GetHloText()
561
562  def GetHloDotGraph(self):
563    """Get a Graphviz Dot representation of this computation.
564
565    Returns:
566       A string containing the graphviz dot graph.
567    """
568    return self.computation.GetHloDotGraph()
569
570  def Compile(self, argument_shapes=None, compile_options=None, backend=None):
571    """Compiles a computation.
572
573    Computations are the result of a "ComputationBuild'ing" process.
574
575    Arguments:
576      argument_shapes: Deprecated. Use compile_options.argument_layouts instead.
577      compile_options: options to use for compilation, includes an optional laid
578        out result shape for the computation.
579      backend: a `Backend` for which an executable should be generated.
580
581    Returns:
582      A Executable instance.
583    """
584    backend = backend or self._backend or get_local_backend()
585
586    compile_options = compile_options or CompileOptions()
587    if argument_shapes:
588      compile_options.argument_layouts = argument_shapes
589    return backend.compile(self.computation, compile_options)
590
591  def GetProgramShape(self):
592    return self._c_computation.GetProgramShape()
593
594  def GetReturnValueShape(self):
595    return self._c_computation.GetProgramShape().result_shape()
596
597  def Hash(self):
598    return self._c_computation.Hash()
599
600
601# An Executable is a C++ class that duck types with the following API:
602# class Executable(object):
603#   def local_devices(self) -> [Device]:
604#   def Execute(self, arguments : [Buffer]) -> Buffer:
605#     """Execute on one replica with Buffer arguments and return value."""
606#
607#   def SizeOfGeneratedCodeInBytes(self) -> int:
608#     """Return generated binary size, or -1 if not known."""
609#
610#   def ExecutePerReplica(self, arguments: [[Buffer]]) -> [Buffer]:
611#     """Execute on many replicas with Buffer arguments and return value.
612#
613#     Args:
614#       arguments: A sequence of sequences of Buffers. The i'th inner sequence
615#         comprises the arguments for execution on the i'th replica.
616#
617#     Returns:
618#       A list of the computation's outputs for each replica, as a Buffer. If
619#       a shallow sequence of arguments was passed in for `arguments`, then the
620#       sole, zero'th replica's output is returned instead, as a Buffer.
621#     """
622#
623# There are different implementations of Executable for different backends.
624
625
626def execute_with_python_values(executable, arguments=(), backend=None):
627  """Execute on one replica with Python values as arguments and output."""
628
629  backend = backend or get_local_backend()
630
631  def put(arg):
632    return Buffer.from_pyval(
633        arg, device=executable.local_devices()[0], backend=backend)
634
635  arguments = [put(arg) for arg in arguments]
636  return executable.Execute(arguments).to_py()
637
638
639def execute_with_python_values_replicated(executable, arguments, backend=None):
640  """Execute on many replicas with Python values as arguments and output.
641
642  Arguments:
643    executable: the program to run.
644    arguments: a list of lists of Python values indexed by
645      `[replica][arg_num]` to pass as inputs.
646    backend: the backend we are targeting.
647
648  Returns:
649    A list of python values, one per replica.
650  """
651  backend = backend or get_local_backend()
652  devices = executable.local_devices()
653  # pylint: disable=g-complex-comprehension
654  flat_args = [(arg, devices[replica])
655               for replica, replica_args in enumerate(arguments)
656               for arg in replica_args]
657  flat_arg_buffers = [
658      backend.buffer_from_pyval(pyval, device) for pyval, device in flat_args
659  ]
660  arg_buffers = []
661  for replica_args in arguments:
662    arg_buffers.append(flat_arg_buffers[:len(replica_args)])
663    flat_arg_buffers = flat_arg_buffers[len(replica_args):]
664  return [out.to_py() for out in executable.ExecutePerReplica(arg_buffers)]
665
666
667class PaddingType(enum.Enum):
668  VALID = 1
669  SAME = 2
670
671
672def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
673                                        window_strides):
674  """Maps PaddingType or string to pad values (list of pairs of ints)."""
675  if not isinstance(padding_type, (str, PaddingType)):
676    msg = 'padding_type must be str or PaddingType, got {}.'
677    raise TypeError(msg.format(type(padding_type)))
678
679  if isinstance(padding_type, str):
680    if padding_type.upper() == 'VALID':
681      padding_type = PaddingType.VALID
682    elif padding_type.upper() == 'SAME':
683      padding_type = PaddingType.SAME
684    else:
685      msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
686      raise ValueError(msg.format(padding_type))
687
688  if padding_type == PaddingType.VALID:
689    return [(0, 0)] * len(window_strides)
690  elif padding_type == PaddingType.SAME:
691    out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
692    pad_sizes = [
693        max((out_size - 1) * stride + filter_size - in_size, 0)
694        for out_size, stride, filter_size, in_size in zip(
695            out_shape, window_strides, rhs_dims, lhs_dims)
696    ]
697    return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
698  else:
699    msg = 'Unexpected PaddingType value: {}'
700    raise ValueError(msg.format(padding_type))
701
702
703class ComputationBuilder(object):
704  """XLA computation builder.
705
706  Enqueues XLA ops in sequence and in order to build a
707  Computation, which in turn can be compiled into a
708  LocalExecutable, which in turn can be locally executed.
709  """
710
711  # The methods of this class map 1-to-1 onto the XLA C++
712  # computation builder API. Therefore, there's no need to laboriously list
713  # arguments and return values for every method, especially where it's obvious.
714  #
715  # pylint: disable=g-doc-return-or-yield
716  # pylint: disable=g-doc-args
717
718  def __init__(self, name):
719    self._builder = _xla.XlaBuilder(name)
720    self._parameter_numbering = itertools.count()
721
722  def Build(self, root=None, backend=None):
723    """Builds a `Computation` from the contents of the builder.
724
725    Args:
726      root: if not None, the operator containing the return value of the
727        computation.
728
729    Returns:
730      A `Computation`.
731    """
732    if root is not None:
733      return Computation(self._builder.Build(root), backend=backend)
734    else:
735      return Computation(self._builder.Build(), backend=backend)
736
737  def GetShape(self, operand):
738    return self._builder.GetShape(operand)
739
740  def SetOpMetadata(self, op_metadata):
741    """Set metadata for operations that are about to be enqueued."""
742    self._builder.SetOpMetadata(op_metadata)
743
744  def ClearOpMetadata(self):
745    """Clear metadata for operations that are about to be enqueued."""
746    self._builder.ClearOpMetadata()
747
748  def SetSharding(self, sharding):
749    """Set sharding that will be attached to all instructions until cleared."""
750    self._builder.SetSharding(sharding)
751
752  def ClearSharding(self):
753    """Clears the sharding.
754
755    Ops will be sharded according to the default placement policy.
756    """
757    self._builder.ClearSharding()
758
759  def CreateToken(self):
760    """Enqueues a CreateToken op onto the computation.
761
762    Returns:
763      An XlaOp, representing a fresh token.
764    """
765    return ops.CreateToken(self._builder)
766
767  def AfterAll(self, tokens):
768    """Enqueues a after-all op onto the computation.
769
770    `AfterAll` takes a variadic number of tokens and produces a single token.
771
772    Args:
773      tokens: a list of `XlaOp` values representing predecessor tokens.
774
775    Returns:
776      An `XlaOp`.
777    """
778    return ops.AfterAll(self._builder, tokens)
779
780  def Infeed(self, shape, token=None):
781    """Enqueues an infeed op onto the computation.
782
783    Infeed operations dequeue data of the given shape from the device's infeed
784    queue for subsequent use in the computation.
785
786    Args:
787      shape: a `Shape` describing the shape of the infed value.
788      token: an optional `XlaOp` representing a token after which the infeed
789        effect should be sequenced.
790    Returns:
791      An XlaOp, representing a (value, token) pair.
792    """
793    if token is None:
794      token = ops.CreateToken(self._builder)
795    return ops.InfeedWithToken(token,
796                               shape.with_major_to_minor_layout_if_absent())
797
798  def Outfeed(self, operand, token=None):
799    """Enqueues an outfeed op onto the computation.
800
801    Outfeed operations enqueue data, using the given operand, onto the XLA
802    outfeed queue for subsequent dequeue via the client API.
803
804    Args:
805      operand: an `XlaOp` representing the data to outfeed.
806      token: an `XlaOp` representing a token after which the outfeed should be
807        sequenced.
808    Returns:
809      An `XlaOp` representing a token.
810    """
811    if token is None:
812      token = ops.CreateToken(self._builder)
813    return ops.OutfeedWithToken(operand, token, self._builder.GetShape(operand),
814                                '')
815
816  def Constant(self, value):
817    """Enqueues a constant op onto the computation.
818
819    Args:
820      value: value for the constant, as a np.array with an explicit dtype set to
821        one of the supported types.
822
823    Returns:
824      An XlaOp.
825    """
826    return ops.ConstantLiteral(self._builder, value)
827
828  def ConstantF32Scalar(self, value):
829    """Convenience method to enqueue a scalar F32 constant op.
830
831    Args:
832      value: a floating-point number.
833
834    Returns:
835      An XlaOp.
836    """
837    return self.Constant(np.array(value, dtype=np.float32))
838
839  def ConstantF64Scalar(self, value):
840    """Convenience method to enqueue a scalar F32 constant op.
841
842    Args:
843      value: a floating-point number.
844
845    Returns:
846      An XlaOp.
847    """
848    return self.Constant(np.array(value, dtype=np.float64))
849
850  def ConstantS32Scalar(self, value):
851    """Convenience method to enqueue a scalar S32 constant op.
852
853    Args:
854      value: a floating-point number.
855
856    Returns:
857      An XlaOp.
858    """
859    return self.Constant(np.array(value, dtype=np.int32))
860
861  def ConstantS64Scalar(self, value):
862    """Convenience method to enqueue a scalar S64 constant op.
863
864    Args:
865      value: a floating-point number.
866
867    Returns:
868      An XlaOp.
869    """
870    return self.Constant(np.array(value, dtype=np.int64))
871
872  def ConstantPredScalar(self, value):
873    """Convenience method to enqueue a scalar PRED constant op.
874
875    Args:
876      value: a boolean value.
877
878    Returns:
879      An XlaOp.
880    """
881    return self.Constant(np.array(value, dtype=np.bool))
882
883  def ParameterWithShape(self, shape, name=None, parameter_num=None,
884                         replicated=False):
885    """Enqueues a Parameter op onto the computation, given a shape.
886
887    Args:
888      shape: the parameter's shape as a Shape object.
889      name: optional string name for the parameter.
890      parameter_num: parameter number in the computation function. If None, the
891        next linear parameter number is used. The default value capability can
892        be used for auto-numbering. If you're using auto-numbering for some
893        parameters, use it for *all* parameters to avoid clashes.
894      replicated: whether to mark the parameter's leaves as replicated. May be
895        a bool, in which case it applies to all leaves, or an iterable of bools.
896
897    Returns:
898      An XlaOp.
899    """
900    if name is None:
901      name = ''
902    if parameter_num is None:
903      parameter_num = next(self._parameter_numbering)
904    if isinstance(replicated, bool):
905      replicated = [replicated] * shape.leaf_count()
906
907    return ops.Parameter(self._builder, parameter_num,
908                         shape.with_major_to_minor_layout_if_absent(),
909                         name.encode('utf8'), replicated)
910
911  def ParameterFromNumpy(self, value, name=None, parameter_num=None):
912    """Enqueues a Parameter op onto the computation.
913
914    Args:
915      value: a Numpy array, or a nested tuple thereof, from which the shape is
916        inferred.
917      name: as in ParameterWithShape.
918      parameter_num: as in ParameterWithShape.
919
920    Returns:
921      An XlaOp.
922    """
923    return self.ParameterWithShape(
924        shape_from_pyval(value), name=name, parameter_num=parameter_num)
925
926  def Iota(self, dtype, size):
927    """Enqueues an iota constant onto the computation.
928
929    Args:
930      dtype: expected numpy dtype of the output.
931      size: integer, the number of elements in the array.
932
933    Returns:
934      An XlaOp representing the added iota constant.
935    """
936    element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
937    return ops.Iota(self._builder, element_type, size)
938
939  def BroadcastedIota(self, dtype, shape, dimension):
940    """Enqueues a broadcasted iota constant onto the computation.
941
942    Args:
943      dtype: expected numpy dtype of the output.
944      shape: tuple of integers, the expected output shape (dimensions).
945      dimension: positive integer, dimension along which to increment values.
946
947    Returns:
948      An XlaOp representing the added broadcasted iota constant.
949    """
950    element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
951    xla_shape = _xla.Shape.array_shape(element_type, shape, None)
952    return ops.Iota(self._builder, xla_shape, dimension)
953
954  def Concatenate(self, operands, dimension):
955    """Enqueues a concatenate operation onto the computation.
956
957    Args:
958      operands: the operands to concatenate.
959      dimension: the dimension in which to perform the concatenation.
960
961    Returns:
962      An XlaOp representing the added concatenate op.
963    """
964    return ops.ConcatInDim(self._builder, list(operands), dimension)
965
966  def ReplicaId(self):
967    """Enqueues a ReplicaId operation onto the computation.
968
969    Returns:
970      A LocalOp representing the replica id.
971    """
972    return _xla.ops.ReplicaId(self._builder)
973
974  def Pad(self, operand, padding_value, padding_config):
975    """Enqueues a Pad operation onto the computation.
976
977    Args:
978      operand: XlaOp representing the array to pad.
979      padding_value: XlaOp representing the scalar pad value.
980      padding_config: either a PaddingConfig or a list of integer triples
981        (edge_padding_low, edge_padding_high, interior_padding) representing the
982        configuration of the padding operation.
983
984    Returns:
985      An XlaOp representing the added Pad op.
986    """
987    if isinstance(padding_config, tuple) or isinstance(padding_config, list):
988      padding_config = GetPaddingConfigFromTriples(padding_config)
989    return ops.Pad(operand, padding_value, padding_config)
990
991  def Reshape(self, operand, dimensions, new_sizes):
992    """Enqueues a reshape op onto the computation.
993
994    Args:
995      operand: XlaOp representing the array to be reshaped.
996      dimensions: sequence of integers encoding the order in which dimensions
997        are collapsed or None, in which case dimensions are flattened in order.
998      new_sizes: sequence of integers encoding the new dimension sizes (shape).
999
1000    Returns:
1001      An XlaOp representing the added Reshape op.
1002    """
1003    if dimensions is None:
1004      ndim = len(self.GetShape(operand).dimensions())
1005      dimensions = tuple(range(ndim))
1006    return ops.Reshape(operand, dimensions, new_sizes)
1007
1008  def AllReduce(self, operand, computation, replica_groups=None):
1009    """AllReduce op.
1010
1011    Args:
1012      operand: XlaOp representing the input array
1013      computation: a Computation object - binary reduction function.
1014      replica_groups: optional, list of lists of ints encoding a partition of
1015        the set {0, 1, ..., num_replicas} into equally-sized replica groups
1016        within which the all-to-all is performed. If not supplied or None (the
1017        default), all replicas belong to the same group.
1018
1019    Returns:
1020      An XlaOp that represents the all-reduced result.
1021    """
1022    replica_groups_protos = _get_replica_groups_protos(replica_groups)
1023    return ops.AllReduce(operand, computation.computation,
1024                         replica_groups_protos, None, None)
1025
1026  def AllToAll(self,
1027               operand,
1028               split_dimension,
1029               concat_dimension,
1030               replica_groups=None):
1031    """AllToAll op.
1032
1033    Args:
1034      operand: XlaOp representing the input array
1035      split_dimension: the dimension along which the operand is split
1036      concat_dimension: the dimension along which the split blocks are
1037        concatenated
1038      replica_groups: optional, list of lists of ints encoding a partition of
1039        the set {0, 1, ..., num_replicas} into equally-sized replica groups
1040        within which the all-to-all is performed. If not supplied or None (the
1041        default), all replicas belong to the same group.
1042
1043    Returns:
1044      An XlaOp that represents the all-to-all concatenation.
1045    """
1046    replica_groups_protos = _get_replica_groups_protos(replica_groups)
1047    if not replica_groups:
1048      split_count = 1
1049    else:
1050      split_count = len(replica_groups[0])
1051      if not all(split_count == len(g) for g in replica_groups):
1052        raise ValueError('Replica groups must be equally sized')
1053    return ops.AllToAll(operand, split_dimension, concat_dimension, split_count,
1054                        replica_groups_protos)
1055
1056  def CrossReplicaSum(self, operand, replica_groups=None):
1057    """CrossReplicaSum op.
1058
1059    Args:
1060      operand: the operand to sum across replica instances.
1061      replica_groups: optional, list of lists of ints encoding a partition of
1062        the set {0, 1, ..., num_replicas} into equally-sized replica groups
1063        within which the cross-replica sum is performed. If not supplied or None
1064        (the default), all replicas belong to the same group.
1065
1066    Returns:
1067      An XlaOp that represents on each replica the sum of its group's values.
1068    """
1069    replica_groups_protos = _get_replica_groups_protos(replica_groups)
1070    return ops.CrossReplicaSum(operand, replica_groups_protos)
1071
1072  def Trans(self, operand):
1073    """Specialized matrix transpose op."""
1074    return ops.Transpose(operand, [1, 0])
1075
1076  def Transpose(self, operand, permutation):
1077    """Transpose op."""
1078    return ops.Transpose(operand, permutation)
1079
1080  def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
1081                       padding, source, init_value, scatter):
1082    """Select and scatter op, used by the gradient of ReduceWindow.
1083
1084    Args:
1085      operand: XlaOp for array of dimension N and type T over which the windows
1086        slide.
1087      select: Computation of type (T, T) -> Pred to apply to the elements of
1088        each window to indicate which element is selected.
1089      window_dimensions: sequence of N integers for dimensions of the window.
1090      window_strides: sequence of N integers for the strides of the window.
1091      padding: PaddingType representing either 'SAME' or 'VALID ' padding.
1092      source: XlaOp for array of type T with values to scatter.
1093      init_value: XlaOp of scalar type T for initial out value.
1094      scatter: Computation of type (T, T) -> T to apply to each scatter source
1095        element with its destination element.
1096
1097    Returns:
1098      An XlaOp representing the added SelectAndScatter op.
1099    """
1100    pads = _convert_padding_type_to_pad_values(
1101        padding,
1102        self.GetShape(operand).dimensions(), window_dimensions, window_strides)
1103    return ops.SelectAndScatterWithGeneralPadding(operand, select.computation,
1104                                                  window_dimensions,
1105                                                  window_strides, pads, source,
1106                                                  init_value,
1107                                                  scatter.computation)
1108
1109  def Slice(self, operand, start_indices, limit_indices, strides=None):
1110    """Enqueues a slice operation onto the computation.
1111
1112    Args:
1113      operand: XlaOp for the N dimensional array to be sliced.
1114      start_indices: iterable of N integers containing the starting indices of
1115        the slice for each dimension.
1116      limit_indices: iterable of N integers containing the ending indices
1117        (exclusive) of the slice for each dimension.
1118      strides: optional iterable of N integers containing the stride sizes for
1119        each dimension.
1120
1121    Returns:
1122      An XlaOp representing the added Slice op.
1123    """
1124    if strides is None:
1125      start_indices = list(start_indices)
1126      strides = [1] * len(start_indices)
1127    return ops.Slice(operand, start_indices, limit_indices, strides)
1128
1129  def DynamicSlice(self, operand, start_indices, slice_sizes):
1130    """Enqueues a slice op with dynamic start indices onto the computation.
1131
1132    Args:
1133      operand: XlaOp for the N dimensional array to be sliced.
1134      start_indices: XlaOp for the 1D array of N integers containing the
1135        starting indices of the slice.
1136      slice_sizes: iterable of N integers containing the slice sizes in each
1137        dimension.
1138
1139    Returns:
1140      An XlaOp representing the added DynamicSlice op.
1141    """
1142    slice_sizes = list(slice_sizes)
1143    if isinstance(start_indices, _xla.XlaOp):
1144      start_indices = [
1145          ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), [])
1146          for i in range(len(slice_sizes))
1147      ]
1148    return ops.DynamicSlice(operand, list(start_indices), slice_sizes)
1149
1150  def DynamicUpdateSlice(self, operand, update, start_indices):
1151    """Enqueues a dynamic update slice operation onto the computation.
1152
1153    Args:
1154      operand: XlaOp for the N dimensional array to be updated.
1155      update: N dimensional array comprising the slice update.
1156      start_indices: Rank-1 array of N integers comprising the starting indices
1157        of the slice along each dimension.
1158
1159    Returns:
1160      An XlaOp representing the added DynamicUpdateSlice op.
1161    """
1162    if isinstance(start_indices, _xla.XlaOp):
1163      ndims = self._builder.GetShape(start_indices).dimensions()[0]
1164      start_indices = [
1165          ops.Reshape(ops.Slice(start_indices, [i], [i + 1], [1]), [])
1166          for i in range(ndims)
1167      ]
1168    return ops.DynamicUpdateSlice(operand, update, list(start_indices))
1169
1170  def Tuple(self, *elems):
1171    """Enqueues a tuple operation onto the computation.
1172
1173    Args:
1174      elems: a sequence of tuple operands (each a XlaOp).
1175
1176    Returns:
1177      An XlaOp representing the added Tuple op.
1178    """
1179    return ops.Tuple(self._builder, list(elems))
1180
1181  def Call(self, computation_to_apply, operands):
1182    """Enqueues a call operation onto the computation.
1183
1184    Args:
1185      computation_to_apply: a Computation object.
1186      operands: an iterable of XlaOp. The number and types of operands must
1187        match the arity of computation_to_apply.
1188
1189    Returns:
1190      An XlaOp representing the added call op.
1191    """
1192    return ops.Call(self._builder, computation_to_apply.computation,
1193                    list(operands))
1194
1195  def CustomCall(self,
1196                 call_target_name,
1197                 operands,
1198                 shape_with_layout,
1199                 operand_shapes_with_layout,
1200                 opaque=None):
1201    """Enqueues a custom call operation onto the computation.
1202
1203    Args:
1204      call_target_name: the name of the function to call.
1205      operands: an iterable of XlaOp. The number and types of operands must
1206        match the arity of `operand_shapes_with_layout`.
1207      shape_with_layout: the shape of the operator's output, with layout.
1208      operand_shapes_with_layout: the shapes of `operands`, including the
1209        expected layouts.
1210      opaque: an opaque string passed to the backend.
1211
1212    Returns:
1213      An XlaOp representing the added custom call op.
1214    """
1215    opaque = opaque or b''
1216    return ops.CustomCall(self._builder, call_target_name,
1217                          list(operands), shape_with_layout,
1218                          list(operand_shapes_with_layout), opaque)
1219
1220  def Map(self, operands, computation_to_apply, dimensions):
1221    """Enqueues a map operation onto the computation.
1222
1223    Args:
1224      operands: an iterable of XlaOp.
1225      computation_to_apply: a Computation object.
1226      dimensions: dimensions over which to apply map the function.
1227
1228    Returns:
1229      An XlaOp representing the added Map op.
1230    """
1231    return ops.Map(self._builder, list(operands),
1232                   computation_to_apply.computation, dimensions, [])
1233
1234  def Reduce(self, operand, init_value, computation_to_apply, dimensions):
1235    """Enqueues a reduction operation onto the computation.
1236
1237    Args:
1238      operand: reduction operand (XlaOp).
1239      init_value: reduction initial value (XlaOp).
1240      computation_to_apply: a Computation object - binary reduction function.
1241      dimensions: sequence of dimensions (integers) to reduce on.
1242
1243    Returns:
1244      An XlaOp representing the added Reduce op.
1245    """
1246    return ops.Reduce(self._builder, [operand], [init_value],
1247                      computation_to_apply.computation, dimensions)
1248
1249  def ReduceWindow(self, operand, init_value, computation_to_apply,
1250                   window_dimensions, window_strides, padding):
1251    """Enqueues a windowed reduction operation onto the computation.
1252
1253    Args:
1254      operand: reduction operand (XlaOp).
1255      init_value: reduction initial value (XlaOp).
1256      computation_to_apply: a binary reduction function (Computation).
1257      window_dimensions: dimensions of window (sequence of integers).
1258      window_strides: strides for window (sequence of integers).
1259      padding: PaddingType representing either 'SAME' or 'VALID' padding.
1260
1261    Returns:
1262      An XlaOp representing the added ReduceWindow op.
1263    """
1264    pads = _convert_padding_type_to_pad_values(
1265        padding,
1266        self.GetShape(operand).dimensions(), window_dimensions, window_strides)
1267    return ops.ReduceWindowWithGeneralPadding(operand, init_value,
1268                                              computation_to_apply.computation,
1269                                              window_dimensions, window_strides,
1270                                              (), (), pads)
1271
1272  def ReduceWindowWithGeneralPadding(self, operand, init_value,
1273                                     computation_to_apply, window_dimensions,
1274                                     window_strides, base_dilations,
1275                                     window_dilations, padding):
1276    """Enqueues a windowed reduction operation onto the computation.
1277
1278    Args:
1279      operand: reduction operand (XlaOp).
1280      init_value: reduction initial value (XlaOp).
1281      computation_to_apply: a binary reduction function (Computation).
1282      window_dimensions: dimensions of window (sequence of integers).
1283      window_strides: strides for window (sequence of integers).
1284      base_dilations: dilations for the base (sequence of integers).
1285      window_dilations: dilations for window (sequence of integers).
1286      padding: length-N array-like of pairs of integers of (low, high) padding.
1287
1288    Returns:
1289      An XlaOp representing the added ReduceWindow op.
1290    """
1291    return ops.ReduceWindowWithGeneralPadding(operand, init_value,
1292                                              computation_to_apply.computation,
1293                                              window_dimensions, window_strides,
1294                                              base_dilations, window_dilations,
1295                                              padding)
1296
1297  def RngNormal(self, mu, sigma, dims):
1298    """Enqueues an RngNormal operation onto the computation.
1299
1300    Args:
1301      mu: An XlaOp to an F32 scalar specifying the mean.
1302      sigma: An XlaOp to an F32 scalar specifying the standard deviation.
1303      dims: A 1D array-like of nonnegative integers specifying the dimensions.
1304    Returns: a XlaOp to the generated array of F32 values.
1305    """
1306    shape = _xla.Shape.array_shape(self.GetShape(mu).xla_element_type(), dims)
1307    return ops.RngNormal(mu, sigma, shape)
1308
1309  def RngUniform(self, a, b, dims):
1310    """Enqueues an RngUniform operation onto the computation.
1311
1312    Args:
1313      a: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of b)
1314        specifying the low end of the interval [a, b) over which values are
1315        generated.
1316      b: a XlaOp to an F32, S32, or U32 scalar (consistent with the type of a)
1317        specifying the high end of the interval [a, b) over which values are
1318        generated.
1319      dims: A 1D array-like of nonnegative integers specifying the dimensions.
1320    Returns: a XlaOp to the generated array of values with the same numeric type
1321      (F32, S32, or U32) as the arguments a and b.
1322    """
1323    shape = _xla.Shape.array_shape(self.GetShape(a).xla_element_type(), dims)
1324    return ops.RngUniform(a, b, shape)
1325
1326  def While(self, cond, body, init):
1327    """Enqueues a While operation onto the computation.
1328
1329    Args:
1330      cond: a Computation for the loop condition, which has type T -> PRED
1331      body: a Computation for the loop body, which has type T -> T
1332      init: a XlaOp for the initial parameter, which has type T
1333    Returns: a XlaOp representing the While operation.
1334    """
1335    return ops.While(cond.computation, body.computation, init)
1336
1337  def Conditional(self, pred, true_operand, true_computation, false_operand,
1338                  false_computation):
1339    """Enqueues a Conditional operation onto the computation.
1340
1341    Args:
1342      predicate: a XlaOp to test, which has scalar type PRED
1343      true_operand: a XlaOp of type T_0
1344      true_computation: a Computation to apply to true_operand, type T_0 -> S
1345      false_operand: a ComputationDatahandle of type T_1
1346      false_computation: a Computation to apply to false_operand, type T_1 -> S
1347    Returns: a XlaOp representing the Conditional operation.
1348    """
1349    return ops.Conditional(pred, true_operand, true_computation.computation,
1350                           false_operand, false_computation.computation)
1351
1352  def IsConstant(self, operand):
1353    """Checks whether the given operand is a compile-time constant.
1354
1355    Args:
1356      operand: a ComputationDataHandle to test.
1357    Returns: bool indicating whether `operand` is a compile-time constant,
1358      meaning its value does not depend on any parametersor, or on stateful
1359      operators such as `RngNormal` or `Infeed`.
1360    """
1361    return self._builder.IsConstant(operand)
1362
1363  def BuildConstantSubGraph(self, operand):
1364    """Builds a constant sub graph.
1365
1366    Args:
1367      operand: a XlaOp to test.
1368    Returns: a Computation that is rooted on the given `operand` which is a
1369      compile-time constant.
1370    """
1371    return ops.BuildConstantSubGraph(operand)
1372
1373  def DotGeneral(self, lhs, rhs, dimension_numbers, precision_config=None):
1374    """Enqueues a general dot operation onto the computation.
1375
1376    Args:
1377      lhs: XlaOp for the left-hand-side array.
1378      rhs: XlaOp for the right-hand-side array.
1379      dimension_numbers: either a DotDimensionNumbers or a nested tuple
1380        ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of
1381        integers representing the dimensions to treat as contracting dimensions
1382        and batch dimensions on each input operand.
1383    Returns: a XlaOp representing the DotGeneral operation.
1384    """
1385    if isinstance(dimension_numbers, tuple):
1386      dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
1387    return ops.DotGeneral(
1388        lhs, rhs, dimension_numbers, precision_config=precision_config)
1389
1390  def Conv(self,
1391           lhs,
1392           rhs,
1393           window_strides,
1394           padding,
1395           feature_group_count=1,
1396           batch_group_count=1,
1397           precision_config=None):
1398    """Enqueues a Conv operation onto the computation.
1399
1400    Args:
1401      lhs: XlaOp for the rank N+2 array of inputs.
1402      rhs: XlaOp for the rank N+2 array of kernel weights.
1403      window_strides: length-N array-like of integer kernel strides.
1404      padding: PaddingType representing either 'SAME' or 'VALID' padding.
1405      feature_group_count: number of feature groups for grouped convolution.
1406      batch_group_count: number of batch groups for grouped convolution.
1407    Returns: a XlaOp representing the Conv operation.
1408    """
1409    pads = _convert_padding_type_to_pad_values(
1410        padding,
1411        self.GetShape(lhs).dimensions()[2:],
1412        self.GetShape(rhs).dimensions()[2:], window_strides)
1413    return self.ConvGeneralDilated(
1414        lhs,
1415        rhs,
1416        window_strides,
1417        pads, [], [],
1418        dimension_numbers=None,
1419        feature_group_count=feature_group_count,
1420        batch_group_count=batch_group_count,
1421        precision_config=precision_config)
1422
1423  def ConvWithGeneralPadding(self,
1424                             lhs,
1425                             rhs,
1426                             window_strides,
1427                             padding,
1428                             lhs_dilation,
1429                             rhs_dilation,
1430                             feature_group_count=1,
1431                             batch_group_count=1,
1432                             precision_config=None):
1433    """Enqueues a ConvWithGeneralPadding operation onto the computation.
1434
1435    Args:
1436      lhs: XlaOp for the rank N+2 array of inputs.
1437      rhs: XlaOp for the rank N+2 array of kernel weights.
1438      window_strides: length-N array-like of kernel strides.
1439      padding: length-N array-like of pairs of integers of (low, high) padding.
1440      lhs_dilation: length-N array-like of dilation factors.
1441      rhs_dilation: length-N array-like of dilation factors.
1442      feature_group_count: number of feature groups for grouped convolution.
1443      batch_group_count: number of batch groups for grouped convolution.
1444
1445    Returns:
1446      A ComputationdataHandle representing the added ConvWithGeneralPadding op.
1447    """
1448    return self.ConvGeneralDilated(
1449        lhs,
1450        rhs,
1451        list(window_strides),
1452        list(padding),
1453        list(lhs_dilation),
1454        list(rhs_dilation),
1455        dimension_numbers=None,
1456        feature_group_count=feature_group_count,
1457        batch_group_count=batch_group_count,
1458        precision_config=precision_config)
1459
1460  def _GetConvDimensionNumbers(self, num_spatial_dims):
1461    """Create ConvolutionDimensionNumbers proto for convolutions."""
1462    nd = num_spatial_dims
1463    dimension_numbers = ConvolutionDimensionNumbers()
1464    dimension_numbers.input_batch_dimension = 0
1465    dimension_numbers.input_feature_dimension = 1
1466    dimension_numbers.output_batch_dimension = 0
1467    dimension_numbers.output_feature_dimension = 1
1468    dimension_numbers.kernel_output_feature_dimension = 0
1469    dimension_numbers.kernel_input_feature_dimension = 1
1470    dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
1471    dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
1472    dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
1473    return dimension_numbers
1474
1475  def ConvGeneralDilated(self,
1476                         lhs,
1477                         rhs,
1478                         window_strides,
1479                         padding,
1480                         lhs_dilation,
1481                         rhs_dilation,
1482                         dimension_numbers=None,
1483                         feature_group_count=1,
1484                         batch_group_count=1,
1485                         precision_config=None):
1486    """Enqueues a ConvGeneralDilated operation onto the computation.
1487
1488    Args:
1489      lhs: XlaOp for the rank N+2 array of inputs.
1490      rhs: XlaOp for the rank N+2 array of kernel weights.
1491      window_strides: length-N array-like of integer kernel strides.
1492      padding: length-N array-like of pairs of integers of (low, high) padding.
1493      lhs_dilation: length-N array-like of integer dilation factors.
1494      rhs_dilation: length-N array-like of integer dilation factors.
1495      dimension_numbers: optional, either a ConvolutionDimensionNumbers object
1496        or a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of
1497        length N+2 identifying by position: (1) batch dimensions in lhs, rhs,
1498          and the output with the character 'N', (2) feature dimensions in lhs
1499          and the output with the character 'C', (3) input and output feature
1500          dimensions in rhs with the characters 'I' and 'O' respectively, and
1501          (4) spatial dimension correspondences between lhs, rhs, and the output
1502          using any distinct characters. For example, to indicate dimension
1503          numbers consistent with the Conv operation with two spatial
1504          dimensions, one could use ('NCHW', 'OIHW', 'NCHW'). As another
1505          example, to indicate dimension numbers consistent with the TensorFlow
1506          Conv2D operation, one could use ('NHWC', 'HWIO', 'NHWC'). When using
1507          the latter form of convolution dimension specification, window strides
1508          are associated with spatial dimension character labels according to
1509          the order in which the labels appear in the rhs_spec string, so that
1510          window_strides[0] is matched with the dimension corresponding to the
1511          first character appearing in rhs_spec that is not 'I' or 'O'. By
1512          default, use the same dimension numbering as Conv and
1513          ConvWithGeneralPadding.
1514      feature_group_count: number of feature groups for grouped convolution.
1515      batch_group_count: number of batch groups for grouped convolution.
1516    Returns: a XlaOp representing the ConvGeneralDilated operation.
1517    """
1518    if dimension_numbers is None:
1519      dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
1520    elif isinstance(dimension_numbers, tuple):
1521      lhs_spec, rhs_spec, out_spec = dimension_numbers
1522      dimension_numbers = ConvolutionDimensionNumbers()
1523
1524      dimension_numbers.input_batch_dimension = lhs_spec.index('N')
1525      dimension_numbers.input_feature_dimension = lhs_spec.index('C')
1526      dimension_numbers.output_batch_dimension = out_spec.index('N')
1527      dimension_numbers.output_feature_dimension = out_spec.index('C')
1528      dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O')
1529      dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I')
1530
1531      dimension_numbers.kernel_spatial_dimensions.extend(
1532          i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'})
1533      dimension_numbers.input_spatial_dimensions.extend(
1534          sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}),
1535                 key=lambda i: rhs_spec.index(lhs_spec[i])))
1536      dimension_numbers.output_spatial_dimensions.extend(
1537          sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
1538                 key=lambda i: rhs_spec.index(out_spec[i])))
1539    return ops.ConvGeneralDilated(
1540        lhs,
1541        rhs,
1542        window_strides,
1543        padding,
1544        lhs_dilation,
1545        rhs_dilation,
1546        dimension_numbers,
1547        feature_group_count,
1548        batch_group_count,
1549        precision_config=precision_config)
1550
1551  def Sort(self, operands, dimension=-1, comparator=None):
1552    """Enqueues a sort operation onto the computation.
1553
1554    Args:
1555      operands: either an XlaOp or a sequence of XlaOps to sort. All operands
1556        must be arrays with the same dimensions.
1557      dimension: the array dimension over which to sort.
1558      comparator: a comparator XlaComputation. See the XLA operation semantics
1559        for details.
1560
1561    Returns:
1562      Either an XlaOp or a tuple of XlaOps (if `operands` was an XlaOp or
1563      a tuple of XlaOps, respectively.)
1564    """
1565    operands = (
1566        list(operands)
1567        if isinstance(operands, collections.Sequence) else [operands])
1568    return ops.Sort(self._builder, operands, dimension,
1569                    comparator.computation if comparator else None)
1570
1571  def SortKeyVal(self, keys, values, dimension=-1):
1572    """Enqueues a key-value sort operation onto the computation.
1573
1574    Deprecated. Use `Sort` instead.
1575    """
1576    return ops.Sort(self._builder, [keys, values], dimension)
1577
1578  def QR(self, a, full_matrices=True):
1579    """Enqueues a QR decomposition onto the computation."""
1580    return self.Tuple(*ops.QR(a, full_matrices))
1581
1582  def TriangularSolve(self,
1583                      a,
1584                      b,
1585                      left_side=False,
1586                      lower=False,
1587                      transpose_a=False,
1588                      conjugate_a=False,
1589                      unit_diagonal=False):
1590    """Enqueues a triangular-solve operation onto the computation."""
1591    if not transpose_a:
1592      transpose = _xla.TriangularSolveOptions_Transpose.NO_TRANSPOSE
1593      if conjugate_a:
1594        a = self.Conj(a)
1595    else:
1596      transpose = (
1597          _xla.TriangularSolveOptions_Transpose.ADJOINT
1598          if conjugate_a else _xla.TriangularSolveOptions_Transpose.TRANSPOSE)
1599    return ops.TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose)
1600
1601  def Eigh(self, a, full_matrices=True):
1602    """Enqueues a symmetric/Hermitian eigendecomposition."""
1603    return self.Tuple(*ops.Eigh(a, full_matrices))
1604
1605  def SVD(self, a):
1606    """Enqueues a singular value decomposition."""
1607    return self.Tuple(*ops.SVD(a))
1608
1609  def Gather(self,
1610             a,
1611             start_indices,
1612             dimension_numbers,
1613             slice_sizes,
1614             indices_are_sorted=False):
1615    """Enqueues a Gather operation onto the computation."""
1616    return ops.Gather(a, start_indices, dimension_numbers, slice_sizes,
1617                      indices_are_sorted)
1618
1619  def Scatter(self,
1620              a,
1621              scatter_indices,
1622              updates,
1623              update_computation,
1624              dimension_numbers,
1625              indices_are_sorted=False,
1626              unique_indices=False):
1627    """Enqueues a Scatter operation onto the computation."""
1628    return ops.Scatter(a, scatter_indices, updates,
1629                       update_computation.computation, dimension_numbers,
1630                       indices_are_sorted, unique_indices)
1631
1632  def Fft(self, operand, fft_type, fft_lengths):
1633    """Enqueues a FFT operation onto the computation."""
1634    return ops.Fft(operand, fft_type, fft_lengths)
1635
1636
1637FftType = _xla.FftType
1638
1639_UNARY_OPS = [
1640    'Not',
1641    'Clz',
1642    'Abs',
1643    'Exp',
1644    'Expm1',
1645    'Floor',
1646    'Round',
1647    'Ceil',
1648    'Log',
1649    'Log1p',
1650    'Sign',
1651    'Cos',
1652    'Sin',
1653    'Tanh',
1654    'IsFinite',
1655    'Sqrt',
1656    'Rsqrt',
1657    'Square',
1658    'Reciprocal',
1659    'Neg',
1660    'Erf',
1661    'Erfc',
1662    'ErfInv',
1663    'Lgamma',
1664    'Digamma',
1665    'BesselI0e',
1666    'BesselI1e',
1667    'Acos',
1668    'Asin',
1669    'Atan',
1670    'Tan',
1671    'Acosh',
1672    'Asinh',
1673    'Atanh',
1674    'Cosh',
1675    'Sinh',
1676    'Real',
1677    'Imag',
1678    'Conj',
1679]
1680
1681_BINARY_OPS = [
1682    'Eq',
1683    'Ne',
1684    'Ge',
1685    'Gt',
1686    'Lt',
1687    'Le',
1688    'Add',
1689    'Sub',
1690    'Mul',
1691    'Div',
1692    'Rem',
1693    'Max',
1694    'Min',
1695    'And',
1696    'Or',
1697    'Xor',
1698    'Pow',
1699    'ShiftLeft',
1700    'ShiftRightArithmetic',
1701    'ShiftRightLogical',
1702    'Atan2',
1703    'Igamma',
1704    'Igammac',
1705    'Complex',
1706    'NextAfter',
1707]
1708
1709_OTHER_OPS = [
1710    'BitcastConvertType',
1711    'Broadcast',
1712    'BroadcastInDim',
1713    'Cholesky',
1714    'Clamp',
1715    'Collapse',
1716    'CollectivePermute',
1717    'ConvertElementType',
1718    'Dot',
1719    'GetTupleElement',
1720    'ReducePrecision',
1721    'RegularizedIncompleteBeta',
1722    'Rev',
1723    'Select',
1724    'SliceInDim',
1725]
1726
1727
1728def _forward_methods_to_local_builder():
1729  """Forward remaining ComputationBuilder methods to the C API.
1730
1731  Set up methods, corresponding to XLA operations,
1732  whose calls are forwarded in a boilerplate manner to the underlying
1733  _xla.ops API.
1734  """
1735
1736  def forward_op(target_method):
1737
1738    def forward(builder, *args, **kwargs):
1739      del builder
1740      return target_method(*args, **kwargs)
1741
1742    return forward
1743
1744  for method_name in itertools.chain(_UNARY_OPS, _BINARY_OPS, _OTHER_OPS):
1745    forward = forward_op(getattr(ops, method_name))
1746    forward.__name__ = method_name
1747    setattr(ComputationBuilder, method_name, forward)
1748
1749
1750_forward_methods_to_local_builder()
1751
1752
1753def register_custom_call_target(name, fn, platform='cpu'):
1754  """Registers a custom call target.
1755
1756  Args:
1757    name: bytes containing the name of the function.
1758    fn: a PyCapsule object containing the function pointer.
1759    platform: the target platform.
1760  """
1761  _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform])
1762
1763# Deprecated. Use register_custom_call_target instead.
1764register_cpu_custom_call_target = register_custom_call_target
1765
1766
1767class PaddingConfigDimension(object):
1768  """Python representation of a xla.PaddingConfigDimension protobuf."""
1769  __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding')
1770
1771  def __init__(self):
1772    self.edge_padding_low = 0
1773    self.edge_padding_high = 0
1774    self.interior_padding = 0
1775
1776
1777class PaddingConfig(object):
1778  """Python representation of a xla.PaddingConfig protobuf."""
1779  __slots__ = ('dimensions',)
1780
1781  def __init__(self):
1782    self.dimensions = []
1783
1784
1785def GetPaddingConfigFromTriples(triples):
1786  """Create PaddingConfig proto from list of triples of integers."""
1787  padding_config = PaddingConfig()
1788  for lo, hi, interior in triples:
1789    dimension = PaddingConfigDimension()
1790    dimension.edge_padding_low = lo
1791    dimension.edge_padding_high = hi
1792    dimension.interior_padding = interior
1793    padding_config.dimensions.append(dimension)
1794  return padding_config
1795
1796
1797class DotDimensionNumbers(object):
1798  """Python representation of a xla.DotDimensionNumbers protobuf."""
1799  __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions',
1800               'lhs_batch_dimensions', 'rhs_batch_dimensions')
1801
1802  def __init__(self):
1803    self.lhs_contracting_dimensions = []
1804    self.rhs_contracting_dimensions = []
1805    self.lhs_batch_dimensions = []
1806    self.rhs_batch_dimensions = []
1807
1808
1809def GetDotDimensionsFromLists(dimension_numbers):
1810  (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
1811  dot_dims_proto = DotDimensionNumbers()
1812  dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
1813  dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
1814  dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
1815  dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
1816  return dot_dims_proto
1817
1818
1819class ConvolutionDimensionNumbers(object):
1820  """Python representation of a xla.ConvolutionDimensionNumbers protobuf."""
1821  __slots__ = ('input_batch_dimension', 'input_feature_dimension',
1822               'input_spatial_dimensions', 'kernel_input_feature_dimension',
1823               'kernel_output_feature_dimension', 'kernel_spatial_dimensions',
1824               'output_batch_dimension', 'output_feature_dimension',
1825               'output_spatial_dimensions')
1826
1827  def __init__(self):
1828    self.input_batch_dimension = 0
1829    self.input_feature_dimension = 0
1830    self.input_spatial_dimensions = []
1831    self.kernel_input_feature_dimension = 0
1832    self.kernel_output_feature_dimension = 0
1833    self.kernel_spatial_dimensions = []
1834    self.output_batch_dimension = 0
1835    self.output_feature_dimension = 0
1836    self.output_spatial_dimensions = []
1837
1838
1839class OpSharding(object):
1840  """Python representation of a xla.OpSharding protobuf."""
1841  __slots__ = ('type', 'tile_assignment_dimensions', 'tile_assignment_devices',
1842               'tuple_shardings')
1843
1844  Type = _xla.OpSharding_Type
1845
1846  def __init__(self):
1847    self.type = self.Type.REPLICATED
1848    self.tile_assignment_dimensions = []
1849    self.tile_assignment_devices = []
1850    self.tuple_shardings = []
1851
1852
1853class PrecisionConfig(object):
1854  """Python representation of a xla.PrecisionConfig protobuf."""
1855  __slots__ = ('operand_precision',)
1856
1857  Precision = _xla.PrecisionConfig_Precision
1858
1859  def __init__(self):
1860    self.operand_precision = []
1861
1862
1863class GatherDimensionNumbers(object):
1864  """Python representation of a xla.GatherDimensionNumbers protobuf."""
1865  __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map',
1866               'index_vector_dim')
1867
1868  def __init__(self):
1869    self.offset_dims = []
1870    self.collapsed_slice_dims = []
1871    self.start_index_map = []
1872    self.index_vector_dim = 0
1873
1874
1875class ScatterDimensionNumbers(object):
1876  """Python representation of a xla.ScatterDimensionNumbers protobuf."""
1877  __slots__ = ('update_window_dims', 'inserted_window_dims',
1878               'scatter_dims_to_operand_dims', 'index_vector_dim')
1879
1880  def __init__(self):
1881    self.update_window_dims = []
1882    self.inserted_window_dims = []
1883    self.scatter_dims_to_operand_dims = []
1884    self.index_vector_dim = 0
1885
1886
1887class ReplicaGroup(object):
1888  """Python representation of a xla.ReplicaGroup protobuf."""
1889  __slots__ = ('replica_ids',)
1890
1891  def __init__(self):
1892    self.replica_ids = []
1893
1894
1895def _make_replica_group_proto(replica_group):
1896  replica_group_proto = ReplicaGroup()
1897  replica_group_proto.replica_ids.extend(replica_group)
1898  return replica_group_proto
1899
1900
1901def _get_replica_groups_protos(replica_groups):
1902  if replica_groups is None:
1903    replica_groups_protos = []  # special value for XLA API
1904  else:
1905    replica_groups = list(replica_groups)
1906    replica_groups_protos = [
1907        _make_replica_group_proto(group) for group in replica_groups
1908    ]
1909  return replica_groups_protos
1910