1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16syntax = "proto3"; 17 18package xla; 19option cc_enable_arenas = true; 20 21// Primitive types are the individual values that can be held in rectangular 22// multidimensional arrays. A description of the rectangular multidimensional 23// array dimensions / primitive type is given by Shape, below. 24enum PrimitiveType { 25 // Invalid primitive type to serve as default. 26 PRIMITIVE_TYPE_INVALID = 0; 27 28 // Predicates are two-state booleans. 29 PRED = 1; 30 31 // Signed integral values of fixed width. 32 S8 = 2; 33 S16 = 3; 34 S32 = 4; 35 S64 = 5; 36 37 // Unsigned integral values of fixed width. 38 U8 = 6; 39 U16 = 7; 40 U32 = 8; 41 U64 = 9; 42 43 // Floating-point values of fixed width. 44 // 45 // Note: if f16s are not natively supported on the device, they will be 46 // converted to f16 from f32 at arbirary points in the computation. 47 F16 = 10; 48 F32 = 11; 49 50 // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit 51 // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent 52 // and 7 bits for the mantissa. 53 BF16 = 16; 54 55 F64 = 12; 56 57 // Complex values of fixed width. 58 C64 = 15; // Paired F32 (real, imag), as in std::complex<float>. 59 60 // A tuple is a polymorphic sequence; e.g. a shape that holds different 61 // sub-shapes. They are used for things like returning multiple values from a 62 // computation; e.g. a computation that returns weights and biases may have a 63 // signature that results in a tuple like (f32[784x2000], f32[2000]) 64 // 65 // If a shape proto has the tuple element type, it may not have any entries 66 // in the dimensions field. 67 TUPLE = 13; 68 69 // An opaque type used for passing context specific data to a custom 70 // operation. 71 OPAQUE = 14; 72 73 // Next = 17 74} 75 76// Describes the value held inside padding elements. 77enum PaddingValue { 78 INVALID_PAD = 0; 79 80 // Zero padding must be 0-values that correspond to the shape's element type. 81 ZERO_PAD = 1; 82 83 // One padding must be 1-values that correspond to the shape's element type. 84 ONE_PAD = 2; 85 86 // "Lowest" padding must be the lowest values in the shape's element type, 87 // used as padding for operations like max-accumulation. 88 LOWEST_PAD = 3; 89 90 // "Highest" padding must be the largest values in the shape's element type, 91 // used as padding for operations like min-accumulation. 92 HIGHEST_PAD = 4; 93 94 // Unknown padding could be anything; e.g. floating NaNs! 95 UNKNOWN_PAD = 5; 96} 97 98// Describes the padding configuration for Pad operation. The padding amount on 99// both edges as well as between the elements are specified for each dimension. 100message PaddingConfig { 101 // Describes the padding configuration for a dimension. 102 message PaddingConfigDimension { 103 // Padding amount on the low-end (next to the index 0). 104 int64 edge_padding_low = 1; 105 106 // Padding amount on the high-end (next to the highest index). 107 int64 edge_padding_high = 2; 108 109 // Padding amount between the elements. 110 int64 interior_padding = 3; 111 } 112 113 // The padding configuration for all dimensions. 114 repeated PaddingConfigDimension dimensions = 1; 115} 116 117// A format specifies the method used by a layout to store an array in memory. 118enum Format { 119 INVALID_FORMAT = 0; 120 // The default layout, with exactly one storage location per element (ignoring 121 // padding). 122 DENSE = 1; 123 // A sparsely encoded layout, providing only the index/value pairs of non-zero 124 // elements. 125 SPARSE = 2; 126} 127 128// A layout describes how the array is placed in (1D) memory space. This 129// includes the minor-to-major ordering of dimensions within a shape, as well as 130// any padding present in those dimensions. 131// 132// Clients must specify the layouts of input Literals to the 133// computation. Layouts specified in interior operations which take Shapes (for 134// example, Convert) are ignored. 135// 136// See the XLA documentation for more information on shapes and layouts. 137message Layout { 138 // The method used to store the data in memory. The format determines which of 139 // the other fields are used by the layout. 140 Format format = 4; 141 142 // Sequence of dimension numbers, from minor (fastest varying index) to major 143 // (slowest varying index). This field is required. 144 repeated int64 minor_to_major = 1; 145 146 // The width to which the layout of each dimension is padded up to. If 147 // present, the size of the padded_dimensions must equal the rank of the 148 // shape. The padding appears at the end of a dimension, not at the 149 // beginning. This kind of padding, unlike padding in e.g. convolution, is not 150 // part of the shape. This field must be unset unless the format is DENSE. 151 repeated int64 padded_dimensions = 2; 152 153 // Describes the values in the padding specified by padded_dimensions. This 154 // field must be unset unless the format is DENSE. 155 PaddingValue padding_value = 3; 156 157 // The maximum number of elements that can be stored for SPARSE formats. This 158 // can be used to determine the maximum size in bytes of arrays stored in 159 // memory. This field must be unset unless the format is SPARSE. 160 int64 max_sparse_elements = 5; 161 162 // Important: if any field is added, be sure to modify ShapeUtil::Equal() 163 // appropriately to account for the new field. 164} 165 166// A shape describes the number of dimensions in the array, the size of each 167// dimension, and the primitive component type. 168// 169// Tuples are a special case in that they have rank zero and have tuple_shapes 170// defined. 171// 172// See the XLA documentation for more information on shapes and layouts. 173message Shape { 174 reserved 1; 175 reserved "rank"; 176 177 // The element type for this shape. 178 PrimitiveType element_type = 2; 179 180 // The size (number of elements) for each dimension. 181 // In XLA, dimensions are numbered from 0 to N-1 for an 182 // N-dimensional array. The first element of 'dimensions' is the size of 183 // dimension 0, the second element is the size of dimension 1, and so forth. 184 // Empty list indicates a scalar. 185 repeated int64 dimensions = 3; 186 187 // For tuples only, the shapes of constitutent shapes in the tuple sequence. 188 repeated Shape tuple_shapes = 4; 189 190 // The layout used to back this shape. 191 Layout layout = 5; 192 193 // Important: if any field is added, be sure to modify ShapeUtil::Equal() and 194 // ShapeUtil::Compatible() appropriately to account for the new field. 195} 196 197// Shape of the parameters and output of a computation (like a traditional 198// function signature). 199message ProgramShape { 200 repeated Shape parameters = 1; 201 Shape result = 2; 202 repeated string parameter_names = 3; 203} 204 205// Statistics of a computation. 206message ComputationStats { 207 // The number of floating point operations in the computation. 208 double flop_count = 1; 209 210 // The number of transcendental operations (e.g., exp) in the computation. 211 double transcendental_count = 2; 212} 213 214// Symbolization metadata for HLO Instructions. 215// 216// This metadata is used for debugging XLA code generation, as well as 217// performance profiling of XLA-generated executables. 218message OpMetadata { 219 // The framework op name that generated this XLA op. 220 // 221 // Frameworks that build on top of XLA should mirror the names of their ops 222 // back to users by specifying the op_type. In this way, even if the 223 // framework's "ops" are implemented as multiple XLA HLO Ops, they can be 224 // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as 225 // multiple ops, then each op should have the op_type be "SoftMax".) 226 string op_type = 1; 227 // The user-specified name of the op. 228 // 229 // This name is often unique within a computation. Note: some frameworks 230 // add auto-generated names if the user does not provide one. 231 string op_name = 2; 232 // Indicate a file and line that this op is associated to in a user's program. 233 // 234 // e.g. it could be the file and line of user code that generated the op. 235 string source_file = 3; 236 int32 source_line = 4; 237} 238 239// Profile data from the execution of a computation. 240message ExecutionProfile { 241 // Whether the executable was read from the compilation cache. 242 bool compilation_cache_hit = 1; 243 244 // The time in milliseconds spent to compile the computation. This only set if 245 // the executable was not read from the compilation cache 246 // (compilation_cache_hit == false). 247 int64 compile_time_ms = 2; 248 249 // The number of cycles spent for the computation. This does not include the 250 // time taken for the data transfers between the host and the device. This is 251 // a target-dependent field and only used for debugging purposes. 252 int64 compute_cycle_count = 3; 253 254 // The time in nanoseconds spent for the computation, without data transfer. 255 int64 compute_time_ns = 4; 256 257 // The time in nanoseconds spent for the entire computation, including the 258 // result data transfer time. Current implementation does not spend any cycles 259 // for the input data transfer since the memory is initialized with the proper 260 // values before the execution. 261 int64 compute_and_transfer_time_ns = 5; 262} 263 264// Handle given to a user that represents a computation that the user builds up 265// before execution. 266message ComputationHandle { 267 int64 handle = 1; 268} 269 270// Handle given to a user that represents an execution that the user launched 271// asynchronously on the device. 272message ExecutionHandle { 273 int64 handle = 1; 274} 275 276// Handle given to a user that represents a globally accessible allocation. 277// Contrast this against a ComputationDataHandle, which is not globally 278// accessible, since it only exists within a specific computation. 279message GlobalDataHandle { 280 int64 handle = 1; 281} 282 283// Handle given to a user that represents a data result in a computation. 284// This is used to pass to subsequent computations that depends upon the data as 285// an operand. 286message ComputationDataHandle { 287 int64 handle = 1; 288} 289 290// Handle given to a user that represents a replicated virtual device. Each 291// replicated device represents N physical devices for execution where N is the 292// number of replicas. 293message DeviceHandle { 294 int64 handle = 1; 295 296 // The number of model-parallel virtual devices that communicate via XLA 297 // Send/Recv instructions. 298 int64 device_count = 2; 299} 300 301// Handle given to a user to represent a channel between two computations 302// via a Send and Recv instruction pair. Channels are unbuffered, so Send 303// Send instructions will be blocked until the data is transferred. 304message ChannelHandle { 305 int64 handle = 1; 306} 307 308// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which 309// represents the device ids assigned to a set of replicated computations. 310// See xla::DeviceAssignment class comment for more details. 311message DeviceAssignmentProto { 312 int32 replica_count = 1; 313 int32 computation_count = 2; 314 315 // Each logical computation runs on replica_count physical devices. 316 // ComputationDevice represents the device ids assinged to the replicas. 317 message ComputationDevice { 318 repeated int32 replica_device_ids = 1; 319 } 320 repeated ComputationDevice computation_devices = 3; 321} 322 323// Literals are used when the server and client need to exchange materialized 324// data / results. Literals are also used to describe constants used in 325// computations. 326// 327// Transfers to/from the client are encoded in literal form, and the structure 328// of the repeated fields is implied by the shape. 329message LiteralProto { 330 Shape shape = 1; 331 repeated bool preds = 2; 332 bytes u8s = 3; 333 repeated int32 s32s = 4; 334 repeated int64 s64s = 5; 335 repeated uint32 u32s = 6; 336 repeated uint64 u64s = 7; 337 repeated float f32s = 8; 338 repeated double f64s = 9; 339 repeated float c64s = 12; // Stored as interleaved real, imag floats. 340 repeated LiteralProto tuple_literals = 10; 341 // The F16s and BF16s are encoded in little endian byte order 342 bytes f16s = 11; 343 bytes bf16s = 13; 344 repeated int64 sparse_indices = 14; 345 // Next = 15 346} 347 348message WindowDimension { 349 // The size of the window in this dimension. For a rectangle, this would be 350 // the width or height. 351 int64 size = 1; 352 353 // The stride at which the window moves across the base area in this 354 // dimension. In other words, this is the spacing between different 355 // positions of the window in this dimension. 356 int64 stride = 2; 357 358 // If positive, means the amount of padding with zeroes to add to the base 359 // area at the low end of this dimension; if negative, its negative means the 360 // number of elements removed from the low end of this dimension. For example, 361 // in the horizontal dimension of a rectangle, this would be the number of 362 // zeroes to pad on the left, given that indices increase when going right. 363 int64 padding_low = 3; 364 365 // As padding_low, but on the high end of this dimension. For 366 // example, in the horizontal dimension of a rectangle, this would 367 // be the number of zeroes to pad on the right, given that indices 368 // increase when going right. 369 int64 padding_high = 4; 370 371 // Dilation factor of the sliding window in this dimension. A dilation factor 372 // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are 373 // implicitly placed between each kernel element. See documentation for 374 // convolution. 375 int64 window_dilation = 5; 376 377 // Dilation factor of the base area in this dimension. A dilation factor of 1 378 // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly 379 // placed between each base area element. See documentation for convolution. 380 int64 base_dilation = 6; 381 382 // Window reversal means that this dimension was logically reversed before the 383 // operation. 384 bool window_reversal = 7; 385} 386 387// Describes the windowing in an operation such as convolution. 388// 389// The window is moved across a base area and for each position of the 390// window a computation is performed. The field below describes the 391// window and the movement of the window across a base area. 392message Window { 393 repeated WindowDimension dimensions = 1; 394} 395 396// Describes the dimension numbers for a gather operation. 397// 398// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for 399// more details. 400message GatherDimensionNumbers { 401 // "Window indices" is a term for a set of indices that index into the 402 // interior of a dynamic-slice from the input tensor, the starting indices for 403 // which were computed from output_gather_dims (see the operation semantic for 404 // how this is defined) and the gather_indices tensor. 405 // 406 // The window indices for a specific output index Out is computed as: 407 // 408 // i = 0 409 // for (k : [0, input_tensor_shape.rank)) 410 // window_indices[k] = 411 // if k in elided_window_dims 412 // then 0 413 // else Out[output_window_dims[i++]] 414 repeated int64 output_window_dims = 1; 415 repeated int64 elided_window_dims = 2; 416 417 // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It 418 // transforms the gather index looked up from the gather_indices tensor into 419 // the starting index in the input space. 420 repeated int64 gather_dims_to_operand_dims = 3; 421} 422 423// Operation requests that are all collected as a tagged union with a oneof 424// field in OpRequest. 425 426message ConstantRequest { 427 LiteralProto literal = 2; 428} 429 430message GetTupleElementRequest { 431 ComputationDataHandle operand = 2; 432 int64 index = 3; 433} 434 435message SliceRequest { 436 ComputationDataHandle operand = 2; 437 repeated int64 start_indices = 3; 438 repeated int64 limit_indices = 4; 439 repeated int64 strides = 5; 440} 441 442message DynamicSliceRequest { 443 // Operand from which to slice at dynamic 'start_indices'. 444 ComputationDataHandle operand = 2; 445 // Dynamically computed 'start_indices' for slice operation. 446 ComputationDataHandle start_indices = 3; 447 // Slice sizes for each dimension (note that indices calculations are computed 448 // modulo dimension sizes to avoid out-of-bound array accesses). 449 repeated int64 slice_sizes = 4; 450} 451 452message DynamicUpdateSliceRequest { 453 // Operand on which slice 'update' is to be applied. 454 ComputationDataHandle operand = 2; 455 // The slice update to apply to 'operand'. 456 ComputationDataHandle update = 3; 457 // Dynamically computed start indices for the update slice operation. 458 ComputationDataHandle start_indices = 4; 459} 460 461message ConvolutionDimensionNumbers { 462 // The number of the dimension that represents batch in the input. 463 int64 input_batch_dimension = 7; 464 465 // The number of the dimension that represents features in the input. 466 int64 input_feature_dimension = 8; 467 468 // The dimension numbers for the spatial dimensions that the window 469 // moves through in the input. 470 repeated int64 input_spatial_dimensions = 11; 471 472 // The number of the dimension that represents input features in the 473 // convolutional kernel (rhs). 474 int64 kernel_input_feature_dimension = 3; 475 476 // The number of the dimension that represents output features in 477 // the convolutional kernel (rhs). 478 int64 kernel_output_feature_dimension = 4; 479 480 // The dimension numbers for the spatial dimensions that the window 481 // moves through in the kernel (rhs). window.strides(0) is the 482 // stride in the kernel_spatial_dimensions(0) dimension. 483 repeated int64 kernel_spatial_dimensions = 6; 484 485 // The number of the dimension that represents batch in the output. 486 int64 output_batch_dimension = 9; 487 488 // The number of the dimension that represents features in the output. 489 int64 output_feature_dimension = 10; 490 491 // The dimension numbers for the spatial dimensions that the window 492 // moves through in the output. 493 repeated int64 output_spatial_dimensions = 12; 494 495 // Next = 13 496}; 497 498message ConvolveRequest { 499 ComputationDataHandle lhs = 2; 500 ComputationDataHandle rhs = 3; // This is the filter/kernel. 501 Window window = 4; // Describes the filter/kernel. 502 ConvolutionDimensionNumbers dimension_numbers = 5; 503} 504 505enum FftType { 506 FFT = 0; // Forward FFT; complex in, complex out. 507 IFFT = 1; // Inverse FFT; complex in, complex out. 508 RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out 509 IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, 510 // fft_length real out 511} 512 513message FftRequest { 514 FftType fft_type = 1; 515 repeated int64 fft_length = 2; // Multivalent for higher-order FFT. 516 ComputationDataHandle operand = 3; 517} 518 519message InfeedRequest { 520 // The shape of the data returned by reading the device's infeed buffer. 521 Shape shape = 2; 522 523 // Additional infeed configuration for the backend. 524 bytes config = 3; 525} 526 527message OutfeedRequest { 528 // The shape of the data returned by reading the device's outfeed buffer. 529 Shape shape = 1; 530 531 // Operand to the Outfeed. Supports tuple. 532 ComputationDataHandle operand = 2; 533 534 // Backend-specific information for how to perform the outfeed. 535 bytes outfeed_config = 3; 536} 537 538message CallRequest { 539 ComputationHandle to_apply = 2; 540 repeated ComputationDataHandle operands = 3; 541} 542 543message CustomCallRequest { 544 string call_target_name = 2; 545 repeated ComputationDataHandle operands = 3; 546 Shape shape = 4; 547} 548 549message HostComputeRequest { 550 // Operand to the HostCompute. Supports tuple. 551 repeated ComputationDataHandle operands = 1; 552 553 // Name used to identify HostSend/Recv channels. 554 string channel_name = 2; 555 556 // Cost estimate in nanoseconds. 557 int64 cost_estimate_ns = 3; 558 559 // The shape of any data returned by host. 560 Shape shape = 4; 561} 562 563message DotDimensionNumbers { 564 // The dimension numbers that represent the 'lhs' contracting dimensions. 565 repeated int64 lhs_contracting_dimensions = 1; 566 // The dimension numbers that represent the 'rhs' contracting dimensions. 567 repeated int64 rhs_contracting_dimensions = 2; 568 // The dimension numbers that represent the 'lhs' batch dimensions. 569 repeated int64 lhs_batch_dimensions = 3; 570 // The dimension numbers that represent the 'rhs' batch dimensions. 571 repeated int64 rhs_batch_dimensions = 4; 572}; 573 574message DotRequest { 575 ComputationDataHandle lhs = 2; 576 ComputationDataHandle rhs = 3; 577 DotDimensionNumbers dimension_numbers = 4; 578} 579 580message MapRequest { 581 repeated ComputationDataHandle operands = 2; 582 ComputationHandle to_apply = 3; 583 repeated ComputationDataHandle static_operands = 4; 584 // The dimensions over which to map. 585 // Example mapping a Dot operation along the batch dimension 0: 586 // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] 587 // Map({operand0, operand1}, Dot, {0}) 588 repeated int64 dimensions = 5; 589} 590 591message ReduceRequest { 592 // Operand to the reduction. 593 ComputationDataHandle operand = 2; 594 595 // Initial value for the reduction. This must be consistent with the result 596 // shape of to_apply. 597 ComputationDataHandle init_value = 3; 598 599 // The dimensions to reduce over. 600 repeated int64 dimensions = 4; 601 602 // The computation to apply in the reduction. 603 ComputationHandle to_apply = 5; 604} 605 606message ReduceWindowRequest { 607 ComputationDataHandle operand = 2; 608 ComputationDataHandle init_value = 3; 609 Window window = 4; 610 ComputationHandle to_apply = 5; 611} 612 613message BatchNormTrainingRequest { 614 ComputationDataHandle operand = 1; 615 ComputationDataHandle scale = 2; 616 ComputationDataHandle offset = 3; 617 float epsilon = 4; 618 int64 feature_index = 5; 619} 620 621message BatchNormInferenceRequest { 622 ComputationDataHandle operand = 1; 623 ComputationDataHandle scale = 2; 624 ComputationDataHandle offset = 3; 625 ComputationDataHandle mean = 4; 626 ComputationDataHandle variance = 5; 627 float epsilon = 6; 628 int64 feature_index = 7; 629} 630 631message BatchNormGradRequest { 632 ComputationDataHandle operand = 1; 633 ComputationDataHandle scale = 2; 634 ComputationDataHandle mean = 3; 635 ComputationDataHandle variance = 4; 636 ComputationDataHandle grad_output = 5; 637 float epsilon = 6; 638 int64 feature_index = 7; 639} 640 641message CrossReplicaSumRequest { 642 ComputationDataHandle operand = 2; 643} 644 645message SelectAndScatterRequest { 646 // Operand array on which the windows slide. 647 ComputationDataHandle operand = 2; 648 649 // Source array for the data to scatter. 650 ComputationDataHandle source = 3; 651 652 // Initial scalar value for each element in the output. 653 ComputationDataHandle init_value = 4; 654 655 // Window configuration. 656 Window window = 5; 657 658 // Binary function used to select an element from each window. 659 ComputationHandle select = 6; 660 661 // Binary function used to combine each scattered value from source with the 662 // current output value at the selected location. 663 ComputationHandle scatter = 7; 664} 665 666message ReverseRequest { 667 ComputationDataHandle operand = 2; 668 repeated int64 dimensions = 3; 669} 670 671message BroadcastRequest { 672 ComputationDataHandle operand = 2; 673 repeated int64 broadcast_sizes = 3; 674} 675 676message PadRequest { 677 ComputationDataHandle operand = 2; 678 ComputationDataHandle padding_value = 3; 679 PaddingConfig padding_config = 4; 680} 681 682message ReshapeRequest { 683 ComputationDataHandle operand = 2; 684 685 // The dimension order for collapse (from fastest-changing to slowest). 686 repeated int64 dimensions = 3; 687 688 // The new dimension sizes (from dimension 0 to n-1). 689 repeated int64 new_sizes = 4; 690} 691 692message TransposeRequest { 693 ComputationDataHandle operand = 2; 694 695 // The permutation of the operand's dimensions (in the range 0 to n-1). 696 repeated int64 dimensions = 3; 697} 698 699message ParameterRequest { 700 Shape shape = 2; 701 int64 parameter = 3; 702 string name = 4; 703} 704 705message GetLocalShapeRequest { 706 ComputationHandle computation = 1; 707 ComputationDataHandle operand = 2; 708} 709 710message GetLocalShapeResponse { 711 Shape shape = 1; 712} 713 714message TraceRequest { 715 string tag = 2; 716 ComputationDataHandle operand = 3; 717} 718 719message ConvertRequest { 720 ComputationDataHandle operand = 2; 721 PrimitiveType new_element_type = 3; 722} 723 724message ConcatenateRequest { 725 repeated ComputationDataHandle operands = 2; 726 // The dimension in which we concatenate; e.g. if you had dimension arrays of 727 // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1]. 728 // Attempting to concatenate those in dimension 1 would produce an error, as 729 // 4 != 5 (and there is no ragged array support). 730 int64 dimension = 3; 731} 732 733message ConditionalRequest { 734 ComputationDataHandle predicate = 2; 735 ComputationDataHandle true_operand = 3; 736 ComputationHandle true_computation = 4; 737 ComputationDataHandle false_operand = 5; 738 ComputationHandle false_computation = 6; 739} 740 741message WhileRequest { 742 ComputationHandle condition = 2; 743 ComputationHandle body = 3; 744 ComputationDataHandle init = 4; 745} 746 747enum UnaryOperation { 748 UNOP_INVALID = 0; 749 750 // Elementwise, logical negation on booleans and bitwise negation on ints. 751 UNOP_NOT = 1; 752 753 // Elementwise, computes e^x. 754 UNOP_EXP = 2; 755 756 // Elementwise, computes -x. 757 UNOP_NEGATE = 3; 758 759 // Puts the elements in the operand into sorted order. 760 UNOP_SORT = 4; 761 762 // Elementwise, computes tanh(x). 763 UNOP_TANH = 5; 764 765 // Elementwise, computes the natural logarithm of x. 766 UNOP_LOG = 6; 767 768 // Elementwise, computes the floor of x. 769 UNOP_FLOOR = 7; 770 771 // Elementwise, computes the ceil of x. 772 UNOP_CEIL = 8; 773 774 // Elementwise, computes the abs of x. 775 UNOP_ABS = 9; 776 777 // Elementwise, computes the sign of x. 778 UNOP_SIGN = 10; 779 780 // Elementwise, tests if values are finite (not NaN or inf) 781 UNOP_IS_FINITE = 11; 782 783 // Elementwise, computes the cosine of x. 784 UNOP_COS = 12; 785 786 // Elementwise, computes the sine of x. 787 UNOP_SIN = 13; 788 789 // Elementwise, rounds x to nearest integral value, rounding half-way cases 790 // away from zero. 791 UNOP_ROUND_NEAREST_AFZ = 14; 792 793 // Elementwise, extract real component of complex x. 794 UNOP_REAL = 15; 795 796 // Elementwise, extract real component of complex x. 797 UNOP_IMAG = 16; 798} 799 800message UnaryOpRequest { 801 UnaryOperation unop = 2; 802 ComputationDataHandle operand = 3; 803} 804 805enum BinaryOperation { 806 BINOP_INVALID = 0; 807 808 // Arithmetic operations. 809 BINOP_ADD = 1; 810 BINOP_DIV = 2; 811 BINOP_MUL = 3; 812 BINOP_SUB = 4; 813 814 // Comparison operators. 815 BINOP_EQ = 5; 816 BINOP_GE = 6; 817 BINOP_GT = 7; 818 BINOP_LE = 8; 819 BINOP_LT = 9; 820 BINOP_NE = 10; 821 822 // Element-wise maximum. 823 BINOP_MAX = 14; 824 825 // Element-wise minimum. 826 BINOP_MIN = 15; 827 828 // Raises the left-hand-side to the right-hand-side power. 829 BINOP_POW = 16; 830 831 // Remainder operation. 832 BINOP_REM = 17; 833 834 // Element-wise, logical operators on booleans and bitwise operators on ints. 835 BINOP_AND = 18; 836 BINOP_OR = 19; 837 838 BINOP_SHIFT_LEFT = 20; 839 BINOP_SHIFT_RIGHT_ARITHMETIC = 21; 840 BINOP_SHIFT_RIGHT_LOGICAL = 22; 841 842 // Complex from real, imag. 843 BINOP_COMPLEX = 23; 844 845 // Computes the 4-quadrant arctangent of the y, x input arguments. 846 BINOP_ATAN2 = 24; 847} 848 849message BinaryOpRequest { 850 BinaryOperation binop = 2; 851 ComputationDataHandle lhs = 3; 852 ComputationDataHandle rhs = 4; 853 repeated int64 broadcast_dimensions = 5; 854} 855 856enum RandomDistribution { 857 RNG_INVALID = 0; 858 859 // Creates a uniform-distribution-generated random number on the semi-open 860 // interval [parameter[0], parameter[1]). 861 RNG_UNIFORM = 1; 862 863 // Creates a normal-distribution-generated random number with mean 864 // parameter[0] and standard deviation parameter[1]. 865 RNG_NORMAL = 2; 866 867 // Next: 4 868} 869 870message RngRequest { 871 RandomDistribution distribution = 2; 872 repeated ComputationDataHandle parameter = 3; 873 Shape shape = 4; 874} 875 876enum TernaryOperation { 877 TRIOP_INVALID = 0; 878 879 // Given a predicate and two operands, selects operand0 if the predicate is 880 // true and operand1 if the predicate is false. 881 TRIOP_SELECT = 1; 882 883 // Given a min, max and an operand returns the operand if between min and max, 884 // else returns min if operand is less than min or max if operand is greater 885 // than max. 886 TRIOP_CLAMP = 3; 887} 888 889message TernaryOpRequest { 890 TernaryOperation triop = 2; 891 ComputationDataHandle lhs = 3; 892 ComputationDataHandle rhs = 4; 893 ComputationDataHandle ehs = 5; 894} 895 896enum VariadicOperation { 897 VAROP_INVALID = 0; 898 899 // Creates a tuple from its operands. 900 VAROP_TUPLE = 1; 901} 902 903message VariadicOpRequest { 904 VariadicOperation varop = 2; 905 repeated ComputationDataHandle operands = 3; 906} 907 908message ReducePrecisionRequest { 909 ComputationDataHandle operand = 1; 910 int32 exponent_bits = 2; 911 int32 mantissa_bits = 3; 912} 913 914message SendRequest { 915 ComputationDataHandle operand = 1; 916 ChannelHandle channel_handle = 2; 917} 918 919message RecvRequest { 920 Shape shape = 1; 921 ChannelHandle channel_handle = 2; 922} 923 924message GatherRequest { 925 ComputationDataHandle input = 1; 926 ComputationDataHandle gather_indices = 2; 927 GatherDimensionNumbers dimension_numbers = 3; 928 repeated int64 window_bounds = 4; 929} 930 931message OpSharding { 932 enum Type { 933 // This sharding is replicated across all devices (implies maximal, 934 // all other fields are unused). 935 REPLICATED = 0; 936 // This sharding is maximal - one device runs the entire operation. 937 MAXIMAL = 1; 938 // This sharding is a tuple - only the tuple_shardings field is valid. 939 TUPLE = 2; 940 // None of the above; tile_shape and tile_assignment are both used. 941 OTHER = 3; 942 } 943 Type type = 1; 944 // The shape of the sharded tile. 945 Shape tile_shape = 2; 946 // The shape of the tile assignment tensor - this must be the same rank as 947 // tile_shape and the product of its dimensions must equal 948 // tile_assignment_devices.size(). 949 repeated int64 tile_assignment_dimensions = 3; 950 // Flattened list of device IDs. The order of flattening is the same as used 951 // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). 952 repeated int64 tile_assignment_devices = 4; 953 // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, 954 // in pre-order. The tuple shape could be nested; here we store just a 955 // flattened list of all leaves in the tuple shape. Note that the tuple shape 956 // is not stored here; shardings do not store the shapes to which they are 957 // applied, this is inferred from the instruction this sharding gets attached 958 // to. 959 repeated OpSharding tuple_shardings = 5; 960} 961 962message OpRequest { 963 ComputationHandle computation = 1; 964 OpMetadata metadata = 33; 965 OpSharding sharding = 40; 966 967 oneof op { 968 BinaryOpRequest binary_op_request = 2; 969 BroadcastRequest broadcast_request = 3; 970 CallRequest call_request = 4; 971 ConcatenateRequest concatenate_request = 5; 972 ConstantRequest constant_request = 6; 973 ConvertRequest convert_request = 7; 974 ConvolveRequest convolve_request = 8; 975 CrossReplicaSumRequest cross_replica_sum_request = 9; 976 CustomCallRequest custom_call_request = 10; 977 DotRequest dot_request = 43; 978 DynamicSliceRequest dynamic_slice_request = 11; 979 DynamicUpdateSliceRequest dynamic_update_slice_request = 12; 980 GetTupleElementRequest get_tuple_element_request = 13; 981 InfeedRequest infeed_request = 14; 982 MapRequest map_request = 15; 983 PadRequest pad_request = 16; 984 ParameterRequest parameter_request = 17; 985 ReducePrecisionRequest reduce_precision_request = 36; 986 ReduceRequest reduce_request = 18; 987 ReduceWindowRequest reduce_window_request = 19; 988 ReshapeRequest reshape_request = 20; 989 ReverseRequest reverse_request = 21; 990 RngRequest rng_request = 22; 991 SelectAndScatterRequest select_and_scatter_request = 23; 992 SliceRequest slice_request = 24; 993 TernaryOpRequest ternary_op_request = 25; 994 TraceRequest trace_request = 26; 995 TransposeRequest transpose_request = 34; 996 UnaryOpRequest unary_op_request = 27; 997 VariadicOpRequest variadic_op_request = 28; 998 WhileRequest while_request = 29; 999 SendRequest send_request = 30; 1000 RecvRequest recv_request = 31; 1001 OutfeedRequest outfeed_request = 32; 1002 BatchNormTrainingRequest batch_norm_training_request = 35; 1003 BatchNormGradRequest batch_norm_grad_request = 37; 1004 BatchNormInferenceRequest batch_norm_inference_request = 38; 1005 FftRequest fft_request = 41; 1006 ConvertRequest bitcast_convert_request = 42; 1007 ConditionalRequest conditional_request = 44; 1008 HostComputeRequest host_compute_request = 45; 1009 GatherRequest gather_request = 46; 1010 // Next: 47 1011 } 1012} 1013 1014message OpResponse { 1015 ComputationDataHandle output = 1; 1016} 1017