1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This proto file defines messages which represent the HLO module. This is a 17// full fidelity serialization of the c++ HLO constructs. 18// 19// Many of the protos below are simple 1-to-1 serializations of the 20// corresponding C++ classes, e.g., HloModule, HloComputation, and 21// HloInstruction. 22// 23// FIELD NAMES ARE IMPORTANT 24// 25// Unlike most protos, you can't safely change the names of fields, even if you 26// keep the numeric ids the same. This is because we sometimes serialize these 27// protos as JSON, which includes the field names in the serialization. 28 29syntax = "proto3"; 30 31package xla; 32 33import "tensorflow/compiler/xla/xla_data.proto"; 34 35option cc_enable_arenas = true; 36 37// Serialization of HloInstruction. 38// Next ID: 76 39message HloInstructionProto { 40 reserved 10; 41 reserved "parameter_name"; 42 reserved 12; 43 reserved "fused_instructions_computation"; 44 reserved 4; 45 reserved "operand_names"; 46 reserved 5; 47 reserved "control_predecessor_names"; 48 reserved 6; 49 reserved "called_computation_names"; 50 reserved 44; 51 reserved "replica_group_ids"; 52 // Use backend_config instead for custom_call_opaque. 53 reserved 53; 54 reserved "custom_call_opaque"; 55 // Use backend_config instead for all_reduce_barrier. 56 reserved 46; 57 reserved "all_reduce_barrier"; 58 59 string name = 1; 60 string opcode = 2; 61 xla.ShapeProto shape = 3; 62 63 xla.OpMetadata metadata = 7; 64 65 // Literal, only present for kConstant. 66 xla.LiteralProto literal = 8; 67 68 // Parameter number is only present for kParameter. 69 int64 parameter_number = 9; 70 71 // Fusion state, only present for kFusion. 72 string fusion_kind = 11; 73 74 // Index for kGetTupleElement. 75 int64 tuple_index = 13; 76 77 // Dimensions present for some operations that require reshaping or 78 // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. 79 repeated int64 dimensions = 14; 80 81 // Describes the window in a windowed operation such as convolution. 82 xla.Window window = 15; 83 84 // Describes the dimension numbers used for a convolution. 85 xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; 86 87 // The number of feature groups. Used for a convolution. Must be a divisor of 88 // the input feature dimension and output feature dimension. If not specified, 89 // it will use a default value of 1. 90 int64 feature_group_count = 50; 91 92 int64 batch_group_count = 58; 93 94 // Describes the [begin, end) index range and stride for slices. 95 message SliceDimensions { 96 int64 start = 1; 97 int64 limit = 2; 98 int64 stride = 3; 99 } 100 repeated SliceDimensions slice_dimensions = 17; 101 102 // The bit sizes for a reduce-precision operation. 103 int32 exponent_bits = 18; 104 int32 mantissa_bits = 19; 105 106 // Describes the [start, start + size) range size for a dynamic slice 107 // ('start' is specified dynamically in the second operand of the operation). 108 repeated int64 dynamic_slice_sizes = 20; 109 110 // The padding configuration that describes the edge padding and interior 111 // padding of this pad instruction. Only set for pad instructions. 112 xla.PaddingConfig padding_config = 21; 113 114 // Outfeed configuration information, only present for kOutfeed. 115 bytes outfeed_config = 22; 116 117 // The distribution requested for random number generation. 118 // Only present for kRng. 119 xla.RandomDistribution distribution = 23; 120 121 // A small float number added to the variance to avoid divide-by-zero error. 122 // Only present for kBatchNormTraining. 123 float epsilon = 24; 124 125 // An integer value representing the index of the feature dimension. 126 // Only present for kBatchNormTraining. 127 int64 feature_index = 25; 128 129 // Represents a unique identifier for each Send/Recv instruction pair or 130 // optionally for collective instructions (AllReduce, CollectivePermute, 131 // AllToAll). Non-positive channel_id is equivalent to no channel id. 132 int64 channel_id = 26; 133 134 // The string representation of the infeed configuration. 135 bytes infeed_config = 27; 136 137 // Name of a external target (eg, global symbol) to call, only present for 138 // kCustomCall. 139 string custom_call_target = 28; 140 141 // Shape of outfeed request. 142 xla.ShapeProto outfeed_shape = 29; 143 144 // Describes the dimension numbers used for a dot operation 145 xla.DotDimensionNumbers dot_dimension_numbers = 30; 146 147 // FFT type (FFT, IFFT, etc). 148 xla.FftType fft_type = 31; 149 150 // FFT length. 151 repeated int64 fft_length = 32; 152 153 // Comparison direction only used for kCompare. 154 string comparison_direction = 63; 155 156 // Gather dimension numbers. 157 xla.GatherDimensionNumbers gather_dimension_numbers = 33; 158 repeated int64 gather_slice_sizes = 34; 159 160 // Compute Host. 161 string channel_name = 41; 162 int64 cost_estimate_ns = 42; 163 164 // The id of this instruction. 165 int64 id = 35; 166 167 repeated int64 operand_ids = 36; 168 repeated int64 control_predecessor_ids = 37; 169 repeated int64 called_computation_ids = 38; 170 171 xla.OpSharding sharding = 40; 172 173 // Backend configuration for the instruction. Has backend-specific meaning. 174 bytes backend_config = 43; 175 176 // Cross replica op fields. 177 repeated ReplicaGroup replica_groups = 49; 178 // Deprecated, but keeping it for backward compatibility. Use channel_id. 179 // Non-positive all_reduce_id is equivalent to no all_reduce_id. 180 int64 all_reduce_id = 45 [deprecated = true]; 181 182 // If true, interprets ids in ReplicaGroup as global device ids, which is 183 // a linearized id of `replica_id * partition_count + partition_id`. 184 bool use_global_device_ids = 71; 185 186 // Whether this Send/Recv instruction transfers data to/from the host. Only 187 // present for Send and Recv instructions and their SendDone and RecvDone 188 // partners. 189 bool is_host_transfer = 47; 190 191 // Whether this Sort instruction should be stable. 192 bool is_stable = 60; 193 194 xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; 195 196 // Precision configuration for the instruction. Has backend-specific meaning. 197 xla.PrecisionConfig precision_config = 51; 198 199 // Collective permute field. 200 repeated SourceTarget source_target_pairs = 52; 201 202 // Sharding for kDomain instructions. 203 xla.OpSharding domain_entry_sharding = 54; 204 xla.OpSharding domain_exit_sharding = 55; 205 206 // For custom call this indicates that the layouts are constrained. If 207 // constrain_layout is true then the 'shape' field must contain a layout, and 208 // 'operand_shapes_with_layout' must contain a shape with layout for each 209 // operand. 210 bool constrain_layout = 56; 211 repeated xla.ShapeProto operand_shapes_with_layout = 57; 212 213 // Options for TriangularSolve 214 xla.TriangularSolveOptions triangular_solve_options = 59; 215 216 // Options for Cholesky 217 xla.CholeskyOptions cholesky_options = 62; 218 219 // Describes how parameters behave with regards to replicas. 220 xla.ParameterReplication parameter_replication = 61; 221 222 // If set, the given instruction is run in parallel on e.g. multiple CPU 223 // cores. The outermost dimension gets split up into 224 // outer_dimension_partitions[0] pieces, the next-outermost dim gets split 225 // into outer_dimension_partitions[1] pieces, etc. 226 // 227 // It's illegal to partition a dimension into more shards than there are 228 // elements in that dimension. 229 repeated int64 outer_dimension_partitions = 64; 230 231 // Whether the kCustomCall instruction has side-effects, only present for 232 // kCustomCall. 233 bool custom_call_has_side_effect = 65; 234 235 // A list of CustomCallOutputOperandAliasing pairs that specifies aliasing 236 // buffers between output and operands for kCustomCall. 237 repeated xla.CustomCallOutputOperandAliasing 238 custom_call_output_operand_aliasing = 74; 239 240 // The delta value for kRngGetAndUpdateState. 241 int64 delta = 66; 242 243 // Specifies if the gather/scatter indices are guaranteed to be sorted by the 244 // caller. 245 bool indices_are_sorted = 67; 246 247 // Frontend attributes to pass to the XLA backend. 248 xla.FrontendAttributes frontend_attributes = 68; 249 250 // Specifies if all elements updated are guaranteed to be unique by 251 // the caller. 252 bool unique_indices = 69; 253 254 // RNG algorithm used by kRngBitGenerator. 255 xla.RandomAlgorithm rng_algorithm = 70; 256 257 // The comparison type used for kCompare. 258 string comparison_type = 72; 259 260 // Specifies if this is a cross-program-prefetch, used by kCopyStart. 261 bool is_cross_program_prefetch = 73; 262 263 // If a convolution is dynamic, a dynamic padding type will be specified. 264 xla.PaddingType padding_type = 75; 265} 266 267// Serialization of HloComputation. 268message HloComputationProto { 269 reserved 3; 270 reserved "root_name"; 271 272 string name = 1; 273 274 // The array of instructions is always in a valid dependency order, where 275 // operands appear before their users. 276 repeated HloInstructionProto instructions = 2; 277 278 // The program shape (with layout) of this computation. 279 280 xla.ProgramShapeProto program_shape = 4; 281 282 // The id of this computation. 283 int64 id = 5; 284 285 // The id of the root of the computation. 286 int64 root_id = 6; 287} 288 289// Serialization of an HLO schedule. An HLO schedule contains a total order of 290// instructions for each non-fusion computation in the module. 291message HloScheduleProto { 292 message InstructionSequence { 293 repeated int64 instruction_ids = 1; 294 } 295 296 // Map from computation id to sequence. 297 map<int64, InstructionSequence> sequences = 1; 298} 299 300enum Kind { 301 // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 302 // behavior and missing has_*() APIs. 303 UNDEFINED_ALIAS = 0; 304 // The buffers may or may not alias at runtime. 305 MAY_ALIAS = 1; 306 // The buffers must alias at runtime. 307 MUST_ALIAS = 2; 308} 309 310message HloInputOutputAliasProto { 311 // The following proto describes a pair of aliased an input 312 // (described by parameter number and a ShapeIndex of the parameter) 313 // and an output (described by a ShapeIndex of the root 314 // instruction). For example: 315 // 316 // entry = { 317 // output_shape_index={1}, 318 // parameter_number=0, 319 // parameter_shape_index={1, 2}, 320 // } 321 // 322 // This entry indicates that the first paremter's {1, 2} element is 323 // aliased with the {1} element of the root instruction. 324 message AliasEntryProto { 325 // ShapeIndex of the root hlo. 326 repeated int64 output_shape_index = 1; 327 // Number of the parameter in entry computation. 328 int64 parameter_number = 2; 329 // ShapeIndex of the parameter instruction. 330 repeated int64 parameter_shape_index = 3; 331 // The kind of alias to be setup. 332 Kind kind = 4; 333 } 334 335 repeated AliasEntryProto entries = 1; 336} 337 338message DynamicParameterBindingProto { 339 // A list of bindings which indicates that the `target_dim_num` in 340 // the subshape `target_param_index` of parameter `target_param_num` 341 // is a dynamic dimension and its real dynamic size is represented 342 // by `dynamic_param_index` in parameter `dynamic_param_num`. 343 // 344 // As an example, imagine we have a program: 345 // 346 // ENTRY main { 347 // a = f32[] parameter(0) 348 // b = f32[10] parameter(1) 349 // ROOT root = (f32[], f32[10]) tuple(%a, %b) 350 // } 351 // 352 // Let's say 'b' (param index 1) is a dynamic shape whose input has 353 // an upperbound of 10 and real size is determined at runtime.'a' 354 // represents the real size of b's first dimension. 355 // 356 // In this case, the fields are set in the following way: 357 // dynamic_param_num = 1 358 // dynamic_param_index = {} 359 // target_param_num = 0 360 // target_param_index = {} 361 // target_param_dim = 0 362 message Binding { 363 int64 dynamic_param_num = 1; 364 repeated int64 dynamic_param_index = 2; 365 int64 target_param_num = 3; 366 repeated int64 target_param_index = 4; 367 int64 target_param_dim_num = 5; 368 } 369 370 repeated Binding entries = 1; 371} 372 373message CrossProgramPrefetch { 374 int64 parameter = 1; 375 repeated int64 index = 2; 376} 377 378// Serialization of HloModule. 379message HloModuleProto { 380 string name = 1; 381 string entry_computation_name = 2; 382 int64 entry_computation_id = 6; 383 384 // The array of computations is always in a valid dependency order, where 385 // callees appear before their callers. 386 repeated HloComputationProto computations = 3; 387 388 // The host program shape (with layout) of the entry computation. 389 xla.ProgramShapeProto host_program_shape = 4; 390 391 // The id of this module. 392 int64 id = 5; 393 394 // The schedule for this module. 395 HloScheduleProto schedule = 7; 396 397 // Describes alias information between inputs and outputs. 398 HloInputOutputAliasProto input_output_alias = 8; 399 400 DynamicParameterBindingProto dynamic_parameter_binding = 9; 401 402 repeated CrossProgramPrefetch cross_program_prefetches = 10; 403} 404 405// Serialization of LogicalBuffer. 406message LogicalBufferProto { 407 // Location represents an instruction and its shape index, which uniquely 408 // identifies a point where a buffer is needed. 409 message Location { 410 // NOTE: module_name isn't necessary, since all LogicalBuffers are 411 // associated with a single HloModule. 412 string computation_name = 1; 413 string instruction_name = 2; 414 repeated int64 shape_index = 3; 415 } 416 417 int64 id = 1; 418 int64 size = 2; 419 420 // The location where the buffer is defined. 421 Location defined_at = 3; 422 423 int64 color = 4; 424} 425 426// Serialization of BufferAllocation. 427message BufferAllocationProto { 428 // Assigned represents a single LogicalBuffer that is assigned to this 429 // BufferAllocation. 430 message Assigned { 431 int64 logical_buffer_id = 1; 432 int64 offset = 2; 433 int64 size = 3; 434 } 435 436 int64 index = 1; 437 int64 size = 2; 438 bool is_thread_local = 3; 439 bool is_tuple = 11; 440 bool is_entry_computation_parameter = 5; 441 bool is_constant = 12; 442 int64 parameter_number = 6; 443 repeated int64 parameter_shape_index = 10; 444 bool maybe_live_out = 7; 445 int64 color = 8; 446 repeated Assigned assigned = 9; 447} 448 449// A trace of a HeapSimulator run. 450message HeapSimulatorTrace { 451 // The trace includes a list of events, where each event describes one action 452 // performed by the heap simulator. 453 message Event { 454 enum Kind { 455 ALLOC = 0; // A memory region was allocated for the buffer. 456 FREE = 1; // A memory region was freed for the buffer. 457 458 // A buffer was shared with another (canonical) buffer. This is similar to 459 // ALLOC, except that instead of allocating a new region of memory, the 460 // memory region of the canonical buffer is directly re-used. Multiple 461 // buffers may share with the same canonical buffer. The lifetime of the 462 // canonical buffer is extended to the union of all lifetimes. 463 SHARE_WITH = 2; 464 } 465 Kind kind = 1; 466 467 // The id of the LogicalBuffer that the event applies to. 468 int64 buffer_id = 2; 469 470 // The HloInstruction that the simulation was processing that caused this 471 // event to occur, identified by its computation and instruction name. E.g. 472 // buffers defined by instruction A are allocated when processing A. 473 string computation_name = 3; 474 string instruction_name = 4; 475 476 // The id of the canonical LogicalBuffer that the buffer shares with. Only 477 // set for SHARE_WITH events. 478 int64 share_with_canonical_id = 5; 479 } 480 repeated Event events = 1; 481 bool whole_module_simulation = 2; 482} 483 484// An abstraction representing a set of HLO module built to run concurrently 485// across different devices. 486message HloModuleGroupProto { 487 string name = 1; 488 repeated HloModuleProto hlo_modules = 2; 489} 490 491// Serialization of BufferAssignment. 492message BufferAssignmentProto { 493 // Alias represents a source LogicalBuffer, and the buffer location that 494 // aliases it. 495 message BufferAlias { 496 int64 source_buffer_id = 1; 497 LogicalBufferProto.Location location = 2; 498 } 499 500 repeated LogicalBufferProto logical_buffers = 1; 501 repeated BufferAlias buffer_aliases = 2; 502 repeated BufferAllocationProto buffer_allocations = 3; 503 repeated HeapSimulatorTrace heap_simulator_traces = 4; 504} 505 506// Grouping message that contains all of the information above. 507message HloProto { 508 reserved 2; 509 reserved "hlo_ordering"; 510 511 HloModuleProto hlo_module = 1; 512 BufferAssignmentProto buffer_assignment = 3; 513} 514 515// Encapsulates HloProto together with the arguments, result, and 516// execution_platform. This message is used for purposes such as 517// analysis/replay/file-storage. 518message HloSnapshot { 519 // The hlo graph. 520 HloProto hlo = 1; 521 522 // The arguments passed to the graph. 523 repeated LiteralProto arguments = 2; 524 525 // The result of the graph. 526 LiteralProto result = 3; 527 528 // The name of the platform used to run the graph. 529 string execution_platform = 4; 530} 531 532// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering 533// with filename module_####.metadata.textproto, where #### is 534// canonical_module_id. 535message HloModuleMetadataProto { 536 // Uniquely identifies an HloModuleMetadata. Equal to the first unique_id 537 // of the module (a module may go through multiple unique_ids). If a module 538 // is partitioned into multiple modules, those modules will each have a new 539 // HloModuleMetadata with a different canonical_module_id. 540 int64 canonical_module_id = 1; 541 542 // Name of the module group that the module is part of. 543 string module_group_name = 2; 544 545 // The canonical module id of the module that this one is partitioned from, 546 // if applicable. 547 int64 original_module_id = 3; 548 549 // The canonical module ids of the modules that this one is partitioned into, 550 // if applicable. 551 repeated int64 partitioned_module_ids = 4; 552 553 // Metadata for the HLO passes that are run on the module. 554 repeated HloPassMetadata pass_metadata = 5; 555} 556 557// Metadata for one run of an HLO pass on a module. Provides more information 558// when processing debug dumps of HloProtos about the order of HLO passes and 559// various other stats like duration. `pass_id` may also be used to identify a 560// particular run of a pass in debug info that propagates through stages of 561// compilation. 562message HloPassMetadata { 563 // For a given module, pass_id uniquely identifies a run of an HLO pass on 564 // that module. Note that a pass_id may not always refer to the same pass 565 // because the order of passes during compilation may change. For finding 566 // metadata for a particular pass, pass_name and pipeline_name would be more 567 // reliable, although note that they may not be unique. 568 int64 pass_id = 1; 569 string pass_name = 2; 570 string pipeline_name = 3; 571 572 // Filenames of the dumps of the module after this pass ran. Module may be 573 // dumped in multiple formats, and the order of formats in this field will 574 // stay consistent across passes. 575 repeated string dump_filenames = 4; 576 577 // Return value of pass.Run(). True if this pass changed the module, or, in 578 // the case where the module was run through this pass as part of a module 579 // group, true if this pass changed any module in the same module group. 580 bool module_changed = 5; 581 582 // The unique_id of the module that this pass is run on. May be different from 583 // the canonical_module_id of the HloModuleMetadata that this HloPassMetadata 584 // is inside. 585 int64 module_id = 6; 586 587 // If the module went through this pass as part of a module group, this is 588 // set as the ids of all the modules in the module group. Empty otherwise. 589 repeated int64 module_group_module_ids = 7; 590 591 // Timestamp before and after the pass is run. Note they may be equal. 592 int64 start_timestamp_usec = 8; 593 int64 end_timestamp_usec = 9; 594} 595