1# Lint as: python3 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""An XLA client in Python.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import atexit 23import collections 24import contextlib 25import enum # pylint: disable=g-bad-import-order 26import gzip 27import inspect 28import os 29from typing import List, Sequence, Tuple, Union 30 31from . import xla_extension as _xla 32 33from absl import logging 34import numpy as np 35 36# Note this module does *not* depend on any Python protocol buffers. The XLA 37# Python bindings are currently packaged both as part of jaxlib and as part 38# of TensorFlow. If we use protocol buffers here, then importing both jaxlib 39# and TensorFlow may fail with duplicate protocol buffer message definitions. 40 41# Most functions are snake_case for consistency with other modules, some 42# method names are CamelCase for consistency with XLA. 43# pylint: disable=invalid-name 44 45# Pylint has false positives for type annotations. 46# pylint: disable=invalid-sequence-index 47 48ops = _xla.ops 49profiler = _xla.profiler 50 51# Just an internal arbitrary increasing number to help with backward-compatible 52# changes. 53_version = 32 54 55xla_platform_names = { 56 'cpu': 'Host', 57 'gpu': 'CUDA', 58} 59 60 61def make_interpreter_client(): 62 return _xla.get_interpreter_client() 63 64 65def make_cpu_client(*, use_tfrt=False): 66 if use_tfrt: 67 return _xla.get_tfrt_cpu_client(asynchronous=True) 68 else: 69 return _xla.get_cpu_client(asynchronous=True) 70 71 72def make_gpu_client(distributed_client=None, node_id=0): 73 """Returns a GPU client. BFC allocator is used by default.""" 74 allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower() 75 memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION') 76 preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE') 77 if allocator not in ('default', 'platform', 'bfc', 'cuda_async'): 78 raise ValueError( 79 'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", ' 80 '"bfc", or "cuda_async", got "%s"' % allocator) 81 config = _xla.GpuAllocatorConfig() 82 if allocator == 'default': 83 config.kind = _xla.GpuAllocatorConfig.Kind.DEFAULT 84 if allocator == 'platform': 85 config.kind = _xla.GpuAllocatorConfig.Kind.PLATFORM 86 if allocator == 'bfc': 87 config.kind = _xla.GpuAllocatorConfig.Kind.BFC 88 if allocator == 'cuda_async': 89 config.kind = _xla.GpuAllocatorConfig.Kind.CUDA_ASYNC 90 if memory_fraction: 91 config.memory_fraction = float(memory_fraction) 92 config.preallocate = preallocate not in ('0', 'false', 'False') 93 94 return _xla.get_gpu_client( 95 asynchronous=True, 96 allocator_config=config, 97 distributed_client=distributed_client, 98 node_id=node_id) 99 100 101def make_tpu_client(): 102 return _xla.get_tpu_client(max_inflight_computations=32) 103 104 105# Deprecated client factory API. 106 107# Backend factories, keyed by user-visible name, in increasing priority order. 108_local_backend_factories = collections.OrderedDict([ 109 ('interpreter', make_interpreter_client), 110 ('cpu', make_cpu_client), 111 ('gpu', make_gpu_client), 112 ('tpu', make_tpu_client), 113]) 114 115 116def register_local_backend_factory(name, factory): 117 _local_backend_factories[name] = factory 118 119 120_local_backends = None 121 122 123def _get_local_backends(): 124 """Instantiates all known local backends.""" 125 global _local_backends 126 if _local_backends is not None: 127 return _local_backends 128 129 _local_backends = collections.OrderedDict() 130 for name, factory in _local_backend_factories.items(): 131 logging.vlog(1, "Initializing backend '%s'" % name) 132 try: 133 backend = factory() 134 except RuntimeError as err: 135 if name == 'cpu': 136 # We always expect CPU to initialize successfully. 137 raise 138 else: 139 # If the backend isn't built into the binary, or if it has no devices, 140 # we expect a RuntimeError. 141 logging.vlog(1, "Error initializing backend '%s': %s" % (name, err)) 142 continue 143 _local_backends[name] = backend 144 return _local_backends 145 146 147def get_local_backend(name=None): 148 """Returns a local backend. 149 150 Args: 151 name: the backend name. If `None`, a default local backend is returned, 152 typically `gpu` if one is present, or `cpu` if not. If a string, the named 153 backend is returned or an exception raised. 154 155 Returns: 156 A LocalBackend object. 157 """ 158 backends = _get_local_backends() 159 if name is not None: 160 try: 161 return backends[name] 162 except KeyError: 163 raise RuntimeError('Unknown backend %s. Available: %s' % 164 (name, list(backends.keys()))) 165 166 return list(backends.values())[-1] 167 168 169class OpMetadata(object): 170 """Python representation of a xla.OpMetadata protobuf.""" 171 __slots__ = ('op_type', 'op_name', 'source_file', 'source_line') 172 173 def __init__(self, op_type='', op_name='', source_file='', source_line=0): 174 self.op_type = op_type 175 self.op_name = op_name 176 self.source_file = source_file 177 self.source_line = source_line 178 179 180def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): 181 """Helper for use in source mapping that returns an OpMetadata object.""" 182 full_filename, lineno = inspect.stack()[skip_frames][1:3] 183 filename = os.path.basename(full_filename) 184 return OpMetadata( 185 op_type=op_type, 186 op_name=op_name, 187 source_file=filename, 188 source_line=lineno) 189 190 191PrimitiveType = _xla.PrimitiveType 192 193bfloat16 = _xla.bfloat16_dtype() 194 195XLA_ELEMENT_TYPE_TO_DTYPE = { 196 PrimitiveType.PRED: np.dtype('bool'), 197 PrimitiveType.S8: np.dtype('int8'), 198 PrimitiveType.S16: np.dtype('int16'), 199 PrimitiveType.S32: np.dtype('int32'), 200 PrimitiveType.S64: np.dtype('int64'), 201 PrimitiveType.U8: np.dtype('uint8'), 202 PrimitiveType.U16: np.dtype('uint16'), 203 PrimitiveType.U32: np.dtype('uint32'), 204 PrimitiveType.U64: np.dtype('uint64'), 205 PrimitiveType.BF16: np.dtype(bfloat16), 206 PrimitiveType.F16: np.dtype('float16'), 207 PrimitiveType.F32: np.dtype('float32'), 208 PrimitiveType.F64: np.dtype('float64'), 209 PrimitiveType.C64: np.dtype('complex64'), 210 PrimitiveType.C128: np.dtype('complex128'), 211 PrimitiveType.TUPLE: np.dtype(np.object_), 212 PrimitiveType.TOKEN: np.dtype(np.object_), 213} 214 215# Note the conversion on the key. Numpy has a known issue wherein dtype hashing 216# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, 217# when keying by dtype in this dict, we use the string form of dtypes. 218DTYPE_TO_XLA_ELEMENT_TYPE = { 219 str(dt): et for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items() 220} 221 222 223def dtype_to_etype(dtype): 224 """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" 225 return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] 226 227 228Shape = _xla.Shape 229Shape.__doc__ = """ 230A Shape is an object defined in C++ that duck types like the following class: 231 232class Shape(object): 233 '''Represents an XLA shape. 234 235 A shape is either an array shape, having rank-many integer 236 dimensions and an element type (represented by a Numpy dtype), or it 237 is a tuple shape, having a shape for every tuple component: 238 239 type shape = 240 TupleShape of shape list 241 | ArrayShape of { dimensions: int list; element_type: dtype } 242 ''' 243 244 @staticmethod 245 def tuple_shape(tuple_shapes) -> Shape: 246 "Construct a tuple shape." 247 248 @staticmethod 249 def array_shape(element_type, dimensions, minor_to_major=None) -> Shape: 250 251 @staticmethod 252 def from_pyval(pyval) -> Shape: 253 "Returns a Shape that describes a tuple-tree of Numpy arrays." 254 255 def __init__(self, str) -> Shape: 256 "Parses a shape string." 257 def __eq__(self, other: Shape) -> bool: 258 def __ne__(self, other: Shape) -> bool: 259 def __hash__(self): 260 def __repr__(self): 261 def is_tuple(self) -> bool: 262 def is_array(self) -> bool: 263 def tuple_shapes(self) -> [Shape]: 264 def numpy_dtype(self) -> np.dtype: 265 "Like element_type(), but returns dtype('O') for a tuple shape." 266 def xla_element_type(self) -> PrimitiveType: 267 def element_type(self) -> np.dtype: 268 def dimensions(self) -> (int, int, ...): 269 def rank(self) -> int: 270 def with_major_to_minor_layout_if_absent(self) -> Shape: 271 "Returns a copy with missing layouts set to major-to-minor." 272 273 def to_serialized_proto(self) -> bytes: 274 "Returns 'shape' as a serialized proto." 275""" 276 277ProgramShape = _xla.ProgramShape 278ProgramShape.__doc__ = """ 279A ProgramShape is a C++ object that duck types like the following class. 280 281class ProgramShape(object): 282 def __init__(self, parameter_shapes, result_shape): 283 def parameter_shapes(self) -> [Shape]: 284 def result_shape(self) -> Shape: 285 def __repr__(self): 286""" 287 288ShapeIndex = _xla.ShapeIndex 289ShapeIndex.__doc__ = """ 290A Shape is an object defined in C++ that duck types like the following class: 291 292class ShapeIndex(object): 293 '''Represents an XLA ShapeIndex. 294 295 An index for specifying a particular nested subshape within a shape. Used in 296 ShapeUtil::GetSubshape and other interfaces. ShapeIndex defines a path through 297 the Shape tree where each element of ShapeIndex indexes into a tuple (or 298 nested tuple) within the shape. For a non-nested tuple, an index has a single 299 element. 300 ''' 301 302 def __init__(self, List[int]) -> ShapeIndex: 303 def __eq__(self, other: Shape) -> bool: 304 def __ne__(self, other: Shape) -> bool: 305 def __hash__(self): 306 def __repr__(self): 307""" 308 309 310def shape_from_pyval(pyval): 311 """Returns a Shape that describes a tuple-tree of Numpy arrays.""" 312 313 def convert(pyval): 314 if isinstance(pyval, tuple): 315 return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) 316 else: 317 return Shape.array_shape(pyval.dtype, np.shape(pyval)) 318 319 return convert(pyval) 320 321 322DeviceAssignment = _xla.DeviceAssignment 323DeviceAssignment.__doc__ = """ 324A DeviceAssignment is a C++ object with the following signature. 325 326def create(assignment): 327 '''Builds a device assignment. 328 329 Args: 330 assignment: a 2D numpy array of device ordinal integers, indexed by 331 [replica][computation_in_replica]. 332 Returns: 333 A device assignment. 334 ''' 335 336def replica_count(): 337 '''Returns the number of replicas.''' 338def computation_count(): 339 '''Returns the number of computations per replica.''' 340""" 341 342Device = _xla.Device 343CompileOptions = _xla.CompileOptions 344 345HostBufferSemantics = _xla.HostBufferSemantics 346 347# An Executable is a C++ class that duck types with the following API: 348# class Executable(object): 349# def local_devices(self) -> [Device]: 350# def execute(self, arguments : [Buffer]) -> Buffer: 351# """Execute on one replica with Buffer arguments and return value.""" 352# 353# def size_of_generated_code_in_bytes(self) -> int: 354# """Return generated binary size, or -1 if not known.""" 355# 356# def execute_sharded_on_local_devices(self, arguments: [[Buffer]]) 357# -> [Buffer]: 358# """Execute on many replicas with Buffer arguments and return value. 359# 360# Args: 361# arguments: A sequence of sequences of Buffers. The i'th element of each 362# sequence comprises the arguments for execution on the i'th local 363# device. 364# 365# Returns: 366# A list of the computation's outputs as a list of Buffers for each 367# device. 368# """ 369# 370# There are different implementations of Executable for different backends. 371 372 373def execute_with_python_values(executable, arguments, backend): 374 """Execute on one replica with Python values as arguments and output.""" 375 376 def put(arg): 377 return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) 378 379 arguments = [put(arg) for arg in arguments] 380 outputs = executable.execute(arguments) 381 return [x.to_py() for x in outputs] 382 383 384def execute_with_python_values_replicated(executable, arguments, backend): 385 """Execute on many replicas with Python values as arguments and output. 386 387 Args: 388 executable: the program to run. 389 arguments: a list of lists of Python values indexed by `[replica][arg_num]` 390 to pass as inputs. 391 backend: the backend we are targeting. 392 393 Returns: 394 A list of python values, one per replica. 395 """ 396 devices = executable.local_devices() 397 398 # pylint: disable=g-complex-comprehension 399 def copy_to_devices(pyvals): 400 return [backend.buffer_from_pyval(v, d) for v, d in zip(pyvals, devices)] 401 402 inputs = [copy_to_devices(pyvals) for pyvals in zip(*arguments)] 403 outputs = executable.execute_sharded_on_local_devices(inputs) 404 return [[x.to_py() for x in xs] for xs in zip(*outputs)] 405 406 407class PaddingType(enum.Enum): 408 VALID = 1 409 SAME = 2 410 411 412def window_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, 413 window_strides): 414 """Maps PaddingType or string to pad values (list of pairs of ints).""" 415 if not isinstance(padding_type, (str, PaddingType)): 416 msg = 'padding_type must be str or PaddingType, got {}.' 417 raise TypeError(msg.format(type(padding_type))) 418 419 if isinstance(padding_type, str): 420 if padding_type.upper() == 'VALID': 421 padding_type = PaddingType.VALID 422 elif padding_type.upper() == 'SAME': 423 padding_type = PaddingType.SAME 424 else: 425 msg = 'Unknown padding type string: expected "VALID" or "SAME", got {}.' 426 raise ValueError(msg.format(padding_type)) 427 428 if padding_type == PaddingType.VALID: 429 return [(0, 0)] * len(window_strides) 430 elif padding_type == PaddingType.SAME: 431 out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) 432 pad_sizes = [ 433 max((out_size - 1) * stride + filter_size - in_size, 0) 434 for out_size, stride, filter_size, in_size in zip( 435 out_shape, window_strides, rhs_dims, lhs_dims) 436 ] 437 return [(pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes] 438 else: 439 msg = 'Unexpected PaddingType value: {}' 440 raise ValueError(msg.format(padding_type)) 441 442 443XlaBuilder = _xla.XlaBuilder 444XlaComputation = _xla.XlaComputation 445XlaOp = _xla.XlaOp 446FftType = _xla.FftType 447Client = _xla.Client 448Buffer = _xla.Buffer 449DeviceArrayBase = _xla.DeviceArrayBase 450Executable = _xla.Executable 451 452 453def register_custom_call_target(name, fn, platform='cpu'): 454 """Registers a custom call target. 455 456 Args: 457 name: bytes containing the name of the function. 458 fn: a PyCapsule object containing the function pointer. 459 platform: the target platform. 460 """ 461 # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM" 462 # Since that is hardcoded to CUDA, we are using the following as workaround. 463 _xla.register_custom_call_target(name, fn, 464 xla_platform_names.get(platform, platform)) 465 466 467# Deprecated. Use register_custom_call_target instead. 468register_cpu_custom_call_target = register_custom_call_target 469 470 471class PaddingConfigDimension(object): 472 """Python representation of a xla.PaddingConfigDimension protobuf.""" 473 __slots__ = ('edge_padding_low', 'edge_padding_high', 'interior_padding') 474 475 def __init__(self): 476 self.edge_padding_low = 0 477 self.edge_padding_high = 0 478 self.interior_padding = 0 479 480 481class PaddingConfig(object): 482 """Python representation of a xla.PaddingConfig protobuf.""" 483 __slots__ = ('dimensions',) 484 485 def __init__(self): 486 self.dimensions = [] 487 488 489def make_padding_config( 490 padding_config: Union[PaddingConfig, Sequence[Tuple[int, int, int]]] 491) -> PaddingConfig: 492 """Create PaddingConfig proto from list of triples of integers. 493 494 Args: 495 padding_config: either a PaddingConfig or a list of integer triples 496 (edge_padding_low, edge_padding_high, interior_padding) representing the 497 configuration of the padding operation. 498 499 Returns: 500 A `PaddingConfig` object. 501 """ 502 if not isinstance(padding_config, PaddingConfig): 503 triples = padding_config 504 padding_config = PaddingConfig() 505 for lo, hi, interior in triples: 506 dimension = PaddingConfigDimension() 507 dimension.edge_padding_low = lo 508 dimension.edge_padding_high = hi 509 dimension.interior_padding = interior 510 padding_config.dimensions.append(dimension) 511 return padding_config 512 513 514class DotDimensionNumbers(object): 515 """Python representation of a xla.DotDimensionNumbers protobuf.""" 516 __slots__ = ('lhs_contracting_dimensions', 'rhs_contracting_dimensions', 517 'lhs_batch_dimensions', 'rhs_batch_dimensions') 518 519 def __init__(self): 520 self.lhs_contracting_dimensions = [] 521 self.rhs_contracting_dimensions = [] 522 self.lhs_batch_dimensions = [] 523 self.rhs_batch_dimensions = [] 524 525 526def make_dot_dimension_numbers( 527 dimension_numbers: Union[DotDimensionNumbers, 528 Tuple[Tuple[List[int], List[int]], 529 Tuple[List[int], List[int]]]] 530) -> DotDimensionNumbers: 531 """Builds a DotDimensionNumbers object from a specification. 532 533 Args: 534 dimension_numbers: either a `DotDimensionNumbers` or a nested tuple 535 `((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))` of lists of 536 integers representing the dimensions to treat as contracting dimensions 537 and batch dimensions on each input operand. 538 539 Returns: 540 A `DotDimensionNumbers` object. 541 """ 542 if isinstance(dimension_numbers, (list, tuple)): 543 (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers 544 dot_dims_proto = DotDimensionNumbers() 545 dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) 546 dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) 547 dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) 548 dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) 549 return dot_dims_proto 550 else: 551 return dimension_numbers 552 553 554class ConvolutionDimensionNumbers(object): 555 """Python representation of a xla.ConvolutionDimensionNumbers protobuf.""" 556 __slots__ = ('input_batch_dimension', 'input_feature_dimension', 557 'input_spatial_dimensions', 'kernel_input_feature_dimension', 558 'kernel_output_feature_dimension', 'kernel_spatial_dimensions', 559 'output_batch_dimension', 'output_feature_dimension', 560 'output_spatial_dimensions') 561 562 def __init__(self): 563 self.input_batch_dimension = 0 564 self.input_feature_dimension = 0 565 self.input_spatial_dimensions = [] 566 self.kernel_input_feature_dimension = 0 567 self.kernel_output_feature_dimension = 0 568 self.kernel_spatial_dimensions = [] 569 self.output_batch_dimension = 0 570 self.output_feature_dimension = 0 571 self.output_spatial_dimensions = [] 572 573 574def make_convolution_dimension_numbers( 575 dimension_numbers: Union[None, ConvolutionDimensionNumbers, Tuple[str, str, 576 str]], 577 num_spatial_dimensions: int) -> ConvolutionDimensionNumbers: 578 """Builds a ConvolutionDimensionNumbers object from a specification. 579 580 Args: 581 dimension_numbers: optional, either a ConvolutionDimensionNumbers object or 582 a tuple (lhs_spec, rhs_spec, out_spec). Each element is a string of 583 length N+2 identifying by position: (1) batch dimensions in lhs, rhs, and 584 the output with the character 'N', (2) feature dimensions in lhs and the 585 output with the character 'C', (3) input and output feature dimensions 586 in rhs with the characters 'I' and 'O' respectively, and (4) spatial 587 dimension correspondences between lhs, rhs, and the output using any 588 distinct characters. For example, to indicate dimension numbers 589 consistent with the Conv operation with two spatial dimensions, one 590 could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate 591 dimension numbers consistent with the TensorFlow Conv2D operation, one 592 could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of 593 convolution dimension specification, window strides are associated with 594 spatial dimension character labels according to the order in which the 595 labels appear in the rhs_spec string, so that window_strides[0] is 596 matched with the dimension corresponding to the first character 597 appearing in rhs_spec that is not 'I' or 'O'. By default, use the same 598 dimension numbering as Conv and ConvWithGeneralPadding. 599 num_spatial_dimensions: the number of spatial dimensions. 600 601 Returns: 602 A `ConvolutionDimensionNumbers` object. 603 """ 604 if dimension_numbers is None: 605 nd = num_spatial_dimensions 606 dimension_numbers = ConvolutionDimensionNumbers() 607 dimension_numbers.input_batch_dimension = 0 608 dimension_numbers.input_feature_dimension = 1 609 dimension_numbers.output_batch_dimension = 0 610 dimension_numbers.output_feature_dimension = 1 611 dimension_numbers.kernel_output_feature_dimension = 0 612 dimension_numbers.kernel_input_feature_dimension = 1 613 dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) 614 dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) 615 dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) 616 elif isinstance(dimension_numbers, tuple): 617 lhs_spec, rhs_spec, out_spec = dimension_numbers 618 dimension_numbers = ConvolutionDimensionNumbers() 619 620 dimension_numbers.input_batch_dimension = lhs_spec.index('N') 621 dimension_numbers.input_feature_dimension = lhs_spec.index('C') 622 dimension_numbers.output_batch_dimension = out_spec.index('N') 623 dimension_numbers.output_feature_dimension = out_spec.index('C') 624 dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') 625 dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') 626 627 dimension_numbers.kernel_spatial_dimensions.extend( 628 i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) 629 dimension_numbers.input_spatial_dimensions.extend( 630 sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), 631 key=lambda i: rhs_spec.index(lhs_spec[i]))) 632 dimension_numbers.output_spatial_dimensions.extend( 633 sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), 634 key=lambda i: rhs_spec.index(out_spec[i]))) 635 return dimension_numbers 636 637 638class OpSharding(object): 639 """Python representation of a xla.OpSharding protobuf.""" 640 __slots__ = ('type', 'tile_assignment_dimensions', 'tile_assignment_devices', 641 'tuple_shardings', 'replicate_on_last_tile_dim') 642 643 Type = _xla.OpSharding_Type 644 645 def __init__(self): 646 self.type = self.Type.REPLICATED 647 self.tile_assignment_dimensions = [] 648 self.tile_assignment_devices = [] 649 self.tuple_shardings = [] 650 self.replicate_on_last_tile_dim = False 651 652 653class PrecisionConfig(object): 654 """Python representation of a xla.PrecisionConfig protobuf.""" 655 __slots__ = ('operand_precision',) 656 657 Precision = _xla.PrecisionConfig_Precision 658 659 def __init__(self): 660 self.operand_precision = [] 661 662 663class GatherDimensionNumbers(object): 664 """Python representation of a xla.GatherDimensionNumbers protobuf.""" 665 __slots__ = ('offset_dims', 'collapsed_slice_dims', 'start_index_map', 666 'index_vector_dim') 667 668 def __init__(self): 669 self.offset_dims = [] 670 self.collapsed_slice_dims = [] 671 self.start_index_map = [] 672 self.index_vector_dim = 0 673 674 675class ScatterDimensionNumbers(object): 676 """Python representation of a xla.ScatterDimensionNumbers protobuf.""" 677 __slots__ = ('update_window_dims', 'inserted_window_dims', 678 'scatter_dims_to_operand_dims', 'index_vector_dim') 679 680 def __init__(self): 681 self.update_window_dims = [] 682 self.inserted_window_dims = [] 683 self.scatter_dims_to_operand_dims = [] 684 self.index_vector_dim = 0 685 686 687class ReplicaGroup(object): 688 """Python representation of a xla.ReplicaGroup protobuf.""" 689 __slots__ = ('replica_ids',) 690 691 def __init__(self): 692 self.replica_ids = [] 693 694 695def _make_replica_group_proto(replica_group): 696 replica_group_proto = ReplicaGroup() 697 replica_group_proto.replica_ids.extend(replica_group) 698 return replica_group_proto 699 700 701def make_replica_groups(replica_groups): 702 if replica_groups is None: 703 replica_groups_protos = [] # special value for XLA API 704 else: 705 replica_groups = list(replica_groups) 706 replica_groups_protos = [ 707 _make_replica_group_proto(group) for group in replica_groups 708 ] 709 return replica_groups_protos 710 711 712Traceback = _xla.Traceback 713Frame = _xla.Frame 714 715 716@contextlib.contextmanager 717def tracebacks(enabled=True): 718 """Context manager that enables or disables traceback collection.""" 719 saved = Traceback.enabled 720 Traceback.enabled = enabled 721 try: 722 yield 723 finally: 724 Traceback.enabled = saved 725 726 727def heap_profile(client: Client) -> bytes: 728 """Returns a gzipped pprof protocol buffer containing a heap profile.""" 729 return gzip.compress(client.heap_profile()) 730 731 732# Perform one last garbage collection of deferred Python references. This is 733# mostly to keep ASAN happy. 734atexit.register(_xla.collect_garbage) 735