1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16syntax = "proto3"; 17 18package xla; 19 20option cc_enable_arenas = true; 21 22// Primitive types are the individual values that can be held in rectangular 23// multidimensional arrays. A description of the rectangular multidimensional 24// array dimensions / primitive type is given by Shape, below. 25// 26// LINT.IfChange 27enum PrimitiveType { 28 // Invalid primitive type to serve as default. 29 PRIMITIVE_TYPE_INVALID = 0; 30 31 // Predicates are two-state booleans. 32 PRED = 1; 33 34 // Signed integral values of fixed width. 35 S8 = 2; 36 S16 = 3; 37 S32 = 4; 38 S64 = 5; 39 40 // Unsigned integral values of fixed width. 41 U8 = 6; 42 U16 = 7; 43 U32 = 8; 44 U64 = 9; 45 46 // Floating-point values of fixed width. 47 // 48 // Note: if f16s are not natively supported on the device, they will be 49 // converted to f16 from f32 at arbirary points in the computation. 50 F16 = 10; 51 F32 = 11; 52 53 // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit 54 // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent 55 // and 7 bits for the mantissa. 56 BF16 = 16; 57 58 F64 = 12; 59 60 // Complex values of fixed width. 61 C64 = 15; // Paired F32 (real, imag), as in std::complex<float>. 62 C128 = 18; // Paired F64 (real, imag), as in std::complex<double>. 63 64 // A tuple is a polymorphic sequence; e.g. a shape that holds different 65 // sub-shapes. They are used for things like returning multiple values from a 66 // computation; e.g. a computation that returns weights and biases may have a 67 // signature that results in a tuple like (f32[784x2000], f32[2000]) 68 // 69 // If a shape proto has the tuple element type, it may not have any entries 70 // in the dimensions field. 71 TUPLE = 13; 72 73 // An opaque type used for passing context-specific data to a custom 74 // operation. Shapes of this primitive type will have empty dimensions and 75 // tuple_shapes fields. 76 // 77 // (OPAQUE would be a better name for this identifier, but that conflicts with 78 // a macro defined in windows.h.) 79 OPAQUE_TYPE = 14; 80 81 // A token type threaded between side-effecting operations. Shapes of this 82 // primitive type will have empty dimensions and tuple_shapes fields. 83 TOKEN = 17; 84 85 // Next = 19 86} 87// LINT.ThenChange( 88// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, 89// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc 90// ) 91 92// Describes the padding configuration for Pad operation. The padding amount on 93// both edges as well as between the elements are specified for each dimension. 94message PaddingConfig { 95 // Describes the padding configuration for a dimension. 96 message PaddingConfigDimension { 97 // Padding amount on the low-end (next to the index 0). May be negative. 98 int64 edge_padding_low = 1; 99 100 // Padding amount on the high-end (next to the highest index). May be 101 // negative. 102 int64 edge_padding_high = 2; 103 104 // Padding amount between the elements. May not be negative. 105 int64 interior_padding = 3; 106 } 107 108 // The padding configuration for all dimensions. 109 repeated PaddingConfigDimension dimensions = 1; 110} 111 112// A DimLevelType indicates the encoding method for a dimension in an array. 113// The semantics of this field are identical to those of the MLIR SparseTensor 114// dialect. 115// This should be kept in sync with the SparseTensor DimLevelType enum: 116// https://github.com/llvm/llvm-project/blob/5674a3c88088e668b684326c2194a6282e8270ff/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td#L86 117enum DimLevelType { 118 // The corresponding dimension is Dense, every entry is stored. 119 DIM_DENSE = 0; 120 // The corresponding dimension is Compressed, only nonzeros are stored. 121 DIM_COMPRESSED = 1; 122 // The corresponding dimension contains a single coordinate, no sibling 123 // elements for each parent. 124 DIM_SINGLETON = 2; 125} 126 127// Describes a tile used in tiling-based layout. Refer to 128// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for 129// details about tiling-based layout. 130message TileProto { 131 // Number of elements in each dimension of the tile. It's ordered from the 132 // most major dimension of the tile to the most minor dimension of the tile. 133 // The dimensions correspond to a suffix of the dimensions of the shape being 134 // tiled. 135 repeated int64 dimensions = 1; 136} 137 138// A layout describes how the array is placed in (1D) memory space. This 139// includes the minor-to-major ordering of dimensions within a shape. 140// 141// Clients must specify the layouts of input Literals to the 142// computation. Layouts specified in interior operations which take Shapes (for 143// example, Convert) are ignored. 144// 145// See the XLA documentation for more information on shapes and layouts. 146// 147// LINT.IfChange 148message LayoutProto { 149 // The dimension level type list for this array, specifying the way in which 150 // each array dimension is represented in memory. If this list is empty, the 151 // array is assumed to be dense. 152 repeated DimLevelType dim_level_types = 9; 153 154 // Sequence of dimension numbers, from minor (fastest varying index) to major 155 // (slowest varying index). This field is required. 156 repeated int64 minor_to_major = 1; 157 158 // A sequence of tiles, starting from the tile that's applied first to the 159 // Shape. 160 // 161 // TODO(b/119839262): implement tiling in each backend or add Unimplemented 162 // error. 163 repeated TileProto tiles = 6; 164 165 // Bit size of each element. If the size is bigger than what the element 166 // type requires, the value is stored in the least significant 167 // bits and the additional most significant bits are filled with 0's. 168 // 169 // TODO(b/119839262): implement in each backend or add Unimplemented error. 170 int64 element_size_in_bits = 7; 171 172 // Memory space where this array resides. The integer field is interpreted in 173 // a backend-specific manner. 174 int64 memory_space = 8; 175 176 // Important: if any field is added, be sure to modify ShapeUtil::Equal() and 177 // LayoutUtil::Hash appropriately to account for the new field. 178 179 reserved 2; 180 reserved "padded_dimensions"; 181 reserved 3; 182 reserved "padding_value"; 183 reserved 4; 184 reserved "format"; 185 reserved 5; 186 reserved "max_sparse_elements"; 187} 188// LINT.ThenChange( \ 189// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ 190// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) 191 192// A shape describes the number of dimensions in the array, the size of each 193// dimension, and the primitive component type. 194// 195// Tuples are a special case in that they have rank zero and have tuple_shapes 196// defined. 197// 198// See the XLA documentation for more information on shapes and layouts. 199// 200// LINT.IfChange 201message ShapeProto { 202 reserved 1; 203 reserved "rank"; 204 205 // The element type for this shape. 206 PrimitiveType element_type = 2; 207 208 // The size (number of elements) for each dimension, or an upper bound on the 209 // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 210 // to N-1 for an N-dimensional array. The first element of 'dimensions' is the 211 // size of dimension 0, the second element is the size of dimension 1, and so 212 // forth. Empty list indicates a scalar. 213 // 214 // If the respective element in 'is_dimension_dynamic' is true then the value 215 // in this field represents an upper bound on the size of the dimension. 216 repeated int64 dimensions = 3; 217 218 // For tuples only, the shapes of constituent shapes in the tuple sequence. 219 repeated ShapeProto tuple_shapes = 4; 220 221 // The layout used to back this shape. 222 LayoutProto layout = 5; 223 224 // For arrays, this indicates whether or not each dimension is 225 // dynamically-sized. The number of elements in this repeated field should be 226 // zero (indicating that no dimensions are dynamic) or equal to the number of 227 // elements in the 'dimensions' field. 228 repeated bool is_dynamic_dimension = 6; 229 230 // Important: if any field is added, be sure to modify ShapeUtil::Equal(), 231 // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for 232 // the new field. 233} 234// LINT.ThenChange( \ 235// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) 236 237// Shape of the parameters and output of a computation (like a traditional 238// function signature). 239message ProgramShapeProto { 240 repeated ShapeProto parameters = 1; 241 ShapeProto result = 2; 242 repeated string parameter_names = 3; 243} 244 245// Statistics of a computation. 246message ComputationStats { 247 // The number of floating point operations in the computation. 248 double flop_count = 1; 249 250 // The number of transcendental operations (e.g., exp) in the computation. 251 double transcendental_count = 2; 252} 253 254// The type optimization profiles in use for Op-level optimizations. 255enum ProfileType { 256 INVALID = 0; 257 WINDOW = 1; 258 FLAG = 2; 259 INTEGER = 3; 260} 261 262// The source of the optimization profile. 263enum ProfileSource { 264 PROFILE_SOURCE_UNKNOWN_SOURCE = 0; 265 PROFILE_SOURCE_EMBEDDED = 1; 266 PROFILE_SOURCE_REMOTE = 2; 267} 268 269// The compilation event that triggered the use of the profile. 270enum CompilationEvent { 271 COMPILATION_EVENT_UNKNOWN_EVENT = 0; 272 COMPILATION_EVENT_FIRST_COMPILATION = 1; 273 COMPILATION_EVENT_RECOMPILATION = 2; 274} 275 276// Symbolization metadata for HLO Instructions. 277// 278// This metadata is used for debugging XLA code generation, as well as 279// performance profiling of XLA-generated executables. 280message OpMetadata { 281 // The framework op name that generated this XLA op. 282 // 283 // Frameworks that build on top of XLA should mirror the names of their ops 284 // back to users by specifying the op_type. In this way, even if the 285 // framework's "ops" are implemented as multiple XLA HLO Ops, they can be 286 // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as 287 // multiple ops, then each op should have the op_type be "SoftMax".) 288 string op_type = 1; 289 // The user-specified name of the op. 290 // 291 // This name is often unique within a computation. Note: some frameworks 292 // add auto-generated names if the user does not provide one. 293 string op_name = 2; 294 // Indicate a file and line that this op is associated to in a user's program. 295 // 296 // e.g. it could be the file and line of user code that generated the op. 297 string source_file = 3; 298 int32 source_line = 4; 299 300 // Deprecated, use [ProfileInfo][profile_type] instead. 301 repeated ProfileType profile_type = 5 [deprecated = true]; 302 303 // HloPassMetadata.pass_id of the pass that created this HLO instruction 304 // object. Should never be copied between HLO instructions. Zero if unset and 305 // -1 if the instruction was created before HLO passes began. 306 int64 creation_pass_id = 6; 307 308 // HloPassMetadata.pass_id of the pass that created the logical functionality 309 // that this HLO instruction represents. Should be copied between HLO 310 // instructions that correspond across compilation passes. Zero if unset and 311 // -1 if the instruction was created before HLO passes began. 312 int64 logical_creation_pass_id = 7; 313 314 // The footprint of the generated code for the instruction. 315 int64 size_of_generated_code_in_bytes = 8; 316 // The size of the working set, i.e., the amount of memory, used by the 317 // instruction in a compiler-managed fast device memory. 318 int64 size_of_memory_working_set_in_bytes = 9; 319 320 // Information about the optimization profile that this operation contains. 321 message ProfileInfo { 322 // The type of optimization profiles that this operation contains. 323 repeated ProfileType profile_type = 1; 324 // Speedup of tuned config compared to default config. 325 // TODO(b/203817882) Set the relative_speedup. 326 double relative_speedup = 2; 327 // The source of the optimization profiles that this operation contains. 328 ProfileSource profile_source = 3; 329 // The compilation event that triggered the use of the profiles. 330 CompilationEvent compilation_event = 4; 331 } 332 333 // Profile information for the Op. 334 ProfileInfo profile_info = 10; 335 336 // Information about the replaced operation. We store the HLO canonical text, 337 // which contains canonicalized instruction name, operand shape, and full 338 // bodies of the subcomputations besides the HLO itself. This is set 339 // independently of other fields. 340 string replaced_op = 11; 341} 342 343// Profile data from the execution of a computation. 344message ExecutionProfile { 345 // Whether the executable was read from the compilation cache. 346 bool compilation_cache_hit = 1; 347 348 // The time in milliseconds spent to compile the computation. This only set if 349 // the executable was not read from the compilation cache 350 // (compilation_cache_hit == false). 351 int64 compile_time_ms = 2; 352 353 // The number of cycles spent for the computation. This does not include the 354 // time taken for the data transfers between the host and the device. This is 355 // a target-dependent field and only used for debugging purposes. 356 int64 compute_cycle_count = 3; 357 358 // The time in nanoseconds spent for the computation, without data transfer. 359 int64 compute_time_ns = 4; 360 361 // The time in nanoseconds spent for the entire computation, including the 362 // result data transfer time. Current implementation does not spend any cycles 363 // for the input data transfer since the memory is initialized with the proper 364 // values before the execution. 365 int64 compute_and_transfer_time_ns = 5; 366 367 // The size of the binary code in the executable. 368 int64 executable_size_in_bytes = 6; 369 370 // Whether this profile was drawn from a cache of profiles instead of from 371 // execution on the hardware. 372 bool profile_cache_hit = 7; 373} 374 375// Handle given to a user that represents an execution that the user launched 376// asynchronously on the device. 377message ExecutionHandle { 378 int64 handle = 1; 379} 380 381// Handle given to a user that represents a globally accessible allocation. 382// Contrast this against a ComputationDataHandle, which is not globally 383// accessible, since it only exists within a specific computation. 384message GlobalDataHandle { 385 int64 handle = 1; 386} 387 388// Handle given to a user that represents a replicated virtual device. Each 389// replicated device represents N physical devices for execution where N is the 390// number of replicas. 391message DeviceHandle { 392 int64 handle = 1; 393 394 // The number of model-parallel virtual devices that communicate via XLA 395 // Send/Recv instructions. 396 int64 device_count = 2; 397} 398 399// Handle given to a user to represent a channel between two computations 400// via a Send and Recv instruction pair. Channels are unbuffered, so Send 401// Send instructions will be blocked until the data is transferred. 402message ChannelHandle { 403 int64 handle = 1; 404 enum ChannelType { 405 // Invalid primitive type to serve as default. 406 CHANNEL_TYPE_INVALID = 0; 407 408 // A channel for sending data between devices. 409 DEVICE_TO_DEVICE = 1; 410 411 // A channel for sending data from the device to the host. Can only be used 412 // with a Send operation. 413 DEVICE_TO_HOST = 2; 414 415 // A channel for sending data from the host to the device. Can only be used 416 // with a Recv operation. 417 HOST_TO_DEVICE = 3; 418 } 419 ChannelType type = 2; 420} 421 422// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which 423// represents the device ids assigned to a set of replicated computations. 424// See xla::DeviceAssignment class comment for more details. 425message DeviceAssignmentProto { 426 int32 replica_count = 1; 427 int32 computation_count = 2; 428 429 // Each logical computation runs on replica_count physical devices. 430 // ComputationDevice represents the device ids assinged to the replicas. 431 message ComputationDevice { 432 repeated int32 replica_device_ids = 1; 433 } 434 repeated ComputationDevice computation_devices = 3; 435} 436 437// Literals are used when the server and client need to exchange materialized 438// data / results. Literals are also used to describe constants used in 439// computations. 440// 441// Transfers to/from the client are encoded in literal form, and the structure 442// of the repeated fields is implied by the shape. 443message LiteralProto { 444 ShapeProto shape = 1; 445 repeated bool preds = 2; 446 bytes s8s = 15; 447 bytes u8s = 3; 448 repeated int32 s32s = 4; 449 repeated int64 s64s = 5; 450 repeated uint32 u32s = 6; 451 repeated uint64 u64s = 7; 452 repeated float f32s = 8; 453 repeated double f64s = 9; 454 repeated float c64s = 12; // Stored as interleaved real, imag floats. 455 repeated double c128s = 18; // Stored as interleaved real, imag doubles. 456 repeated LiteralProto tuple_literals = 10; 457 // The F16s, BF16s, U16s and S16s are encoded in little endian byte order 458 bytes f16s = 11; 459 bytes bf16s = 13; 460 bytes u16s = 16; 461 bytes s16s = 17; 462 repeated int64 sparse_indices = 14; 463 // Next = 19 464} 465 466message WindowDimension { 467 // The size of the window in this dimension. For a rectangle, this would be 468 // the width or height. 469 int64 size = 1; 470 471 // The stride at which the window moves across the base area in this 472 // dimension. In other words, this is the spacing between different 473 // positions of the window in this dimension. 474 int64 stride = 2; 475 476 // If positive, means the amount of padding to add to the base area at the low 477 // end of this dimension; if negative, its negative means the number of 478 // elements removed from the low end of this dimension. For example, in the 479 // horizontal dimension of a rectangle, this would be the number of padding 480 // values to pad on the left, given that indices increase when going right. 481 // The actual padding value depends upon the context. Convolution pads with 482 // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's 483 // init value. 484 int64 padding_low = 3; 485 486 // As padding_low, but on the high end of this dimension. For example, in the 487 // horizontal dimension of a rectangle, this would be the number of values to 488 // pad on the right, given that indices increase when going right. 489 int64 padding_high = 4; 490 491 // Dilation factor of the sliding window in this dimension. A dilation factor 492 // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are 493 // implicitly placed between each kernel element. This value may not be less 494 // than 1. See documentation for convolution. 495 int64 window_dilation = 5; 496 497 // Dilation factor of the base area in this dimension. A dilation factor of 1 498 // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly 499 // placed between each base area element. This value may not be less than 1. 500 // See documentation for convolution. 501 int64 base_dilation = 6; 502 503 // Window reversal means that this dimension was logically reversed before the 504 // operation. 505 bool window_reversal = 7; 506} 507 508// Describes the windowing in an operation such as convolution. 509// 510// The window is moved across a base area and for each position of the 511// window a computation is performed. The field below describes the 512// window and the movement of the window across a base area. 513message Window { 514 repeated WindowDimension dimensions = 1; 515} 516 517// Describes the dimension numbers for a gather operation. 518// 519// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for 520// more details. 521message GatherDimensionNumbers { 522 // "Window indices" is a term for a set of indices that index into the 523 // interior of a dynamic-slice from the input tensor, the starting indices for 524 // which were computed from output_gather_dims (see the operation semantic for 525 // how this is defined) and the start_indices tensor. 526 // 527 // The window indices for a specific output index Out is computed as: 528 // 529 // i = 0 530 // for (k : [0, input_tensor_shape.rank)) 531 // window_indices[k] = 532 // if k in collapsed_slice_dims 533 // then 0 534 // else Out[offset_dims[i++]] 535 repeated int64 offset_dims = 1; 536 repeated int64 collapsed_slice_dims = 2; 537 538 // This is interpreted as a map from i to start_index_map[i]. It 539 // transforms the gather index looked up from the start_indices tensor into 540 // the starting index in the input space. 541 repeated int64 start_index_map = 3; 542 543 // The dimension in the start_indices input that contains the starting 544 // indices. 545 int64 index_vector_dim = 4; 546} 547 548// Describes the dimension numbers for a scatter operation. 549// 550// All the fields are similar to the corresponding fields in 551// GatherDimensionNumbers. Differences are noted below. 552message ScatterDimensionNumbers { 553 // The set of dimensions in the updates shape that are window dimensions. 554 repeated int64 update_window_dims = 1; 555 // The set of window dimensions that must be inserted into the updates shape. 556 repeated int64 inserted_window_dims = 2; 557 558 repeated int64 scatter_dims_to_operand_dims = 3; 559 int64 index_vector_dim = 4; 560} 561 562message ConvolutionDimensionNumbers { 563 // The number of the dimension that represents batch in the input. 564 int64 input_batch_dimension = 7; 565 566 // The number of the dimension that represents features in the input. 567 int64 input_feature_dimension = 8; 568 569 // The dimension numbers for the spatial dimensions that the window 570 // moves through in the input. 571 repeated int64 input_spatial_dimensions = 11; 572 573 // The number of the dimension that represents input features in the 574 // convolutional kernel (rhs). 575 int64 kernel_input_feature_dimension = 3; 576 577 // The number of the dimension that represents output features in 578 // the convolutional kernel (rhs). 579 int64 kernel_output_feature_dimension = 4; 580 581 // The dimension numbers for the spatial dimensions that the window 582 // moves through in the kernel (rhs). window.strides(0) is the 583 // stride in the kernel_spatial_dimensions(0) dimension. 584 repeated int64 kernel_spatial_dimensions = 6; 585 586 // The number of the dimension that represents batch in the output. 587 int64 output_batch_dimension = 9; 588 589 // The number of the dimension that represents features in the output. 590 int64 output_feature_dimension = 10; 591 592 // The dimension numbers for the spatial dimensions that the window 593 // moves through in the output. 594 repeated int64 output_spatial_dimensions = 12; 595 596 // Next = 13 597} 598 599enum PaddingType { 600 PADDING_INVALID = 0; 601 PADDING_VALID = 1; // Only valid portion of the base are covered. 602 PADDING_SAME = 2; // Extra is added to produce same output size as the input. 603} 604 605enum FftType { 606 FFT = 0; // Forward FFT; complex in, complex out. 607 IFFT = 1; // Inverse FFT; complex in, complex out. 608 RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out 609 IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, 610 // fft_length real out 611} 612 613message DotDimensionNumbers { 614 // The dimension numbers that represent the 'lhs' contracting dimensions. 615 repeated int64 lhs_contracting_dimensions = 1; 616 // The dimension numbers that represent the 'rhs' contracting dimensions. 617 repeated int64 rhs_contracting_dimensions = 2; 618 // The dimension numbers that represent the 'lhs' batch dimensions. 619 repeated int64 lhs_batch_dimensions = 3; 620 // The dimension numbers that represent the 'rhs' batch dimensions. 621 repeated int64 rhs_batch_dimensions = 4; 622} 623 624enum RandomDistribution { 625 RNG_INVALID = 0; 626 627 // Creates a uniform-distribution-generated random number on the semi-open 628 // interval [parameter[0], parameter[1]). 629 RNG_UNIFORM = 1; 630 631 // Creates a normal-distribution-generated random number with mean 632 // parameter[0] and standard deviation parameter[1]. 633 RNG_NORMAL = 2; 634 635 // Next: 4 636} 637 638enum RandomAlgorithm { 639 RNG_DEFAULT = 0; // Backend dependent default algorithm. 640 RNG_THREE_FRY = 1; 641 RNG_PHILOX = 2; 642 // Next: 2 643} 644 645message TriangularSolveOptions { 646 // If true, solves ax = b. If false, solves xa = b. 647 bool left_side = 1; 648 649 // If true, 'a' is lower triangular. If false, 'a' is upper triangular. 650 bool lower = 2; 651 652 // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. 653 bool unit_diagonal = 3; 654 655 // Should we transpose or use the adjoint of 'a'? 656 enum Transpose { 657 TRANSPOSE_INVALID = 0; 658 NO_TRANSPOSE = 1; // Don't transpose 'a'. 659 TRANSPOSE = 2; // Transpose 'a'. 660 ADJOINT = 3; // Complex conjugate and transpose 'a'. 661 } 662 Transpose transpose_a = 4; 663} 664 665message CholeskyOptions { 666 // If true, uses the lower triangle of `a`. If false, uses the upper triangle 667 // of `a`. 668 bool lower = 1; 669} 670 671// Generic map of attributes used to pass hints / configuration options from 672// the Python frontend to the XLA backend. 673message FrontendAttributes { 674 map<string, string> map = 1; 675} 676 677// LINT.IfChange 678message OpSharding { 679 enum Type { 680 // This sharding is replicated across all devices (implies maximal, 681 // all other fields are unused). 682 REPLICATED = 0; 683 // This sharding is maximal - one device runs the entire operation. 684 MAXIMAL = 1; 685 // This sharding is a tuple - only the tuple_shardings field is valid. 686 TUPLE = 2; 687 // None of the above; tile_shape and tile_assignment are both used. 688 OTHER = 3; 689 // This op is manually sharded: the shapes are already partitioned and the 690 // partitioner should not change this op. 691 MANUAL = 4; 692 } 693 Type type = 1; 694 // The shape of the sharded tile. 695 ShapeProto tile_shape = 2; 696 // The shape of the tile assignment tensor - this must be the same rank as 697 // tile_shape and the product of its dimensions must equal 698 // tile_assignment_devices.size(). 699 repeated int64 tile_assignment_dimensions = 3; 700 // Flattened list of device IDs. The order of flattening is the same as used 701 // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). 702 repeated int64 tile_assignment_devices = 4; 703 // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, 704 // in pre-order. The tuple shape could be nested; here we store just a 705 // flattened list of all leaves in the tuple shape. Note that the tuple shape 706 // is not stored here; shardings do not store the shapes to which they are 707 // applied, this is inferred from the instruction this sharding gets attached 708 // to. 709 repeated OpSharding tuple_shardings = 5; 710 711 // Only used for OTHER type. If true, data is sharded according to other 712 // dimensions of tile_assignment(), but replicated across devices along the 713 // last dimension. (Experimental) 714 bool replicate_on_last_tile_dim = 6; 715 // This field is used to track the source of this sharding, usually derived 716 // from instructions. Multple metadata may be populated if sharding is 717 // combined with other shardings. Metadata are to not be populated when 718 // type == TUPLE and instead metadata should be set on individual tuple 719 // elements. 720 repeated OpMetadata metadata = 7; 721 722 // This field is used to represented the sharding type of each subgroup. 723 // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={ 724 // replicate, manual, unreduced}} means that each of the last 3 dimensions 725 // in [2,2,2,2] represents a subgrouping in replicate, manual, 726 // unreduced sharding type respectively. 727 repeated Type last_tile_dims = 8; 728} 729// LINT.ThenChange() 730 731// Describes the replica groups in a cross replica op (e.g., all-reduce and 732// all-to-all). 733message ReplicaGroup { 734 // The ids of the replicas that belongs to the same group. The ordering of the 735 // ids matters in some ops (e.g., all-to-all). 736 repeated int64 replica_ids = 1; 737} 738 739// Describes the source target pair in the collective permute op. 740message SourceTarget { 741 int64 source = 1; 742 int64 target = 2; 743} 744 745// Used to indicate the precision configuration. It has backend specific 746// meaning. 747message PrecisionConfig { 748 enum Precision { 749 DEFAULT = 0; 750 HIGH = 1; 751 HIGHEST = 2; 752 753 // Next: 3 754 } 755 repeated Precision operand_precision = 1; 756 757 // Next: 2 758} 759 760// Describes whether all data-parallelism replicas will receive the same 761// parameter data at each buffer. 762message ParameterReplication { 763 // A list of boolean values for the flattened leaf buffers. Each value 764 // indicates whether the corresponding leaf buffer is replicated. 765 // 766 // If this field is empty, it means no buffer is replicated. Otherwise, the 767 // number of elements in this field must match the number of leaf buffers in 768 // the HLO instruction's shape. 769 repeated bool replicated_at_leaf_buffers = 1; 770} 771 772// A backend-config for kWhile loops that stores the loop's trip count, if it is 773// known. 774// 775// This is useful for backends that can implement a `for i in 0..N` loop more 776// efficiently than a `while` loop. For example, on GPUs, we can implement a 777// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, 778// whereas implementing a `while` loop requires a host-device sync on each 779// iteration. 780message WhileLoopBackendConfig { 781 message KnownTripCount { 782 int64 n = 1; 783 } 784 // This indirection lets us distinguish between known-trip-count == 0 and 785 // unknown-trip-count. 786 KnownTripCount known_trip_count = 1; 787} 788 789// Specifies a pair of output/operand buffers for kCustomCall that alias each 790// other. 791message CustomCallOutputOperandAliasing { 792 repeated int64 output_shape_index = 1; 793 int64 operand_index = 2; 794 repeated int64 operand_shape_index = 3; 795} 796