1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16syntax = "proto3"; 17 18package xla; 19 20option cc_enable_arenas = true; 21 22// Primitive types are the individual values that can be held in rectangular 23// multidimensional arrays. A description of the rectangular multidimensional 24// array dimensions / primitive type is given by Shape, below. 25// 26// LINT.IfChange 27enum PrimitiveType { 28 // Invalid primitive type to serve as default. 29 PRIMITIVE_TYPE_INVALID = 0; 30 31 // Predicates are two-state booleans. 32 PRED = 1; 33 34 // Signed integral values of fixed width. 35 S8 = 2; 36 S16 = 3; 37 S32 = 4; 38 S64 = 5; 39 40 // Unsigned integral values of fixed width. 41 U8 = 6; 42 U16 = 7; 43 U32 = 8; 44 U64 = 9; 45 46 // Floating-point values of fixed width. 47 // 48 // Note: if f16s are not natively supported on the device, they will be 49 // converted to f16 from f32 at arbirary points in the computation. 50 F16 = 10; 51 F32 = 11; 52 53 // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit 54 // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent 55 // and 7 bits for the mantissa. 56 BF16 = 16; 57 58 F64 = 12; 59 60 // Complex values of fixed width. 61 C64 = 15; // Paired F32 (real, imag), as in std::complex<float>. 62 C128 = 18; // Paired F64 (real, imag), as in std::complex<double>. 63 64 // A tuple is a polymorphic sequence; e.g. a shape that holds different 65 // sub-shapes. They are used for things like returning multiple values from a 66 // computation; e.g. a computation that returns weights and biases may have a 67 // signature that results in a tuple like (f32[784x2000], f32[2000]) 68 // 69 // If a shape proto has the tuple element type, it may not have any entries 70 // in the dimensions field. 71 TUPLE = 13; 72 73 // An opaque type used for passing context-specific data to a custom 74 // operation. Shapes of this primitive type will have empty dimensions and 75 // tuple_shapes fields. 76 // 77 // (OPAQUE would be a better name for this identifier, but that conflicts with 78 // a macro defined in windows.h.) 79 OPAQUE_TYPE = 14; 80 81 // A token type threaded between side-effecting operations. Shapes of this 82 // primitive type will have empty dimensions and tuple_shapes fields. 83 TOKEN = 17; 84 85 // Next = 19 86} 87// LINT.ThenChange( 88// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, 89// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc 90// ) 91 92// Describes the padding configuration for Pad operation. The padding amount on 93// both edges as well as between the elements are specified for each dimension. 94message PaddingConfig { 95 // Describes the padding configuration for a dimension. 96 message PaddingConfigDimension { 97 // Padding amount on the low-end (next to the index 0). May be negative. 98 int64 edge_padding_low = 1; 99 100 // Padding amount on the high-end (next to the highest index). May be 101 // negative. 102 int64 edge_padding_high = 2; 103 104 // Padding amount between the elements. May not be negative. 105 int64 interior_padding = 3; 106 } 107 108 // The padding configuration for all dimensions. 109 repeated PaddingConfigDimension dimensions = 1; 110} 111 112// A format specifies the method used by a layout to store an array in memory. 113enum Format { 114 // TODO(b/120869032): Rename this to FORMAT_NONE or something else which 115 // better corresponds to its meaning. 116 INVALID_FORMAT = 0; 117 // The default layout, with exactly one storage location per element. 118 DENSE = 1; 119 reserved 2; 120 reserved "SPARSE"; 121} 122 123// Describes a tile used in tiling-based layout. Refer to 124// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for 125// details about tiling-based layout. 126message TileProto { 127 // Number of elements in each dimension of the tile. It's ordered from the 128 // most major dimension of the tile to the most minor dimension of the tile. 129 // The dimensions correspond to a suffix of the dimensions of the shape being 130 // tiled. 131 repeated int64 dimensions = 1; 132} 133 134// A layout describes how the array is placed in (1D) memory space. This 135// includes the minor-to-major ordering of dimensions within a shape. 136// 137// Clients must specify the layouts of input Literals to the 138// computation. Layouts specified in interior operations which take Shapes (for 139// example, Convert) are ignored. 140// 141// See the XLA documentation for more information on shapes and layouts. 142// 143// LINT.IfChange 144message LayoutProto { 145 // The method used to store the data in memory. The format determines which of 146 // the other fields are used by the layout. 147 Format format = 4; 148 149 // Sequence of dimension numbers, from minor (fastest varying index) to major 150 // (slowest varying index). This field is required. 151 repeated int64 minor_to_major = 1; 152 153 reserved 2; 154 reserved "padded_dimensions"; 155 156 reserved 3; 157 reserved "padding_value"; 158 159 reserved 5; 160 reserved "max_sparse_elements"; 161 162 // A sequence of tiles, starting from the tile that's applied first to the 163 // Shape. 164 // 165 // TODO(b/119839262): implement tiling in each backend or add Unimplemented 166 // error. 167 repeated TileProto tiles = 6; 168 169 // Bit size of each element. If the size is bigger than what the element 170 // type requires, the value is stored in the least significant 171 // bits and the additional most significant bits are filled with 0's. 172 // 173 // TODO(b/119839262): implement in each backend or add Unimplemented error. 174 int64 element_size_in_bits = 7; 175 176 // Memory space where this array resides. The integer field is interpreted in 177 // a backend-specific manner. 178 int64 memory_space = 8; 179 180 // Important: if any field is added, be sure to modify ShapeUtil::Equal() and 181 // LayoutUtil::Hash appropriately to account for the new field. 182} 183// LINT.ThenChange( \ 184// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ 185// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) 186 187// A shape describes the number of dimensions in the array, the size of each 188// dimension, and the primitive component type. 189// 190// Tuples are a special case in that they have rank zero and have tuple_shapes 191// defined. 192// 193// See the XLA documentation for more information on shapes and layouts. 194// 195// LINT.IfChange 196message ShapeProto { 197 reserved 1; 198 reserved "rank"; 199 200 // The element type for this shape. 201 PrimitiveType element_type = 2; 202 203 // The size (number of elements) for each dimension, or an upper bound on the 204 // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 205 // to N-1 for an N-dimensional array. The first element of 'dimensions' is the 206 // size of dimension 0, the second element is the size of dimension 1, and so 207 // forth. Empty list indicates a scalar. 208 // 209 // If the respective element in 'is_dimension_dynamic' is true then the value 210 // in this field represents an upper bound on the size of the dimension. 211 repeated int64 dimensions = 3; 212 213 // For tuples only, the shapes of constituent shapes in the tuple sequence. 214 repeated ShapeProto tuple_shapes = 4; 215 216 // The layout used to back this shape. 217 LayoutProto layout = 5; 218 219 // For arrays, this indicates whether or not each dimension is 220 // dynamically-sized. The number of elements in this repeated field should be 221 // zero (indicating that no dimensions are dynamic) or equal to the number of 222 // elements in the 'dimensions' field. 223 repeated bool is_dynamic_dimension = 6; 224 225 // Important: if any field is added, be sure to modify ShapeUtil::Equal(), 226 // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for 227 // the new field. 228} 229// LINT.ThenChange( \ 230// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) 231 232// Shape of the parameters and output of a computation (like a traditional 233// function signature). 234message ProgramShapeProto { 235 repeated ShapeProto parameters = 1; 236 ShapeProto result = 2; 237 repeated string parameter_names = 3; 238} 239 240// Statistics of a computation. 241message ComputationStats { 242 // The number of floating point operations in the computation. 243 double flop_count = 1; 244 245 // The number of transcendental operations (e.g., exp) in the computation. 246 double transcendental_count = 2; 247} 248 249// The type optimization profiles in use. 250enum ProfileType { 251 INVALID = 0; 252 WINDOW = 1; 253 FLAG = 2; 254 INTEGER = 3; 255} 256 257// Symbolization metadata for HLO Instructions. 258// 259// This metadata is used for debugging XLA code generation, as well as 260// performance profiling of XLA-generated executables. 261message OpMetadata { 262 // The framework op name that generated this XLA op. 263 // 264 // Frameworks that build on top of XLA should mirror the names of their ops 265 // back to users by specifying the op_type. In this way, even if the 266 // framework's "ops" are implemented as multiple XLA HLO Ops, they can be 267 // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as 268 // multiple ops, then each op should have the op_type be "SoftMax".) 269 string op_type = 1; 270 // The user-specified name of the op. 271 // 272 // This name is often unique within a computation. Note: some frameworks 273 // add auto-generated names if the user does not provide one. 274 string op_name = 2; 275 // Indicate a file and line that this op is associated to in a user's program. 276 // 277 // e.g. it could be the file and line of user code that generated the op. 278 string source_file = 3; 279 int32 source_line = 4; 280 281 repeated ProfileType profile_type = 5; 282 283 // HloPassMetadata.pass_id of the pass that created this HLO instruction 284 // object. Should never be copied between HLO instructions. Zero if unset and 285 // -1 if the instruction was created before HLO passes began. 286 int64 creation_pass_id = 6; 287 288 // HloPassMetadata.pass_id of the pass that created the logical functionality 289 // that this HLO instruction represents. Should be copied between HLO 290 // instructions that correspond across compilation passes. Zero if unset and 291 // -1 if the instruction was created before HLO passes began. 292 int64 logical_creation_pass_id = 7; 293} 294 295// Profile data from the execution of a computation. 296message ExecutionProfile { 297 // Whether the executable was read from the compilation cache. 298 bool compilation_cache_hit = 1; 299 300 // The time in milliseconds spent to compile the computation. This only set if 301 // the executable was not read from the compilation cache 302 // (compilation_cache_hit == false). 303 int64 compile_time_ms = 2; 304 305 // The number of cycles spent for the computation. This does not include the 306 // time taken for the data transfers between the host and the device. This is 307 // a target-dependent field and only used for debugging purposes. 308 int64 compute_cycle_count = 3; 309 310 // The time in nanoseconds spent for the computation, without data transfer. 311 int64 compute_time_ns = 4; 312 313 // The time in nanoseconds spent for the entire computation, including the 314 // result data transfer time. Current implementation does not spend any cycles 315 // for the input data transfer since the memory is initialized with the proper 316 // values before the execution. 317 int64 compute_and_transfer_time_ns = 5; 318 319 // The size of the binary code in the executable. 320 int64 executable_size_in_bytes = 6; 321 322 // Whether this profile was drawn from a cache of profiles instead of from 323 // execution on the hardware. 324 bool profile_cache_hit = 7; 325} 326 327// Handle given to a user that represents an execution that the user launched 328// asynchronously on the device. 329message ExecutionHandle { 330 int64 handle = 1; 331} 332 333// Handle given to a user that represents a globally accessible allocation. 334// Contrast this against a ComputationDataHandle, which is not globally 335// accessible, since it only exists within a specific computation. 336message GlobalDataHandle { 337 int64 handle = 1; 338} 339 340// Handle given to a user that represents a replicated virtual device. Each 341// replicated device represents N physical devices for execution where N is the 342// number of replicas. 343message DeviceHandle { 344 int64 handle = 1; 345 346 // The number of model-parallel virtual devices that communicate via XLA 347 // Send/Recv instructions. 348 int64 device_count = 2; 349} 350 351// Handle given to a user to represent a channel between two computations 352// via a Send and Recv instruction pair. Channels are unbuffered, so Send 353// Send instructions will be blocked until the data is transferred. 354message ChannelHandle { 355 int64 handle = 1; 356 enum ChannelType { 357 // Invalid primitive type to serve as default. 358 CHANNEL_TYPE_INVALID = 0; 359 360 // A channel for sending data between devices. 361 DEVICE_TO_DEVICE = 1; 362 363 // A channel for sending data from the device to the host. Can only be used 364 // with a Send operation. 365 DEVICE_TO_HOST = 2; 366 367 // A channel for sending data from the host to the device. Can only be used 368 // with a Recv operation. 369 HOST_TO_DEVICE = 3; 370 } 371 ChannelType type = 2; 372} 373 374// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which 375// represents the device ids assigned to a set of replicated computations. 376// See xla::DeviceAssignment class comment for more details. 377message DeviceAssignmentProto { 378 int32 replica_count = 1; 379 int32 computation_count = 2; 380 381 // Each logical computation runs on replica_count physical devices. 382 // ComputationDevice represents the device ids assinged to the replicas. 383 message ComputationDevice { 384 repeated int32 replica_device_ids = 1; 385 } 386 repeated ComputationDevice computation_devices = 3; 387} 388 389// Literals are used when the server and client need to exchange materialized 390// data / results. Literals are also used to describe constants used in 391// computations. 392// 393// Transfers to/from the client are encoded in literal form, and the structure 394// of the repeated fields is implied by the shape. 395message LiteralProto { 396 ShapeProto shape = 1; 397 repeated bool preds = 2; 398 bytes s8s = 15; 399 bytes u8s = 3; 400 repeated int32 s32s = 4; 401 repeated int64 s64s = 5; 402 repeated uint32 u32s = 6; 403 repeated uint64 u64s = 7; 404 repeated float f32s = 8; 405 repeated double f64s = 9; 406 repeated float c64s = 12; // Stored as interleaved real, imag floats. 407 repeated double c128s = 18; // Stored as interleaved real, imag doubles. 408 repeated LiteralProto tuple_literals = 10; 409 // The F16s, BF16s, U16s and S16s are encoded in little endian byte order 410 bytes f16s = 11; 411 bytes bf16s = 13; 412 bytes u16s = 16; 413 bytes s16s = 17; 414 repeated int64 sparse_indices = 14; 415 // Next = 19 416} 417 418message WindowDimension { 419 // The size of the window in this dimension. For a rectangle, this would be 420 // the width or height. 421 int64 size = 1; 422 423 // The stride at which the window moves across the base area in this 424 // dimension. In other words, this is the spacing between different 425 // positions of the window in this dimension. 426 int64 stride = 2; 427 428 // If positive, means the amount of padding to add to the base area at the low 429 // end of this dimension; if negative, its negative means the number of 430 // elements removed from the low end of this dimension. For example, in the 431 // horizontal dimension of a rectangle, this would be the number of padding 432 // values to pad on the left, given that indices increase when going right. 433 // The actual padding value depends upon the context. Convolution pads with 434 // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's 435 // init value. 436 int64 padding_low = 3; 437 438 // As padding_low, but on the high end of this dimension. For example, in the 439 // horizontal dimension of a rectangle, this would be the number of values to 440 // pad on the right, given that indices increase when going right. 441 int64 padding_high = 4; 442 443 // Dilation factor of the sliding window in this dimension. A dilation factor 444 // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are 445 // implicitly placed between each kernel element. This value may not be less 446 // than 1. See documentation for convolution. 447 int64 window_dilation = 5; 448 449 // Dilation factor of the base area in this dimension. A dilation factor of 1 450 // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly 451 // placed between each base area element. This value may not be less than 1. 452 // See documentation for convolution. 453 int64 base_dilation = 6; 454 455 // Window reversal means that this dimension was logically reversed before the 456 // operation. 457 bool window_reversal = 7; 458} 459 460// Describes the windowing in an operation such as convolution. 461// 462// The window is moved across a base area and for each position of the 463// window a computation is performed. The field below describes the 464// window and the movement of the window across a base area. 465message Window { 466 repeated WindowDimension dimensions = 1; 467} 468 469// Describes the dimension numbers for a gather operation. 470// 471// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for 472// more details. 473message GatherDimensionNumbers { 474 // "Window indices" is a term for a set of indices that index into the 475 // interior of a dynamic-slice from the input tensor, the starting indices for 476 // which were computed from output_gather_dims (see the operation semantic for 477 // how this is defined) and the start_indices tensor. 478 // 479 // The window indices for a specific output index Out is computed as: 480 // 481 // i = 0 482 // for (k : [0, input_tensor_shape.rank)) 483 // window_indices[k] = 484 // if k in collapsed_slice_dims 485 // then 0 486 // else Out[offset_dims[i++]] 487 repeated int64 offset_dims = 1; 488 repeated int64 collapsed_slice_dims = 2; 489 490 // This is interpreted as a map from i to start_index_map[i]. It 491 // transforms the gather index looked up from the start_indices tensor into 492 // the starting index in the input space. 493 repeated int64 start_index_map = 3; 494 495 // The dimension in the start_indices input that contains the starting 496 // indices. 497 int64 index_vector_dim = 4; 498} 499 500// Describes the dimension numbers for a scatter operation. 501// 502// All the fields are similar to the corresponding fields in 503// GatherDimensionNumbers. Differences are noted below. 504message ScatterDimensionNumbers { 505 // The set of dimensions in the updates shape that are window dimensions. 506 repeated int64 update_window_dims = 1; 507 // The set of window dimensions that must be inserted into the updates shape. 508 repeated int64 inserted_window_dims = 2; 509 510 repeated int64 scatter_dims_to_operand_dims = 3; 511 int64 index_vector_dim = 4; 512} 513 514message ConvolutionDimensionNumbers { 515 // The number of the dimension that represents batch in the input. 516 int64 input_batch_dimension = 7; 517 518 // The number of the dimension that represents features in the input. 519 int64 input_feature_dimension = 8; 520 521 // The dimension numbers for the spatial dimensions that the window 522 // moves through in the input. 523 repeated int64 input_spatial_dimensions = 11; 524 525 // The number of the dimension that represents input features in the 526 // convolutional kernel (rhs). 527 int64 kernel_input_feature_dimension = 3; 528 529 // The number of the dimension that represents output features in 530 // the convolutional kernel (rhs). 531 int64 kernel_output_feature_dimension = 4; 532 533 // The dimension numbers for the spatial dimensions that the window 534 // moves through in the kernel (rhs). window.strides(0) is the 535 // stride in the kernel_spatial_dimensions(0) dimension. 536 repeated int64 kernel_spatial_dimensions = 6; 537 538 // The number of the dimension that represents batch in the output. 539 int64 output_batch_dimension = 9; 540 541 // The number of the dimension that represents features in the output. 542 int64 output_feature_dimension = 10; 543 544 // The dimension numbers for the spatial dimensions that the window 545 // moves through in the output. 546 repeated int64 output_spatial_dimensions = 12; 547 548 // Next = 13 549} 550 551enum PaddingType { 552 PADDING_INVALID = 0; 553 PADDING_VALID = 1; // Only valid portion of the base are covered. 554 PADDING_SAME = 2; // Extra is added to produce same output size as the input. 555} 556 557enum FftType { 558 FFT = 0; // Forward FFT; complex in, complex out. 559 IFFT = 1; // Inverse FFT; complex in, complex out. 560 RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out 561 IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, 562 // fft_length real out 563} 564 565message DotDimensionNumbers { 566 // The dimension numbers that represent the 'lhs' contracting dimensions. 567 repeated int64 lhs_contracting_dimensions = 1; 568 // The dimension numbers that represent the 'rhs' contracting dimensions. 569 repeated int64 rhs_contracting_dimensions = 2; 570 // The dimension numbers that represent the 'lhs' batch dimensions. 571 repeated int64 lhs_batch_dimensions = 3; 572 // The dimension numbers that represent the 'rhs' batch dimensions. 573 repeated int64 rhs_batch_dimensions = 4; 574} 575 576enum RandomDistribution { 577 RNG_INVALID = 0; 578 579 // Creates a uniform-distribution-generated random number on the semi-open 580 // interval [parameter[0], parameter[1]). 581 RNG_UNIFORM = 1; 582 583 // Creates a normal-distribution-generated random number with mean 584 // parameter[0] and standard deviation parameter[1]. 585 RNG_NORMAL = 2; 586 587 // Next: 4 588} 589 590enum RandomAlgorithm { 591 RNG_DEFAULT = 0; // Backend dependent default algorithm. 592 RNG_THREE_FRY = 1; 593 RNG_PHILOX = 2; 594 // Next: 2 595} 596 597message TriangularSolveOptions { 598 // If true, solves ax = b. If false, solves xa = b. 599 bool left_side = 1; 600 601 // If true, 'a' is lower triangular. If false, 'a' is upper triangular. 602 bool lower = 2; 603 604 // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. 605 bool unit_diagonal = 3; 606 607 // Should we transpose or use the adjoint of 'a'? 608 enum Transpose { 609 TRANSPOSE_INVALID = 0; 610 NO_TRANSPOSE = 1; // Don't transpose 'a'. 611 TRANSPOSE = 2; // Transpose 'a'. 612 ADJOINT = 3; // Complex conjugate and transpose 'a'. 613 } 614 Transpose transpose_a = 4; 615} 616 617message CholeskyOptions { 618 // If true, uses the lower triangle of `a`. If false, uses the upper triangle 619 // of `a`. 620 bool lower = 1; 621} 622 623// Generic map of attributes used to pass hints / configuration options from 624// the Python frontend to the XLA backend. 625message FrontendAttributes { 626 map<string, string> map = 1; 627} 628 629message OpSharding { 630 enum Type { 631 // This sharding is replicated across all devices (implies maximal, 632 // all other fields are unused). 633 REPLICATED = 0; 634 // This sharding is maximal - one device runs the entire operation. 635 MAXIMAL = 1; 636 // This sharding is a tuple - only the tuple_shardings field is valid. 637 TUPLE = 2; 638 // None of the above; tile_shape and tile_assignment are both used. 639 OTHER = 3; 640 // This op is manually sharded: the shapes are already partitioned and the 641 // partitioner should not change this op. 642 MANUAL = 4; 643 } 644 Type type = 1; 645 // The shape of the sharded tile. 646 ShapeProto tile_shape = 2; 647 // The shape of the tile assignment tensor - this must be the same rank as 648 // tile_shape and the product of its dimensions must equal 649 // tile_assignment_devices.size(). 650 repeated int64 tile_assignment_dimensions = 3; 651 // Flattened list of device IDs. The order of flattening is the same as used 652 // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). 653 repeated int64 tile_assignment_devices = 4; 654 // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, 655 // in pre-order. The tuple shape could be nested; here we store just a 656 // flattened list of all leaves in the tuple shape. Note that the tuple shape 657 // is not stored here; shardings do not store the shapes to which they are 658 // applied, this is inferred from the instruction this sharding gets attached 659 // to. 660 repeated OpSharding tuple_shardings = 5; 661 662 // Only used for OTHER type. If true, data is sharded according to other 663 // dimensions of tile_assignment(), but replicated across devices along the 664 // last dimension. (Experimental) 665 bool replicate_on_last_tile_dim = 6; 666 // This field is used to track the source of this sharding, usually derived 667 // from instructions. Multple metadata may be populated if sharding is 668 // combined with other shardings. Metadata are to not be populated when 669 // type == TUPLE and instead metadata should be set on individual tuple 670 // elements. 671 repeated OpMetadata metadata = 7; 672} 673 674// Describes the replica groups in a cross replica op (e.g., all-reduce and 675// all-to-all). 676message ReplicaGroup { 677 // The ids of the replicas that belongs to the same group. The ordering of the 678 // ids matters in some ops (e.g., all-to-all). 679 repeated int64 replica_ids = 1; 680} 681 682// Describes the source target pair in the collective permute op. 683message SourceTarget { 684 int64 source = 1; 685 int64 target = 2; 686} 687 688// Used to indicate the precision configuration. It has backend specific 689// meaning. 690message PrecisionConfig { 691 enum Precision { 692 DEFAULT = 0; 693 HIGH = 1; 694 HIGHEST = 2; 695 696 // Next: 3 697 } 698 repeated Precision operand_precision = 1; 699 700 // Next: 2 701} 702 703// Describes whether all data-parallelism replicas will receive the same 704// parameter data at each buffer. 705message ParameterReplication { 706 // A list of boolean values for the flattened leaf buffers. Each value 707 // indicates whether the corresponding leaf buffer is replicated. 708 // 709 // If this field is empty, it means no buffer is replicated. Otherwise, the 710 // number of elements in this field must match the number of leaf buffers in 711 // the HLO instruction's shape. 712 repeated bool replicated_at_leaf_buffers = 1; 713} 714 715// A backend-config for kWhile loops that stores the loop's trip count, if it is 716// known. 717// 718// This is useful for backends that can implement a `for i in 0..N` loop more 719// efficiently than a `while` loop. For example, on GPUs, we can implement a 720// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, 721// whereas implementing a `while` loop requires a host-device sync on each 722// iteration. 723message WhileLoopBackendConfig { 724 message KnownTripCount { 725 int64 n = 1; 726 } 727 // This indirection lets us distinguish between known-trip-count == 0 and 728 // unknown-trip-count. 729 KnownTripCount known_trip_count = 1; 730} 731 732// Specifies a pair of output/operand buffers for kCustomCall that alias each 733// other. 734message CustomCallOutputOperandAliasing { 735 repeated int64 output_shape_index = 1; 736 int64 operand_index = 2; 737 repeated int64 operand_shape_index = 3; 738} 739