1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16syntax = "proto3"; 17 18package xla; 19 20option cc_enable_arenas = true; 21 22// Primitive types are the individual values that can be held in rectangular 23// multidimensional arrays. A description of the rectangular multidimensional 24// array dimensions / primitive type is given by Shape, below. 25// 26// LINT.IfChange 27enum PrimitiveType { 28 // Invalid primitive type to serve as default. 29 PRIMITIVE_TYPE_INVALID = 0; 30 31 // Predicates are two-state booleans. 32 PRED = 1; 33 34 // Signed integral values of fixed width. 35 S8 = 2; 36 S16 = 3; 37 S32 = 4; 38 S64 = 5; 39 40 // Unsigned integral values of fixed width. 41 U8 = 6; 42 U16 = 7; 43 U32 = 8; 44 U64 = 9; 45 46 // Floating-point values of fixed width. 47 // 48 // Note: if f16s are not natively supported on the device, they will be 49 // converted to f16 from f32 at arbirary points in the computation. 50 F16 = 10; 51 F32 = 11; 52 53 // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit 54 // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent 55 // and 7 bits for the mantissa. 56 BF16 = 16; 57 58 F64 = 12; 59 60 // Complex values of fixed width. 61 C64 = 15; // Paired F32 (real, imag), as in std::complex<float>. 62 C128 = 18; // Paired F64 (real, imag), as in std::complex<double>. 63 64 // A tuple is a polymorphic sequence; e.g. a shape that holds different 65 // sub-shapes. They are used for things like returning multiple values from a 66 // computation; e.g. a computation that returns weights and biases may have a 67 // signature that results in a tuple like (f32[784x2000], f32[2000]) 68 // 69 // If a shape proto has the tuple element type, it may not have any entries 70 // in the dimensions field. 71 TUPLE = 13; 72 73 // An opaque type used for passing context-specific data to a custom 74 // operation. Shapes of this primitive type will have empty dimensions and 75 // tuple_shapes fields. 76 // 77 // (OPAQUE would be a better name for this identifier, but that conflicts with 78 // a macro defined in windows.h.) 79 OPAQUE_TYPE = 14; 80 81 // A token type threaded between side-effecting operations. Shapes of this 82 // primitive type will have empty dimensions and tuple_shapes fields. 83 TOKEN = 17; 84 85 // Next = 19 86} 87// LINT.ThenChange( 88// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, 89// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc 90// ) 91 92// Describes the padding configuration for Pad operation. The padding amount on 93// both edges as well as between the elements are specified for each dimension. 94message PaddingConfig { 95 // Describes the padding configuration for a dimension. 96 message PaddingConfigDimension { 97 // Padding amount on the low-end (next to the index 0). May be negative. 98 int64 edge_padding_low = 1; 99 100 // Padding amount on the high-end (next to the highest index). May be 101 // negative. 102 int64 edge_padding_high = 2; 103 104 // Padding amount between the elements. May not be negative. 105 int64 interior_padding = 3; 106 } 107 108 // The padding configuration for all dimensions. 109 repeated PaddingConfigDimension dimensions = 1; 110} 111 112// A format specifies the method used by a layout to store an array in memory. 113enum Format { 114 // TODO(b/120869032): Rename this to FORMAT_NONE or something else which 115 // better corresponds to its meaning. 116 INVALID_FORMAT = 0; 117 // The default layout, with exactly one storage location per element. 118 DENSE = 1; 119 reserved 2; 120 reserved "SPARSE"; 121} 122 123// Describes a tile used in tiling-based layout. Refer to 124// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for 125// details about tiling-based layout. 126message TileProto { 127 // Number of elements in each dimension of the tile. It's ordered from the 128 // most major dimension of the tile to the most minor dimension of the tile. 129 // The dimensions correspond to a suffix of the dimensions of the shape being 130 // tiled. 131 repeated int64 dimensions = 1; 132} 133 134// A layout describes how the array is placed in (1D) memory space. This 135// includes the minor-to-major ordering of dimensions within a shape. 136// 137// Clients must specify the layouts of input Literals to the 138// computation. Layouts specified in interior operations which take Shapes (for 139// example, Convert) are ignored. 140// 141// See the XLA documentation for more information on shapes and layouts. 142// 143// LINT.IfChange 144message LayoutProto { 145 // The method used to store the data in memory. The format determines which of 146 // the other fields are used by the layout. 147 Format format = 4; 148 149 // Sequence of dimension numbers, from minor (fastest varying index) to major 150 // (slowest varying index). This field is required. 151 repeated int64 minor_to_major = 1; 152 153 reserved 2; 154 reserved "padded_dimensions"; 155 156 reserved 3; 157 reserved "padding_value"; 158 159 reserved 5; 160 reserved "max_sparse_elements"; 161 162 // A sequence of tiles, starting from the tile that's applied first to the 163 // Shape. 164 // 165 // TODO(b/119839262): implement tiling in each backend or add Unimplemented 166 // error. 167 repeated TileProto tiles = 6; 168 169 // Bit size of each element. If the size is bigger than what the element 170 // type requires, the value is stored in the least significant 171 // bits and the additional most significant bits are filled with 0's. 172 // 173 // TODO(b/119839262): implement in each backend or add Unimplemented error. 174 int64 element_size_in_bits = 7; 175 176 // Memory space where this array resides. The integer field is interpreted in 177 // a backend-specific manner. 178 int64 memory_space = 8; 179 180 // Important: if any field is added, be sure to modify ShapeUtil::Equal() and 181 // LayoutUtil::Hash appropriately to account for the new field. 182} 183// LINT.ThenChange( \ 184// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ 185// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) 186 187// A shape describes the number of dimensions in the array, the size of each 188// dimension, and the primitive component type. 189// 190// Tuples are a special case in that they have rank zero and have tuple_shapes 191// defined. 192// 193// See the XLA documentation for more information on shapes and layouts. 194// 195// LINT.IfChange 196message ShapeProto { 197 reserved 1; 198 reserved "rank"; 199 200 // The element type for this shape. 201 PrimitiveType element_type = 2; 202 203 // The size (number of elements) for each dimension, or an upper bound on the 204 // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 205 // to N-1 for an N-dimensional array. The first element of 'dimensions' is the 206 // size of dimension 0, the second element is the size of dimension 1, and so 207 // forth. Empty list indicates a scalar. 208 // 209 // If the respective element in 'is_dimension_dynamic' is true then the value 210 // in this field represents an upper bound on the size of the dimension. 211 repeated int64 dimensions = 3; 212 213 // For tuples only, the shapes of constituent shapes in the tuple sequence. 214 repeated ShapeProto tuple_shapes = 4; 215 216 // The layout used to back this shape. 217 LayoutProto layout = 5; 218 219 // For arrays, this indicates whether or not each dimension is 220 // dynamically-sized. The number of elements in this repeated field should be 221 // zero (indicating that no dimensions are dynamic) or equal to the number of 222 // elements in the 'dimensions' field. 223 repeated bool is_dynamic_dimension = 6; 224 225 // Important: if any field is added, be sure to modify ShapeUtil::Equal(), 226 // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for 227 // the new field. 228} 229// LINT.ThenChange( \ 230// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) 231 232// Shape of the parameters and output of a computation (like a traditional 233// function signature). 234message ProgramShapeProto { 235 repeated ShapeProto parameters = 1; 236 ShapeProto result = 2; 237 repeated string parameter_names = 3; 238} 239 240// Statistics of a computation. 241message ComputationStats { 242 // The number of floating point operations in the computation. 243 double flop_count = 1; 244 245 // The number of transcendental operations (e.g., exp) in the computation. 246 double transcendental_count = 2; 247} 248 249// The type optimization profiles in use. 250enum ProfileType { 251 INVALID = 0; 252 WINDOW = 1; 253 FLAG = 2; 254 INTEGER = 3; 255} 256 257// Symbolization metadata for HLO Instructions. 258// 259// This metadata is used for debugging XLA code generation, as well as 260// performance profiling of XLA-generated executables. 261message OpMetadata { 262 // The framework op name that generated this XLA op. 263 // 264 // Frameworks that build on top of XLA should mirror the names of their ops 265 // back to users by specifying the op_type. In this way, even if the 266 // framework's "ops" are implemented as multiple XLA HLO Ops, they can be 267 // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as 268 // multiple ops, then each op should have the op_type be "SoftMax".) 269 string op_type = 1; 270 // The user-specified name of the op. 271 // 272 // This name is often unique within a computation. Note: some frameworks 273 // add auto-generated names if the user does not provide one. 274 string op_name = 2; 275 // Indicate a file and line that this op is associated to in a user's program. 276 // 277 // e.g. it could be the file and line of user code that generated the op. 278 string source_file = 3; 279 int32 source_line = 4; 280 281 repeated ProfileType profile_type = 5; 282 283 // HloPassMetadata.pass_id of the pass that created this HLO instruction 284 // object. Should never be copied between HLO instructions. Zero if unset and 285 // -1 if the instruction was created before HLO passes began. 286 int64 creation_pass_id = 6; 287 288 // HloPassMetadata.pass_id of the pass that created the logical functionality 289 // that this HLO instruction represents. Should be copied between HLO 290 // instructions that correspond across compilation passes. Zero if unset and 291 // -1 if the instruction was created before HLO passes began. 292 int64 logical_creation_pass_id = 7; 293 294 // The footprint of the generated code for the instruction. 295 int64 size_of_generated_code_in_bytes = 8; 296 // The size of the working set, i.e., the amount of memory, used by the 297 // instruction in a compiler-managed fast device memory. 298 int64 size_of_memory_working_set_in_bytes = 9; 299} 300 301// Profile data from the execution of a computation. 302message ExecutionProfile { 303 // Whether the executable was read from the compilation cache. 304 bool compilation_cache_hit = 1; 305 306 // The time in milliseconds spent to compile the computation. This only set if 307 // the executable was not read from the compilation cache 308 // (compilation_cache_hit == false). 309 int64 compile_time_ms = 2; 310 311 // The number of cycles spent for the computation. This does not include the 312 // time taken for the data transfers between the host and the device. This is 313 // a target-dependent field and only used for debugging purposes. 314 int64 compute_cycle_count = 3; 315 316 // The time in nanoseconds spent for the computation, without data transfer. 317 int64 compute_time_ns = 4; 318 319 // The time in nanoseconds spent for the entire computation, including the 320 // result data transfer time. Current implementation does not spend any cycles 321 // for the input data transfer since the memory is initialized with the proper 322 // values before the execution. 323 int64 compute_and_transfer_time_ns = 5; 324 325 // The size of the binary code in the executable. 326 int64 executable_size_in_bytes = 6; 327 328 // Whether this profile was drawn from a cache of profiles instead of from 329 // execution on the hardware. 330 bool profile_cache_hit = 7; 331} 332 333// Handle given to a user that represents an execution that the user launched 334// asynchronously on the device. 335message ExecutionHandle { 336 int64 handle = 1; 337} 338 339// Handle given to a user that represents a globally accessible allocation. 340// Contrast this against a ComputationDataHandle, which is not globally 341// accessible, since it only exists within a specific computation. 342message GlobalDataHandle { 343 int64 handle = 1; 344} 345 346// Handle given to a user that represents a replicated virtual device. Each 347// replicated device represents N physical devices for execution where N is the 348// number of replicas. 349message DeviceHandle { 350 int64 handle = 1; 351 352 // The number of model-parallel virtual devices that communicate via XLA 353 // Send/Recv instructions. 354 int64 device_count = 2; 355} 356 357// Handle given to a user to represent a channel between two computations 358// via a Send and Recv instruction pair. Channels are unbuffered, so Send 359// Send instructions will be blocked until the data is transferred. 360message ChannelHandle { 361 int64 handle = 1; 362 enum ChannelType { 363 // Invalid primitive type to serve as default. 364 CHANNEL_TYPE_INVALID = 0; 365 366 // A channel for sending data between devices. 367 DEVICE_TO_DEVICE = 1; 368 369 // A channel for sending data from the device to the host. Can only be used 370 // with a Send operation. 371 DEVICE_TO_HOST = 2; 372 373 // A channel for sending data from the host to the device. Can only be used 374 // with a Recv operation. 375 HOST_TO_DEVICE = 3; 376 } 377 ChannelType type = 2; 378} 379 380// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which 381// represents the device ids assigned to a set of replicated computations. 382// See xla::DeviceAssignment class comment for more details. 383message DeviceAssignmentProto { 384 int32 replica_count = 1; 385 int32 computation_count = 2; 386 387 // Each logical computation runs on replica_count physical devices. 388 // ComputationDevice represents the device ids assinged to the replicas. 389 message ComputationDevice { 390 repeated int32 replica_device_ids = 1; 391 } 392 repeated ComputationDevice computation_devices = 3; 393} 394 395// Literals are used when the server and client need to exchange materialized 396// data / results. Literals are also used to describe constants used in 397// computations. 398// 399// Transfers to/from the client are encoded in literal form, and the structure 400// of the repeated fields is implied by the shape. 401message LiteralProto { 402 ShapeProto shape = 1; 403 repeated bool preds = 2; 404 bytes s8s = 15; 405 bytes u8s = 3; 406 repeated int32 s32s = 4; 407 repeated int64 s64s = 5; 408 repeated uint32 u32s = 6; 409 repeated uint64 u64s = 7; 410 repeated float f32s = 8; 411 repeated double f64s = 9; 412 repeated float c64s = 12; // Stored as interleaved real, imag floats. 413 repeated double c128s = 18; // Stored as interleaved real, imag doubles. 414 repeated LiteralProto tuple_literals = 10; 415 // The F16s, BF16s, U16s and S16s are encoded in little endian byte order 416 bytes f16s = 11; 417 bytes bf16s = 13; 418 bytes u16s = 16; 419 bytes s16s = 17; 420 repeated int64 sparse_indices = 14; 421 // Next = 19 422} 423 424message WindowDimension { 425 // The size of the window in this dimension. For a rectangle, this would be 426 // the width or height. 427 int64 size = 1; 428 429 // The stride at which the window moves across the base area in this 430 // dimension. In other words, this is the spacing between different 431 // positions of the window in this dimension. 432 int64 stride = 2; 433 434 // If positive, means the amount of padding to add to the base area at the low 435 // end of this dimension; if negative, its negative means the number of 436 // elements removed from the low end of this dimension. For example, in the 437 // horizontal dimension of a rectangle, this would be the number of padding 438 // values to pad on the left, given that indices increase when going right. 439 // The actual padding value depends upon the context. Convolution pads with 440 // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's 441 // init value. 442 int64 padding_low = 3; 443 444 // As padding_low, but on the high end of this dimension. For example, in the 445 // horizontal dimension of a rectangle, this would be the number of values to 446 // pad on the right, given that indices increase when going right. 447 int64 padding_high = 4; 448 449 // Dilation factor of the sliding window in this dimension. A dilation factor 450 // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are 451 // implicitly placed between each kernel element. This value may not be less 452 // than 1. See documentation for convolution. 453 int64 window_dilation = 5; 454 455 // Dilation factor of the base area in this dimension. A dilation factor of 1 456 // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly 457 // placed between each base area element. This value may not be less than 1. 458 // See documentation for convolution. 459 int64 base_dilation = 6; 460 461 // Window reversal means that this dimension was logically reversed before the 462 // operation. 463 bool window_reversal = 7; 464} 465 466// Describes the windowing in an operation such as convolution. 467// 468// The window is moved across a base area and for each position of the 469// window a computation is performed. The field below describes the 470// window and the movement of the window across a base area. 471message Window { 472 repeated WindowDimension dimensions = 1; 473} 474 475// Describes the dimension numbers for a gather operation. 476// 477// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for 478// more details. 479message GatherDimensionNumbers { 480 // "Window indices" is a term for a set of indices that index into the 481 // interior of a dynamic-slice from the input tensor, the starting indices for 482 // which were computed from output_gather_dims (see the operation semantic for 483 // how this is defined) and the start_indices tensor. 484 // 485 // The window indices for a specific output index Out is computed as: 486 // 487 // i = 0 488 // for (k : [0, input_tensor_shape.rank)) 489 // window_indices[k] = 490 // if k in collapsed_slice_dims 491 // then 0 492 // else Out[offset_dims[i++]] 493 repeated int64 offset_dims = 1; 494 repeated int64 collapsed_slice_dims = 2; 495 496 // This is interpreted as a map from i to start_index_map[i]. It 497 // transforms the gather index looked up from the start_indices tensor into 498 // the starting index in the input space. 499 repeated int64 start_index_map = 3; 500 501 // The dimension in the start_indices input that contains the starting 502 // indices. 503 int64 index_vector_dim = 4; 504} 505 506// Describes the dimension numbers for a scatter operation. 507// 508// All the fields are similar to the corresponding fields in 509// GatherDimensionNumbers. Differences are noted below. 510message ScatterDimensionNumbers { 511 // The set of dimensions in the updates shape that are window dimensions. 512 repeated int64 update_window_dims = 1; 513 // The set of window dimensions that must be inserted into the updates shape. 514 repeated int64 inserted_window_dims = 2; 515 516 repeated int64 scatter_dims_to_operand_dims = 3; 517 int64 index_vector_dim = 4; 518} 519 520message ConvolutionDimensionNumbers { 521 // The number of the dimension that represents batch in the input. 522 int64 input_batch_dimension = 7; 523 524 // The number of the dimension that represents features in the input. 525 int64 input_feature_dimension = 8; 526 527 // The dimension numbers for the spatial dimensions that the window 528 // moves through in the input. 529 repeated int64 input_spatial_dimensions = 11; 530 531 // The number of the dimension that represents input features in the 532 // convolutional kernel (rhs). 533 int64 kernel_input_feature_dimension = 3; 534 535 // The number of the dimension that represents output features in 536 // the convolutional kernel (rhs). 537 int64 kernel_output_feature_dimension = 4; 538 539 // The dimension numbers for the spatial dimensions that the window 540 // moves through in the kernel (rhs). window.strides(0) is the 541 // stride in the kernel_spatial_dimensions(0) dimension. 542 repeated int64 kernel_spatial_dimensions = 6; 543 544 // The number of the dimension that represents batch in the output. 545 int64 output_batch_dimension = 9; 546 547 // The number of the dimension that represents features in the output. 548 int64 output_feature_dimension = 10; 549 550 // The dimension numbers for the spatial dimensions that the window 551 // moves through in the output. 552 repeated int64 output_spatial_dimensions = 12; 553 554 // Next = 13 555} 556 557enum PaddingType { 558 PADDING_INVALID = 0; 559 PADDING_VALID = 1; // Only valid portion of the base are covered. 560 PADDING_SAME = 2; // Extra is added to produce same output size as the input. 561} 562 563enum FftType { 564 FFT = 0; // Forward FFT; complex in, complex out. 565 IFFT = 1; // Inverse FFT; complex in, complex out. 566 RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out 567 IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, 568 // fft_length real out 569} 570 571message DotDimensionNumbers { 572 // The dimension numbers that represent the 'lhs' contracting dimensions. 573 repeated int64 lhs_contracting_dimensions = 1; 574 // The dimension numbers that represent the 'rhs' contracting dimensions. 575 repeated int64 rhs_contracting_dimensions = 2; 576 // The dimension numbers that represent the 'lhs' batch dimensions. 577 repeated int64 lhs_batch_dimensions = 3; 578 // The dimension numbers that represent the 'rhs' batch dimensions. 579 repeated int64 rhs_batch_dimensions = 4; 580} 581 582enum RandomDistribution { 583 RNG_INVALID = 0; 584 585 // Creates a uniform-distribution-generated random number on the semi-open 586 // interval [parameter[0], parameter[1]). 587 RNG_UNIFORM = 1; 588 589 // Creates a normal-distribution-generated random number with mean 590 // parameter[0] and standard deviation parameter[1]. 591 RNG_NORMAL = 2; 592 593 // Next: 4 594} 595 596enum RandomAlgorithm { 597 RNG_DEFAULT = 0; // Backend dependent default algorithm. 598 RNG_THREE_FRY = 1; 599 RNG_PHILOX = 2; 600 // Next: 2 601} 602 603message TriangularSolveOptions { 604 // If true, solves ax = b. If false, solves xa = b. 605 bool left_side = 1; 606 607 // If true, 'a' is lower triangular. If false, 'a' is upper triangular. 608 bool lower = 2; 609 610 // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. 611 bool unit_diagonal = 3; 612 613 // Should we transpose or use the adjoint of 'a'? 614 enum Transpose { 615 TRANSPOSE_INVALID = 0; 616 NO_TRANSPOSE = 1; // Don't transpose 'a'. 617 TRANSPOSE = 2; // Transpose 'a'. 618 ADJOINT = 3; // Complex conjugate and transpose 'a'. 619 } 620 Transpose transpose_a = 4; 621} 622 623message CholeskyOptions { 624 // If true, uses the lower triangle of `a`. If false, uses the upper triangle 625 // of `a`. 626 bool lower = 1; 627} 628 629// Generic map of attributes used to pass hints / configuration options from 630// the Python frontend to the XLA backend. 631message FrontendAttributes { 632 map<string, string> map = 1; 633} 634 635message OpSharding { 636 enum Type { 637 // This sharding is replicated across all devices (implies maximal, 638 // all other fields are unused). 639 REPLICATED = 0; 640 // This sharding is maximal - one device runs the entire operation. 641 MAXIMAL = 1; 642 // This sharding is a tuple - only the tuple_shardings field is valid. 643 TUPLE = 2; 644 // None of the above; tile_shape and tile_assignment are both used. 645 OTHER = 3; 646 // This op is manually sharded: the shapes are already partitioned and the 647 // partitioner should not change this op. 648 MANUAL = 4; 649 } 650 Type type = 1; 651 // The shape of the sharded tile. 652 ShapeProto tile_shape = 2; 653 // The shape of the tile assignment tensor - this must be the same rank as 654 // tile_shape and the product of its dimensions must equal 655 // tile_assignment_devices.size(). 656 repeated int64 tile_assignment_dimensions = 3; 657 // Flattened list of device IDs. The order of flattening is the same as used 658 // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). 659 repeated int64 tile_assignment_devices = 4; 660 // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, 661 // in pre-order. The tuple shape could be nested; here we store just a 662 // flattened list of all leaves in the tuple shape. Note that the tuple shape 663 // is not stored here; shardings do not store the shapes to which they are 664 // applied, this is inferred from the instruction this sharding gets attached 665 // to. 666 repeated OpSharding tuple_shardings = 5; 667 668 // Only used for OTHER type. If true, data is sharded according to other 669 // dimensions of tile_assignment(), but replicated across devices along the 670 // last dimension. (Experimental) 671 bool replicate_on_last_tile_dim = 6; 672 // This field is used to track the source of this sharding, usually derived 673 // from instructions. Multple metadata may be populated if sharding is 674 // combined with other shardings. Metadata are to not be populated when 675 // type == TUPLE and instead metadata should be set on individual tuple 676 // elements. 677 repeated OpMetadata metadata = 7; 678 679 // This field is used to represented the sharding type of each subgroup. 680 // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={ 681 // replicate, manual, unreduced}} means that each of the last 3 dimensions 682 // in [2,2,2,2] represents a subgrouping in replicate, manual, 683 // unreduced sharding type respectively. 684 repeated Type last_tile_dims = 8; 685} 686 687// Describes the replica groups in a cross replica op (e.g., all-reduce and 688// all-to-all). 689message ReplicaGroup { 690 // The ids of the replicas that belongs to the same group. The ordering of the 691 // ids matters in some ops (e.g., all-to-all). 692 repeated int64 replica_ids = 1; 693} 694 695// Describes the source target pair in the collective permute op. 696message SourceTarget { 697 int64 source = 1; 698 int64 target = 2; 699} 700 701// Used to indicate the precision configuration. It has backend specific 702// meaning. 703message PrecisionConfig { 704 enum Precision { 705 DEFAULT = 0; 706 HIGH = 1; 707 HIGHEST = 2; 708 709 // Next: 3 710 } 711 repeated Precision operand_precision = 1; 712 713 // Next: 2 714} 715 716// Describes whether all data-parallelism replicas will receive the same 717// parameter data at each buffer. 718message ParameterReplication { 719 // A list of boolean values for the flattened leaf buffers. Each value 720 // indicates whether the corresponding leaf buffer is replicated. 721 // 722 // If this field is empty, it means no buffer is replicated. Otherwise, the 723 // number of elements in this field must match the number of leaf buffers in 724 // the HLO instruction's shape. 725 repeated bool replicated_at_leaf_buffers = 1; 726} 727 728// A backend-config for kWhile loops that stores the loop's trip count, if it is 729// known. 730// 731// This is useful for backends that can implement a `for i in 0..N` loop more 732// efficiently than a `while` loop. For example, on GPUs, we can implement a 733// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, 734// whereas implementing a `while` loop requires a host-device sync on each 735// iteration. 736message WhileLoopBackendConfig { 737 message KnownTripCount { 738 int64 n = 1; 739 } 740 // This indirection lets us distinguish between known-trip-count == 0 and 741 // unknown-trip-count. 742 KnownTripCount known_trip_count = 1; 743} 744 745// Specifies a pair of output/operand buffers for kCustomCall that alias each 746// other. 747message CustomCallOutputOperandAliasing { 748 repeated int64 output_shape_index = 1; 749 int64 operand_index = 2; 750 repeated int64 operand_shape_index = 3; 751} 752