• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18package xla;
19
20option cc_enable_arenas = true;
21
22// Primitive types are the individual values that can be held in rectangular
23// multidimensional arrays. A description of the rectangular multidimensional
24// array dimensions / primitive type is given by Shape, below.
25//
26// LINT.IfChange
27enum PrimitiveType {
28  // Invalid primitive type to serve as default.
29  PRIMITIVE_TYPE_INVALID = 0;
30
31  // Predicates are two-state booleans.
32  PRED = 1;
33
34  // Signed integral values of fixed width.
35  S8 = 2;
36  S16 = 3;
37  S32 = 4;
38  S64 = 5;
39
40  // Unsigned integral values of fixed width.
41  U8 = 6;
42  U16 = 7;
43  U32 = 8;
44  U64 = 9;
45
46  // Floating-point values of fixed width.
47  //
48  // Note: if f16s are not natively supported on the device, they will be
49  // converted to f16 from f32 at arbirary points in the computation.
50  F16 = 10;
51  F32 = 11;
52
53  // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
54  // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
55  // and 7 bits for the mantissa.
56  BF16 = 16;
57
58  F64 = 12;
59
60  // Complex values of fixed width.
61  C64 = 15;   // Paired F32 (real, imag), as in std::complex<float>.
62  C128 = 18;  // Paired F64 (real, imag), as in std::complex<double>.
63
64  // A tuple is a polymorphic sequence; e.g. a shape that holds different
65  // sub-shapes. They are used for things like returning multiple values from a
66  // computation; e.g. a computation that returns weights and biases may have a
67  // signature that results in a tuple like (f32[784x2000], f32[2000])
68  //
69  // If a shape proto has the tuple element type, it may not have any entries
70  // in the dimensions field.
71  TUPLE = 13;
72
73  // An opaque type used for passing context-specific data to a custom
74  // operation. Shapes of this primitive type will have empty dimensions and
75  // tuple_shapes fields.
76  //
77  // (OPAQUE would be a better name for this identifier, but that conflicts with
78  // a macro defined in windows.h.)
79  OPAQUE_TYPE = 14;
80
81  // A token type threaded between side-effecting operations. Shapes of this
82  // primitive type will have empty dimensions and tuple_shapes fields.
83  TOKEN = 17;
84
85  // Next = 19
86}
87// LINT.ThenChange(
88//   https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,
89//   https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc
90// )
91
92// Describes the padding configuration for Pad operation. The padding amount on
93// both edges as well as between the elements are specified for each dimension.
94message PaddingConfig {
95  // Describes the padding configuration for a dimension.
96  message PaddingConfigDimension {
97    // Padding amount on the low-end (next to the index 0). May be negative.
98    int64 edge_padding_low = 1;
99
100    // Padding amount on the high-end (next to the highest index). May be
101    // negative.
102    int64 edge_padding_high = 2;
103
104    // Padding amount between the elements. May not be negative.
105    int64 interior_padding = 3;
106  }
107
108  // The padding configuration for all dimensions.
109  repeated PaddingConfigDimension dimensions = 1;
110}
111
112// A DimLevelType indicates the encoding method for a dimension in an array.
113// The semantics of this field are identical to those of the MLIR SparseTensor
114// dialect.
115// This should be kept in sync with the SparseTensor DimLevelType enum:
116// https://github.com/llvm/llvm-project/blob/5674a3c88088e668b684326c2194a6282e8270ff/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td#L86
117enum DimLevelType {
118  // The corresponding dimension is Dense, every entry is stored.
119  DIM_DENSE = 0;
120  // The corresponding dimension is Compressed, only nonzeros are stored.
121  DIM_COMPRESSED = 1;
122  // The corresponding dimension contains a single coordinate, no sibling
123  // elements for each parent.
124  DIM_SINGLETON = 2;
125}
126
127// Describes a tile used in tiling-based layout. Refer to
128// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
129// details about tiling-based layout.
130message TileProto {
131  // Number of elements in each dimension of the tile. It's ordered from the
132  // most major dimension of the tile to the most minor dimension of the tile.
133  // The dimensions correspond to a suffix of the dimensions of the shape being
134  // tiled.
135  repeated int64 dimensions = 1;
136}
137
138// A layout describes how the array is placed in (1D) memory space.  This
139// includes the minor-to-major ordering of dimensions within a shape.
140//
141// Clients must specify the layouts of input Literals to the
142// computation. Layouts specified in interior operations which take Shapes (for
143// example, Convert) are ignored.
144//
145// See the XLA documentation for more information on shapes and layouts.
146//
147// LINT.IfChange
148message LayoutProto {
149  // The dimension level type list for this array, specifying the way in which
150  // each array dimension is represented in memory. If this list is empty, the
151  // array is assumed to be dense.
152  repeated DimLevelType dim_level_types = 9;
153
154  // Sequence of dimension numbers, from minor (fastest varying index) to major
155  // (slowest varying index). This field is required.
156  repeated int64 minor_to_major = 1;
157
158  // A sequence of tiles, starting from the tile that's applied first to the
159  // Shape.
160  //
161  // TODO(b/119839262): implement tiling in each backend or add Unimplemented
162  // error.
163  repeated TileProto tiles = 6;
164
165  // Bit size of each element. If the size is bigger than what the element
166  // type requires, the value is stored in the least significant
167  // bits and the additional most significant bits are filled with 0's.
168  //
169  // TODO(b/119839262): implement in each backend or add Unimplemented error.
170  int64 element_size_in_bits = 7;
171
172  // Memory space where this array resides. The integer field is interpreted in
173  // a backend-specific manner.
174  int64 memory_space = 8;
175
176  // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
177  // LayoutUtil::Hash appropriately to account for the new field.
178
179  reserved 2;
180  reserved "padded_dimensions";
181  reserved 3;
182  reserved "padding_value";
183  reserved 4;
184  reserved "format";
185  reserved 5;
186  reserved "max_sparse_elements";
187}
188// LINT.ThenChange( \
189//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,      \
190//     https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc)
191
192// A shape describes the number of dimensions in the array, the size of each
193// dimension, and the primitive component type.
194//
195// Tuples are a special case in that they have rank zero and have tuple_shapes
196// defined.
197//
198// See the XLA documentation for more information on shapes and layouts.
199//
200// LINT.IfChange
201message ShapeProto {
202  reserved 1;
203  reserved "rank";
204
205  // The element type for this shape.
206  PrimitiveType element_type = 2;
207
208  // The size (number of elements) for each dimension, or an upper bound on the
209  // size if the dimension is dynamic.  In XLA, dimensions are numbered from 0
210  // to N-1 for an N-dimensional array. The first element of 'dimensions' is the
211  // size of dimension 0, the second element is the size of dimension 1, and so
212  // forth.  Empty list indicates a scalar.
213  //
214  // If the respective element in 'is_dimension_dynamic' is true then the value
215  // in this field represents an upper bound on the size of the dimension.
216  repeated int64 dimensions = 3;
217
218  // For tuples only, the shapes of constituent shapes in the tuple sequence.
219  repeated ShapeProto tuple_shapes = 4;
220
221  // The layout used to back this shape.
222  LayoutProto layout = 5;
223
224  // For arrays, this indicates whether or not each dimension is
225  // dynamically-sized. The number of elements in this repeated field should be
226  // zero (indicating that no dimensions are dynamic) or equal to the number of
227  // elements in the 'dimensions' field.
228  repeated bool is_dynamic_dimension = 6;
229
230  // Important: if any field is added, be sure to modify ShapeUtil::Equal(),
231  // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
232  // the new field.
233}
234// LINT.ThenChange( \
235//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc)
236
237// Shape of the parameters and output of a computation (like a traditional
238// function signature).
239message ProgramShapeProto {
240  repeated ShapeProto parameters = 1;
241  ShapeProto result = 2;
242  repeated string parameter_names = 3;
243}
244
245// Statistics of a computation.
246message ComputationStats {
247  // The number of floating point operations in the computation.
248  double flop_count = 1;
249
250  // The number of transcendental operations (e.g., exp) in the computation.
251  double transcendental_count = 2;
252}
253
254// The type optimization profiles in use for Op-level optimizations.
255enum ProfileType {
256  INVALID = 0;
257  WINDOW = 1;
258  FLAG = 2;
259  INTEGER = 3;
260}
261
262// The source of the optimization profile.
263enum ProfileSource {
264  PROFILE_SOURCE_UNKNOWN_SOURCE = 0;
265  PROFILE_SOURCE_EMBEDDED = 1;
266  PROFILE_SOURCE_REMOTE = 2;
267}
268
269// The compilation event that triggered the use of the profile.
270enum CompilationEvent {
271  COMPILATION_EVENT_UNKNOWN_EVENT = 0;
272  COMPILATION_EVENT_FIRST_COMPILATION = 1;
273  COMPILATION_EVENT_RECOMPILATION = 2;
274}
275
276// Symbolization metadata for HLO Instructions.
277//
278// This metadata is used for debugging XLA code generation, as well as
279// performance profiling of XLA-generated executables.
280message OpMetadata {
281  // The framework op name that generated this XLA op.
282  //
283  // Frameworks that build on top of XLA should mirror the names of their ops
284  // back to users by specifying the op_type. In this way, even if the
285  // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
286  // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
287  // multiple ops, then each op should have the op_type be "SoftMax".)
288  string op_type = 1;
289  // The user-specified name of the op.
290  //
291  // This name is often unique within a computation. Note: some frameworks
292  // add auto-generated names if the user does not provide one.
293  string op_name = 2;
294  // Indicate a file and line that this op is associated to in a user's program.
295  //
296  // e.g. it could be the file and line of user code that generated the op.
297  string source_file = 3;
298  int32 source_line = 4;
299
300  // Deprecated, use [ProfileInfo][profile_type] instead.
301  repeated ProfileType profile_type = 5 [deprecated = true];
302
303  // HloPassMetadata.pass_id of the pass that created this HLO instruction
304  // object. Should never be copied between HLO instructions. Zero if unset and
305  // -1 if the instruction was created before HLO passes began.
306  int64 creation_pass_id = 6;
307
308  // HloPassMetadata.pass_id of the pass that created the logical functionality
309  // that this HLO instruction represents. Should be copied between HLO
310  // instructions that correspond across compilation passes. Zero if unset and
311  // -1 if the instruction was created before HLO passes began.
312  int64 logical_creation_pass_id = 7;
313
314  // The footprint of the generated code for the instruction.
315  int64 size_of_generated_code_in_bytes = 8;
316  // The size of the working set, i.e., the amount of memory, used by the
317  // instruction in a compiler-managed fast device memory.
318  int64 size_of_memory_working_set_in_bytes = 9;
319
320  // Information about the optimization profile that this operation contains.
321  message ProfileInfo {
322    // The type of optimization profiles that this operation contains.
323    repeated ProfileType profile_type = 1;
324    // Speedup of tuned config compared to default config.
325    // TODO(b/203817882) Set the relative_speedup.
326    double relative_speedup = 2;
327    // The source of the optimization profiles that this operation contains.
328    ProfileSource profile_source = 3;
329    // The compilation event that triggered the use of the profiles.
330    CompilationEvent compilation_event = 4;
331  }
332
333  // Profile information for the Op.
334  ProfileInfo profile_info = 10;
335
336  // Information about the replaced operation. We store the HLO canonical text,
337  // which contains canonicalized instruction name, operand shape, and full
338  // bodies of the subcomputations besides the HLO itself. This is set
339  // independently of other fields.
340  string replaced_op = 11;
341}
342
343// Profile data from the execution of a computation.
344message ExecutionProfile {
345  // Whether the executable was read from the compilation cache.
346  bool compilation_cache_hit = 1;
347
348  // The time in milliseconds spent to compile the computation. This only set if
349  // the executable was not read from the compilation cache
350  // (compilation_cache_hit == false).
351  int64 compile_time_ms = 2;
352
353  // The number of cycles spent for the computation. This does not include the
354  // time taken for the data transfers between the host and the device. This is
355  // a target-dependent field and only used for debugging purposes.
356  int64 compute_cycle_count = 3;
357
358  // The time in nanoseconds spent for the computation, without data transfer.
359  int64 compute_time_ns = 4;
360
361  // The time in nanoseconds spent for the entire computation, including the
362  // result data transfer time. Current implementation does not spend any cycles
363  // for the input data transfer since the memory is initialized with the proper
364  // values before the execution.
365  int64 compute_and_transfer_time_ns = 5;
366
367  // The size of the binary code in the executable.
368  int64 executable_size_in_bytes = 6;
369
370  // Whether this profile was drawn from a cache of profiles instead of from
371  // execution on the hardware.
372  bool profile_cache_hit = 7;
373}
374
375// Handle given to a user that represents an execution that the user launched
376// asynchronously on the device.
377message ExecutionHandle {
378  int64 handle = 1;
379}
380
381// Handle given to a user that represents a globally accessible allocation.
382// Contrast this against a ComputationDataHandle, which is not globally
383// accessible, since it only exists within a specific computation.
384message GlobalDataHandle {
385  int64 handle = 1;
386}
387
388// Handle given to a user that represents a replicated virtual device. Each
389// replicated device represents N physical devices for execution where N is the
390// number of replicas.
391message DeviceHandle {
392  int64 handle = 1;
393
394  // The number of model-parallel virtual devices that communicate via XLA
395  // Send/Recv instructions.
396  int64 device_count = 2;
397}
398
399// Handle given to a user to represent a channel between two computations
400// via a Send and Recv instruction pair. Channels are unbuffered, so Send
401// Send instructions will be blocked until the data is transferred.
402message ChannelHandle {
403  int64 handle = 1;
404  enum ChannelType {
405    // Invalid primitive type to serve as default.
406    CHANNEL_TYPE_INVALID = 0;
407
408    // A channel for sending data between devices.
409    DEVICE_TO_DEVICE = 1;
410
411    // A channel for sending data from the device to the host. Can only be used
412    // with a Send operation.
413    DEVICE_TO_HOST = 2;
414
415    // A channel for sending data from the host to the device. Can only be used
416    // with a Recv operation.
417    HOST_TO_DEVICE = 3;
418  }
419  ChannelType type = 2;
420}
421
422// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
423// represents the device ids assigned to a set of replicated computations.
424// See xla::DeviceAssignment class comment for more details.
425message DeviceAssignmentProto {
426  int32 replica_count = 1;
427  int32 computation_count = 2;
428
429  // Each logical computation runs on replica_count physical devices.
430  // ComputationDevice represents the device ids assinged to the replicas.
431  message ComputationDevice {
432    repeated int32 replica_device_ids = 1;
433  }
434  repeated ComputationDevice computation_devices = 3;
435}
436
437// Literals are used when the server and client need to exchange materialized
438// data / results. Literals are also used to describe constants used in
439// computations.
440//
441// Transfers to/from the client are encoded in literal form, and the structure
442// of the repeated fields is implied by the shape.
443message LiteralProto {
444  ShapeProto shape = 1;
445  repeated bool preds = 2;
446  bytes s8s = 15;
447  bytes u8s = 3;
448  repeated int32 s32s = 4;
449  repeated int64 s64s = 5;
450  repeated uint32 u32s = 6;
451  repeated uint64 u64s = 7;
452  repeated float f32s = 8;
453  repeated double f64s = 9;
454  repeated float c64s = 12;    // Stored as interleaved real, imag floats.
455  repeated double c128s = 18;  // Stored as interleaved real, imag doubles.
456  repeated LiteralProto tuple_literals = 10;
457  // The F16s, BF16s, U16s and S16s are encoded in little endian byte order
458  bytes f16s = 11;
459  bytes bf16s = 13;
460  bytes u16s = 16;
461  bytes s16s = 17;
462  repeated int64 sparse_indices = 14;
463  // Next = 19
464}
465
466message WindowDimension {
467  // The size of the window in this dimension. For a rectangle, this would be
468  // the width or height.
469  int64 size = 1;
470
471  // The stride at which the window moves across the base area in this
472  // dimension. In other words, this is the spacing between different
473  // positions of the window in this dimension.
474  int64 stride = 2;
475
476  // If positive, means the amount of padding to add to the base area at the low
477  // end of this dimension; if negative, its negative means the number of
478  // elements removed from the low end of this dimension. For example, in the
479  // horizontal dimension of a rectangle, this would be the number of padding
480  // values to pad on the left, given that indices increase when going right.
481  // The actual padding value depends upon the context. Convolution pads with
482  // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's
483  // init value.
484  int64 padding_low = 3;
485
486  // As padding_low, but on the high end of this dimension. For example, in the
487  // horizontal dimension of a rectangle, this would be the number of values to
488  // pad on the right, given that indices increase when going right.
489  int64 padding_high = 4;
490
491  // Dilation factor of the sliding window in this dimension. A dilation factor
492  // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
493  // implicitly placed between each kernel element. This value may not be less
494  // than 1. See documentation for convolution.
495  int64 window_dilation = 5;
496
497  // Dilation factor of the base area in this dimension. A dilation factor of 1
498  // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
499  // placed between each base area element. This value may not be less than 1.
500  // See documentation for convolution.
501  int64 base_dilation = 6;
502
503  // Window reversal means that this dimension was logically reversed before the
504  // operation.
505  bool window_reversal = 7;
506}
507
508// Describes the windowing in an operation such as convolution.
509//
510// The window is moved across a base area and for each position of the
511// window a computation is performed. The field below describes the
512// window and the movement of the window across a base area.
513message Window {
514  repeated WindowDimension dimensions = 1;
515}
516
517// Describes the dimension numbers for a gather operation.
518//
519// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
520// more details.
521message GatherDimensionNumbers {
522  // "Window indices" is a term for a set of indices that index into the
523  // interior of a dynamic-slice from the input tensor, the starting indices for
524  // which were computed from output_gather_dims (see the operation semantic for
525  // how this is defined) and the start_indices tensor.
526  //
527  // The window indices for a specific output index Out is computed as:
528  //
529  //  i = 0
530  //  for (k : [0, input_tensor_shape.rank))
531  //    window_indices[k] =
532  //      if k in collapsed_slice_dims
533  //      then 0
534  //      else Out[offset_dims[i++]]
535  repeated int64 offset_dims = 1;
536  repeated int64 collapsed_slice_dims = 2;
537
538  // This is interpreted as a map from i to start_index_map[i]. It
539  // transforms the gather index looked up from the start_indices tensor into
540  // the starting index in the input space.
541  repeated int64 start_index_map = 3;
542
543  // The dimension in the start_indices input that contains the starting
544  // indices.
545  int64 index_vector_dim = 4;
546}
547
548// Describes the dimension numbers for a scatter operation.
549//
550// All the fields are similar to the corresponding fields in
551// GatherDimensionNumbers. Differences are noted below.
552message ScatterDimensionNumbers {
553  // The set of dimensions in the updates shape that are window dimensions.
554  repeated int64 update_window_dims = 1;
555  // The set of window dimensions that must be inserted into the updates shape.
556  repeated int64 inserted_window_dims = 2;
557
558  repeated int64 scatter_dims_to_operand_dims = 3;
559  int64 index_vector_dim = 4;
560}
561
562message ConvolutionDimensionNumbers {
563  // The number of the dimension that represents batch in the input.
564  int64 input_batch_dimension = 7;
565
566  // The number of the dimension that represents features in the input.
567  int64 input_feature_dimension = 8;
568
569  // The dimension numbers for the spatial dimensions that the window
570  // moves through in the input.
571  repeated int64 input_spatial_dimensions = 11;
572
573  // The number of the dimension that represents input features in the
574  // convolutional kernel (rhs).
575  int64 kernel_input_feature_dimension = 3;
576
577  // The number of the dimension that represents output features in
578  // the convolutional kernel (rhs).
579  int64 kernel_output_feature_dimension = 4;
580
581  // The dimension numbers for the spatial dimensions that the window
582  // moves through in the kernel (rhs). window.strides(0) is the
583  // stride in the kernel_spatial_dimensions(0) dimension.
584  repeated int64 kernel_spatial_dimensions = 6;
585
586  // The number of the dimension that represents batch in the output.
587  int64 output_batch_dimension = 9;
588
589  // The number of the dimension that represents features in the output.
590  int64 output_feature_dimension = 10;
591
592  // The dimension numbers for the spatial dimensions that the window
593  // moves through in the output.
594  repeated int64 output_spatial_dimensions = 12;
595
596  // Next = 13
597}
598
599enum PaddingType {
600  PADDING_INVALID = 0;
601  PADDING_VALID = 1;  // Only valid portion of the base are covered.
602  PADDING_SAME = 2;  // Extra is added to produce same output size as the input.
603}
604
605enum FftType {
606  FFT = 0;    // Forward FFT; complex in, complex out.
607  IFFT = 1;   // Inverse FFT; complex in, complex out.
608  RFFT = 2;   // Forward real FFT; real in, fft_length / 2 + 1 complex out
609  IRFFT = 3;  // Inverse real FFT; fft_length / 2 + 1 complex in,
610              //                   fft_length real out
611}
612
613message DotDimensionNumbers {
614  // The dimension numbers that represent the 'lhs' contracting dimensions.
615  repeated int64 lhs_contracting_dimensions = 1;
616  // The dimension numbers that represent the 'rhs' contracting dimensions.
617  repeated int64 rhs_contracting_dimensions = 2;
618  // The dimension numbers that represent the 'lhs' batch dimensions.
619  repeated int64 lhs_batch_dimensions = 3;
620  // The dimension numbers that represent the 'rhs' batch dimensions.
621  repeated int64 rhs_batch_dimensions = 4;
622}
623
624enum RandomDistribution {
625  RNG_INVALID = 0;
626
627  // Creates a uniform-distribution-generated random number on the semi-open
628  // interval [parameter[0], parameter[1]).
629  RNG_UNIFORM = 1;
630
631  // Creates a normal-distribution-generated random number with mean
632  // parameter[0] and standard deviation parameter[1].
633  RNG_NORMAL = 2;
634
635  // Next: 4
636}
637
638enum RandomAlgorithm {
639  RNG_DEFAULT = 0;  // Backend dependent default algorithm.
640  RNG_THREE_FRY = 1;
641  RNG_PHILOX = 2;
642  // Next: 2
643}
644
645message TriangularSolveOptions {
646  // If true, solves ax = b. If false, solves xa = b.
647  bool left_side = 1;
648
649  // If true, 'a' is lower triangular. If false, 'a' is upper triangular.
650  bool lower = 2;
651
652  // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed.
653  bool unit_diagonal = 3;
654
655  // Should we transpose or use the adjoint of 'a'?
656  enum Transpose {
657    TRANSPOSE_INVALID = 0;
658    NO_TRANSPOSE = 1;  // Don't transpose 'a'.
659    TRANSPOSE = 2;     // Transpose 'a'.
660    ADJOINT = 3;       // Complex conjugate and transpose 'a'.
661  }
662  Transpose transpose_a = 4;
663}
664
665message CholeskyOptions {
666  // If true, uses the lower triangle of `a`. If false, uses the upper triangle
667  // of `a`.
668  bool lower = 1;
669}
670
671// Generic map of attributes used to pass hints / configuration options from
672// the Python frontend to the XLA backend.
673message FrontendAttributes {
674  map<string, string> map = 1;
675}
676
677// LINT.IfChange
678message OpSharding {
679  enum Type {
680    // This sharding is replicated across all devices (implies maximal,
681    // all other fields are unused).
682    REPLICATED = 0;
683    // This sharding is maximal - one device runs the entire operation.
684    MAXIMAL = 1;
685    // This sharding is a tuple - only the tuple_shardings field is valid.
686    TUPLE = 2;
687    // None of the above; tile_shape and tile_assignment are both used.
688    OTHER = 3;
689    // This op is manually sharded: the shapes are already partitioned and the
690    // partitioner should not change this op.
691    MANUAL = 4;
692  }
693  Type type = 1;
694  // The shape of the sharded tile.
695  ShapeProto tile_shape = 2;
696  // The shape of the tile assignment tensor - this must be the same rank as
697  // tile_shape and the product of its dimensions must equal
698  // tile_assignment_devices.size().
699  repeated int64 tile_assignment_dimensions = 3;
700  // Flattened list of device IDs. The order of flattening is the same as used
701  // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
702  repeated int64 tile_assignment_devices = 4;
703  // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
704  // in pre-order. The tuple shape could be nested; here we store just a
705  // flattened list of all leaves in the tuple shape. Note that the tuple shape
706  // is not stored here; shardings do not store the shapes to which they are
707  // applied, this is inferred from the instruction this sharding gets attached
708  // to.
709  repeated OpSharding tuple_shardings = 5;
710
711  // Only used for OTHER type. If true, data is sharded according to other
712  // dimensions of tile_assignment(), but replicated across devices along the
713  // last dimension. (Experimental)
714  bool replicate_on_last_tile_dim = 6;
715  // This field is used to track the source of this sharding, usually derived
716  // from instructions. Multple metadata may be populated if sharding is
717  // combined with other shardings.  Metadata are to not be populated when
718  // type == TUPLE and instead metadata should be set on individual tuple
719  // elements.
720  repeated OpMetadata metadata = 7;
721
722  // This field is used to represented the sharding type of each subgroup.
723  // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={
724  // replicate, manual, unreduced}} means that each of the last 3 dimensions
725  // in [2,2,2,2] represents a subgrouping in replicate, manual,
726  // unreduced sharding type respectively.
727  repeated Type last_tile_dims = 8;
728}
729// LINT.ThenChange()
730
731// Describes the replica groups in a cross replica op (e.g., all-reduce and
732// all-to-all).
733message ReplicaGroup {
734  // The ids of the replicas that belongs to the same group. The ordering of the
735  // ids matters in some ops (e.g., all-to-all).
736  repeated int64 replica_ids = 1;
737}
738
739// Describes the source target pair in the collective permute op.
740message SourceTarget {
741  int64 source = 1;
742  int64 target = 2;
743}
744
745// Used to indicate the precision configuration. It has backend specific
746// meaning.
747message PrecisionConfig {
748  enum Precision {
749    DEFAULT = 0;
750    HIGH = 1;
751    HIGHEST = 2;
752
753    // Next: 3
754  }
755  repeated Precision operand_precision = 1;
756
757  // Next: 2
758}
759
760// Describes whether all data-parallelism replicas will receive the same
761// parameter data at each buffer.
762message ParameterReplication {
763  // A list of boolean values for the flattened leaf buffers. Each value
764  // indicates whether the corresponding leaf buffer is replicated.
765  //
766  // If this field is empty, it means no buffer is replicated. Otherwise, the
767  // number of elements in this field must match the number of leaf buffers in
768  // the HLO instruction's shape.
769  repeated bool replicated_at_leaf_buffers = 1;
770}
771
772// A backend-config for kWhile loops that stores the loop's trip count, if it is
773// known.
774//
775// This is useful for backends that can implement a `for i in 0..N` loop more
776// efficiently than a `while` loop.  For example, on GPUs, we can implement a
777// `for i in 0..N` loop by enqueueing the kernels for the loop body N times,
778// whereas implementing a `while` loop requires a host-device sync on each
779// iteration.
780message WhileLoopBackendConfig {
781  message KnownTripCount {
782    int64 n = 1;
783  }
784  // This indirection lets us distinguish between known-trip-count == 0 and
785  // unknown-trip-count.
786  KnownTripCount known_trip_count = 1;
787}
788
789// Specifies a pair of output/operand buffers for kCustomCall that alias each
790// other.
791message CustomCallOutputOperandAliasing {
792  repeated int64 output_shape_index = 1;
793  int64 operand_index = 2;
794  repeated int64 operand_shape_index = 3;
795}
796