• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18package xla;
19
20option cc_enable_arenas = true;
21
22// Primitive types are the individual values that can be held in rectangular
23// multidimensional arrays. A description of the rectangular multidimensional
24// array dimensions / primitive type is given by Shape, below.
25//
26// LINT.IfChange
27enum PrimitiveType {
28  // Invalid primitive type to serve as default.
29  PRIMITIVE_TYPE_INVALID = 0;
30
31  // Predicates are two-state booleans.
32  PRED = 1;
33
34  // Signed integral values of fixed width.
35  S8 = 2;
36  S16 = 3;
37  S32 = 4;
38  S64 = 5;
39
40  // Unsigned integral values of fixed width.
41  U8 = 6;
42  U16 = 7;
43  U32 = 8;
44  U64 = 9;
45
46  // Floating-point values of fixed width.
47  //
48  // Note: if f16s are not natively supported on the device, they will be
49  // converted to f16 from f32 at arbirary points in the computation.
50  F16 = 10;
51  F32 = 11;
52
53  // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
54  // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
55  // and 7 bits for the mantissa.
56  BF16 = 16;
57
58  F64 = 12;
59
60  // Complex values of fixed width.
61  C64 = 15;   // Paired F32 (real, imag), as in std::complex<float>.
62  C128 = 18;  // Paired F64 (real, imag), as in std::complex<double>.
63
64  // A tuple is a polymorphic sequence; e.g. a shape that holds different
65  // sub-shapes. They are used for things like returning multiple values from a
66  // computation; e.g. a computation that returns weights and biases may have a
67  // signature that results in a tuple like (f32[784x2000], f32[2000])
68  //
69  // If a shape proto has the tuple element type, it may not have any entries
70  // in the dimensions field.
71  TUPLE = 13;
72
73  // An opaque type used for passing context-specific data to a custom
74  // operation. Shapes of this primitive type will have empty dimensions and
75  // tuple_shapes fields.
76  //
77  // (OPAQUE would be a better name for this identifier, but that conflicts with
78  // a macro defined in windows.h.)
79  OPAQUE_TYPE = 14;
80
81  // A token type threaded between side-effecting operations. Shapes of this
82  // primitive type will have empty dimensions and tuple_shapes fields.
83  TOKEN = 17;
84
85  // Next = 19
86}
87// LINT.ThenChange(
88//   https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,
89//   https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc
90// )
91
92// Describes the padding configuration for Pad operation. The padding amount on
93// both edges as well as between the elements are specified for each dimension.
94message PaddingConfig {
95  // Describes the padding configuration for a dimension.
96  message PaddingConfigDimension {
97    // Padding amount on the low-end (next to the index 0). May be negative.
98    int64 edge_padding_low = 1;
99
100    // Padding amount on the high-end (next to the highest index). May be
101    // negative.
102    int64 edge_padding_high = 2;
103
104    // Padding amount between the elements. May not be negative.
105    int64 interior_padding = 3;
106  }
107
108  // The padding configuration for all dimensions.
109  repeated PaddingConfigDimension dimensions = 1;
110}
111
112// A format specifies the method used by a layout to store an array in memory.
113enum Format {
114  // TODO(b/120869032): Rename this to FORMAT_NONE or something else which
115  // better corresponds to its meaning.
116  INVALID_FORMAT = 0;
117  // The default layout, with exactly one storage location per element.
118  DENSE = 1;
119  reserved 2;
120  reserved "SPARSE";
121}
122
123// Describes a tile used in tiling-based layout. Refer to
124// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
125// details about tiling-based layout.
126message TileProto {
127  // Number of elements in each dimension of the tile. It's ordered from the
128  // most major dimension of the tile to the most minor dimension of the tile.
129  // The dimensions correspond to a suffix of the dimensions of the shape being
130  // tiled.
131  repeated int64 dimensions = 1;
132}
133
134// A layout describes how the array is placed in (1D) memory space.  This
135// includes the minor-to-major ordering of dimensions within a shape.
136//
137// Clients must specify the layouts of input Literals to the
138// computation. Layouts specified in interior operations which take Shapes (for
139// example, Convert) are ignored.
140//
141// See the XLA documentation for more information on shapes and layouts.
142//
143// LINT.IfChange
144message LayoutProto {
145  // The method used to store the data in memory. The format determines which of
146  // the other fields are used by the layout.
147  Format format = 4;
148
149  // Sequence of dimension numbers, from minor (fastest varying index) to major
150  // (slowest varying index). This field is required.
151  repeated int64 minor_to_major = 1;
152
153  reserved 2;
154  reserved "padded_dimensions";
155
156  reserved 3;
157  reserved "padding_value";
158
159  reserved 5;
160  reserved "max_sparse_elements";
161
162  // A sequence of tiles, starting from the tile that's applied first to the
163  // Shape.
164  //
165  // TODO(b/119839262): implement tiling in each backend or add Unimplemented
166  // error.
167  repeated TileProto tiles = 6;
168
169  // Bit size of each element. If the size is bigger than what the element
170  // type requires, the value is stored in the least significant
171  // bits and the additional most significant bits are filled with 0's.
172  //
173  // TODO(b/119839262): implement in each backend or add Unimplemented error.
174  int64 element_size_in_bits = 7;
175
176  // Memory space where this array resides. The integer field is interpreted in
177  // a backend-specific manner.
178  int64 memory_space = 8;
179
180  // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
181  // LayoutUtil::Hash appropriately to account for the new field.
182}
183// LINT.ThenChange( \
184//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,      \
185//     https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc)
186
187// A shape describes the number of dimensions in the array, the size of each
188// dimension, and the primitive component type.
189//
190// Tuples are a special case in that they have rank zero and have tuple_shapes
191// defined.
192//
193// See the XLA documentation for more information on shapes and layouts.
194//
195// LINT.IfChange
196message ShapeProto {
197  reserved 1;
198  reserved "rank";
199
200  // The element type for this shape.
201  PrimitiveType element_type = 2;
202
203  // The size (number of elements) for each dimension, or an upper bound on the
204  // size if the dimension is dynamic.  In XLA, dimensions are numbered from 0
205  // to N-1 for an N-dimensional array. The first element of 'dimensions' is the
206  // size of dimension 0, the second element is the size of dimension 1, and so
207  // forth.  Empty list indicates a scalar.
208  //
209  // If the respective element in 'is_dimension_dynamic' is true then the value
210  // in this field represents an upper bound on the size of the dimension.
211  repeated int64 dimensions = 3;
212
213  // For tuples only, the shapes of constituent shapes in the tuple sequence.
214  repeated ShapeProto tuple_shapes = 4;
215
216  // The layout used to back this shape.
217  LayoutProto layout = 5;
218
219  // For arrays, this indicates whether or not each dimension is
220  // dynamically-sized. The number of elements in this repeated field should be
221  // zero (indicating that no dimensions are dynamic) or equal to the number of
222  // elements in the 'dimensions' field.
223  repeated bool is_dynamic_dimension = 6;
224
225  // Important: if any field is added, be sure to modify ShapeUtil::Equal(),
226  // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
227  // the new field.
228}
229// LINT.ThenChange( \
230//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc)
231
232// Shape of the parameters and output of a computation (like a traditional
233// function signature).
234message ProgramShapeProto {
235  repeated ShapeProto parameters = 1;
236  ShapeProto result = 2;
237  repeated string parameter_names = 3;
238}
239
240// Statistics of a computation.
241message ComputationStats {
242  // The number of floating point operations in the computation.
243  double flop_count = 1;
244
245  // The number of transcendental operations (e.g., exp) in the computation.
246  double transcendental_count = 2;
247}
248
249// The type optimization profiles in use.
250enum ProfileType {
251  INVALID = 0;
252  WINDOW = 1;
253  FLAG = 2;
254  INTEGER = 3;
255}
256
257// Symbolization metadata for HLO Instructions.
258//
259// This metadata is used for debugging XLA code generation, as well as
260// performance profiling of XLA-generated executables.
261message OpMetadata {
262  // The framework op name that generated this XLA op.
263  //
264  // Frameworks that build on top of XLA should mirror the names of their ops
265  // back to users by specifying the op_type. In this way, even if the
266  // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
267  // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
268  // multiple ops, then each op should have the op_type be "SoftMax".)
269  string op_type = 1;
270  // The user-specified name of the op.
271  //
272  // This name is often unique within a computation. Note: some frameworks
273  // add auto-generated names if the user does not provide one.
274  string op_name = 2;
275  // Indicate a file and line that this op is associated to in a user's program.
276  //
277  // e.g. it could be the file and line of user code that generated the op.
278  string source_file = 3;
279  int32 source_line = 4;
280
281  repeated ProfileType profile_type = 5;
282
283  // HloPassMetadata.pass_id of the pass that created this HLO instruction
284  // object. Should never be copied between HLO instructions. Zero if unset and
285  // -1 if the instruction was created before HLO passes began.
286  int64 creation_pass_id = 6;
287
288  // HloPassMetadata.pass_id of the pass that created the logical functionality
289  // that this HLO instruction represents. Should be copied between HLO
290  // instructions that correspond across compilation passes. Zero if unset and
291  // -1 if the instruction was created before HLO passes began.
292  int64 logical_creation_pass_id = 7;
293
294  // The footprint of the generated code for the instruction.
295  int64 size_of_generated_code_in_bytes = 8;
296  // The size of the working set, i.e., the amount of memory, used by the
297  // instruction in a compiler-managed fast device memory.
298  int64 size_of_memory_working_set_in_bytes = 9;
299}
300
301// Profile data from the execution of a computation.
302message ExecutionProfile {
303  // Whether the executable was read from the compilation cache.
304  bool compilation_cache_hit = 1;
305
306  // The time in milliseconds spent to compile the computation. This only set if
307  // the executable was not read from the compilation cache
308  // (compilation_cache_hit == false).
309  int64 compile_time_ms = 2;
310
311  // The number of cycles spent for the computation. This does not include the
312  // time taken for the data transfers between the host and the device. This is
313  // a target-dependent field and only used for debugging purposes.
314  int64 compute_cycle_count = 3;
315
316  // The time in nanoseconds spent for the computation, without data transfer.
317  int64 compute_time_ns = 4;
318
319  // The time in nanoseconds spent for the entire computation, including the
320  // result data transfer time. Current implementation does not spend any cycles
321  // for the input data transfer since the memory is initialized with the proper
322  // values before the execution.
323  int64 compute_and_transfer_time_ns = 5;
324
325  // The size of the binary code in the executable.
326  int64 executable_size_in_bytes = 6;
327
328  // Whether this profile was drawn from a cache of profiles instead of from
329  // execution on the hardware.
330  bool profile_cache_hit = 7;
331}
332
333// Handle given to a user that represents an execution that the user launched
334// asynchronously on the device.
335message ExecutionHandle {
336  int64 handle = 1;
337}
338
339// Handle given to a user that represents a globally accessible allocation.
340// Contrast this against a ComputationDataHandle, which is not globally
341// accessible, since it only exists within a specific computation.
342message GlobalDataHandle {
343  int64 handle = 1;
344}
345
346// Handle given to a user that represents a replicated virtual device. Each
347// replicated device represents N physical devices for execution where N is the
348// number of replicas.
349message DeviceHandle {
350  int64 handle = 1;
351
352  // The number of model-parallel virtual devices that communicate via XLA
353  // Send/Recv instructions.
354  int64 device_count = 2;
355}
356
357// Handle given to a user to represent a channel between two computations
358// via a Send and Recv instruction pair. Channels are unbuffered, so Send
359// Send instructions will be blocked until the data is transferred.
360message ChannelHandle {
361  int64 handle = 1;
362  enum ChannelType {
363    // Invalid primitive type to serve as default.
364    CHANNEL_TYPE_INVALID = 0;
365
366    // A channel for sending data between devices.
367    DEVICE_TO_DEVICE = 1;
368
369    // A channel for sending data from the device to the host. Can only be used
370    // with a Send operation.
371    DEVICE_TO_HOST = 2;
372
373    // A channel for sending data from the host to the device. Can only be used
374    // with a Recv operation.
375    HOST_TO_DEVICE = 3;
376  }
377  ChannelType type = 2;
378}
379
380// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
381// represents the device ids assigned to a set of replicated computations.
382// See xla::DeviceAssignment class comment for more details.
383message DeviceAssignmentProto {
384  int32 replica_count = 1;
385  int32 computation_count = 2;
386
387  // Each logical computation runs on replica_count physical devices.
388  // ComputationDevice represents the device ids assinged to the replicas.
389  message ComputationDevice {
390    repeated int32 replica_device_ids = 1;
391  }
392  repeated ComputationDevice computation_devices = 3;
393}
394
395// Literals are used when the server and client need to exchange materialized
396// data / results. Literals are also used to describe constants used in
397// computations.
398//
399// Transfers to/from the client are encoded in literal form, and the structure
400// of the repeated fields is implied by the shape.
401message LiteralProto {
402  ShapeProto shape = 1;
403  repeated bool preds = 2;
404  bytes s8s = 15;
405  bytes u8s = 3;
406  repeated int32 s32s = 4;
407  repeated int64 s64s = 5;
408  repeated uint32 u32s = 6;
409  repeated uint64 u64s = 7;
410  repeated float f32s = 8;
411  repeated double f64s = 9;
412  repeated float c64s = 12;    // Stored as interleaved real, imag floats.
413  repeated double c128s = 18;  // Stored as interleaved real, imag doubles.
414  repeated LiteralProto tuple_literals = 10;
415  // The F16s, BF16s, U16s and S16s are encoded in little endian byte order
416  bytes f16s = 11;
417  bytes bf16s = 13;
418  bytes u16s = 16;
419  bytes s16s = 17;
420  repeated int64 sparse_indices = 14;
421  // Next = 19
422}
423
424message WindowDimension {
425  // The size of the window in this dimension. For a rectangle, this would be
426  // the width or height.
427  int64 size = 1;
428
429  // The stride at which the window moves across the base area in this
430  // dimension. In other words, this is the spacing between different
431  // positions of the window in this dimension.
432  int64 stride = 2;
433
434  // If positive, means the amount of padding to add to the base area at the low
435  // end of this dimension; if negative, its negative means the number of
436  // elements removed from the low end of this dimension. For example, in the
437  // horizontal dimension of a rectangle, this would be the number of padding
438  // values to pad on the left, given that indices increase when going right.
439  // The actual padding value depends upon the context. Convolution pads with
440  // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's
441  // init value.
442  int64 padding_low = 3;
443
444  // As padding_low, but on the high end of this dimension. For example, in the
445  // horizontal dimension of a rectangle, this would be the number of values to
446  // pad on the right, given that indices increase when going right.
447  int64 padding_high = 4;
448
449  // Dilation factor of the sliding window in this dimension. A dilation factor
450  // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
451  // implicitly placed between each kernel element. This value may not be less
452  // than 1. See documentation for convolution.
453  int64 window_dilation = 5;
454
455  // Dilation factor of the base area in this dimension. A dilation factor of 1
456  // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
457  // placed between each base area element. This value may not be less than 1.
458  // See documentation for convolution.
459  int64 base_dilation = 6;
460
461  // Window reversal means that this dimension was logically reversed before the
462  // operation.
463  bool window_reversal = 7;
464}
465
466// Describes the windowing in an operation such as convolution.
467//
468// The window is moved across a base area and for each position of the
469// window a computation is performed. The field below describes the
470// window and the movement of the window across a base area.
471message Window {
472  repeated WindowDimension dimensions = 1;
473}
474
475// Describes the dimension numbers for a gather operation.
476//
477// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
478// more details.
479message GatherDimensionNumbers {
480  // "Window indices" is a term for a set of indices that index into the
481  // interior of a dynamic-slice from the input tensor, the starting indices for
482  // which were computed from output_gather_dims (see the operation semantic for
483  // how this is defined) and the start_indices tensor.
484  //
485  // The window indices for a specific output index Out is computed as:
486  //
487  //  i = 0
488  //  for (k : [0, input_tensor_shape.rank))
489  //    window_indices[k] =
490  //      if k in collapsed_slice_dims
491  //      then 0
492  //      else Out[offset_dims[i++]]
493  repeated int64 offset_dims = 1;
494  repeated int64 collapsed_slice_dims = 2;
495
496  // This is interpreted as a map from i to start_index_map[i]. It
497  // transforms the gather index looked up from the start_indices tensor into
498  // the starting index in the input space.
499  repeated int64 start_index_map = 3;
500
501  // The dimension in the start_indices input that contains the starting
502  // indices.
503  int64 index_vector_dim = 4;
504}
505
506// Describes the dimension numbers for a scatter operation.
507//
508// All the fields are similar to the corresponding fields in
509// GatherDimensionNumbers. Differences are noted below.
510message ScatterDimensionNumbers {
511  // The set of dimensions in the updates shape that are window dimensions.
512  repeated int64 update_window_dims = 1;
513  // The set of window dimensions that must be inserted into the updates shape.
514  repeated int64 inserted_window_dims = 2;
515
516  repeated int64 scatter_dims_to_operand_dims = 3;
517  int64 index_vector_dim = 4;
518}
519
520message ConvolutionDimensionNumbers {
521  // The number of the dimension that represents batch in the input.
522  int64 input_batch_dimension = 7;
523
524  // The number of the dimension that represents features in the input.
525  int64 input_feature_dimension = 8;
526
527  // The dimension numbers for the spatial dimensions that the window
528  // moves through in the input.
529  repeated int64 input_spatial_dimensions = 11;
530
531  // The number of the dimension that represents input features in the
532  // convolutional kernel (rhs).
533  int64 kernel_input_feature_dimension = 3;
534
535  // The number of the dimension that represents output features in
536  // the convolutional kernel (rhs).
537  int64 kernel_output_feature_dimension = 4;
538
539  // The dimension numbers for the spatial dimensions that the window
540  // moves through in the kernel (rhs). window.strides(0) is the
541  // stride in the kernel_spatial_dimensions(0) dimension.
542  repeated int64 kernel_spatial_dimensions = 6;
543
544  // The number of the dimension that represents batch in the output.
545  int64 output_batch_dimension = 9;
546
547  // The number of the dimension that represents features in the output.
548  int64 output_feature_dimension = 10;
549
550  // The dimension numbers for the spatial dimensions that the window
551  // moves through in the output.
552  repeated int64 output_spatial_dimensions = 12;
553
554  // Next = 13
555}
556
557enum PaddingType {
558  PADDING_INVALID = 0;
559  PADDING_VALID = 1;  // Only valid portion of the base are covered.
560  PADDING_SAME = 2;  // Extra is added to produce same output size as the input.
561}
562
563enum FftType {
564  FFT = 0;    // Forward FFT; complex in, complex out.
565  IFFT = 1;   // Inverse FFT; complex in, complex out.
566  RFFT = 2;   // Forward real FFT; real in, fft_length / 2 + 1 complex out
567  IRFFT = 3;  // Inverse real FFT; fft_length / 2 + 1 complex in,
568              //                   fft_length real out
569}
570
571message DotDimensionNumbers {
572  // The dimension numbers that represent the 'lhs' contracting dimensions.
573  repeated int64 lhs_contracting_dimensions = 1;
574  // The dimension numbers that represent the 'rhs' contracting dimensions.
575  repeated int64 rhs_contracting_dimensions = 2;
576  // The dimension numbers that represent the 'lhs' batch dimensions.
577  repeated int64 lhs_batch_dimensions = 3;
578  // The dimension numbers that represent the 'rhs' batch dimensions.
579  repeated int64 rhs_batch_dimensions = 4;
580}
581
582enum RandomDistribution {
583  RNG_INVALID = 0;
584
585  // Creates a uniform-distribution-generated random number on the semi-open
586  // interval [parameter[0], parameter[1]).
587  RNG_UNIFORM = 1;
588
589  // Creates a normal-distribution-generated random number with mean
590  // parameter[0] and standard deviation parameter[1].
591  RNG_NORMAL = 2;
592
593  // Next: 4
594}
595
596enum RandomAlgorithm {
597  RNG_DEFAULT = 0;  // Backend dependent default algorithm.
598  RNG_THREE_FRY = 1;
599  RNG_PHILOX = 2;
600  // Next: 2
601}
602
603message TriangularSolveOptions {
604  // If true, solves ax = b. If false, solves xa = b.
605  bool left_side = 1;
606
607  // If true, 'a' is lower triangular. If false, 'a' is upper triangular.
608  bool lower = 2;
609
610  // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed.
611  bool unit_diagonal = 3;
612
613  // Should we transpose or use the adjoint of 'a'?
614  enum Transpose {
615    TRANSPOSE_INVALID = 0;
616    NO_TRANSPOSE = 1;  // Don't transpose 'a'.
617    TRANSPOSE = 2;     // Transpose 'a'.
618    ADJOINT = 3;       // Complex conjugate and transpose 'a'.
619  }
620  Transpose transpose_a = 4;
621}
622
623message CholeskyOptions {
624  // If true, uses the lower triangle of `a`. If false, uses the upper triangle
625  // of `a`.
626  bool lower = 1;
627}
628
629// Generic map of attributes used to pass hints / configuration options from
630// the Python frontend to the XLA backend.
631message FrontendAttributes {
632  map<string, string> map = 1;
633}
634
635message OpSharding {
636  enum Type {
637    // This sharding is replicated across all devices (implies maximal,
638    // all other fields are unused).
639    REPLICATED = 0;
640    // This sharding is maximal - one device runs the entire operation.
641    MAXIMAL = 1;
642    // This sharding is a tuple - only the tuple_shardings field is valid.
643    TUPLE = 2;
644    // None of the above; tile_shape and tile_assignment are both used.
645    OTHER = 3;
646    // This op is manually sharded: the shapes are already partitioned and the
647    // partitioner should not change this op.
648    MANUAL = 4;
649  }
650  Type type = 1;
651  // The shape of the sharded tile.
652  ShapeProto tile_shape = 2;
653  // The shape of the tile assignment tensor - this must be the same rank as
654  // tile_shape and the product of its dimensions must equal
655  // tile_assignment_devices.size().
656  repeated int64 tile_assignment_dimensions = 3;
657  // Flattened list of device IDs. The order of flattening is the same as used
658  // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
659  repeated int64 tile_assignment_devices = 4;
660  // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
661  // in pre-order. The tuple shape could be nested; here we store just a
662  // flattened list of all leaves in the tuple shape. Note that the tuple shape
663  // is not stored here; shardings do not store the shapes to which they are
664  // applied, this is inferred from the instruction this sharding gets attached
665  // to.
666  repeated OpSharding tuple_shardings = 5;
667
668  // Only used for OTHER type. If true, data is sharded according to other
669  // dimensions of tile_assignment(), but replicated across devices along the
670  // last dimension. (Experimental)
671  bool replicate_on_last_tile_dim = 6;
672  // This field is used to track the source of this sharding, usually derived
673  // from instructions. Multple metadata may be populated if sharding is
674  // combined with other shardings.  Metadata are to not be populated when
675  // type == TUPLE and instead metadata should be set on individual tuple
676  // elements.
677  repeated OpMetadata metadata = 7;
678
679  // This field is used to represented the sharding type of each subgroup.
680  // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={
681  // replicate, manual, unreduced}} means that each of the last 3 dimensions
682  // in [2,2,2,2] represents a subgrouping in replicate, manual,
683  // unreduced sharding type respectively.
684  repeated Type last_tile_dims = 8;
685}
686
687// Describes the replica groups in a cross replica op (e.g., all-reduce and
688// all-to-all).
689message ReplicaGroup {
690  // The ids of the replicas that belongs to the same group. The ordering of the
691  // ids matters in some ops (e.g., all-to-all).
692  repeated int64 replica_ids = 1;
693}
694
695// Describes the source target pair in the collective permute op.
696message SourceTarget {
697  int64 source = 1;
698  int64 target = 2;
699}
700
701// Used to indicate the precision configuration. It has backend specific
702// meaning.
703message PrecisionConfig {
704  enum Precision {
705    DEFAULT = 0;
706    HIGH = 1;
707    HIGHEST = 2;
708
709    // Next: 3
710  }
711  repeated Precision operand_precision = 1;
712
713  // Next: 2
714}
715
716// Describes whether all data-parallelism replicas will receive the same
717// parameter data at each buffer.
718message ParameterReplication {
719  // A list of boolean values for the flattened leaf buffers. Each value
720  // indicates whether the corresponding leaf buffer is replicated.
721  //
722  // If this field is empty, it means no buffer is replicated. Otherwise, the
723  // number of elements in this field must match the number of leaf buffers in
724  // the HLO instruction's shape.
725  repeated bool replicated_at_leaf_buffers = 1;
726}
727
728// A backend-config for kWhile loops that stores the loop's trip count, if it is
729// known.
730//
731// This is useful for backends that can implement a `for i in 0..N` loop more
732// efficiently than a `while` loop.  For example, on GPUs, we can implement a
733// `for i in 0..N` loop by enqueueing the kernels for the loop body N times,
734// whereas implementing a `while` loop requires a host-device sync on each
735// iteration.
736message WhileLoopBackendConfig {
737  message KnownTripCount {
738    int64 n = 1;
739  }
740  // This indirection lets us distinguish between known-trip-count == 0 and
741  // unknown-trip-count.
742  KnownTripCount known_trip_count = 1;
743}
744
745// Specifies a pair of output/operand buffers for kCustomCall that alias each
746// other.
747message CustomCallOutputOperandAliasing {
748  repeated int64 output_shape_index = 1;
749  int64 operand_index = 2;
750  repeated int64 operand_shape_index = 3;
751}
752