1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This proto file defines messages which represent the HLO module. This is a 17// full fidelity serialization of the c++ HLO constructs. 18// 19// Many of the protos below are simple 1-to-1 serializations of the 20// corresponding C++ classes, e.g., HloModule, HloComputation, and 21// HloInstruction. 22// 23// FIELD NAMES ARE IMPORTANT 24// 25// Unlike most protos, you can't safely change the names of fields, even if you 26// keep the numeric ids the same. This is because we sometimes serialize these 27// protos as JSON, which includes the field names in the serialization. 28 29syntax = "proto3"; 30 31package xla; 32 33import "tensorflow/compiler/xla/xla_data.proto"; 34 35option cc_enable_arenas = true; 36 37enum CustomCallSchedule { 38 SCHEDULE_NONE = 0; 39 SCHEDULE_LATEST = 1; 40 SCHEDULE_EARLIEST = 2; 41} 42 43// The version of the API used by the custom call function. The signatures for 44// each version are given below. 45// TODO(b/189822916): Remove this enum when all clients are migrated to the 46// status-returning API. 47enum CustomCallApiVersion { 48 API_VERSION_UNSPECIFIED = 0; 49 50 // The first version of the API, with the following signatures: 51 // 52 // CPU: 53 // void do_custom_call(void* out, const void** in) 54 // 55 // GPU: 56 // void do_custom_call(CUstream stream, void** buffers, 57 // const char* opaque, size_t opaque_len); 58 API_VERSION_ORIGINAL = 1; 59 60 // When the ability to return success/failure status was added: 61 // 62 // CPU: Unimplemented 63 // 64 // GPU: 65 // void do_custom_call(CUstream stream, void** buffers, 66 // const char* opaque, size_t opaque_len, 67 // XlaCustomCallStatus* status); 68 // 69 API_VERSION_STATUS_RETURNING = 2; 70} 71 72// Serialization of HloInstruction. 73// Next ID: 78 74message HloInstructionProto { 75 reserved 10; 76 reserved "parameter_name"; 77 reserved 12; 78 reserved "fused_instructions_computation"; 79 reserved 4; 80 reserved "operand_names"; 81 reserved 5; 82 reserved "control_predecessor_names"; 83 reserved 6; 84 reserved "called_computation_names"; 85 reserved 44; 86 reserved "replica_group_ids"; 87 // Use backend_config instead for custom_call_opaque. 88 reserved 53; 89 reserved "custom_call_opaque"; 90 // Use backend_config instead for all_reduce_barrier. 91 reserved 46; 92 reserved "all_reduce_barrier"; 93 94 string name = 1; 95 string opcode = 2; 96 xla.ShapeProto shape = 3; 97 98 xla.OpMetadata metadata = 7; 99 100 // Literal, only present for kConstant. 101 xla.LiteralProto literal = 8; 102 103 // Parameter number is only present for kParameter. 104 int64 parameter_number = 9; 105 106 // Fusion state, only present for kFusion. 107 string fusion_kind = 11; 108 109 // Index for kGetTupleElement. 110 int64 tuple_index = 13; 111 112 // Dimensions present for some operations that require reshaping or 113 // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. 114 repeated int64 dimensions = 14; 115 116 // Describes the window in a windowed operation such as convolution. 117 xla.Window window = 15; 118 119 // Describes the dimension numbers used for a convolution. 120 xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; 121 122 // The number of feature groups. Used for a convolution. Must be a divisor of 123 // the input feature dimension and output feature dimension. If not specified, 124 // it will use a default value of 1. 125 int64 feature_group_count = 50; 126 127 int64 batch_group_count = 58; 128 129 // Describes the [begin, end) index range and stride for slices. 130 message SliceDimensions { 131 int64 start = 1; 132 int64 limit = 2; 133 int64 stride = 3; 134 } 135 repeated SliceDimensions slice_dimensions = 17; 136 137 // The bit sizes for a reduce-precision operation. 138 int32 exponent_bits = 18; 139 int32 mantissa_bits = 19; 140 141 // Describes the [start, start + size) range size for a dynamic slice 142 // ('start' is specified dynamically in the second operand of the operation). 143 repeated int64 dynamic_slice_sizes = 20; 144 145 // The padding configuration that describes the edge padding and interior 146 // padding of this pad instruction. Only set for pad instructions. 147 xla.PaddingConfig padding_config = 21; 148 149 // Outfeed configuration information, only present for kOutfeed. 150 bytes outfeed_config = 22; 151 152 // The distribution requested for random number generation. 153 // Only present for kRng. 154 xla.RandomDistribution distribution = 23; 155 156 // A small float number added to the variance to avoid divide-by-zero error. 157 // Only present for kBatchNormTraining. 158 float epsilon = 24; 159 160 // An integer value representing the index of the feature dimension. 161 // Only present for kBatchNormTraining. 162 int64 feature_index = 25; 163 164 // Represents a unique identifier for each Send/Recv instruction pair or 165 // optionally for collective instructions (AllReduce, CollectivePermute, 166 // AllToAll). Non-positive channel_id is equivalent to no channel id. 167 int64 channel_id = 26; 168 169 // The string representation of the infeed configuration. 170 bytes infeed_config = 27; 171 172 // Name of a external target (eg, global symbol) to call, only present for 173 // kCustomCall. 174 string custom_call_target = 28; 175 176 // Shape of outfeed request. 177 xla.ShapeProto outfeed_shape = 29; 178 179 // Describes the dimension numbers used for a dot operation 180 xla.DotDimensionNumbers dot_dimension_numbers = 30; 181 182 // FFT type (FFT, IFFT, etc). 183 xla.FftType fft_type = 31; 184 185 // FFT length. 186 repeated int64 fft_length = 32; 187 188 // Comparison direction only used for kCompare. 189 string comparison_direction = 63; 190 191 // Gather dimension numbers. 192 xla.GatherDimensionNumbers gather_dimension_numbers = 33; 193 repeated int64 gather_slice_sizes = 34; 194 195 // Compute Host. 196 string channel_name = 41; 197 int64 cost_estimate_ns = 42; 198 199 // The id of this instruction. 200 int64 id = 35; 201 202 repeated int64 operand_ids = 36; 203 repeated int64 control_predecessor_ids = 37; 204 repeated int64 called_computation_ids = 38; 205 206 xla.OpSharding sharding = 40; 207 208 // Backend configuration for the instruction. Has backend-specific meaning. 209 bytes backend_config = 43; 210 211 // Cross replica op fields. 212 repeated ReplicaGroup replica_groups = 49; 213 // Deprecated, but keeping it for backward compatibility. Use channel_id. 214 // Non-positive all_reduce_id is equivalent to no all_reduce_id. 215 int64 all_reduce_id = 45 [deprecated = true]; 216 217 // If true, interprets ids in ReplicaGroup as global device ids, which is 218 // a linearized id of `replica_id * partition_count + partition_id`. 219 bool use_global_device_ids = 71; 220 221 // Whether this Send/Recv instruction transfers data to/from the host. Only 222 // present for Send and Recv instructions and their SendDone and RecvDone 223 // partners. 224 bool is_host_transfer = 47; 225 226 // Whether this Sort instruction should be stable. 227 bool is_stable = 60; 228 229 xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; 230 231 // Precision configuration for the instruction. Has backend-specific meaning. 232 xla.PrecisionConfig precision_config = 51; 233 234 // Collective permute field. 235 repeated SourceTarget source_target_pairs = 52; 236 237 // Sharding for kDomain instructions. 238 xla.OpSharding domain_entry_sharding = 54; 239 xla.OpSharding domain_exit_sharding = 55; 240 241 // For custom call this indicates that the layouts are constrained. If 242 // constrain_layout is true then the 'shape' field must contain a layout, and 243 // 'operand_shapes_with_layout' must contain a shape with layout for each 244 // operand. 245 bool constrain_layout = 56; 246 repeated xla.ShapeProto operand_shapes_with_layout = 57; 247 248 // Options for TriangularSolve 249 xla.TriangularSolveOptions triangular_solve_options = 59; 250 251 // Options for Cholesky 252 xla.CholeskyOptions cholesky_options = 62; 253 254 // Describes how parameters behave with regards to replicas. 255 xla.ParameterReplication parameter_replication = 61; 256 257 // If set, the given instruction is run in parallel on e.g. multiple CPU 258 // cores. The outermost dimension gets split up into 259 // outer_dimension_partitions[0] pieces, the next-outermost dim gets split 260 // into outer_dimension_partitions[1] pieces, etc. 261 // 262 // It's illegal to partition a dimension into more shards than there are 263 // elements in that dimension. 264 repeated int64 outer_dimension_partitions = 64; 265 266 // Whether the kCustomCall instruction has side-effects, only present for 267 // kCustomCall. 268 bool custom_call_has_side_effect = 65; 269 270 // A list of CustomCallOutputOperandAliasing pairs that specifies aliasing 271 // buffers between output and operands for kCustomCall. 272 repeated xla.CustomCallOutputOperandAliasing 273 custom_call_output_operand_aliasing = 74; 274 275 // Specifies the desired schedule for the custom-call. The field is only 276 // present for custom-call. 277 CustomCallSchedule custom_call_schedule = 76; 278 279 // The delta value for kRngGetAndUpdateState. 280 int64 delta = 66; 281 282 // Specifies if the gather/scatter indices are guaranteed to be sorted by the 283 // caller. 284 bool indices_are_sorted = 67; 285 286 // Frontend attributes to pass to the XLA backend. 287 xla.FrontendAttributes frontend_attributes = 68; 288 289 // Specifies if all elements updated are guaranteed to be unique by 290 // the caller. 291 bool unique_indices = 69; 292 293 // RNG algorithm used by kRngBitGenerator. 294 xla.RandomAlgorithm rng_algorithm = 70; 295 296 // The comparison type used for kCompare. 297 string comparison_type = 72; 298 299 // Specifies if this is a cross-program-prefetch, used by kCopyStart. 300 bool is_cross_program_prefetch = 73; 301 302 // If a convolution is dynamic, a dynamic padding type will be specified. 303 xla.PaddingType padding_type = 75; 304 305 // The API version used by the custom call function. This field is only 306 // present for custom-call. 307 // TODO(b/189822916): Remove this field when all clients are migrated to the 308 // status-returning API. 309 CustomCallApiVersion custom_call_api_version = 77; 310} 311 312// Serialization of HloComputation. 313message HloComputationProto { 314 reserved 3; 315 reserved "root_name"; 316 317 string name = 1; 318 319 // The array of instructions is always in a valid dependency order, where 320 // operands appear before their users. 321 repeated HloInstructionProto instructions = 2; 322 323 // The program shape (with layout) of this computation. 324 325 xla.ProgramShapeProto program_shape = 4; 326 327 // The id of this computation. 328 int64 id = 5; 329 330 // The id of the root of the computation. 331 int64 root_id = 6; 332} 333 334// Serialization of an HLO schedule. An HLO schedule contains a total order of 335// instructions for each non-fusion computation in the module. 336message HloScheduleProto { 337 message InstructionSequence { 338 repeated int64 instruction_ids = 1; 339 } 340 341 // Map from computation id to sequence. 342 map<int64, InstructionSequence> sequences = 1; 343} 344 345enum Kind { 346 // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 347 // behavior and missing has_*() APIs. 348 UNDEFINED_ALIAS = 0; 349 // The buffers may or may not alias at runtime. 350 MAY_ALIAS = 1; 351 // The buffers must alias at runtime. 352 MUST_ALIAS = 2; 353} 354 355message HloInputOutputAliasProto { 356 // The following proto describes a pair of aliased an input 357 // (described by parameter number and a ShapeIndex of the parameter) 358 // and an output (described by a ShapeIndex of the root 359 // instruction). For example: 360 // 361 // entry = { 362 // output_shape_index={1}, 363 // parameter_number=0, 364 // parameter_shape_index={1, 2}, 365 // } 366 // 367 // This entry indicates that the first paremter's {1, 2} element is 368 // aliased with the {1} element of the root instruction. 369 message AliasEntryProto { 370 // ShapeIndex of the root hlo. 371 repeated int64 output_shape_index = 1; 372 // Number of the parameter in entry computation. 373 int64 parameter_number = 2; 374 // ShapeIndex of the parameter instruction. 375 repeated int64 parameter_shape_index = 3; 376 // The kind of alias to be setup. 377 Kind kind = 4; 378 } 379 380 repeated AliasEntryProto entries = 1; 381} 382 383message DynamicParameterBindingProto { 384 // A list of bindings which indicates that the `target_dim_num` in 385 // the subshape `target_param_index` of parameter `target_param_num` 386 // is a dynamic dimension and its real dynamic size is represented 387 // by `dynamic_param_index` in parameter `dynamic_param_num`. 388 // 389 // As an example, imagine we have a program: 390 // 391 // ENTRY main { 392 // a = f32[] parameter(0) 393 // b = f32[10] parameter(1) 394 // ROOT root = (f32[], f32[10]) tuple(%a, %b) 395 // } 396 // 397 // Let's say 'b' (param index 1) is a dynamic shape whose input has 398 // an upperbound of 10 and real size is determined at runtime.'a' 399 // represents the real size of b's first dimension. 400 // 401 // In this case, the fields are set in the following way: 402 // dynamic_param_num = 1 403 // dynamic_param_index = {} 404 // target_param_num = 0 405 // target_param_index = {} 406 // target_param_dim = 0 407 message Binding { 408 int64 dynamic_param_num = 1; 409 repeated int64 dynamic_param_index = 2; 410 int64 target_param_num = 3; 411 repeated int64 target_param_index = 4; 412 int64 target_param_dim_num = 5; 413 } 414 415 repeated Binding entries = 1; 416} 417 418message CrossProgramPrefetch { 419 int64 parameter = 1; 420 repeated int64 index = 2; 421} 422 423// Serialization of HloModule. 424message HloModuleProto { 425 string name = 1; 426 string entry_computation_name = 2; 427 int64 entry_computation_id = 6; 428 429 // The array of computations is always in a valid dependency order, where 430 // callees appear before their callers. 431 repeated HloComputationProto computations = 3; 432 433 // The host program shape (with layout) of the entry computation. 434 xla.ProgramShapeProto host_program_shape = 4; 435 436 // The id of this module. 437 int64 id = 5; 438 439 // The schedule for this module. 440 HloScheduleProto schedule = 7; 441 442 // Describes alias information between inputs and outputs. 443 HloInputOutputAliasProto input_output_alias = 8; 444 445 DynamicParameterBindingProto dynamic_parameter_binding = 9; 446 447 repeated CrossProgramPrefetch cross_program_prefetches = 10; 448 449 // True if the module contains dynamic computation. 450 bool is_dynamic = 11; 451} 452 453// Serialization of LogicalBuffer. 454message LogicalBufferProto { 455 // Location represents an instruction and its shape index, which uniquely 456 // identifies a point where a buffer is needed. 457 message Location { 458 // NOTE: module_name isn't necessary, since all LogicalBuffers are 459 // associated with a single HloModule. 460 string computation_name = 1; 461 string instruction_name = 2; 462 repeated int64 shape_index = 3; 463 } 464 465 int64 id = 1; 466 int64 size = 2; 467 468 // The location where the buffer is defined. 469 Location defined_at = 3; 470 471 int64 color = 4; 472} 473 474// Serialization of BufferAllocation. 475message BufferAllocationProto { 476 // Assigned represents a single LogicalBuffer that is assigned to this 477 // BufferAllocation. 478 message Assigned { 479 int64 logical_buffer_id = 1; 480 int64 offset = 2; 481 int64 size = 3; 482 } 483 484 int64 index = 1; 485 int64 size = 2; 486 bool is_thread_local = 3; 487 bool is_tuple = 11; 488 bool is_entry_computation_parameter = 5; 489 bool is_constant = 12; 490 int64 parameter_number = 6; 491 repeated int64 parameter_shape_index = 10; 492 bool maybe_live_out = 7; 493 int64 color = 8; 494 repeated Assigned assigned = 9; 495} 496 497// A trace of a HeapSimulator run. 498message HeapSimulatorTrace { 499 // The trace includes a list of events, where each event describes one action 500 // performed by the heap simulator. 501 message Event { 502 enum Kind { 503 ALLOC = 0; // A memory region was allocated for the buffer. 504 FREE = 1; // A memory region was freed for the buffer. 505 506 // A buffer was shared with another (canonical) buffer. This is similar to 507 // ALLOC, except that instead of allocating a new region of memory, the 508 // memory region of the canonical buffer is directly re-used. Multiple 509 // buffers may share with the same canonical buffer. The lifetime of the 510 // canonical buffer is extended to the union of all lifetimes. 511 SHARE_WITH = 2; 512 } 513 Kind kind = 1; 514 515 // The id of the LogicalBuffer that the event applies to. 516 int64 buffer_id = 2; 517 518 // The HloInstruction that the simulation was processing that caused this 519 // event to occur, identified by its computation and instruction name. E.g. 520 // buffers defined by instruction A are allocated when processing A. 521 string computation_name = 3; 522 string instruction_name = 4; 523 524 // The id of the canonical LogicalBuffer that the buffer shares with. Only 525 // set for SHARE_WITH events. 526 int64 share_with_canonical_id = 5; 527 } 528 repeated Event events = 1; 529 bool whole_module_simulation = 2; 530 int64 buffer_allocation_index = 3; 531} 532 533// An abstraction representing a set of HLO module built to run concurrently 534// across different devices. 535message HloModuleGroupProto { 536 string name = 1; 537 repeated HloModuleProto hlo_modules = 2; 538} 539 540// Serialization of BufferAssignment. 541message BufferAssignmentProto { 542 // Alias represents a source LogicalBuffer, and the buffer location that 543 // aliases it. 544 message BufferAlias { 545 int64 source_buffer_id = 1; 546 LogicalBufferProto.Location location = 2; 547 } 548 549 repeated LogicalBufferProto logical_buffers = 1; 550 repeated BufferAlias buffer_aliases = 2; 551 repeated BufferAllocationProto buffer_allocations = 3; 552 repeated HeapSimulatorTrace heap_simulator_traces = 4; 553} 554 555// Grouping message that contains all of the information above. 556message HloProto { 557 reserved 2; 558 reserved "hlo_ordering"; 559 560 HloModuleProto hlo_module = 1; 561 BufferAssignmentProto buffer_assignment = 3; 562} 563 564// Encapsulates HloProto together with the arguments, result, and 565// execution_platform. This message is used for purposes such as 566// analysis/replay/file-storage. 567message HloSnapshot { 568 // The hlo graph. 569 HloProto hlo = 1; 570 571 // The arguments passed to the graph. 572 repeated LiteralProto arguments = 2; 573 574 // The result of the graph. 575 LiteralProto result = 3; 576 577 // The name of the platform used to run the graph. 578 string execution_platform = 4; 579} 580 581// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering 582// with filename module_####.metadata.textproto, where #### is 583// canonical_module_id. 584message HloModuleMetadataProto { 585 // Uniquely identifies an HloModuleMetadata. Equal to the first unique_id 586 // of the module (a module may go through multiple unique_ids). If a module 587 // is partitioned into multiple modules, those modules will each have a new 588 // HloModuleMetadata with a different canonical_module_id. 589 int64 canonical_module_id = 1; 590 591 // Name of the module group that the module is part of. 592 string module_group_name = 2; 593 594 // The canonical module id of the module that this one is partitioned from, 595 // if applicable. 596 int64 original_module_id = 3; 597 598 // The canonical module ids of the modules that this one is partitioned into, 599 // if applicable. 600 repeated int64 partitioned_module_ids = 4; 601 602 // Metadata for the HLO passes that are run on the module. 603 repeated HloPassMetadata pass_metadata = 5; 604} 605 606// Metadata for one run of an HLO pass on a module. Provides more information 607// when processing debug dumps of HloProtos about the order of HLO passes and 608// various other stats like duration. `pass_id` may also be used to identify a 609// particular run of a pass in debug info that propagates through stages of 610// compilation. 611message HloPassMetadata { 612 // For a given module, pass_id uniquely identifies a run of an HLO pass on 613 // that module. Note that a pass_id may not always refer to the same pass 614 // because the order of passes during compilation may change. For finding 615 // metadata for a particular pass, pass_name and pipeline_name would be more 616 // reliable, although note that they may not be unique. 617 int64 pass_id = 1; 618 string pass_name = 2; 619 string pipeline_name = 3; 620 621 // Filenames of the dumps of the module after this pass ran. Module may be 622 // dumped in multiple formats, and the order of formats in this field will 623 // stay consistent across passes. 624 repeated string dump_filenames = 4; 625 626 // Return value of pass.Run(). True if this pass changed the module, or, in 627 // the case where the module was run through this pass as part of a module 628 // group, true if this pass changed any module in the same module group. 629 bool module_changed = 5; 630 631 // The unique_id of the module that this pass is run on. May be different from 632 // the canonical_module_id of the HloModuleMetadata that this HloPassMetadata 633 // is inside. 634 int64 module_id = 6; 635 636 // If the module went through this pass as part of a module group, this is 637 // set as the ids of all the modules in the module group. Empty otherwise. 638 repeated int64 module_group_module_ids = 7; 639 640 // Timestamp before and after the pass is run. Note they may be equal. 641 int64 start_timestamp_usec = 8; 642 int64 end_timestamp_usec = 9; 643} 644