1# Lint as: python3 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""An XLA client in Python.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import atexit 23import collections 24import contextlib 25import enum # pylint: disable=g-bad-import-order 26import gzip 27import inspect 28import os 29from typing import List, Sequence, Tuple, Union 30 31from . import xla_extension as _xla 32 33from absl import logging 34import numpy as np 35 36# Note this module does *not* depend on any Python protocol buffers. The XLA 37# Python bindings are currently packaged both as part of jaxlib and as part 38# of TensorFlow. If we use protocol buffers here, then importing both jaxlib 39# and TensorFlow may fail with duplicate protocol buffer message definitions. 40 41# Most functions are snake_case for consistency with other modules, some 42# method names are CamelCase for consistency with XLA. 43# pylint: disable=invalid-name 44 45# Pylint has false positives for type annotations. 46# pylint: disable=invalid-sequence-index 47 48ops = _xla.ops 49profiler = _xla.profiler 50 51# Just an internal arbitrary increasing number to help with backward-compatible 52# changes. 53_version = 3 54 55xla_platform_names = { 56 'cpu': 'Host', 57 'gpu': 'CUDA', 58} 59 60 61def _interpreter_backend_factory(): 62 return _xla.get_interpreter_client() 63 64 65def _cpu_backend_factory(): 66 return _xla.get_cpu_client(asynchronous=True) 67 68 69def _gpu_backend_factory(distributed_client=None, node_id=0): 70 """Returns a GPU backend. BFC allocator is used by default.""" 71 allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() 72 memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION') 73 preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE') 74 if allocator not in ('default', 'platform', 'bfc'): 75 raise ValueError( 76 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", or ' 77 '"bfc", got "%s"' % allocator) 78 config = _xla.GpuAllocatorConfig() 79 if allocator == 'default': 80 config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT 81 if allocator == 'platform': 82 config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM 83 if allocator == 'bfc': 84 config.kind = _xla.GpuAllocatorConfig.Kind.BFC 85 if memory_fraction: 86 config.memory_fraction = float(memory_fraction) 87 config.preallocate = preallocate not in ('0', 'false', 'False') 88 89 return _xla.get_gpu_client( 90 asynchronous=True, 91 allocator_config=config, 92 distributed_client=distributed_client, 93 node_id=node_id) 94 95 96def _tpu_backend_factory(): 97 return _xla.get_tpu_client(asynchronous=True) 98 99 100# Backend factories, keyed by user-visible name, in increasing priority order. 101_local_backend_factories = collections.OrderedDict([ 102 ('interpreter', _interpreter_backend_factory), 103 ('cpu', _cpu_backend_factory), 104 ('gpu', _gpu_backend_factory), 105 ('tpu', _tpu_backend_factory), 106]) 107 108 109def register_local_backend_factory(name, factory): 110 _local_backend_factories[name] = factory 111 112 113_local_backends = None 114 115 116def _get_local_backends(): 117 """Instantiates all known local backends.""" 118 global _local_backends 119 if _local_backends is not None: 120 return _local_backends 121 122 _local_backends = collections.OrderedDict() 123 for name, factory in _local_backend_factories.items(): 124 logging.vlog(1, "Initializing backend '%s'" % name) 125 try: 126 backend = factory() 127 except RuntimeError as err: 128 if name == 'cpu': 129 # We always expect CPU to initialize successfully. 130 raise 131 else: 132 # If the backend isn't built into the binary, or if it has no devices, 133 # we expect a RuntimeError. 134 logging.vlog(1, "Error initializing backend '%s': %s" % (name, err)) 135 continue 136 _local_backends[name] = backend 137 return _local_backends 138 139 140def get_local_backend(name=None): 141 """Returns a local backend. 142 143 Args: 144 name: the backend name. If `None`, a default local backend is returned, 145 typically `gpu` if one is present, or `cpu` if not. If a string, the named 146 backend is returned or an exception raised. 147 148 Returns: 149 A LocalBackend object. 150 """ 151 backends = _get_local_backends() 152 if name is not None: 153 try: 154 return backends[name] 155 except KeyError: 156 raise RuntimeError( 157 'Unknown backend %s. Available: %s' % (name, list(backends.keys()))) 158 159 return list(backends.values())[-1] 160 161 162class OpMetadata(object): 163 """Python representation of a xla.OpMetadata protobuf.""" 164 __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') 165 166 def __init__(self, op_type='', op_name='', source_file='', source_line=0): 167 self.op_type = op_type 168 self.op_name = op_name 169 self.source_file = source_file 170 self.source_line = source_line 171 172 173def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): 174 """Helper for use in source mapping that returns an OpMetadata object.""" 175 full_filename, lineno = inspect.stack()[skip_frames][1:3] 176 filename = os.path.basename(full_filename) 177 return OpMetadata( 178 op_type=op_type, 179 op_name=op_name, 180 source_file=filename, 181 source_line=lineno) 182 183 184PrimitiveType = _xla.PrimitiveType 185 186bfloat16 = _xla.bfloat16_dtype() 187 188XLA_ELEMENT_TYPE_TO_DTYPE = { 189 PrimitiveType.PRED: np.dtype('bool'), 190 PrimitiveType.S8: np.dtype('int8'), 191 PrimitiveType.S16: np.dtype('int16'), 192 PrimitiveType.S32: np.dtype('int32'), 193 PrimitiveType.S64: np.dtype('int64'), 194 PrimitiveType.U8: np.dtype('uint8'), 195 PrimitiveType.U16: np.dtype('uint16'), 196 PrimitiveType.U32: np.dtype('uint32'), 197 PrimitiveType.U64: np.dtype('uint64'), 198 PrimitiveType.BF16: np.dtype(bfloat16), 199 PrimitiveType.F16: np.dtype('float16'), 200 PrimitiveType.F32: np.dtype('float32'), 201 PrimitiveType.F64: np.dtype('float64'), 202 PrimitiveType.C64: np.dtype('complex64'), 203 PrimitiveType.C128: np.dtype('complex128'), 204 PrimitiveType.TUPLE: np.dtype(np.object_), 205 PrimitiveType.TOKEN: np.dtype(np.object_), 206} 207 208# Note the conversion on the key. Numpy has a known issue wherein dtype hashing 209# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, 210# when keying by dtype in this dict, we use the string form of dtypes. 211DTYPE_TO_XLA_ELEMENT_TYPE = { 212 str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() 213} 214 215 216def dtype_to_etype(dtype): 217 """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" 218 return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] 219 220 221Shape = _xla.Shape 222Shape.__doc__ = """ 223A Shape is an object defined in C++ that duck types like the following class: 224 225class Shape(object): 226 '''Represents an XLA shape. 227 228 A shape is either an array shape, having rank-many integer 229 dimensions and an element type (represented by a Numpy dtype), or it 230 is a tuple shape, having a shape for every tuple component: 231 232 type shape = 233 TupleShape of shape list 234 | ArrayShape of { dimensions: int list; element_type: dtype } 235 ''' 236 237 @staticmethod 238 def tuple_shape(tuple_shapes) -> Shape: 239 "Construct a tuple shape." 240 241 @staticmethod 242 def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: 243 244 @staticmethod 245 def from_pyval(pyval) -> Shape: 246 "Returns a Shape that describes a tuple-tree of Numpy arrays." 247 248 def __init__(self, str) -> Shape: 249 "Parses a shape string." 250 def __eq__(self, other: Shape) -> bool: 251 def __ne__(self, other: Shape) -> bool: 252 def __hash__(self): 253 def __repr__(self): 254 def is_tuple(self) -> bool: 255 def is_array(self) -> bool: 256 def tuple_shapes(self) -> [Shape]: 257 def numpy_dtype(self) -> np.dtype: 258 "Like element_type(), but returns dtype('O') for a tuple shape." 259 def xla_element_type(self) -> PrimitiveType: 260 def element_type(self) -> np.dtype: 261 def dimensions(self) -> (int, int, ...): 262 def rank(self) -> int: 263 def with_major_to_minor_layout_if_absent(self) -> Shape: 264 "Returns a copy with missing layouts set to major-to-minor." 265 266 def to_serialized_proto(self) -> bytes: 267 "Returns 'shape' as a serialized proto." 268""" 269 270ProgramShape = _xla.ProgramShape 271ProgramShape.__doc__ = """ 272A ProgramShape is a C++ object that duck types like the following class. 273 274class ProgramShape(object): 275 def __init__(self, parameter_shapes, result_shape): 276 def parameter_shapes(self) -> [Shape]: 277 def result_shape(self) -> Shape: 278 def __repr__(self): 279""" 280 281 282def shape_from_pyval(pyval): 283 """Returns a Shape that describes a tuple-tree of Numpy arrays.""" 284 285 def convert(pyval): 286 if isinstance(pyval, tuple): 287 return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) 288 else: 289 return Shape.array_shape(pyval.dtype, np.shape(pyval)) 290 291 return convert(pyval) 292 293 294DeviceAssignment = _xla.DeviceAssignment 295DeviceAssignment.__doc__ = """ 296A DeviceAssignment is a C++ object with the following signature. 297 298def create(assignment): 299 '''Builds a device assignment. 300 301 Args: 302 assignment: a 2D numpy array of device ordinal integers, indexed by 303 [replica][computation_in_replica]. 304 Returns: 305 A device assignment. 306 ''' 307 308def replica_count(): 309 '''Returns the number of replicas.''' 310def computation_count(): 311 '''Returns the number of computations per replica.''' 312""" 313 314Device = _xla.Device 315CompileOptions = _xla.CompileOptions 316 317HostBufferSemantics = _xla.HostBufferSemantics 318 319# An Executable is a C++ class that duck types with the following API: 320# class Executable(object): 321# def local_devices(self) -> [Device]: 322# def execute(self, arguments : [Buffer]) -> Buffer: 323# """Execute on one replica with Buffer arguments and return value.""" 324# 325# def size_of_generated_code_in_bytes(self) -> int: 326# """Return generated binary size, or -1 if not known.""" 327# 328# def execute_on_local_devices(self, arguments: [[Buffer]]) -> [Buffer]: 329# """Execute on many replicas with Buffer arguments and return value. 330# 331# Args: 332# arguments: A sequence of sequences of Buffers. The i'th inner sequence 333# comprises the arguments for execution on the i'th local device. 334# 335# Returns: 336# A list of the computation's outputs for each local device, as a Buffer. 337# If a shallow sequence of arguments was passed in for `arguments`, then 338# the sole, zero'th device's output is returned instead, as a Buffer. 339# """ 340# 341# There are different implementations of Executable for different backends. 342 343 344def execute_with_python_values(executable, arguments, backend): 345 """Execute on one replica with Python values as arguments and output.""" 346 347 def put(arg): 348 return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) 349 350 arguments = [put(arg) for arg in arguments] 351 outputs = executable.execute(arguments) 352 return [x.to_py() for x in outputs] 353 354 355def execute_with_python_values_replicated(executable, arguments, backend): 356 """Execute on many replicas with Python values as arguments and output. 357 358 Args: 359 executable: the program to run. 360 arguments: a list of lists of Python values indexed by `[replica][arg_num]` 361 to pass as inputs. 362 backend: the backend we are targeting. 363 364 Returns: 365 A list of python values, one per replica. 366 """ 367 devices = executable.local_devices() 368 # pylint: disable=g-complex-comprehension 369 flat_args = [(arg, devices[replica]) 370 for replica, replica_args in enumerate(arguments) 371 for arg in replica_args] 372 flat_arg_buffers = [ 373 backend.buffer_from_pyval(pyval, device) for pyval, device in flat_args 374 ] 375 arg_buffers = [] 376 for replica_args in arguments: 377 arg_buffers.append(flat_arg_buffers[:len(replica_args)]) 378 flat_arg_buffers = flat_arg_buffers[len(replica_args):] 379 return [[x.to_py() 380 for x in xs] 381 for xs in executable.execute_on_local_devices(arg_buffers)] 382 383 384class PaddingType(enum.Enum): 385 VALID = 1 386 SAME = 2 387 388 389def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, 390 window_strides): 391 """Maps PaddingType or string to pad values (list of pairs of ints).""" 392 if not isinstance(padding_type, (str, PaddingType)): 393 msg = 'padding_type must be str or PaddingType, got {}.' 394 raise TypeError(msg.format(type(padding_type))) 395 396 if isinstance(padding_type, str): 397 if padding_type.upper() == 'VALID': 398 padding_type = PaddingType.VALID 399 elif padding_type.upper() == 'SAME': 400 padding_type = PaddingType.SAME 401 else: 402 msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' 403 raise ValueError(msg.format(padding_type)) 404 405 if padding_type == PaddingType.VALID: 406 return [(0, 0)] * len(window_strides) 407 elif padding_type == PaddingType.SAME: 408 out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) 409 pad_sizes = [ 410 max((out_size - 1) * stride + filter_size - in_size, 0) 411 for out_size, stride, filter_size, in_size in zip( 412 out_shape, window_strides, rhs_dims, lhs_dims) 413 ] 414 return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] 415 else: 416 msg = 'Unexpected PaddingType value: {}' 417 raise ValueError(msg.format(padding_type)) 418 419 420XlaBuilder = _xla.XlaBuilder 421XlaComputation = _xla.XlaComputation 422FftType = _xla.FftType 423Client = _xla.Client 424Buffer = _xla.Buffer 425DeviceArrayBase = _xla.DeviceArrayBase 426Executable = _xla.Executable 427 428 429def register_custom_call_target(name, fn, platform='cpu'): 430 """Registers a custom call target. 431 432 Args: 433 name: bytes containing the name of the function. 434 fn: a PyCapsule object containing the function pointer. 435 platform: the target platform. 436 """ 437 # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" 438 # Since that is hardcoded to CUDA, we are using the following as workaround. 439 _xla.register_custom_call_target(name, fn, 440 xla_platform_names.get(platform, platform)) 441 442 443# Deprecated. Use register_custom_call_target instead. 444register_cpu_custom_call_target = register_custom_call_target 445 446 447class PaddingConfigDimension(object): 448 """Python representation of a xla.PaddingConfigDimension protobuf.""" 449 __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') 450 451 def __init__(self): 452 self.edge_padding_low = 0 453 self.edge_padding_high = 0 454 self.interior_padding = 0 455 456 457class PaddingConfig(object): 458 """Python representation of a xla.PaddingConfig protobuf.""" 459 __slots__ = ('dimensions',) 460 461 def __init__(self): 462 self.dimensions = [] 463 464 465def make_padding_config( 466 padding_config: Union[PaddingConfig, Sequence[Tuple[int, int, int]]] 467) -> PaddingConfig: 468 """Create PaddingConfig proto from list of triples of integers. 469 470 Args: 471 padding_config: either a PaddingConfig or a list of integer triples 472 (edge_padding_low, edge_padding_high, interior_padding) representing the 473 configuration of the padding operation. 474 475 Returns: 476 A `PaddingConfig` object. 477 """ 478 if isinstance(padding_config, tuple) or isinstance(padding_config, list): 479 triples = padding_config 480 padding_config = PaddingConfig() 481 for lo, hi, interior in triples: 482 dimension = PaddingConfigDimension() 483 dimension.edge_padding_low = lo 484 dimension.edge_padding_high = hi 485 dimension.interior_padding = interior 486 padding_config.dimensions.append(dimension) 487 return padding_config 488 489 490class DotDimensionNumbers(object): 491 """Python representation of a xla.DotDimensionNumbers protobuf.""" 492 __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', 493 'lhs_batch_dimensions', 'rhs_batch_dimensions') 494 495 def __init__(self): 496 self.lhs_contracting_dimensions = [] 497 self.rhs_contracting_dimensions = [] 498 self.lhs_batch_dimensions = [] 499 self.rhs_batch_dimensions = [] 500 501 502def make_dot_dimension_numbers( 503 dimension_numbers: Union[DotDimensionNumbers, 504 Tuple[Tuple[List[int], List[int]], 505 Tuple[List[int], List[int]]]] 506) -> DotDimensionNumbers: 507 """Builds a DotDimensionNumbers object from a specification. 508 509 Args: 510 dimension_numbers: either a `DotDimensionNumbers` or a nested tuple 511 `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of 512 integers representing the dimensions to treat as contracting dimensions 513 and batch dimensions on each input operand. 514 515 Returns: 516 A `DotDimensionNumbers` object. 517 """ 518 if isinstance(dimension_numbers, (list, tuple)): 519 (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 520 dot_dims_proto = DotDimensionNumbers() 521 dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) 522 dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) 523 dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) 524 dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) 525 return dot_dims_proto 526 else: 527 return dimension_numbers 528 529 530class ConvolutionDimensionNumbers(object): 531 """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" 532 __slots__ = ('input_batch_dimension', 'input_feature_dimension', 533 'input_spatial_dimensions', 'kernel_input_feature_dimension', 534 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', 535 'output_batch_dimension', 'output_feature_dimension', 536 'output_spatial_dimensions') 537 538 def __init__(self): 539 self.input_batch_dimension = 0 540 self.input_feature_dimension = 0 541 self.input_spatial_dimensions = [] 542 self.kernel_input_feature_dimension = 0 543 self.kernel_output_feature_dimension = 0 544 self.kernel_spatial_dimensions = [] 545 self.output_batch_dimension = 0 546 self.output_feature_dimension = 0 547 self.output_spatial_dimensions = [] 548 549 550def make_convolution_dimension_numbers( 551 dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str, 552 str]], 553 num_spatial_dimensions: int) -> ConvolutionDimensionNumbers: 554 """Builds a ConvolutionDimensionNumbers object from a specification. 555 556 Args: 557 dimension_numbers: optional, either a ConvolutionDimensionNumbers object or 558 a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of 559 length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and 560 the output with the character 'N', (2) feature dimensions in lhs and the 561 output with the character 'C', (3) input and output feature dimensions 562 in rhs with the characters 'I' and 'O' respectively, and (4) spatial 563 dimension correspondences between lhs, rhs, and the output using any 564 distinct characters. For example, to indicate dimension numbers 565 consistent with the Conv operation with two spatial dimensions, one 566 could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate 567 dimension numbers consistent with the TensorFlow Conv2D operation, one 568 could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of 569 convolution dimension specification, window strides are associated with 570 spatial dimension character labels according to the order in which the 571 labels appear in the rhs_spec string, so that window_strides[0] is 572 matched with the dimension corresponding to the first character 573 appearing in rhs_spec that is not 'I' or 'O'. By default, use the same 574 dimension numbering as Conv and ConvWithGeneralPadding. 575 num_spatial_dimensions: the number of spatial dimensions. 576 577 Returns: 578 A `ConvolutionDimensionNumbers` object. 579 """ 580 if dimension_numbers is None: 581 nd = num_spatial_dimensions 582 dimension_numbers = ConvolutionDimensionNumbers() 583 dimension_numbers.input_batch_dimension = 0 584 dimension_numbers.input_feature_dimension = 1 585 dimension_numbers.output_batch_dimension = 0 586 dimension_numbers.output_feature_dimension = 1 587 dimension_numbers.kernel_output_feature_dimension = 0 588 dimension_numbers.kernel_input_feature_dimension = 1 589 dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) 590 dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) 591 dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) 592 elif isinstance(dimension_numbers, tuple): 593 lhs_spec, rhs_spec, out_spec = dimension_numbers 594 dimension_numbers = ConvolutionDimensionNumbers() 595 596 dimension_numbers.input_batch_dimension = lhs_spec.index('N') 597 dimension_numbers.input_feature_dimension = lhs_spec.index('C') 598 dimension_numbers.output_batch_dimension = out_spec.index('N') 599 dimension_numbers.output_feature_dimension = out_spec.index('C') 600 dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') 601 dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') 602 603 dimension_numbers.kernel_spatial_dimensions.extend( 604 i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) 605 dimension_numbers.input_spatial_dimensions.extend( 606 sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), 607 key=lambda i: rhs_spec.index(lhs_spec[i]))) 608 dimension_numbers.output_spatial_dimensions.extend( 609 sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), 610 key=lambda i: rhs_spec.index(out_spec[i]))) 611 return dimension_numbers 612 613 614class OpSharding(object): 615 """Python representation of a xla.OpSharding protobuf.""" 616 __slots__ = ('type', 'tile_assignment_dimensions', 'tile_assignment_devices', 617 'tuple_shardings', 'replicate_on_last_tile_dim') 618 619 Type = _xla.OpSharding_Type 620 621 def __init__(self): 622 self.type = self.Type.REPLICATED 623 self.tile_assignment_dimensions = [] 624 self.tile_assignment_devices = [] 625 self.tuple_shardings = [] 626 self.replicate_on_last_tile_dim = False 627 628 629class PrecisionConfig(object): 630 """Python representation of a xla.PrecisionConfig protobuf.""" 631 __slots__ = ('operand_precision',) 632 633 Precision = _xla.PrecisionConfig_Precision 634 635 def __init__(self): 636 self.operand_precision = [] 637 638 639class GatherDimensionNumbers(object): 640 """Python representation of a xla.GatherDimensionNumbers protobuf.""" 641 __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', 642 'index_vector_dim') 643 644 def __init__(self): 645 self.offset_dims = [] 646 self.collapsed_slice_dims = [] 647 self.start_index_map = [] 648 self.index_vector_dim = 0 649 650 651class ScatterDimensionNumbers(object): 652 """Python representation of a xla.ScatterDimensionNumbers protobuf.""" 653 __slots__ = ('update_window_dims', 'inserted_window_dims', 654 'scatter_dims_to_operand_dims', 'index_vector_dim') 655 656 def __init__(self): 657 self.update_window_dims = [] 658 self.inserted_window_dims = [] 659 self.scatter_dims_to_operand_dims = [] 660 self.index_vector_dim = 0 661 662 663class ReplicaGroup(object): 664 """Python representation of a xla.ReplicaGroup protobuf.""" 665 __slots__ = ('replica_ids',) 666 667 def __init__(self): 668 self.replica_ids = [] 669 670 671def _make_replica_group_proto(replica_group): 672 replica_group_proto = ReplicaGroup() 673 replica_group_proto.replica_ids.extend(replica_group) 674 return replica_group_proto 675 676 677def make_replica_groups(replica_groups): 678 if replica_groups is None: 679 replica_groups_protos = [] # special value for XLA API 680 else: 681 replica_groups = list(replica_groups) 682 replica_groups_protos = [ 683 _make_replica_group_proto(group) for group in replica_groups 684 ] 685 return replica_groups_protos 686 687 688Traceback = _xla.Traceback 689 690 691@contextlib.contextmanager 692def tracebacks(enabled=True): 693 """Context manager that enables or disables traceback collection.""" 694 saved = Traceback.enabled 695 Traceback.enabled = enabled 696 try: 697 yield 698 finally: 699 Traceback.enabled = saved 700 701 702def heap_profile(client: Client) -> str: 703 """Returns a gzipped pprof protocol buffer containing a heap profile.""" 704 return gzip.compress(client.heap_profile()) 705 706 707# Perform one last garbage collection of deferred Python references. This is 708# mostly to keep ASAN happy. 709atexit.register(_xla.collect_garbage) 710