1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// This proto file defines messages which represent the HLO module. This is a 17// full fidelity serialization of the c++ HLO constructs. 18// 19// Many of the protos below are simple 1-to-1 serializations of the 20// corresponding C++ classes, e.g., HloModule, HloComputation, and 21// HloInstruction. 22// 23// FIELD NAMES ARE IMPORTANT 24// 25// Unlike most protos, you can't safely change the names of fields, even if you 26// keep the numeric ids the same. This is because we sometimes serialize these 27// protos as JSON, which includes the field names in the serialization. 28 29syntax = "proto3"; 30 31package xla; 32 33import "tensorflow/compiler/xla/xla_data.proto"; 34 35option cc_enable_arenas = true; 36 37// Serialization of HloInstruction. 38// Next ID: 70 39message HloInstructionProto { 40 reserved 10; 41 reserved "parameter_name"; 42 reserved 12; 43 reserved "fused_instructions_computation"; 44 reserved 4; 45 reserved "operand_names"; 46 reserved 5; 47 reserved "control_predecessor_names"; 48 reserved 6; 49 reserved "called_computation_names"; 50 reserved 44; 51 reserved "replica_group_ids"; 52 // Use backend_config instead for custom_call_opaque. 53 reserved 53; 54 reserved "custom_call_opaque"; 55 // Use backend_config instead for all_reduce_barrier. 56 reserved 46; 57 reserved "all_reduce_barrier"; 58 59 string name = 1; 60 string opcode = 2; 61 xla.ShapeProto shape = 3; 62 63 xla.OpMetadata metadata = 7; 64 65 // Literal, only present for kConstant. 66 xla.LiteralProto literal = 8; 67 68 // Parameter number is only present for kParameter. 69 int64 parameter_number = 9; 70 71 // Fusion state, only present for kFusion. 72 string fusion_kind = 11; 73 74 // Index for kGetTupleElement. 75 int64 tuple_index = 13; 76 77 // Dimensions present for some operations that require reshaping or 78 // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. 79 repeated int64 dimensions = 14; 80 81 // Describes the window in a windowed operation such as convolution. 82 xla.Window window = 15; 83 84 // Describes the dimension numbers used for a convolution. 85 xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; 86 87 // The number of feature groups. Used for a convolution. Must be a divisor of 88 // the input feature dimension and output feature dimension. If not specified, 89 // it will use a default value of 1. 90 int64 feature_group_count = 50; 91 92 int64 batch_group_count = 58; 93 94 // Describes the [begin, end) index range and stride for slices. 95 message SliceDimensions { 96 int64 start = 1; 97 int64 limit = 2; 98 int64 stride = 3; 99 } 100 repeated SliceDimensions slice_dimensions = 17; 101 102 // The bit sizes for a reduce-precision operation. 103 int32 exponent_bits = 18; 104 int32 mantissa_bits = 19; 105 106 // Describes the [start, start + size) range size for a dynamic slice 107 // ('start' is specified dynamically in the second operand of the operation). 108 repeated int64 dynamic_slice_sizes = 20; 109 110 // The padding configuration that describes the edge padding and interior 111 // padding of this pad instruction. Only set for pad instructions. 112 xla.PaddingConfig padding_config = 21; 113 114 // Outfeed configuration information, only present for kOutfeed. 115 bytes outfeed_config = 22; 116 117 // The distribution requested for random number generation. 118 // Only present for kRng. 119 xla.RandomDistribution distribution = 23; 120 121 // A small float number added to the variance to avoid divide-by-zero error. 122 // Only present for kBatchNormTraining. 123 float epsilon = 24; 124 125 // An integer value representing the index of the feature dimension. 126 // Only present for kBatchNormTraining. 127 int64 feature_index = 25; 128 129 // Represents a unique identifier for each Send/Recv instruction pair or 130 // optionally for collective instructions (AllReduce, CollectivePermute, 131 // AllToAll). Non-positive channel_id is equivalent to no channel id. 132 int64 channel_id = 26; 133 134 // The string representation of the infeed configuration. 135 bytes infeed_config = 27; 136 137 // Name of a external target (eg, global symbol) to call, only present for 138 // kCustomCall. 139 string custom_call_target = 28; 140 141 // Shape of outfeed request. 142 xla.ShapeProto outfeed_shape = 29; 143 144 // Describes the dimension numbers used for a dot operation 145 xla.DotDimensionNumbers dot_dimension_numbers = 30; 146 147 // FFT type (FFT, IFFT, etc). 148 xla.FftType fft_type = 31; 149 150 // FFT length. 151 repeated int64 fft_length = 32; 152 153 // Comparison direction only used for kCompare. 154 string comparison_direction = 63; 155 156 // Gather dimension numbers. 157 xla.GatherDimensionNumbers gather_dimension_numbers = 33; 158 repeated int64 gather_slice_sizes = 34; 159 160 // Compute Host. 161 string channel_name = 41; 162 int64 cost_estimate_ns = 42; 163 164 // The id of this instruction. 165 int64 id = 35; 166 167 repeated int64 operand_ids = 36; 168 repeated int64 control_predecessor_ids = 37; 169 repeated int64 called_computation_ids = 38; 170 171 xla.OpSharding sharding = 40; 172 173 // Backend configuration for the instruction. Has backend-specific meaning. 174 string backend_config = 43; 175 176 // Cross replica op fields. 177 repeated ReplicaGroup replica_groups = 49; 178 // Deprecated, but keeping it for backward compatibility. Use channel_id. 179 // Non-positive all_reduce_id is equivalent to no all_reduce_id. 180 int64 all_reduce_id = 45 [deprecated = true]; 181 182 // Whether this Send/Recv instruction transfers data to/from the host. Only 183 // present for Send and Recv instructions and their SendDone and RecvDone 184 // partners. 185 bool is_host_transfer = 47; 186 187 // Whether this Sort instruction should be stable. 188 bool is_stable = 60; 189 190 xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; 191 192 // Precision configuration for the instruction. Has backend-specific meaning. 193 xla.PrecisionConfig precision_config = 51; 194 195 // Collective permute field. 196 repeated SourceTarget source_target_pairs = 52; 197 198 // Sharding for kDomain instructions. 199 xla.OpSharding domain_entry_sharding = 54; 200 xla.OpSharding domain_exit_sharding = 55; 201 202 // For custom call this indicates that the layouts are constrained. If 203 // constrain_layout is true then the 'shape' field must contain a layout, and 204 // 'operand_shapes_with_layout' must contain a shape with layout for each 205 // operand. 206 bool constrain_layout = 56; 207 repeated xla.ShapeProto operand_shapes_with_layout = 57; 208 209 // Options for TriangularSolve 210 xla.TriangularSolveOptions triangular_solve_options = 59; 211 212 // Options for Cholesky 213 xla.CholeskyOptions cholesky_options = 62; 214 215 // Describes how parameters behave with regards to replicas. 216 xla.ParameterReplication parameter_replication = 61; 217 218 // If set, the given instruction is run in parallel on e.g. multiple CPU 219 // cores. The outermost dimension gets split up into 220 // outer_dimension_partitions[0] pieces, the next-outermost dim gets split 221 // into outer_dimension_partitions[1] pieces, etc. 222 // 223 // It's illegal to partition a dimension into more shards than there are 224 // elements in that dimension. 225 repeated int64 outer_dimension_partitions = 64; 226 227 // Whether the kCustomCall instruction has side-effects, only present for 228 // kCustomCall. 229 bool custom_call_has_side_effect = 65; 230 231 // The delta value for kRngGetAndUpdateState. 232 int64 delta = 66; 233 234 // Specifies if the gather/scatter indices are guaranteed to be sorted by the 235 // caller. 236 bool indices_are_sorted = 67; 237 238 // Frontend attributes to pass to the XLA backend. 239 xla.FrontendAttributes frontend_attributes = 68; 240 241 // Specifies if all elements updated are guaranteed to be unique by 242 // the caller. 243 bool unique_indices = 69; 244} 245 246// Serialization of HloComputation. 247message HloComputationProto { 248 reserved 3; 249 reserved "root_name"; 250 251 string name = 1; 252 253 // The array of instructions is always in a valid dependency order, where 254 // operands appear before their users. 255 repeated HloInstructionProto instructions = 2; 256 257 // The program shape (with layout) of this computation. 258 259 xla.ProgramShapeProto program_shape = 4; 260 261 // The id of this computation. 262 int64 id = 5; 263 264 // The id of the root of the computation. 265 int64 root_id = 6; 266} 267 268// Serialization of an HLO schedule. An HLO schedule contains a total order of 269// instructions for each non-fusion computation in the module. 270message HloScheduleProto { 271 message InstructionSequence { 272 repeated int64 instruction_ids = 1; 273 } 274 275 // Map from computation id to sequence. 276 map<int64, InstructionSequence> sequences = 1; 277} 278 279message HloInputOutputAliasProto { 280 enum Kind { 281 // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 282 // behavior and missing has_*() APIs. 283 UNDEFINED_ALIAS = 0; 284 // An alias setup by the user as must alias. A use setting USER_ALIAS is 285 // expecting the designed output to be dropped over the given input 286 // parameter number+index. 287 USER_ALIAS = 1; 288 // An alias setup by the compiler as part of its optimizations. 289 SYSTEM_ALIAS = 2; 290 } 291 292 // The following proto describes a pair of aliased an input 293 // (described by parameter number and a ShapeIndex of the parameter) 294 // and an output (described by a ShapeIndex of the root 295 // instruction). For example: 296 // 297 // entry = { 298 // output_shape_index={1}, 299 // parameter_number=0, 300 // parameter_shape_index={1, 2}, 301 // } 302 // 303 // This entry indicates that the first paremter's {1, 2} element is 304 // aliased with the {1} element of the root instruction. 305 message AliasEntryProto { 306 // ShapeIndex of the root hlo. 307 repeated int64 output_shape_index = 1; 308 // Number of the parameter in entry computation. 309 int64 parameter_number = 2; 310 // ShapeIndex of the parameter instruction. 311 repeated int64 parameter_shape_index = 3; 312 // The kind of alias to be setup. 313 Kind kind = 4; 314 } 315 316 repeated AliasEntryProto entries = 1; 317} 318 319message DynamicParameterBindingProto { 320 // A list of bindings which indicates that the `target_dim_num` in 321 // the subshape `target_param_index` of parameter `target_param_num` 322 // is a dynamic dimension and its real dynamic size is represented 323 // by `dynamic_param_index` in parameter `dynamic_param_num`. 324 // 325 // As an example, imagine we have a program: 326 // 327 // ENTRY main { 328 // a = f32[] parameter(0) 329 // b = f32[10] parameter(1) 330 // ROOT root = (f32[], f32[10]) tuple(%a, %b) 331 // } 332 // 333 // Let's say 'b' (param index 1) is a dynamic shape whose input has 334 // an upperbound of 10 and real size is determined at runtime.'a' 335 // represents the real size of b's first dimension. 336 // 337 // In this case, the fields are set in the following way: 338 // dynamic_param_num = 1 339 // dynamic_param_index = {} 340 // target_param_num = 0 341 // target_param_index = {} 342 // target_param_dim = 0 343 message Binding { 344 int64 dynamic_param_num = 1; 345 repeated int64 dynamic_param_index = 2; 346 int64 target_param_num = 3; 347 repeated int64 target_param_index = 4; 348 int64 target_param_dim_num = 5; 349 } 350 351 repeated Binding entries = 1; 352} 353 354// Serialization of HloModule. 355message HloModuleProto { 356 string name = 1; 357 string entry_computation_name = 2; 358 int64 entry_computation_id = 6; 359 360 // The array of computations is always in a valid dependency order, where 361 // callees appear before their callers. 362 repeated HloComputationProto computations = 3; 363 364 // The host program shape (with layout) of the entry computation. 365 xla.ProgramShapeProto host_program_shape = 4; 366 367 // The id of this module. 368 int64 id = 5; 369 370 // The schedule for this module. 371 HloScheduleProto schedule = 7; 372 373 // Describes alias information between inputs and outputs. 374 HloInputOutputAliasProto input_output_alias = 8; 375 376 DynamicParameterBindingProto dynamic_parameter_binding = 9; 377} 378 379// Serialization of LogicalBuffer. 380message LogicalBufferProto { 381 // Location represents an instruction and its shape index, which uniquely 382 // identifies a point where a buffer is needed. 383 message Location { 384 // NOTE: module_name isn't necessary, since all LogicalBuffers are 385 // associated with a single HloModule. 386 string computation_name = 1; 387 string instruction_name = 2; 388 repeated int64 shape_index = 3; 389 } 390 391 int64 id = 1; 392 int64 size = 2; 393 394 // The location where the buffer is defined. 395 Location defined_at = 3; 396 397 int64 color = 4; 398} 399 400// Serialization of BufferAllocation. 401message BufferAllocationProto { 402 // Assigned represents a single LogicalBuffer that is assigned to this 403 // BufferAllocation. 404 message Assigned { 405 int64 logical_buffer_id = 1; 406 int64 offset = 2; 407 int64 size = 3; 408 } 409 410 int64 index = 1; 411 int64 size = 2; 412 bool is_thread_local = 3; 413 bool is_tuple = 11; 414 bool is_entry_computation_parameter = 5; 415 bool is_constant = 12; 416 int64 parameter_number = 6; 417 repeated int64 parameter_shape_index = 10; 418 bool maybe_live_out = 7; 419 int64 color = 8; 420 repeated Assigned assigned = 9; 421} 422 423// A trace of a HeapSimulator run. 424message HeapSimulatorTrace { 425 // The trace includes a list of events, where each event describes one action 426 // performed by the heap simulator. 427 message Event { 428 enum Kind { 429 ALLOC = 0; // A memory region was allocated for the buffer. 430 FREE = 1; // A memory region was freed for the buffer. 431 432 // A buffer was shared with another (canonical) buffer. This is similar to 433 // ALLOC, except that instead of allocating a new region of memory, the 434 // memory region of the canonical buffer is directly re-used. Multiple 435 // buffers may share with the same canonical buffer. The lifetime of the 436 // canonical buffer is extended to the union of all lifetimes. 437 SHARE_WITH = 2; 438 } 439 Kind kind = 1; 440 441 // The id of the LogicalBuffer that the event applies to. 442 int64 buffer_id = 2; 443 444 // The HloInstruction that the simulation was processing that caused this 445 // event to occur, identified by its computation and instruction name. E.g. 446 // buffers defined by instruction A are allocated when processing A. 447 string computation_name = 3; 448 string instruction_name = 4; 449 450 // The id of the canonical LogicalBuffer that the buffer shares with. Only 451 // set for SHARE_WITH events. 452 int64 share_with_canonical_id = 5; 453 } 454 repeated Event events = 1; 455 bool whole_module_simulation = 2; 456} 457 458// An abstraction representing a set of HLO module built to run concurrently 459// across different devices. 460message HloModuleGroupProto { 461 string name = 1; 462 repeated HloModuleProto hlo_modules = 2; 463} 464 465// Serialization of BufferAssignment. 466message BufferAssignmentProto { 467 // Alias represents a source LogicalBuffer, and the buffer location that 468 // aliases it. 469 message BufferAlias { 470 int64 source_buffer_id = 1; 471 LogicalBufferProto.Location location = 2; 472 } 473 474 repeated LogicalBufferProto logical_buffers = 1; 475 repeated BufferAlias buffer_aliases = 2; 476 repeated BufferAllocationProto buffer_allocations = 3; 477 repeated HeapSimulatorTrace heap_simulator_traces = 4; 478} 479 480// Grouping message that contains all of the information above. 481message HloProto { 482 reserved 2; 483 reserved "hlo_ordering"; 484 485 HloModuleProto hlo_module = 1; 486 BufferAssignmentProto buffer_assignment = 3; 487} 488 489// Encapsulates HloProto together with the arguments, result, and 490// execution_platform. This message is used for purposes such as 491// analysis/replay/file-storage. 492message HloSnapshot { 493 // The hlo graph. 494 HloProto hlo = 1; 495 496 // The arguments passed to the graph. 497 repeated LiteralProto arguments = 2; 498 499 // The result of the graph. 500 LiteralProto result = 3; 501 502 // The name of the platform used to run the graph. 503 string execution_platform = 4; 504} 505