1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16syntax = "proto3"; 17 18package xla; 19 20option cc_enable_arenas = true; 21 22// Primitive types are the individual values that can be held in rectangular 23// multidimensional arrays. A description of the rectangular multidimensional 24// array dimensions / primitive type is given by Shape, below. 25// 26// LINT.IfChange 27enum PrimitiveType { 28 // Invalid primitive type to serve as default. 29 PRIMITIVE_TYPE_INVALID = 0; 30 31 // Predicates are two-state booleans. 32 PRED = 1; 33 34 // Signed integral values of fixed width. 35 S8 = 2; 36 S16 = 3; 37 S32 = 4; 38 S64 = 5; 39 40 // Unsigned integral values of fixed width. 41 U8 = 6; 42 U16 = 7; 43 U32 = 8; 44 U64 = 9; 45 46 // Floating-point values of fixed width. 47 // 48 // Note: if f16s are not natively supported on the device, they will be 49 // converted to f16 from f32 at arbirary points in the computation. 50 F16 = 10; 51 F32 = 11; 52 53 // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit 54 // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent 55 // and 7 bits for the mantissa. 56 BF16 = 16; 57 58 F64 = 12; 59 60 // Complex values of fixed width. 61 C64 = 15; // Paired F32 (real, imag), as in std::complex<float>. 62 C128 = 18; // Paired F64 (real, imag), as in std::complex<double>. 63 64 // A tuple is a polymorphic sequence; e.g. a shape that holds different 65 // sub-shapes. They are used for things like returning multiple values from a 66 // computation; e.g. a computation that returns weights and biases may have a 67 // signature that results in a tuple like (f32[784x2000], f32[2000]) 68 // 69 // If a shape proto has the tuple element type, it may not have any entries 70 // in the dimensions field. 71 TUPLE = 13; 72 73 // An opaque type used for passing context-specific data to a custom 74 // operation. Shapes of this primitive type will have empty dimensions and 75 // tuple_shapes fields. 76 // 77 // (OPAQUE would be a better name for this identifier, but that conflicts with 78 // a macro defined in windows.h.) 79 OPAQUE_TYPE = 14; 80 81 // A token type threaded between side-effecting operations. Shapes of this 82 // primitive type will have empty dimensions and tuple_shapes fields. 83 TOKEN = 17; 84 85 // Next = 19 86} 87// LINT.ThenChange( 88// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc 89// ) 90 91// Describes the padding configuration for Pad operation. The padding amount on 92// both edges as well as between the elements are specified for each dimension. 93message PaddingConfig { 94 // Describes the padding configuration for a dimension. 95 message PaddingConfigDimension { 96 // Padding amount on the low-end (next to the index 0). May be negative. 97 int64 edge_padding_low = 1; 98 99 // Padding amount on the high-end (next to the highest index). May be 100 // negative. 101 int64 edge_padding_high = 2; 102 103 // Padding amount between the elements. May not be negative. 104 int64 interior_padding = 3; 105 } 106 107 // The padding configuration for all dimensions. 108 repeated PaddingConfigDimension dimensions = 1; 109} 110 111// A format specifies the method used by a layout to store an array in memory. 112enum Format { 113 // TODO(b/120869032): Rename this to FORMAT_NONE or something else which 114 // better corresponds to its meaning. 115 INVALID_FORMAT = 0; 116 // The default layout, with exactly one storage location per element. 117 DENSE = 1; 118 reserved 2; 119 reserved "SPARSE"; 120} 121 122// Describes a tile used in tiling-based layout. Refer to 123// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for 124// details about tiling-based layout. 125message TileProto { 126 // Number of elements in each dimension of the tile. It's ordered from the 127 // most major dimension of the tile to the most minor dimension of the tile. 128 // The dimensions correspond to a suffix of the dimensions of the shape being 129 // tiled. 130 repeated int64 dimensions = 1; 131} 132 133// A layout describes how the array is placed in (1D) memory space. This 134// includes the minor-to-major ordering of dimensions within a shape. 135// 136// Clients must specify the layouts of input Literals to the 137// computation. Layouts specified in interior operations which take Shapes (for 138// example, Convert) are ignored. 139// 140// See the XLA documentation for more information on shapes and layouts. 141// 142// LINT.IfChange 143message LayoutProto { 144 // The method used to store the data in memory. The format determines which of 145 // the other fields are used by the layout. 146 Format format = 4; 147 148 // Sequence of dimension numbers, from minor (fastest varying index) to major 149 // (slowest varying index). This field is required. 150 repeated int64 minor_to_major = 1; 151 152 reserved 2; 153 reserved "padded_dimensions"; 154 155 reserved 3; 156 reserved "padding_value"; 157 158 reserved 5; 159 reserved "max_sparse_elements"; 160 161 // A sequence of tiles, starting from the tile that's applied first to the 162 // Shape. 163 // 164 // TODO(b/119839262): implement tiling in each backend or add Unimplemented 165 // error. 166 repeated TileProto tiles = 6; 167 168 // Bit size of each element. If the size is bigger than what the element 169 // type requires, the value is stored in the least significant 170 // bits and the additional most significant bits are filled with 0's. 171 // 172 // TODO(b/119839262): implement in each backend or add Unimplemented error. 173 int64 element_size_in_bits = 7; 174 175 // Memory space where this array resides. The integer field is interpreted in 176 // a backend-specific manner. 177 int64 memory_space = 8; 178 179 // Important: if any field is added, be sure to modify ShapeUtil::Equal() and 180 // LayoutUtil::Hash appropriately to account for the new field. 181} 182// LINT.ThenChange( \ 183// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ 184// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) 185 186// A shape describes the number of dimensions in the array, the size of each 187// dimension, and the primitive component type. 188// 189// Tuples are a special case in that they have rank zero and have tuple_shapes 190// defined. 191// 192// See the XLA documentation for more information on shapes and layouts. 193// 194// LINT.IfChange 195message ShapeProto { 196 reserved 1; 197 reserved "rank"; 198 199 // The element type for this shape. 200 PrimitiveType element_type = 2; 201 202 // The size (number of elements) for each dimension, or an upper bound on the 203 // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 204 // to N-1 for an N-dimensional array. The first element of 'dimensions' is the 205 // size of dimension 0, the second element is the size of dimension 1, and so 206 // forth. Empty list indicates a scalar. 207 // 208 // If the respective element in 'is_dimension_dynamic' is true then the value 209 // in this field represents an upper bound on the size of the dimension. 210 repeated int64 dimensions = 3; 211 212 // For tuples only, the shapes of constituent shapes in the tuple sequence. 213 repeated ShapeProto tuple_shapes = 4; 214 215 // The layout used to back this shape. 216 LayoutProto layout = 5; 217 218 // For arrays, this indicates whether or not each dimension is 219 // dynamically-sized. The number of elements in this repeated field should be 220 // zero (indicating that no dimensions are dynamic) or equal to the number of 221 // elements in the 'dimensions' field. 222 repeated bool is_dynamic_dimension = 6; 223 224 // Important: if any field is added, be sure to modify ShapeUtil::Equal(), 225 // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for 226 // the new field. 227} 228// LINT.ThenChange( \ 229// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) 230 231// Shape of the parameters and output of a computation (like a traditional 232// function signature). 233message ProgramShapeProto { 234 repeated ShapeProto parameters = 1; 235 ShapeProto result = 2; 236 repeated string parameter_names = 3; 237} 238 239// Statistics of a computation. 240message ComputationStats { 241 // The number of floating point operations in the computation. 242 double flop_count = 1; 243 244 // The number of transcendental operations (e.g., exp) in the computation. 245 double transcendental_count = 2; 246} 247 248// Symbolization metadata for HLO Instructions. 249// 250// This metadata is used for debugging XLA code generation, as well as 251// performance profiling of XLA-generated executables. 252message OpMetadata { 253 // The framework op name that generated this XLA op. 254 // 255 // Frameworks that build on top of XLA should mirror the names of their ops 256 // back to users by specifying the op_type. In this way, even if the 257 // framework's "ops" are implemented as multiple XLA HLO Ops, they can be 258 // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as 259 // multiple ops, then each op should have the op_type be "SoftMax".) 260 string op_type = 1; 261 // The user-specified name of the op. 262 // 263 // This name is often unique within a computation. Note: some frameworks 264 // add auto-generated names if the user does not provide one. 265 string op_name = 2; 266 // Indicate a file and line that this op is associated to in a user's program. 267 // 268 // e.g. it could be the file and line of user code that generated the op. 269 string source_file = 3; 270 int32 source_line = 4; 271} 272 273// Profile data from the execution of a computation. 274message ExecutionProfile { 275 // Whether the executable was read from the compilation cache. 276 bool compilation_cache_hit = 1; 277 278 // The time in milliseconds spent to compile the computation. This only set if 279 // the executable was not read from the compilation cache 280 // (compilation_cache_hit == false). 281 int64 compile_time_ms = 2; 282 283 // The number of cycles spent for the computation. This does not include the 284 // time taken for the data transfers between the host and the device. This is 285 // a target-dependent field and only used for debugging purposes. 286 int64 compute_cycle_count = 3; 287 288 // The time in nanoseconds spent for the computation, without data transfer. 289 int64 compute_time_ns = 4; 290 291 // The time in nanoseconds spent for the entire computation, including the 292 // result data transfer time. Current implementation does not spend any cycles 293 // for the input data transfer since the memory is initialized with the proper 294 // values before the execution. 295 int64 compute_and_transfer_time_ns = 5; 296 297 // The size of the binary code in the executable. 298 int64 executable_size_in_bytes = 6; 299 300 // Whether this profile was drawn from a cache of profiles instead of from 301 // execution on the hardware. 302 bool profile_cache_hit = 7; 303} 304 305// Handle given to a user that represents an execution that the user launched 306// asynchronously on the device. 307message ExecutionHandle { 308 int64 handle = 1; 309} 310 311// Handle given to a user that represents a globally accessible allocation. 312// Contrast this against a ComputationDataHandle, which is not globally 313// accessible, since it only exists within a specific computation. 314message GlobalDataHandle { 315 int64 handle = 1; 316} 317 318// Handle given to a user that represents a replicated virtual device. Each 319// replicated device represents N physical devices for execution where N is the 320// number of replicas. 321message DeviceHandle { 322 int64 handle = 1; 323 324 // The number of model-parallel virtual devices that communicate via XLA 325 // Send/Recv instructions. 326 int64 device_count = 2; 327} 328 329// Handle given to a user to represent a channel between two computations 330// via a Send and Recv instruction pair. Channels are unbuffered, so Send 331// Send instructions will be blocked until the data is transferred. 332message ChannelHandle { 333 int64 handle = 1; 334 enum ChannelType { 335 // Invalid primitive type to serve as default. 336 CHANNEL_TYPE_INVALID = 0; 337 338 // A channel for sending data between devices. 339 DEVICE_TO_DEVICE = 1; 340 341 // A channel for sending data from the device to the host. Can only be used 342 // with a Send operation. 343 DEVICE_TO_HOST = 2; 344 345 // A channel for sending data from the host to the device. Can only be used 346 // with a Recv operation. 347 HOST_TO_DEVICE = 3; 348 } 349 ChannelType type = 2; 350} 351 352// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which 353// represents the device ids assigned to a set of replicated computations. 354// See xla::DeviceAssignment class comment for more details. 355message DeviceAssignmentProto { 356 int32 replica_count = 1; 357 int32 computation_count = 2; 358 359 // Each logical computation runs on replica_count physical devices. 360 // ComputationDevice represents the device ids assinged to the replicas. 361 message ComputationDevice { 362 repeated int32 replica_device_ids = 1; 363 } 364 repeated ComputationDevice computation_devices = 3; 365} 366 367// Literals are used when the server and client need to exchange materialized 368// data / results. Literals are also used to describe constants used in 369// computations. 370// 371// Transfers to/from the client are encoded in literal form, and the structure 372// of the repeated fields is implied by the shape. 373message LiteralProto { 374 ShapeProto shape = 1; 375 repeated bool preds = 2; 376 bytes s8s = 15; 377 bytes u8s = 3; 378 repeated int32 s32s = 4; 379 repeated int64 s64s = 5; 380 repeated uint32 u32s = 6; 381 repeated uint64 u64s = 7; 382 repeated float f32s = 8; 383 repeated double f64s = 9; 384 repeated float c64s = 12; // Stored as interleaved real, imag floats. 385 repeated double c128s = 18; // Stored as interleaved real, imag doubles. 386 repeated LiteralProto tuple_literals = 10; 387 // The F16s, BF16s, U16s and S16s are encoded in little endian byte order 388 bytes f16s = 11; 389 bytes bf16s = 13; 390 bytes u16s = 16; 391 bytes s16s = 17; 392 repeated int64 sparse_indices = 14; 393 // Next = 19 394} 395 396message WindowDimension { 397 // The size of the window in this dimension. For a rectangle, this would be 398 // the width or height. 399 int64 size = 1; 400 401 // The stride at which the window moves across the base area in this 402 // dimension. In other words, this is the spacing between different 403 // positions of the window in this dimension. 404 int64 stride = 2; 405 406 // If positive, means the amount of padding to add to the base area at the low 407 // end of this dimension; if negative, its negative means the number of 408 // elements removed from the low end of this dimension. For example, in the 409 // horizontal dimension of a rectangle, this would be the number of padding 410 // values to pad on the left, given that indices increase when going right. 411 // The actual padding value depends upon the context. Convolution pads with 412 // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's 413 // init value. 414 int64 padding_low = 3; 415 416 // As padding_low, but on the high end of this dimension. For example, in the 417 // horizontal dimension of a rectangle, this would be the number of values to 418 // pad on the right, given that indices increase when going right. 419 int64 padding_high = 4; 420 421 // Dilation factor of the sliding window in this dimension. A dilation factor 422 // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are 423 // implicitly placed between each kernel element. This value may not be less 424 // than 1. See documentation for convolution. 425 int64 window_dilation = 5; 426 427 // Dilation factor of the base area in this dimension. A dilation factor of 1 428 // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly 429 // placed between each base area element. This value may not be less than 1. 430 // See documentation for convolution. 431 int64 base_dilation = 6; 432 433 // Window reversal means that this dimension was logically reversed before the 434 // operation. 435 bool window_reversal = 7; 436} 437 438// Describes the windowing in an operation such as convolution. 439// 440// The window is moved across a base area and for each position of the 441// window a computation is performed. The field below describes the 442// window and the movement of the window across a base area. 443message Window { 444 repeated WindowDimension dimensions = 1; 445} 446 447// Describes the dimension numbers for a gather operation. 448// 449// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for 450// more details. 451message GatherDimensionNumbers { 452 // "Window indices" is a term for a set of indices that index into the 453 // interior of a dynamic-slice from the input tensor, the starting indices for 454 // which were computed from output_gather_dims (see the operation semantic for 455 // how this is defined) and the start_indices tensor. 456 // 457 // The window indices for a specific output index Out is computed as: 458 // 459 // i = 0 460 // for (k : [0, input_tensor_shape.rank)) 461 // window_indices[k] = 462 // if k in collapsed_slice_dims 463 // then 0 464 // else Out[offset_dims[i++]] 465 repeated int64 offset_dims = 1; 466 repeated int64 collapsed_slice_dims = 2; 467 468 // This is interpreted as a map from i to start_index_map[i]. It 469 // transforms the gather index looked up from the start_indices tensor into 470 // the starting index in the input space. 471 repeated int64 start_index_map = 3; 472 473 // The dimension in the start_indices input that contains the starting 474 // indices. 475 int64 index_vector_dim = 4; 476} 477 478// Describes the dimension numbers for a scatter operation. 479// 480// All the fields are similar to the corresponding fields in 481// GatherDimensionNumbers. Differences are noted below. 482message ScatterDimensionNumbers { 483 // The set of dimensions in the updates shape that are window dimensions. 484 repeated int64 update_window_dims = 1; 485 // The set of window dimensions that must be inserted into the updates shape. 486 repeated int64 inserted_window_dims = 2; 487 488 repeated int64 scatter_dims_to_operand_dims = 3; 489 int64 index_vector_dim = 4; 490} 491 492message ConvolutionDimensionNumbers { 493 // The number of the dimension that represents batch in the input. 494 int64 input_batch_dimension = 7; 495 496 // The number of the dimension that represents features in the input. 497 int64 input_feature_dimension = 8; 498 499 // The dimension numbers for the spatial dimensions that the window 500 // moves through in the input. 501 repeated int64 input_spatial_dimensions = 11; 502 503 // The number of the dimension that represents input features in the 504 // convolutional kernel (rhs). 505 int64 kernel_input_feature_dimension = 3; 506 507 // The number of the dimension that represents output features in 508 // the convolutional kernel (rhs). 509 int64 kernel_output_feature_dimension = 4; 510 511 // The dimension numbers for the spatial dimensions that the window 512 // moves through in the kernel (rhs). window.strides(0) is the 513 // stride in the kernel_spatial_dimensions(0) dimension. 514 repeated int64 kernel_spatial_dimensions = 6; 515 516 // The number of the dimension that represents batch in the output. 517 int64 output_batch_dimension = 9; 518 519 // The number of the dimension that represents features in the output. 520 int64 output_feature_dimension = 10; 521 522 // The dimension numbers for the spatial dimensions that the window 523 // moves through in the output. 524 repeated int64 output_spatial_dimensions = 12; 525 526 // Next = 13 527} 528 529enum FftType { 530 FFT = 0; // Forward FFT; complex in, complex out. 531 IFFT = 1; // Inverse FFT; complex in, complex out. 532 RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out 533 IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, 534 // fft_length real out 535} 536 537message DotDimensionNumbers { 538 // The dimension numbers that represent the 'lhs' contracting dimensions. 539 repeated int64 lhs_contracting_dimensions = 1; 540 // The dimension numbers that represent the 'rhs' contracting dimensions. 541 repeated int64 rhs_contracting_dimensions = 2; 542 // The dimension numbers that represent the 'lhs' batch dimensions. 543 repeated int64 lhs_batch_dimensions = 3; 544 // The dimension numbers that represent the 'rhs' batch dimensions. 545 repeated int64 rhs_batch_dimensions = 4; 546} 547 548enum RandomDistribution { 549 RNG_INVALID = 0; 550 551 // Creates a uniform-distribution-generated random number on the semi-open 552 // interval [parameter[0], parameter[1]). 553 RNG_UNIFORM = 1; 554 555 // Creates a normal-distribution-generated random number with mean 556 // parameter[0] and standard deviation parameter[1]. 557 RNG_NORMAL = 2; 558 559 // Next: 4 560} 561 562message TriangularSolveOptions { 563 // If true, solves ax = b. If false, solves xa = b. 564 bool left_side = 1; 565 566 // If true, 'a' is lower triangular. If false, 'a' is upper triangular. 567 bool lower = 2; 568 569 // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. 570 bool unit_diagonal = 3; 571 572 // Should we transpose or use the adjoint of 'a'? 573 enum Transpose { 574 TRANSPOSE_INVALID = 0; 575 NO_TRANSPOSE = 1; // Don't transpose 'a'. 576 TRANSPOSE = 2; // Transpose 'a'. 577 ADJOINT = 3; // Complex conjugate and transpose 'a'. 578 } 579 Transpose transpose_a = 4; 580} 581 582message CholeskyOptions { 583 // If true, uses the lower triangle of `a`. If false, uses the upper triangle 584 // of `a`. 585 bool lower = 1; 586} 587 588// Generic map of attributes used to pass hints / configuration options from 589// the Python frontend to the XLA backend. 590message FrontendAttributes { 591 map<string, string> map = 1; 592} 593 594message OpSharding { 595 enum Type { 596 // This sharding is replicated across all devices (implies maximal, 597 // all other fields are unused). 598 REPLICATED = 0; 599 // This sharding is maximal - one device runs the entire operation. 600 MAXIMAL = 1; 601 // This sharding is a tuple - only the tuple_shardings field is valid. 602 TUPLE = 2; 603 // None of the above; tile_shape and tile_assignment are both used. 604 OTHER = 3; 605 } 606 Type type = 1; 607 // The shape of the sharded tile. 608 ShapeProto tile_shape = 2; 609 // The shape of the tile assignment tensor - this must be the same rank as 610 // tile_shape and the product of its dimensions must equal 611 // tile_assignment_devices.size(). 612 repeated int64 tile_assignment_dimensions = 3; 613 // Flattened list of device IDs. The order of flattening is the same as used 614 // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). 615 repeated int64 tile_assignment_devices = 4; 616 // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, 617 // in pre-order. The tuple shape could be nested; here we store just a 618 // flattened list of all leaves in the tuple shape. Note that the tuple shape 619 // is not stored here; shardings do not store the shapes to which they are 620 // applied, this is inferred from the instruction this sharding gets attached 621 // to. 622 repeated OpSharding tuple_shardings = 5; 623} 624 625// Describes the replica groups in a cross replica op (e.g., all-reduce and 626// all-to-all). 627message ReplicaGroup { 628 // The ids of the replicas that belongs to the same group. The ordering of the 629 // ids matters in some ops (e.g., all-to-all). 630 repeated int64 replica_ids = 1; 631} 632 633// Describes the source target pair in the collective permute op. 634message SourceTarget { 635 int64 source = 1; 636 int64 target = 2; 637} 638 639// Used to indicate the precision configuration. It has backend specific 640// meaning. 641message PrecisionConfig { 642 enum Precision { 643 DEFAULT = 0; 644 HIGH = 1; 645 HIGHEST = 2; 646 647 // Next: 3 648 } 649 repeated Precision operand_precision = 1; 650 651 // Next: 2 652} 653 654// Describes whether all data-parallelism replicas will receive the same 655// parameter data at each buffer. 656message ParameterReplication { 657 // A list of boolean values for the flattened leaf buffers. Each value 658 // indicates whether the corresponding leaf buffer is replicated. 659 // 660 // If this field is empty, it means no buffer is replicated. Otherwise, the 661 // number of elements in this field must match the number of leaf buffers in 662 // the HLO instruction's shape. 663 repeated bool replicated_at_leaf_buffers = 1; 664} 665 666// A backend-config for kWhile loops that stores the loop's trip count, if it is 667// known. 668// 669// This is useful for backends that can implement a `for i in 0..N` loop more 670// efficiently than a `while` loop. For example, on GPUs, we can implement a 671// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, 672// whereas implementing a `while` loop requires a host-device sync on each 673// iteration. 674message WhileLoopBackendConfig { 675 message KnownTripCount { 676 int64 n = 1; 677 } 678 // This indirection lets us distinguish between known-trip-count == 0 and 679 // unknown-trip-count. 680 KnownTripCount known_trip_count = 1; 681} 682