• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""An XLA client in Python."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import atexit
23import collections
24import contextlib
25import enum  # pylint: disable=g-bad-import-order
26import gzip
27import inspect
28import os
29from typing import List, Sequence, Tuple, Union
30
31from . import xla_extension as _xla
32
33from absl import logging
34import numpy as np
35
36# Note this module does *not* depend on any Python protocol buffers. The XLA
37# Python bindings are currently packaged both as part of jaxlib and as part
38# of TensorFlow. If we use protocol buffers here, then importing both jaxlib
39# and TensorFlow may fail with duplicate protocol buffer message definitions.
40
41# Most functions are snake_case for consistency with other modules, some
42# method names are CamelCase for consistency with XLA.
43# pylint: disable=invalid-name
44
45# Pylint has false positives for type annotations.
46# pylint: disable=invalid-sequence-index
47
48ops = _xla.ops
49profiler = _xla.profiler
50
51# Just an internal arbitrary increasing number to help with backward-compatible
52# changes.
53_version = 32
54
55xla_platform_names = {
56    'cpu': 'Host',
57    'gpu': 'CUDA',
58}
59
60
61def make_interpreter_client():
62  return _xla.get_interpreter_client()
63
64
65def make_cpu_client(*, use_tfrt=False):
66  if use_tfrt:
67    return _xla.get_tfrt_cpu_client(asynchronous=True)
68  else:
69    return _xla.get_cpu_client(asynchronous=True)
70
71
72def make_gpu_client(distributed_client=None, node_id=0):
73  """Returns a GPU client. BFC allocator is used by default."""
74  allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
75  memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION')
76  preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE')
77  if allocator not in ('default', 'platform', 'bfc', 'cuda_async'):
78    raise ValueError(
79        'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", '
80        '"bfc", or "cuda_async", got "%s"' % allocator)
81  config = _xla.GpuAllocatorConfig()
82  if allocator == 'default':
83    config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT
84  if allocator == 'platform':
85    config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM
86  if allocator == 'bfc':
87    config.kind = _xla.GpuAllocatorConfig.Kind.BFC
88  if allocator == 'cuda_async':
89    config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC
90  if memory_fraction:
91    config.memory_fraction = float(memory_fraction)
92  config.preallocate = preallocate not in ('0', 'false', 'False')
93
94  return _xla.get_gpu_client(
95      asynchronous=True,
96      allocator_config=config,
97      distributed_client=distributed_client,
98      node_id=node_id)
99
100
101def make_tpu_client():
102  return _xla.get_tpu_client(max_inflight_computations=32)
103
104
105# Deprecated client factory API.
106
107# Backend factories, keyed by user-visible name, in increasing priority order.
108_local_backend_factories = collections.OrderedDict([
109    ('interpreter', make_interpreter_client),
110    ('cpu', make_cpu_client),
111    ('gpu', make_gpu_client),
112    ('tpu', make_tpu_client),
113])
114
115
116def register_local_backend_factory(name, factory):
117  _local_backend_factories[name] = factory
118
119
120_local_backends = None
121
122
123def _get_local_backends():
124  """Instantiates all known local backends."""
125  global _local_backends
126  if _local_backends is not None:
127    return _local_backends
128
129  _local_backends = collections.OrderedDict()
130  for name, factory in _local_backend_factories.items():
131    logging.vlog(1, "Initializing backend '%s'" % name)
132    try:
133      backend = factory()
134    except RuntimeError as err:
135      if name == 'cpu':
136        # We always expect CPU to initialize successfully.
137        raise
138      else:
139        # If the backend isn't built into the binary, or if it has no devices,
140        # we expect a RuntimeError.
141        logging.vlog(1, "Error initializing backend '%s': %s" % (name, err))
142        continue
143    _local_backends[name] = backend
144  return _local_backends
145
146
147def get_local_backend(name=None):
148  """Returns a local backend.
149
150  Args:
151    name: the backend name. If `None`, a default local backend is returned,
152      typically `gpu` if one is present, or `cpu` if not. If a string, the named
153      backend is returned or an exception raised.
154
155  Returns:
156    A LocalBackend object.
157  """
158  backends = _get_local_backends()
159  if name is not None:
160    try:
161      return backends[name]
162    except KeyError:
163      raise RuntimeError('Unknown backend %s. Available: %s' %
164                         (name, list(backends.keys())))
165
166  return list(backends.values())[-1]
167
168
169class OpMetadata(object):
170  """Python representation of a xla.OpMetadata protobuf."""
171  __slots__ = ('op_type', 'op_name', 'source_file', 'source_line')
172
173  def __init__(self, op_type='', op_name='', source_file='', source_line=0):
174    self.op_type = op_type
175    self.op_name = op_name
176    self.source_file = source_file
177    self.source_line = source_line
178
179
180def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
181  """Helper for use in source mapping that returns an OpMetadata object."""
182  full_filename, lineno = inspect.stack()[skip_frames][1:3]
183  filename = os.path.basename(full_filename)
184  return OpMetadata(
185      op_type=op_type,
186      op_name=op_name,
187      source_file=filename,
188      source_line=lineno)
189
190
191PrimitiveType = _xla.PrimitiveType
192
193bfloat16 = _xla.bfloat16_dtype()
194
195XLA_ELEMENT_TYPE_TO_DTYPE = {
196    PrimitiveType.PRED: np.dtype('bool'),
197    PrimitiveType.S8: np.dtype('int8'),
198    PrimitiveType.S16: np.dtype('int16'),
199    PrimitiveType.S32: np.dtype('int32'),
200    PrimitiveType.S64: np.dtype('int64'),
201    PrimitiveType.U8: np.dtype('uint8'),
202    PrimitiveType.U16: np.dtype('uint16'),
203    PrimitiveType.U32: np.dtype('uint32'),
204    PrimitiveType.U64: np.dtype('uint64'),
205    PrimitiveType.BF16: np.dtype(bfloat16),
206    PrimitiveType.F16: np.dtype('float16'),
207    PrimitiveType.F32: np.dtype('float32'),
208    PrimitiveType.F64: np.dtype('float64'),
209    PrimitiveType.C64: np.dtype('complex64'),
210    PrimitiveType.C128: np.dtype('complex128'),
211    PrimitiveType.TUPLE: np.dtype(np.object_),
212    PrimitiveType.TOKEN: np.dtype(np.object_),
213}
214
215# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
216# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
217# when keying by dtype in this dict, we use the string form of dtypes.
218DTYPE_TO_XLA_ELEMENT_TYPE = {
219    str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()
220}
221
222
223def dtype_to_etype(dtype):
224  """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE."""
225  return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
226
227
228Shape = _xla.Shape
229Shape.__doc__ = """
230A Shape is an object defined in C++ that duck types like the following class:
231
232class Shape(object):
233  '''Represents an XLA shape.
234
235  A shape is either an array shape, having rank-many integer
236  dimensions and an element type (represented by a Numpy dtype), or it
237  is a tuple shape, having a shape for every tuple component:
238
239    type shape =
240        TupleShape of shape list
241      | ArrayShape of { dimensions: int list; element_type: dtype }
242  '''
243
244  @staticmethod
245  def tuple_shape(tuple_shapes) -> Shape:
246    "Construct a tuple shape."
247
248  @staticmethod
249  def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:
250
251  @staticmethod
252  def from_pyval(pyval) -> Shape:
253    "Returns a Shape that describes a tuple-tree of Numpy arrays."
254
255  def __init__(self, str) -> Shape:
256    "Parses a shape string."
257  def __eq__(self, other: Shape) -> bool:
258  def __ne__(self, other: Shape) -> bool:
259  def __hash__(self):
260  def __repr__(self):
261  def is_tuple(self) -> bool:
262  def is_array(self) -> bool:
263  def tuple_shapes(self) -> [Shape]:
264  def numpy_dtype(self) -> np.dtype:
265    "Like element_type(), but returns dtype('O') for a tuple shape."
266  def xla_element_type(self) -> PrimitiveType:
267  def element_type(self) -> np.dtype:
268  def dimensions(self) -> (int, int, ...):
269  def rank(self) -> int:
270  def with_major_to_minor_layout_if_absent(self) -> Shape:
271    "Returns a copy with missing layouts set to major-to-minor."
272
273  def to_serialized_proto(self) -> bytes:
274    "Returns 'shape' as a serialized proto."
275"""
276
277ProgramShape = _xla.ProgramShape
278ProgramShape.__doc__ = """
279A ProgramShape is a C++ object that duck types like the following class.
280
281class ProgramShape(object):
282  def __init__(self, parameter_shapes, result_shape):
283  def parameter_shapes(self) -> [Shape]:
284  def result_shape(self) -> Shape:
285  def __repr__(self):
286"""
287
288ShapeIndex = _xla.ShapeIndex
289ShapeIndex.__doc__ = """
290A Shape is an object defined in C++ that duck types like the following class:
291
292class ShapeIndex(object):
293  '''Represents an XLA ShapeIndex.
294
295  An index for specifying a particular nested subshape within a shape. Used in
296  ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through
297  the Shape tree where each element of ShapeIndex indexes into a tuple (or
298  nested tuple) within the shape. For a non-nested tuple, an index has a single
299  element.
300  '''
301
302  def __init__(self, List[int]) -> ShapeIndex:
303  def __eq__(self, other: Shape) -> bool:
304  def __ne__(self, other: Shape) -> bool:
305  def __hash__(self):
306  def __repr__(self):
307"""
308
309
310def shape_from_pyval(pyval):
311  """Returns a Shape that describes a tuple-tree of Numpy arrays."""
312
313  def convert(pyval):
314    if isinstance(pyval, tuple):
315      return Shape.tuple_shape(tuple(convert(elt) for elt in pyval))
316    else:
317      return Shape.array_shape(pyval.dtype, np.shape(pyval))
318
319  return convert(pyval)
320
321
322DeviceAssignment = _xla.DeviceAssignment
323DeviceAssignment.__doc__ = """
324A DeviceAssignment is a C++ object with the following signature.
325
326def create(assignment):
327  '''Builds a device assignment.
328
329   Args:
330     assignment: a 2D numpy array of device ordinal integers, indexed by
331       [replica][computation_in_replica].
332   Returns:
333     A device assignment.
334  '''
335
336def replica_count():
337  '''Returns the number of replicas.'''
338def computation_count():
339  '''Returns the number of computations per replica.'''
340"""
341
342Device = _xla.Device
343CompileOptions = _xla.CompileOptions
344
345HostBufferSemantics = _xla.HostBufferSemantics
346
347# An Executable is a C++ class that duck types with the following API:
348# class Executable(object):
349#   def local_devices(self) -> [Device]:
350#   def execute(self, arguments : [Buffer]) -> Buffer:
351#     """Execute on one replica with Buffer arguments and return value."""
352#
353#   def size_of_generated_code_in_bytes(self) -> int:
354#     """Return generated binary size, or -1 if not known."""
355#
356#   def execute_sharded_on_local_devices(self, arguments: [[Buffer]])
357#       -> [Buffer]:
358#     """Execute on many replicas with Buffer arguments and return value.
359#
360#     Args:
361#       arguments: A sequence of sequences of Buffers. The i'th element of each
362#         sequence comprises the arguments for execution on the i'th local
363#         device.
364#
365#     Returns:
366#       A list of the computation's outputs as a list of Buffers for each
367#       device.
368#     """
369#
370# There are different implementations of Executable for different backends.
371
372
373def execute_with_python_values(executable, arguments, backend):
374  """Execute on one replica with Python values as arguments and output."""
375
376  def put(arg):
377    return backend.buffer_from_pyval(arg, device=executable.local_devices()[0])
378
379  arguments = [put(arg) for arg in arguments]
380  outputs = executable.execute(arguments)
381  return [x.to_py() for x in outputs]
382
383
384def execute_with_python_values_replicated(executable, arguments, backend):
385  """Execute on many replicas with Python values as arguments and output.
386
387  Args:
388    executable: the program to run.
389    arguments: a list of lists of Python values indexed by `[replica][arg_num]`
390      to pass as inputs.
391    backend: the backend we are targeting.
392
393  Returns:
394    A list of python values, one per replica.
395  """
396  devices = executable.local_devices()
397
398  # pylint: disable=g-complex-comprehension
399  def copy_to_devices(pyvals):
400    return [backend.buffer_from_pyval(v, d) for v, d in zip(pyvals, devices)]
401
402  inputs = [copy_to_devices(pyvals) for pyvals in zip(*arguments)]
403  outputs = executable.execute_sharded_on_local_devices(inputs)
404  return [[x.to_py() for x in xs] for xs in zip(*outputs)]
405
406
407class PaddingType(enum.Enum):
408  VALID = 1
409  SAME = 2
410
411
412def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims,
413                                      window_strides):
414  """Maps PaddingType or string to pad values (list of pairs of ints)."""
415  if not isinstance(padding_type, (str, PaddingType)):
416    msg = 'padding_type must be str or PaddingType, got {}.'
417    raise TypeError(msg.format(type(padding_type)))
418
419  if isinstance(padding_type, str):
420    if padding_type.upper() == 'VALID':
421      padding_type = PaddingType.VALID
422    elif padding_type.upper() == 'SAME':
423      padding_type = PaddingType.SAME
424    else:
425      msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.'
426      raise ValueError(msg.format(padding_type))
427
428  if padding_type == PaddingType.VALID:
429    return [(0, 0)] * len(window_strides)
430  elif padding_type == PaddingType.SAME:
431    out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int)
432    pad_sizes = [
433        max((out_size - 1) * stride + filter_size - in_size, 0)
434        for out_size, stride, filter_size, in_size in zip(
435            out_shape, window_strides, rhs_dims, lhs_dims)
436    ]
437    return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes]
438  else:
439    msg = 'Unexpected PaddingType value: {}'
440    raise ValueError(msg.format(padding_type))
441
442
443XlaBuilder = _xla.XlaBuilder
444XlaComputation = _xla.XlaComputation
445XlaOp = _xla.XlaOp
446FftType = _xla.FftType
447Client = _xla.Client
448Buffer = _xla.Buffer
449DeviceArrayBase = _xla.DeviceArrayBase
450Executable = _xla.Executable
451
452
453def register_custom_call_target(name, fn, platform='cpu'):
454  """Registers a custom call target.
455
456  Args:
457    name: bytes containing the name of the function.
458    fn: a PyCapsule object containing the function pointer.
459    platform: the target platform.
460  """
461  # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
462  # Since that is hardcoded to CUDA, we are using the following as workaround.
463  _xla.register_custom_call_target(name, fn,
464                                   xla_platform_names.get(platform, platform))
465
466
467# Deprecated. Use register_custom_call_target instead.
468register_cpu_custom_call_target = register_custom_call_target
469
470
471class PaddingConfigDimension(object):
472  """Python representation of a xla.PaddingConfigDimension protobuf."""
473  __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding')
474
475  def __init__(self):
476    self.edge_padding_low = 0
477    self.edge_padding_high = 0
478    self.interior_padding = 0
479
480
481class PaddingConfig(object):
482  """Python representation of a xla.PaddingConfig protobuf."""
483  __slots__ = ('dimensions',)
484
485  def __init__(self):
486    self.dimensions = []
487
488
489def make_padding_config(
490    padding_config: Union[PaddingConfig, Sequence[Tuple[int, int, int]]]
491) -> PaddingConfig:
492  """Create PaddingConfig proto from list of triples of integers.
493
494  Args:
495    padding_config: either a PaddingConfig or a list of integer triples
496      (edge_padding_low, edge_padding_high, interior_padding) representing the
497      configuration of the padding operation.
498
499  Returns:
500    A `PaddingConfig` object.
501  """
502  if not isinstance(padding_config, PaddingConfig):
503    triples = padding_config
504    padding_config = PaddingConfig()
505    for lo, hi, interior in triples:
506      dimension = PaddingConfigDimension()
507      dimension.edge_padding_low = lo
508      dimension.edge_padding_high = hi
509      dimension.interior_padding = interior
510      padding_config.dimensions.append(dimension)
511  return padding_config
512
513
514class DotDimensionNumbers(object):
515  """Python representation of a xla.DotDimensionNumbers protobuf."""
516  __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions',
517               'lhs_batch_dimensions', 'rhs_batch_dimensions')
518
519  def __init__(self):
520    self.lhs_contracting_dimensions = []
521    self.rhs_contracting_dimensions = []
522    self.lhs_batch_dimensions = []
523    self.rhs_batch_dimensions = []
524
525
526def make_dot_dimension_numbers(
527    dimension_numbers: Union[DotDimensionNumbers,
528                             Tuple[Tuple[List[int], List[int]],
529                                   Tuple[List[int], List[int]]]]
530) -> DotDimensionNumbers:
531  """Builds a DotDimensionNumbers object from a specification.
532
533  Args:
534    dimension_numbers: either a `DotDimensionNumbers` or a nested tuple
535      `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of
536      integers representing the dimensions to treat as contracting dimensions
537      and batch dimensions on each input operand.
538
539  Returns:
540    A `DotDimensionNumbers` object.
541  """
542  if isinstance(dimension_numbers, (list, tuple)):
543    (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
544    dot_dims_proto = DotDimensionNumbers()
545    dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
546    dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
547    dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
548    dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
549    return dot_dims_proto
550  else:
551    return dimension_numbers
552
553
554class ConvolutionDimensionNumbers(object):
555  """Python representation of a xla.ConvolutionDimensionNumbers protobuf."""
556  __slots__ = ('input_batch_dimension', 'input_feature_dimension',
557               'input_spatial_dimensions', 'kernel_input_feature_dimension',
558               'kernel_output_feature_dimension', 'kernel_spatial_dimensions',
559               'output_batch_dimension', 'output_feature_dimension',
560               'output_spatial_dimensions')
561
562  def __init__(self):
563    self.input_batch_dimension = 0
564    self.input_feature_dimension = 0
565    self.input_spatial_dimensions = []
566    self.kernel_input_feature_dimension = 0
567    self.kernel_output_feature_dimension = 0
568    self.kernel_spatial_dimensions = []
569    self.output_batch_dimension = 0
570    self.output_feature_dimension = 0
571    self.output_spatial_dimensions = []
572
573
574def make_convolution_dimension_numbers(
575    dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str,
576                                                                      str]],
577    num_spatial_dimensions: int) -> ConvolutionDimensionNumbers:
578  """Builds a ConvolutionDimensionNumbers object from a specification.
579
580  Args:
581    dimension_numbers: optional, either a ConvolutionDimensionNumbers object or
582      a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of
583      length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and
584        the output with the character 'N', (2) feature dimensions in lhs and the
585        output with the character 'C', (3) input and output feature dimensions
586        in rhs with the characters 'I' and 'O' respectively, and (4) spatial
587        dimension correspondences between lhs, rhs, and the output using any
588        distinct characters. For example, to indicate dimension numbers
589        consistent with the Conv operation with two spatial dimensions, one
590        could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate
591        dimension numbers consistent with the TensorFlow Conv2D operation, one
592        could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of
593        convolution dimension specification, window strides are associated with
594        spatial dimension character labels according to the order in which the
595        labels appear in the rhs_spec string, so that window_strides[0] is
596        matched with the dimension corresponding to the first character
597        appearing in rhs_spec that is not 'I' or 'O'. By default, use the same
598        dimension numbering as Conv and ConvWithGeneralPadding.
599    num_spatial_dimensions: the number of spatial dimensions.
600
601  Returns:
602    A `ConvolutionDimensionNumbers` object.
603  """
604  if dimension_numbers is None:
605    nd = num_spatial_dimensions
606    dimension_numbers = ConvolutionDimensionNumbers()
607    dimension_numbers.input_batch_dimension = 0
608    dimension_numbers.input_feature_dimension = 1
609    dimension_numbers.output_batch_dimension = 0
610    dimension_numbers.output_feature_dimension = 1
611    dimension_numbers.kernel_output_feature_dimension = 0
612    dimension_numbers.kernel_input_feature_dimension = 1
613    dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd))
614    dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd))
615    dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
616  elif isinstance(dimension_numbers, tuple):
617    lhs_spec, rhs_spec, out_spec = dimension_numbers
618    dimension_numbers = ConvolutionDimensionNumbers()
619
620    dimension_numbers.input_batch_dimension = lhs_spec.index('N')
621    dimension_numbers.input_feature_dimension = lhs_spec.index('C')
622    dimension_numbers.output_batch_dimension = out_spec.index('N')
623    dimension_numbers.output_feature_dimension = out_spec.index('C')
624    dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O')
625    dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I')
626
627    dimension_numbers.kernel_spatial_dimensions.extend(
628        i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'})
629    dimension_numbers.input_spatial_dimensions.extend(
630        sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}),
631               key=lambda i: rhs_spec.index(lhs_spec[i])))
632    dimension_numbers.output_spatial_dimensions.extend(
633        sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
634               key=lambda i: rhs_spec.index(out_spec[i])))
635  return dimension_numbers
636
637
638class OpSharding(object):
639  """Python representation of a xla.OpSharding protobuf."""
640  __slots__ = ('type', 'tile_assignment_dimensions', 'tile_assignment_devices',
641               'tuple_shardings', 'replicate_on_last_tile_dim')
642
643  Type = _xla.OpSharding_Type
644
645  def __init__(self):
646    self.type = self.Type.REPLICATED
647    self.tile_assignment_dimensions = []
648    self.tile_assignment_devices = []
649    self.tuple_shardings = []
650    self.replicate_on_last_tile_dim = False
651
652
653class PrecisionConfig(object):
654  """Python representation of a xla.PrecisionConfig protobuf."""
655  __slots__ = ('operand_precision',)
656
657  Precision = _xla.PrecisionConfig_Precision
658
659  def __init__(self):
660    self.operand_precision = []
661
662
663class GatherDimensionNumbers(object):
664  """Python representation of a xla.GatherDimensionNumbers protobuf."""
665  __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map',
666               'index_vector_dim')
667
668  def __init__(self):
669    self.offset_dims = []
670    self.collapsed_slice_dims = []
671    self.start_index_map = []
672    self.index_vector_dim = 0
673
674
675class ScatterDimensionNumbers(object):
676  """Python representation of a xla.ScatterDimensionNumbers protobuf."""
677  __slots__ = ('update_window_dims', 'inserted_window_dims',
678               'scatter_dims_to_operand_dims', 'index_vector_dim')
679
680  def __init__(self):
681    self.update_window_dims = []
682    self.inserted_window_dims = []
683    self.scatter_dims_to_operand_dims = []
684    self.index_vector_dim = 0
685
686
687class ReplicaGroup(object):
688  """Python representation of a xla.ReplicaGroup protobuf."""
689  __slots__ = ('replica_ids',)
690
691  def __init__(self):
692    self.replica_ids = []
693
694
695def _make_replica_group_proto(replica_group):
696  replica_group_proto = ReplicaGroup()
697  replica_group_proto.replica_ids.extend(replica_group)
698  return replica_group_proto
699
700
701def make_replica_groups(replica_groups):
702  if replica_groups is None:
703    replica_groups_protos = []  # special value for XLA API
704  else:
705    replica_groups = list(replica_groups)
706    replica_groups_protos = [
707        _make_replica_group_proto(group) for group in replica_groups
708    ]
709  return replica_groups_protos
710
711
712Traceback = _xla.Traceback
713Frame = _xla.Frame
714
715
716@contextlib.contextmanager
717def tracebacks(enabled=True):
718  """Context manager that enables or disables traceback collection."""
719  saved = Traceback.enabled
720  Traceback.enabled = enabled
721  try:
722    yield
723  finally:
724    Traceback.enabled = saved
725
726
727def heap_profile(client: Client) -> bytes:
728  """Returns a gzipped pprof protocol buffer containing a heap profile."""
729  return gzip.compress(client.heap_profile())
730
731
732# Perform one last garbage collection of deferred Python references. This is
733# mostly to keep ASAN happy.
734atexit.register(_xla.collect_garbage)
735