• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18package xla;
19
20option cc_enable_arenas = true;
21
22// Primitive types are the individual values that can be held in rectangular
23// multidimensional arrays. A description of the rectangular multidimensional
24// array dimensions / primitive type is given by Shape, below.
25//
26// LINT.IfChange
27enum PrimitiveType {
28  // Invalid primitive type to serve as default.
29  PRIMITIVE_TYPE_INVALID = 0;
30
31  // Predicates are two-state booleans.
32  PRED = 1;
33
34  // Signed integral values of fixed width.
35  S8 = 2;
36  S16 = 3;
37  S32 = 4;
38  S64 = 5;
39
40  // Unsigned integral values of fixed width.
41  U8 = 6;
42  U16 = 7;
43  U32 = 8;
44  U64 = 9;
45
46  // Floating-point values of fixed width.
47  //
48  // Note: if f16s are not natively supported on the device, they will be
49  // converted to f16 from f32 at arbirary points in the computation.
50  F16 = 10;
51  F32 = 11;
52
53  // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
54  // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
55  // and 7 bits for the mantissa.
56  BF16 = 16;
57
58  F64 = 12;
59
60  // Complex values of fixed width.
61  C64 = 15;   // Paired F32 (real, imag), as in std::complex<float>.
62  C128 = 18;  // Paired F64 (real, imag), as in std::complex<double>.
63
64  // A tuple is a polymorphic sequence; e.g. a shape that holds different
65  // sub-shapes. They are used for things like returning multiple values from a
66  // computation; e.g. a computation that returns weights and biases may have a
67  // signature that results in a tuple like (f32[784x2000], f32[2000])
68  //
69  // If a shape proto has the tuple element type, it may not have any entries
70  // in the dimensions field.
71  TUPLE = 13;
72
73  // An opaque type used for passing context-specific data to a custom
74  // operation. Shapes of this primitive type will have empty dimensions and
75  // tuple_shapes fields.
76  //
77  // (OPAQUE would be a better name for this identifier, but that conflicts with
78  // a macro defined in windows.h.)
79  OPAQUE_TYPE = 14;
80
81  // A token type threaded between side-effecting operations. Shapes of this
82  // primitive type will have empty dimensions and tuple_shapes fields.
83  TOKEN = 17;
84
85  // Next = 19
86}
87// LINT.ThenChange(
88//   https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc
89// )
90
91// Describes the padding configuration for Pad operation. The padding amount on
92// both edges as well as between the elements are specified for each dimension.
93message PaddingConfig {
94  // Describes the padding configuration for a dimension.
95  message PaddingConfigDimension {
96    // Padding amount on the low-end (next to the index 0). May be negative.
97    int64 edge_padding_low = 1;
98
99    // Padding amount on the high-end (next to the highest index). May be
100    // negative.
101    int64 edge_padding_high = 2;
102
103    // Padding amount between the elements. May not be negative.
104    int64 interior_padding = 3;
105  }
106
107  // The padding configuration for all dimensions.
108  repeated PaddingConfigDimension dimensions = 1;
109}
110
111// A format specifies the method used by a layout to store an array in memory.
112enum Format {
113  // TODO(b/120869032): Rename this to FORMAT_NONE or something else which
114  // better corresponds to its meaning.
115  INVALID_FORMAT = 0;
116  // The default layout, with exactly one storage location per element.
117  DENSE = 1;
118  reserved 2;
119  reserved "SPARSE";
120}
121
122// Describes a tile used in tiling-based layout. Refer to
123// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for
124// details about tiling-based layout.
125message TileProto {
126  // Number of elements in each dimension of the tile. It's ordered from the
127  // most major dimension of the tile to the most minor dimension of the tile.
128  // The dimensions correspond to a suffix of the dimensions of the shape being
129  // tiled.
130  repeated int64 dimensions = 1;
131}
132
133// A layout describes how the array is placed in (1D) memory space.  This
134// includes the minor-to-major ordering of dimensions within a shape.
135//
136// Clients must specify the layouts of input Literals to the
137// computation. Layouts specified in interior operations which take Shapes (for
138// example, Convert) are ignored.
139//
140// See the XLA documentation for more information on shapes and layouts.
141//
142// LINT.IfChange
143message LayoutProto {
144  // The method used to store the data in memory. The format determines which of
145  // the other fields are used by the layout.
146  Format format = 4;
147
148  // Sequence of dimension numbers, from minor (fastest varying index) to major
149  // (slowest varying index). This field is required.
150  repeated int64 minor_to_major = 1;
151
152  reserved 2;
153  reserved "padded_dimensions";
154
155  reserved 3;
156  reserved "padding_value";
157
158  reserved 5;
159  reserved "max_sparse_elements";
160
161  // A sequence of tiles, starting from the tile that's applied first to the
162  // Shape.
163  //
164  // TODO(b/119839262): implement tiling in each backend or add Unimplemented
165  // error.
166  repeated TileProto tiles = 6;
167
168  // Bit size of each element. If the size is bigger than what the element
169  // type requires, the value is stored in the least significant
170  // bits and the additional most significant bits are filled with 0's.
171  //
172  // TODO(b/119839262): implement in each backend or add Unimplemented error.
173  int64 element_size_in_bits = 7;
174
175  // Memory space where this array resides. The integer field is interpreted in
176  // a backend-specific manner.
177  int64 memory_space = 8;
178
179  // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
180  // LayoutUtil::Hash appropriately to account for the new field.
181}
182// LINT.ThenChange( \
183//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,      \
184//     https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc)
185
186// A shape describes the number of dimensions in the array, the size of each
187// dimension, and the primitive component type.
188//
189// Tuples are a special case in that they have rank zero and have tuple_shapes
190// defined.
191//
192// See the XLA documentation for more information on shapes and layouts.
193//
194// LINT.IfChange
195message ShapeProto {
196  reserved 1;
197  reserved "rank";
198
199  // The element type for this shape.
200  PrimitiveType element_type = 2;
201
202  // The size (number of elements) for each dimension, or an upper bound on the
203  // size if the dimension is dynamic.  In XLA, dimensions are numbered from 0
204  // to N-1 for an N-dimensional array. The first element of 'dimensions' is the
205  // size of dimension 0, the second element is the size of dimension 1, and so
206  // forth.  Empty list indicates a scalar.
207  //
208  // If the respective element in 'is_dimension_dynamic' is true then the value
209  // in this field represents an upper bound on the size of the dimension.
210  repeated int64 dimensions = 3;
211
212  // For tuples only, the shapes of constituent shapes in the tuple sequence.
213  repeated ShapeProto tuple_shapes = 4;
214
215  // The layout used to back this shape.
216  LayoutProto layout = 5;
217
218  // For arrays, this indicates whether or not each dimension is
219  // dynamically-sized. The number of elements in this repeated field should be
220  // zero (indicating that no dimensions are dynamic) or equal to the number of
221  // elements in the 'dimensions' field.
222  repeated bool is_dynamic_dimension = 6;
223
224  // Important: if any field is added, be sure to modify ShapeUtil::Equal(),
225  // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
226  // the new field.
227}
228// LINT.ThenChange( \
229//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc)
230
231// Shape of the parameters and output of a computation (like a traditional
232// function signature).
233message ProgramShapeProto {
234  repeated ShapeProto parameters = 1;
235  ShapeProto result = 2;
236  repeated string parameter_names = 3;
237}
238
239// Statistics of a computation.
240message ComputationStats {
241  // The number of floating point operations in the computation.
242  double flop_count = 1;
243
244  // The number of transcendental operations (e.g., exp) in the computation.
245  double transcendental_count = 2;
246}
247
248// Symbolization metadata for HLO Instructions.
249//
250// This metadata is used for debugging XLA code generation, as well as
251// performance profiling of XLA-generated executables.
252message OpMetadata {
253  // The framework op name that generated this XLA op.
254  //
255  // Frameworks that build on top of XLA should mirror the names of their ops
256  // back to users by specifying the op_type. In this way, even if the
257  // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
258  // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
259  // multiple ops, then each op should have the op_type be "SoftMax".)
260  string op_type = 1;
261  // The user-specified name of the op.
262  //
263  // This name is often unique within a computation. Note: some frameworks
264  // add auto-generated names if the user does not provide one.
265  string op_name = 2;
266  // Indicate a file and line that this op is associated to in a user's program.
267  //
268  // e.g. it could be the file and line of user code that generated the op.
269  string source_file = 3;
270  int32 source_line = 4;
271}
272
273// Profile data from the execution of a computation.
274message ExecutionProfile {
275  // Whether the executable was read from the compilation cache.
276  bool compilation_cache_hit = 1;
277
278  // The time in milliseconds spent to compile the computation. This only set if
279  // the executable was not read from the compilation cache
280  // (compilation_cache_hit == false).
281  int64 compile_time_ms = 2;
282
283  // The number of cycles spent for the computation. This does not include the
284  // time taken for the data transfers between the host and the device. This is
285  // a target-dependent field and only used for debugging purposes.
286  int64 compute_cycle_count = 3;
287
288  // The time in nanoseconds spent for the computation, without data transfer.
289  int64 compute_time_ns = 4;
290
291  // The time in nanoseconds spent for the entire computation, including the
292  // result data transfer time. Current implementation does not spend any cycles
293  // for the input data transfer since the memory is initialized with the proper
294  // values before the execution.
295  int64 compute_and_transfer_time_ns = 5;
296
297  // The size of the binary code in the executable.
298  int64 executable_size_in_bytes = 6;
299
300  // Whether this profile was drawn from a cache of profiles instead of from
301  // execution on the hardware.
302  bool profile_cache_hit = 7;
303}
304
305// Handle given to a user that represents an execution that the user launched
306// asynchronously on the device.
307message ExecutionHandle {
308  int64 handle = 1;
309}
310
311// Handle given to a user that represents a globally accessible allocation.
312// Contrast this against a ComputationDataHandle, which is not globally
313// accessible, since it only exists within a specific computation.
314message GlobalDataHandle {
315  int64 handle = 1;
316}
317
318// Handle given to a user that represents a replicated virtual device. Each
319// replicated device represents N physical devices for execution where N is the
320// number of replicas.
321message DeviceHandle {
322  int64 handle = 1;
323
324  // The number of model-parallel virtual devices that communicate via XLA
325  // Send/Recv instructions.
326  int64 device_count = 2;
327}
328
329// Handle given to a user to represent a channel between two computations
330// via a Send and Recv instruction pair. Channels are unbuffered, so Send
331// Send instructions will be blocked until the data is transferred.
332message ChannelHandle {
333  int64 handle = 1;
334  enum ChannelType {
335    // Invalid primitive type to serve as default.
336    CHANNEL_TYPE_INVALID = 0;
337
338    // A channel for sending data between devices.
339    DEVICE_TO_DEVICE = 1;
340
341    // A channel for sending data from the device to the host. Can only be used
342    // with a Send operation.
343    DEVICE_TO_HOST = 2;
344
345    // A channel for sending data from the host to the device. Can only be used
346    // with a Recv operation.
347    HOST_TO_DEVICE = 3;
348  }
349  ChannelType type = 2;
350}
351
352// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
353// represents the device ids assigned to a set of replicated computations.
354// See xla::DeviceAssignment class comment for more details.
355message DeviceAssignmentProto {
356  int32 replica_count = 1;
357  int32 computation_count = 2;
358
359  // Each logical computation runs on replica_count physical devices.
360  // ComputationDevice represents the device ids assinged to the replicas.
361  message ComputationDevice {
362    repeated int32 replica_device_ids = 1;
363  }
364  repeated ComputationDevice computation_devices = 3;
365}
366
367// Literals are used when the server and client need to exchange materialized
368// data / results. Literals are also used to describe constants used in
369// computations.
370//
371// Transfers to/from the client are encoded in literal form, and the structure
372// of the repeated fields is implied by the shape.
373message LiteralProto {
374  ShapeProto shape = 1;
375  repeated bool preds = 2;
376  bytes s8s = 15;
377  bytes u8s = 3;
378  repeated int32 s32s = 4;
379  repeated int64 s64s = 5;
380  repeated uint32 u32s = 6;
381  repeated uint64 u64s = 7;
382  repeated float f32s = 8;
383  repeated double f64s = 9;
384  repeated float c64s = 12;    // Stored as interleaved real, imag floats.
385  repeated double c128s = 18;  // Stored as interleaved real, imag doubles.
386  repeated LiteralProto tuple_literals = 10;
387  // The F16s, BF16s, U16s and S16s are encoded in little endian byte order
388  bytes f16s = 11;
389  bytes bf16s = 13;
390  bytes u16s = 16;
391  bytes s16s = 17;
392  repeated int64 sparse_indices = 14;
393  // Next = 19
394}
395
396message WindowDimension {
397  // The size of the window in this dimension. For a rectangle, this would be
398  // the width or height.
399  int64 size = 1;
400
401  // The stride at which the window moves across the base area in this
402  // dimension. In other words, this is the spacing between different
403  // positions of the window in this dimension.
404  int64 stride = 2;
405
406  // If positive, means the amount of padding to add to the base area at the low
407  // end of this dimension; if negative, its negative means the number of
408  // elements removed from the low end of this dimension. For example, in the
409  // horizontal dimension of a rectangle, this would be the number of padding
410  // values to pad on the left, given that indices increase when going right.
411  // The actual padding value depends upon the context. Convolution pads with
412  // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's
413  // init value.
414  int64 padding_low = 3;
415
416  // As padding_low, but on the high end of this dimension. For example, in the
417  // horizontal dimension of a rectangle, this would be the number of values to
418  // pad on the right, given that indices increase when going right.
419  int64 padding_high = 4;
420
421  // Dilation factor of the sliding window in this dimension. A dilation factor
422  // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
423  // implicitly placed between each kernel element. This value may not be less
424  // than 1. See documentation for convolution.
425  int64 window_dilation = 5;
426
427  // Dilation factor of the base area in this dimension. A dilation factor of 1
428  // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
429  // placed between each base area element. This value may not be less than 1.
430  // See documentation for convolution.
431  int64 base_dilation = 6;
432
433  // Window reversal means that this dimension was logically reversed before the
434  // operation.
435  bool window_reversal = 7;
436}
437
438// Describes the windowing in an operation such as convolution.
439//
440// The window is moved across a base area and for each position of the
441// window a computation is performed. The field below describes the
442// window and the movement of the window across a base area.
443message Window {
444  repeated WindowDimension dimensions = 1;
445}
446
447// Describes the dimension numbers for a gather operation.
448//
449// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
450// more details.
451message GatherDimensionNumbers {
452  // "Window indices" is a term for a set of indices that index into the
453  // interior of a dynamic-slice from the input tensor, the starting indices for
454  // which were computed from output_gather_dims (see the operation semantic for
455  // how this is defined) and the start_indices tensor.
456  //
457  // The window indices for a specific output index Out is computed as:
458  //
459  //  i = 0
460  //  for (k : [0, input_tensor_shape.rank))
461  //    window_indices[k] =
462  //      if k in collapsed_slice_dims
463  //      then 0
464  //      else Out[offset_dims[i++]]
465  repeated int64 offset_dims = 1;
466  repeated int64 collapsed_slice_dims = 2;
467
468  // This is interpreted as a map from i to start_index_map[i]. It
469  // transforms the gather index looked up from the start_indices tensor into
470  // the starting index in the input space.
471  repeated int64 start_index_map = 3;
472
473  // The dimension in the start_indices input that contains the starting
474  // indices.
475  int64 index_vector_dim = 4;
476}
477
478// Describes the dimension numbers for a scatter operation.
479//
480// All the fields are similar to the corresponding fields in
481// GatherDimensionNumbers. Differences are noted below.
482message ScatterDimensionNumbers {
483  // The set of dimensions in the updates shape that are window dimensions.
484  repeated int64 update_window_dims = 1;
485  // The set of window dimensions that must be inserted into the updates shape.
486  repeated int64 inserted_window_dims = 2;
487
488  repeated int64 scatter_dims_to_operand_dims = 3;
489  int64 index_vector_dim = 4;
490}
491
492message ConvolutionDimensionNumbers {
493  // The number of the dimension that represents batch in the input.
494  int64 input_batch_dimension = 7;
495
496  // The number of the dimension that represents features in the input.
497  int64 input_feature_dimension = 8;
498
499  // The dimension numbers for the spatial dimensions that the window
500  // moves through in the input.
501  repeated int64 input_spatial_dimensions = 11;
502
503  // The number of the dimension that represents input features in the
504  // convolutional kernel (rhs).
505  int64 kernel_input_feature_dimension = 3;
506
507  // The number of the dimension that represents output features in
508  // the convolutional kernel (rhs).
509  int64 kernel_output_feature_dimension = 4;
510
511  // The dimension numbers for the spatial dimensions that the window
512  // moves through in the kernel (rhs). window.strides(0) is the
513  // stride in the kernel_spatial_dimensions(0) dimension.
514  repeated int64 kernel_spatial_dimensions = 6;
515
516  // The number of the dimension that represents batch in the output.
517  int64 output_batch_dimension = 9;
518
519  // The number of the dimension that represents features in the output.
520  int64 output_feature_dimension = 10;
521
522  // The dimension numbers for the spatial dimensions that the window
523  // moves through in the output.
524  repeated int64 output_spatial_dimensions = 12;
525
526  // Next = 13
527}
528
529enum FftType {
530  FFT = 0;    // Forward FFT; complex in, complex out.
531  IFFT = 1;   // Inverse FFT; complex in, complex out.
532  RFFT = 2;   // Forward real FFT; real in, fft_length / 2 + 1 complex out
533  IRFFT = 3;  // Inverse real FFT; fft_length / 2 + 1 complex in,
534              //                   fft_length real out
535}
536
537message DotDimensionNumbers {
538  // The dimension numbers that represent the 'lhs' contracting dimensions.
539  repeated int64 lhs_contracting_dimensions = 1;
540  // The dimension numbers that represent the 'rhs' contracting dimensions.
541  repeated int64 rhs_contracting_dimensions = 2;
542  // The dimension numbers that represent the 'lhs' batch dimensions.
543  repeated int64 lhs_batch_dimensions = 3;
544  // The dimension numbers that represent the 'rhs' batch dimensions.
545  repeated int64 rhs_batch_dimensions = 4;
546}
547
548enum RandomDistribution {
549  RNG_INVALID = 0;
550
551  // Creates a uniform-distribution-generated random number on the semi-open
552  // interval [parameter[0], parameter[1]).
553  RNG_UNIFORM = 1;
554
555  // Creates a normal-distribution-generated random number with mean
556  // parameter[0] and standard deviation parameter[1].
557  RNG_NORMAL = 2;
558
559  // Next: 4
560}
561
562message TriangularSolveOptions {
563  // If true, solves ax = b. If false, solves xa = b.
564  bool left_side = 1;
565
566  // If true, 'a' is lower triangular. If false, 'a' is upper triangular.
567  bool lower = 2;
568
569  // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed.
570  bool unit_diagonal = 3;
571
572  // Should we transpose or use the adjoint of 'a'?
573  enum Transpose {
574    TRANSPOSE_INVALID = 0;
575    NO_TRANSPOSE = 1;  // Don't transpose 'a'.
576    TRANSPOSE = 2;     // Transpose 'a'.
577    ADJOINT = 3;       // Complex conjugate and transpose 'a'.
578  }
579  Transpose transpose_a = 4;
580}
581
582message CholeskyOptions {
583  // If true, uses the lower triangle of `a`. If false, uses the upper triangle
584  // of `a`.
585  bool lower = 1;
586}
587
588// Generic map of attributes used to pass hints / configuration options from
589// the Python frontend to the XLA backend.
590message FrontendAttributes {
591  map<string, string> map = 1;
592}
593
594message OpSharding {
595  enum Type {
596    // This sharding is replicated across all devices (implies maximal,
597    // all other fields are unused).
598    REPLICATED = 0;
599    // This sharding is maximal - one device runs the entire operation.
600    MAXIMAL = 1;
601    // This sharding is a tuple - only the tuple_shardings field is valid.
602    TUPLE = 2;
603    // None of the above; tile_shape and tile_assignment are both used.
604    OTHER = 3;
605  }
606  Type type = 1;
607  // The shape of the sharded tile.
608  ShapeProto tile_shape = 2;
609  // The shape of the tile assignment tensor - this must be the same rank as
610  // tile_shape and the product of its dimensions must equal
611  // tile_assignment_devices.size().
612  repeated int64 tile_assignment_dimensions = 3;
613  // Flattened list of device IDs. The order of flattening is the same as used
614  // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
615  repeated int64 tile_assignment_devices = 4;
616  // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
617  // in pre-order. The tuple shape could be nested; here we store just a
618  // flattened list of all leaves in the tuple shape. Note that the tuple shape
619  // is not stored here; shardings do not store the shapes to which they are
620  // applied, this is inferred from the instruction this sharding gets attached
621  // to.
622  repeated OpSharding tuple_shardings = 5;
623}
624
625// Describes the replica groups in a cross replica op (e.g., all-reduce and
626// all-to-all).
627message ReplicaGroup {
628  // The ids of the replicas that belongs to the same group. The ordering of the
629  // ids matters in some ops (e.g., all-to-all).
630  repeated int64 replica_ids = 1;
631}
632
633// Describes the source target pair in the collective permute op.
634message SourceTarget {
635  int64 source = 1;
636  int64 target = 2;
637}
638
639// Used to indicate the precision configuration. It has backend specific
640// meaning.
641message PrecisionConfig {
642  enum Precision {
643    DEFAULT = 0;
644    HIGH = 1;
645    HIGHEST = 2;
646
647    // Next: 3
648  }
649  repeated Precision operand_precision = 1;
650
651  // Next: 2
652}
653
654// Describes whether all data-parallelism replicas will receive the same
655// parameter data at each buffer.
656message ParameterReplication {
657  // A list of boolean values for the flattened leaf buffers. Each value
658  // indicates whether the corresponding leaf buffer is replicated.
659  //
660  // If this field is empty, it means no buffer is replicated. Otherwise, the
661  // number of elements in this field must match the number of leaf buffers in
662  // the HLO instruction's shape.
663  repeated bool replicated_at_leaf_buffers = 1;
664}
665
666// A backend-config for kWhile loops that stores the loop's trip count, if it is
667// known.
668//
669// This is useful for backends that can implement a `for i in 0..N` loop more
670// efficiently than a `while` loop.  For example, on GPUs, we can implement a
671// `for i in 0..N` loop by enqueueing the kernels for the loop body N times,
672// whereas implementing a `while` loop requires a host-device sync on each
673// iteration.
674message WhileLoopBackendConfig {
675  message KnownTripCount {
676    int64 n = 1;
677  }
678  // This indirection lets us distinguish between known-trip-count == 0 and
679  // unknown-trip-count.
680  KnownTripCount known_trip_count = 1;
681}
682