• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""An XLA client in Python."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import atexit
23import collections
24import contextlib
25import enum  # pylint: disable=g-bad-import-order
26import gzip
27import inspect
28import os
29from typing import List, Sequence, Tuple, Union
30
31from . import xla_extension as _xla
32
33from absl import logging
34import numpy as np
35
36# Note this module does *not* depend on any Python protocol buffers. The XLA
37# Python bindings are currently packaged both as part of jaxlib and as part
38# of TensorFlow. If we use protocol buffers here, then importing both jaxlib
39# and TensorFlow may fail with duplicate protocol buffer message definitions.
40
41# Most functions are snake_case for consistency with other modules, some
42# method names are CamelCase for consistency with XLA.
43# pylint: disable=invalid-name
44
45# Pylint has false positives for type annotations.
46# pylint: disable=invalid-sequence-index
47
48ops = _xla.ops
49profiler = _xla.profiler
50
51# Just an internal arbitrary increasing number to help with backward-compatible
52# changes.
53_version = 3
54
55xla_platform_names = {
56    'cpu': 'Host',
57    'gpu': 'CUDA',
58}
59
60
61def _interpreter_backend_factory():
62  return _xla.get_interpreter_client()
63
64
65def _cpu_backend_factory():
66  return _xla.get_cpu_client(asynchronous=True)
67
68
69def _gpu_backend_factory(distributed_client=None, node_id=0):
70  """Returns a GPU backend. BFC allocator is used by default."""
71  allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
72  memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION')
73  preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE')
74  if allocator not in ('default', 'platform', 'bfc'):
75    raise ValueError(
76        'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", or '
77        '"bfc", got "%s"' % allocator)
78  config = _xla.GpuAllocatorConfig()
79  if allocator == 'default':
80    config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT
81  if allocator == 'platform':
82    config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM
83  if allocator == 'bfc':
84    config.kind = _xla.GpuAllocatorConfig.Kind.BFC
85  if memory_fraction:
86    config.memory_fraction = float(memory_fraction)
87  config.preallocate = preallocate not in ('0', 'false', 'False')
88
89  return _xla.get_gpu_client(
90      asynchronous=True,
91      allocator_config=config,
92      distributed_client=distributed_client,
93      node_id=node_id)
94
95
96def _tpu_backend_factory():
97  return _xla.get_tpu_client(asynchronous=True)
98
99
100# Backend factories, keyed by user-visible name, in increasing priority order.
101_local_backend_factories = collections.OrderedDict([
102    ('interpreter', _interpreter_backend_factory),
103    ('cpu', _cpu_backend_factory),
104    ('gpu', _gpu_backend_factory),
105    ('tpu', _tpu_backend_factory),
106])
107
108
109def register_local_backend_factory(name, factory):
110  _local_backend_factories[name] = factory
111
112
113_local_backends = None
114
115
116def _get_local_backends():
117  """Instantiates all known local backends."""
118  global _local_backends
119  if _local_backends is not None:
120    return _local_backends
121
122  _local_backends = collections.OrderedDict()
123  for name, factory in _local_backend_factories.items():
124    logging.vlog(1, "Initializing backend '%s'" % name)
125    try:
126      backend = factory()
127    except RuntimeError as err:
128      if name == 'cpu':
129        # We always expect CPU to initialize successfully.
130        raise
131      else:
132        # If the backend isn't built into the binary, or if it has no devices,
133        # we expect a RuntimeError.
134        logging.vlog(1, "Error initializing backend '%s': %s" % (name, err))
135        continue
136    _local_backends[name] = backend
137  return _local_backends
138
139
140def get_local_backend(name=None):
141  """Returns a local backend.
142
143  Args:
144    name: the backend name. If `None`, a default local backend is returned,
145      typically `gpu` if one is present, or `cpu` if not. If a string, the named
146      backend is returned or an exception raised.
147
148  Returns:
149    A LocalBackend object.
150  """
151  backends = _get_local_backends()
152  if name is not None:
153    try:
154      return backends[name]
155    except KeyError:
156      raise RuntimeError(
157          'Unknown backend %s. Available: %s' % (name, list(backends.keys())))
158
159  return list(backends.values())[-1]
160
161
162class OpMetadata(object):
163  """Python representation of a xla.OpMetadata protobuf."""
164  __slots__ = ('op_type', 'op_name', 'source_file', 'source_line')
165
166  def __init__(self, op_type='', op_name='', source_file='', source_line=0):
167    self.op_type = op_type
168    self.op_name = op_name
169    self.source_file = source_file
170    self.source_line = source_line
171
172
173def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
174  """Helper for use in source mapping that returns an OpMetadata object."""
175  full_filename, lineno = inspect.stack()[skip_frames][1:3]
176  filename = os.path.basename(full_filename)
177  return OpMetadata(
178      op_type=op_type,
179      op_name=op_name,
180      source_file=filename,
181      source_line=lineno)
182
183
184PrimitiveType = _xla.PrimitiveType
185
186bfloat16 = _xla.bfloat16_dtype()
187
188XLA_ELEMENT_TYPE_TO_DTYPE = {
189    PrimitiveType.PRED: np.dtype('bool'),
190    PrimitiveType.S8: np.dtype('int8'),
191    PrimitiveType.S16: np.dtype('int16'),
192    PrimitiveType.S32: np.dtype('int32'),
193    PrimitiveType.S64: np.dtype('int64'),
194    PrimitiveType.U8: np.dtype('uint8'),
195    PrimitiveType.U16: np.dtype('uint16'),
196    PrimitiveType.U32: np.dtype('uint32'),
197    PrimitiveType.U64: np.dtype('uint64'),
198    PrimitiveType.BF16: np.dtype(bfloat16),
199    PrimitiveType.F16: np.dtype('float16'),
200    PrimitiveType.F32: np.dtype('float32'),
201    PrimitiveType.F64: np.dtype('float64'),
202    PrimitiveType.C64: np.dtype('complex64'),
203    PrimitiveType.C128: np.dtype('complex128'),
204    PrimitiveType.TUPLE: np.dtype(np.object_),
205    PrimitiveType.TOKEN: np.dtype(np.object_),
206}
207
208# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
209# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
210# when keying by dtype in this dict, we use the string form of dtypes.
211DTYPE_TO_XLA_ELEMENT_TYPE = {
212    str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()
213}
214
215
216def dtype_to_etype(dtype):
217  """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE."""
218  return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
219
220
221Shape = _xla.Shape
222Shape.__doc__ = """
223A Shape is an object defined in C++ that duck types like the following class:
224
225class Shape(object):
226  '''Represents an XLA shape.
227
228  A shape is either an array shape, having rank-many integer
229  dimensions and an element type (represented by a Numpy dtype), or it
230  is a tuple shape, having a shape for every tuple component:
231
232    type shape =
233        TupleShape of shape list
234      | ArrayShape of { dimensions: int list; element_type: dtype }
235  '''
236
237  @staticmethod
238  def tuple_shape(tuple_shapes) -> Shape:
239    "Construct a tuple shape."
240
241  @staticmethod
242  def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:
243
244  @staticmethod
245  def from_pyval(pyval) -> Shape:
246    "Returns a Shape that describes a tuple-tree of Numpy arrays."
247
248  def __init__(self, str) -> Shape:
249    "Parses a shape string."
250  def __eq__(self, other: Shape) -> bool:
251  def __ne__(self, other: Shape) -> bool:
252  def __hash__(self):
253  def __repr__(self):
254  def is_tuple(self) -> bool:
255  def is_array(self) -> bool:
256  def tuple_shapes(self) -> [Shape]:
257  def numpy_dtype(self) -> np.dtype:
258    "Like element_type(), but returns dtype('O') for a tuple shape."
259  def xla_element_type(self) -> PrimitiveType:
260  def element_type(self) -> np.dtype:
261  def dimensions(self) -> (int, int, ...):
262  def rank(self) -> int:
263  def with_major_to_minor_layout_if_absent(self) -> Shape:
264    "Returns a copy with missing layouts set to major-to-minor."
265
266  def to_serialized_proto(self) -> bytes:
267    "Returns 'shape' as a serialized proto."
268"""
269
270ProgramShape = _xla.ProgramShape
271ProgramShape.__doc__ = """
272A ProgramShape is a C++ object that duck types like the following class.
273
274class ProgramShape(object):
275  def __init__(self, parameter_shapes, result_shape):
276  def parameter_shapes(self) -> [Shape]:
277  def result_shape(self) -> Shape:
278  def __repr__(self):
279"""
280
281
282def shape_from_pyval(pyval):
283  """Returns a Shape that describes a tuple-tree of Numpy arrays."""
284
285  def convert(pyval):
286    if isinstance(pyval, tuple):
287      return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
288    else:
289      return Shape.array_shape(pyval.dtype, np.shape(pyval))
290
291  return convert(pyval)
292
293
294DeviceAssignment = _xla.DeviceAssignment
295DeviceAssignment.__doc__ = """
296A DeviceAssignment is a C++ object with the following signature.
297
298def create(assignment):
299  '''Builds a device assignment.
300
301   Args:
302     assignment: a 2D numpy array of device ordinal integers, indexed by
303       [replica][computation_in_replica].
304   Returns:
305     A device assignment.
306  '''
307
308def replica_count():
309  '''Returns the number of replicas.'''
310def computation_count():
311  '''Returns the number of computations per replica.'''
312"""
313
314Device = _xla.Device
315CompileOptions = _xla.CompileOptions
316
317HostBufferSemantics = _xla.HostBufferSemantics
318
319# An Executable is a C++ class that duck types with the following API:
320# class Executable(object):
321#   def local_devices(self) -> [Device]:
322#   def execute(self, arguments : [Buffer]) -> Buffer:
323#     """Execute on one replica with Buffer arguments and return value."""
324#
325#   def size_of_generated_code_in_bytes(self) -> int:
326#     """Return generated binary size, or -1 if not known."""
327#
328#   def execute_on_local_devices(self, arguments: [[Buffer]]) -> [Buffer]:
329#     """Execute on many replicas with Buffer arguments and return value.
330#
331#     Args:
332#       arguments: A sequence of sequences of Buffers. The i'th inner sequence
333#         comprises the arguments for execution on the i'th local device.
334#
335#     Returns:
336#       A list of the computation's outputs for each local device, as a Buffer.
337#       If a shallow sequence of arguments was passed in for `arguments`, then
338#       the sole, zero'th device's output is returned instead, as a Buffer.
339#     """
340#
341# There are different implementations of Executable for different backends.
342
343
344def execute_with_python_values(executable, arguments, backend):
345  """Execute on one replica with Python values as arguments and output."""
346
347  def put(arg):
348    return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
349
350  arguments = [put(arg) for arg in arguments]
351  outputs = executable.execute(arguments)
352  return [x.to_py() for x in outputs]
353
354
355def execute_with_python_values_replicated(executable, arguments, backend):
356  """Execute on many replicas with Python values as arguments and output.
357
358  Args:
359    executable: the program to run.
360    arguments: a list of lists of Python values indexed by `[replica][arg_num]`
361      to pass as inputs.
362    backend: the backend we are targeting.
363
364  Returns:
365    A list of python values, one per replica.
366  """
367  devices = executable.local_devices()
368  # pylint: disable=g-complex-comprehension
369  flat_args = [(arg, devices[replica])
370               for replica, replica_args in enumerate(arguments)
371               for arg in replica_args]
372  flat_arg_buffers = [
373      backend.buffer_from_pyval(pyval, device) for pyval, device in flat_args
374  ]
375  arg_buffers = []
376  for replica_args in arguments:
377    arg_buffers.append(flat_arg_buffers[:len(replica_args)])
378    flat_arg_buffers = flat_arg_buffers[len(replica_args):]
379  return [[x.to_py()
380           for x in xs]
381          for xs in executable.execute_on_local_devices(arg_buffers)]
382
383
384class PaddingType(enum.Enum):
385  VALID = 1
386  SAME = 2
387
388
389def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
390                                      window_strides):
391  """Maps PaddingType or string to pad values (list of pairs of ints)."""
392  if not isinstance(padding_type, (str, PaddingType)):
393    msg = 'padding_type must be str or PaddingType, got {}.'
394    raise TypeError(msg.format(type(padding_type)))
395
396  if isinstance(padding_type, str):
397    if padding_type.upper() == 'VALID':
398      padding_type = PaddingType.VALID
399    elif padding_type.upper() == 'SAME':
400      padding_type = PaddingType.SAME
401    else:
402      msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
403      raise ValueError(msg.format(padding_type))
404
405  if padding_type == PaddingType.VALID:
406    return [(0, 0)] * len(window_strides)
407  elif padding_type == PaddingType.SAME:
408    out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
409    pad_sizes = [
410        max((out_size - 1) * stride + filter_size - in_size, 0)
411        for out_size, stride, filter_size, in_size in zip(
412            out_shape, window_strides, rhs_dims, lhs_dims)
413    ]
414    return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
415  else:
416    msg = 'Unexpected PaddingType value: {}'
417    raise ValueError(msg.format(padding_type))
418
419
420XlaBuilder = _xla.XlaBuilder
421XlaComputation = _xla.XlaComputation
422FftType = _xla.FftType
423Client = _xla.Client
424Buffer = _xla.Buffer
425DeviceArrayBase = _xla.DeviceArrayBase
426Executable = _xla.Executable
427
428
429def register_custom_call_target(name, fn, platform='cpu'):
430  """Registers a custom call target.
431
432  Args:
433    name: bytes containing the name of the function.
434    fn: a PyCapsule object containing the function pointer.
435    platform: the target platform.
436  """
437  # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
438  # Since that is hardcoded to CUDA, we are using the following as workaround.
439  _xla.register_custom_call_target(name, fn,
440                                   xla_platform_names.get(platform, platform))
441
442
443# Deprecated. Use register_custom_call_target instead.
444register_cpu_custom_call_target = register_custom_call_target
445
446
447class PaddingConfigDimension(object):
448  """Python representation of a xla.PaddingConfigDimension protobuf."""
449  __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding')
450
451  def __init__(self):
452    self.edge_padding_low = 0
453    self.edge_padding_high = 0
454    self.interior_padding = 0
455
456
457class PaddingConfig(object):
458  """Python representation of a xla.PaddingConfig protobuf."""
459  __slots__ = ('dimensions',)
460
461  def __init__(self):
462    self.dimensions = []
463
464
465def make_padding_config(
466    padding_config: Union[PaddingConfig, Sequence[Tuple[int, int, int]]]
467) -> PaddingConfig:
468  """Create PaddingConfig proto from list of triples of integers.
469
470  Args:
471    padding_config: either a PaddingConfig or a list of integer triples
472      (edge_padding_low, edge_padding_high, interior_padding) representing the
473      configuration of the padding operation.
474
475  Returns:
476    A `PaddingConfig` object.
477  """
478  if isinstance(padding_config, tuple) or isinstance(padding_config, list):
479    triples = padding_config
480    padding_config = PaddingConfig()
481    for lo, hi, interior in triples:
482      dimension = PaddingConfigDimension()
483      dimension.edge_padding_low = lo
484      dimension.edge_padding_high = hi
485      dimension.interior_padding = interior
486      padding_config.dimensions.append(dimension)
487  return padding_config
488
489
490class DotDimensionNumbers(object):
491  """Python representation of a xla.DotDimensionNumbers protobuf."""
492  __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions',
493               'lhs_batch_dimensions', 'rhs_batch_dimensions')
494
495  def __init__(self):
496    self.lhs_contracting_dimensions = []
497    self.rhs_contracting_dimensions = []
498    self.lhs_batch_dimensions = []
499    self.rhs_batch_dimensions = []
500
501
502def make_dot_dimension_numbers(
503    dimension_numbers: Union[DotDimensionNumbers,
504                             Tuple[Tuple[List[int], List[int]],
505                                   Tuple[List[int], List[int]]]]
506) -> DotDimensionNumbers:
507  """Builds a DotDimensionNumbers object from a specification.
508
509  Args:
510    dimension_numbers: either a `DotDimensionNumbers` or a nested tuple
511      `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of
512      integers representing the dimensions to treat as contracting dimensions
513      and batch dimensions on each input operand.
514
515  Returns:
516    A `DotDimensionNumbers` object.
517  """
518  if isinstance(dimension_numbers, (list, tuple)):
519    (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
520    dot_dims_proto = DotDimensionNumbers()
521    dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
522    dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
523    dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
524    dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
525    return dot_dims_proto
526  else:
527    return dimension_numbers
528
529
530class ConvolutionDimensionNumbers(object):
531  """Python representation of a xla.ConvolutionDimensionNumbers protobuf."""
532  __slots__ = ('input_batch_dimension', 'input_feature_dimension',
533               'input_spatial_dimensions', 'kernel_input_feature_dimension',
534               'kernel_output_feature_dimension', 'kernel_spatial_dimensions',
535               'output_batch_dimension', 'output_feature_dimension',
536               'output_spatial_dimensions')
537
538  def __init__(self):
539    self.input_batch_dimension = 0
540    self.input_feature_dimension = 0
541    self.input_spatial_dimensions = []
542    self.kernel_input_feature_dimension = 0
543    self.kernel_output_feature_dimension = 0
544    self.kernel_spatial_dimensions = []
545    self.output_batch_dimension = 0
546    self.output_feature_dimension = 0
547    self.output_spatial_dimensions = []
548
549
550def make_convolution_dimension_numbers(
551    dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str,
552                                                                      str]],
553    num_spatial_dimensions: int) -> ConvolutionDimensionNumbers:
554  """Builds a ConvolutionDimensionNumbers object from a specification.
555
556  Args:
557    dimension_numbers: optional, either a ConvolutionDimensionNumbers object or
558      a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of
559      length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and
560        the output with the character 'N', (2) feature dimensions in lhs and the
561        output with the character 'C', (3) input and output feature dimensions
562        in rhs with the characters 'I' and 'O' respectively, and (4) spatial
563        dimension correspondences between lhs, rhs, and the output using any
564        distinct characters. For example, to indicate dimension numbers
565        consistent with the Conv operation with two spatial dimensions, one
566        could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate
567        dimension numbers consistent with the TensorFlow Conv2D operation, one
568        could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of
569        convolution dimension specification, window strides are associated with
570        spatial dimension character labels according to the order in which the
571        labels appear in the rhs_spec string, so that window_strides[0] is
572        matched with the dimension corresponding to the first character
573        appearing in rhs_spec that is not 'I' or 'O'. By default, use the same
574        dimension numbering as Conv and ConvWithGeneralPadding.
575    num_spatial_dimensions: the number of spatial dimensions.
576
577  Returns:
578    A `ConvolutionDimensionNumbers` object.
579  """
580  if dimension_numbers is None:
581    nd = num_spatial_dimensions
582    dimension_numbers = ConvolutionDimensionNumbers()
583    dimension_numbers.input_batch_dimension = 0
584    dimension_numbers.input_feature_dimension = 1
585    dimension_numbers.output_batch_dimension = 0
586    dimension_numbers.output_feature_dimension = 1
587    dimension_numbers.kernel_output_feature_dimension = 0
588    dimension_numbers.kernel_input_feature_dimension = 1
589    dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
590    dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
591    dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
592  elif isinstance(dimension_numbers, tuple):
593    lhs_spec, rhs_spec, out_spec = dimension_numbers
594    dimension_numbers = ConvolutionDimensionNumbers()
595
596    dimension_numbers.input_batch_dimension = lhs_spec.index('N')
597    dimension_numbers.input_feature_dimension = lhs_spec.index('C')
598    dimension_numbers.output_batch_dimension = out_spec.index('N')
599    dimension_numbers.output_feature_dimension = out_spec.index('C')
600    dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O')
601    dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I')
602
603    dimension_numbers.kernel_spatial_dimensions.extend(
604        i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'})
605    dimension_numbers.input_spatial_dimensions.extend(
606        sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}),
607               key=lambda i: rhs_spec.index(lhs_spec[i])))
608    dimension_numbers.output_spatial_dimensions.extend(
609        sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
610               key=lambda i: rhs_spec.index(out_spec[i])))
611  return dimension_numbers
612
613
614class OpSharding(object):
615  """Python representation of a xla.OpSharding protobuf."""
616  __slots__ = ('type', 'tile_assignment_dimensions', 'tile_assignment_devices',
617               'tuple_shardings', 'replicate_on_last_tile_dim')
618
619  Type = _xla.OpSharding_Type
620
621  def __init__(self):
622    self.type = self.Type.REPLICATED
623    self.tile_assignment_dimensions = []
624    self.tile_assignment_devices = []
625    self.tuple_shardings = []
626    self.replicate_on_last_tile_dim = False
627
628
629class PrecisionConfig(object):
630  """Python representation of a xla.PrecisionConfig protobuf."""
631  __slots__ = ('operand_precision',)
632
633  Precision = _xla.PrecisionConfig_Precision
634
635  def __init__(self):
636    self.operand_precision = []
637
638
639class GatherDimensionNumbers(object):
640  """Python representation of a xla.GatherDimensionNumbers protobuf."""
641  __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map',
642               'index_vector_dim')
643
644  def __init__(self):
645    self.offset_dims = []
646    self.collapsed_slice_dims = []
647    self.start_index_map = []
648    self.index_vector_dim = 0
649
650
651class ScatterDimensionNumbers(object):
652  """Python representation of a xla.ScatterDimensionNumbers protobuf."""
653  __slots__ = ('update_window_dims', 'inserted_window_dims',
654               'scatter_dims_to_operand_dims', 'index_vector_dim')
655
656  def __init__(self):
657    self.update_window_dims = []
658    self.inserted_window_dims = []
659    self.scatter_dims_to_operand_dims = []
660    self.index_vector_dim = 0
661
662
663class ReplicaGroup(object):
664  """Python representation of a xla.ReplicaGroup protobuf."""
665  __slots__ = ('replica_ids',)
666
667  def __init__(self):
668    self.replica_ids = []
669
670
671def _make_replica_group_proto(replica_group):
672  replica_group_proto = ReplicaGroup()
673  replica_group_proto.replica_ids.extend(replica_group)
674  return replica_group_proto
675
676
677def make_replica_groups(replica_groups):
678  if replica_groups is None:
679    replica_groups_protos = []  # special value for XLA API
680  else:
681    replica_groups = list(replica_groups)
682    replica_groups_protos = [
683        _make_replica_group_proto(group) for group in replica_groups
684    ]
685  return replica_groups_protos
686
687
688Traceback = _xla.Traceback
689
690
691@contextlib.contextmanager
692def tracebacks(enabled=True):
693  """Context manager that enables or disables traceback collection."""
694  saved = Traceback.enabled
695  Traceback.enabled = enabled
696  try:
697    yield
698  finally:
699    Traceback.enabled = saved
700
701
702def heap_profile(client: Client) -> str:
703  """Returns a gzipped pprof protocol buffer containing a heap profile."""
704  return gzip.compress(client.heap_profile())
705
706
707# Perform one last garbage collection of deferred Python references. This is
708# mostly to keep ASAN happy.
709atexit.register(_xla.collect_garbage)
710