• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This proto file defines messages which represent the HLO module. This is a
17// full fidelity serialization of the c++ HLO constructs.
18//
19// Many of the protos below are simple 1-to-1 serializations of the
20// corresponding C++ classes, e.g., HloModule, HloComputation, and
21// HloInstruction.
22//
23// FIELD NAMES ARE IMPORTANT
24//
25// Unlike most protos, you can't safely change the names of fields, even if you
26// keep the numeric ids the same. This is because we sometimes serialize these
27// protos as JSON, which includes the field names in the serialization.
28
29syntax = "proto3";
30
31package xla;
32
33import "tensorflow/compiler/xla/xla_data.proto";
34
35option cc_enable_arenas = true;
36
37enum CustomCallSchedule {
38  SCHEDULE_NONE = 0;
39  SCHEDULE_LATEST = 1;
40  SCHEDULE_EARLIEST = 2;
41}
42
43// The version of the API used by the custom call function. The signatures for
44// each version are given below.
45// TODO(b/189822916): Remove this enum when all clients are migrated to the
46// status-returning API.
47enum CustomCallApiVersion {
48  API_VERSION_UNSPECIFIED = 0;
49
50  // The first version of the API, with the following signatures:
51  //
52  // CPU:
53  //   void do_custom_call(void* out, const void** in)
54  //
55  // GPU:
56  //   void do_custom_call(CUstream stream, void** buffers,
57  //                       const char* opaque, size_t opaque_len);
58  API_VERSION_ORIGINAL = 1;
59
60  // When the ability to return success/failure status was added:
61  //
62  // CPU: Unimplemented
63  //
64  // GPU:
65  //   void do_custom_call(CUstream stream, void** buffers,
66  //                       const char* opaque, size_t opaque_len,
67  //                       XlaCustomCallStatus* status);
68  //
69  API_VERSION_STATUS_RETURNING = 2;
70}
71
72// Serialization of HloInstruction.
73// Next ID: 78
74message HloInstructionProto {
75  reserved 10;
76  reserved "parameter_name";
77  reserved 12;
78  reserved "fused_instructions_computation";
79  reserved 4;
80  reserved "operand_names";
81  reserved 5;
82  reserved "control_predecessor_names";
83  reserved 6;
84  reserved "called_computation_names";
85  reserved 44;
86  reserved "replica_group_ids";
87  // Use backend_config instead for custom_call_opaque.
88  reserved 53;
89  reserved "custom_call_opaque";
90  // Use backend_config instead for all_reduce_barrier.
91  reserved 46;
92  reserved "all_reduce_barrier";
93
94  string name = 1;
95  string opcode = 2;
96  xla.ShapeProto shape = 3;
97
98  xla.OpMetadata metadata = 7;
99
100  // Literal, only present for kConstant.
101  xla.LiteralProto literal = 8;
102
103  // Parameter number is only present for kParameter.
104  int64 parameter_number = 9;
105
106  // Fusion state, only present for kFusion.
107  string fusion_kind = 11;
108
109  // Index for kGetTupleElement.
110  int64 tuple_index = 13;
111
112  // Dimensions present for some operations that require reshaping or
113  // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
114  repeated int64 dimensions = 14;
115
116  // Describes the window in a windowed operation such as convolution.
117  xla.Window window = 15;
118
119  // Describes the dimension numbers used for a convolution.
120  xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16;
121
122  // The number of feature groups. Used for a convolution. Must be a divisor of
123  // the input feature dimension and output feature dimension. If not specified,
124  // it will use a default value of 1.
125  int64 feature_group_count = 50;
126
127  int64 batch_group_count = 58;
128
129  // Describes the [begin, end) index range and stride for slices.
130  message SliceDimensions {
131    int64 start = 1;
132    int64 limit = 2;
133    int64 stride = 3;
134  }
135  repeated SliceDimensions slice_dimensions = 17;
136
137  // The bit sizes for a reduce-precision operation.
138  int32 exponent_bits = 18;
139  int32 mantissa_bits = 19;
140
141  // Describes the [start, start + size) range size for a dynamic slice
142  // ('start' is specified dynamically in the second operand of the operation).
143  repeated int64 dynamic_slice_sizes = 20;
144
145  // The padding configuration that describes the edge padding and interior
146  // padding of this pad instruction. Only set for pad instructions.
147  xla.PaddingConfig padding_config = 21;
148
149  // Outfeed configuration information, only present for kOutfeed.
150  bytes outfeed_config = 22;
151
152  // The distribution requested for random number generation.
153  // Only present for kRng.
154  xla.RandomDistribution distribution = 23;
155
156  // A small float number added to the variance to avoid divide-by-zero error.
157  // Only present for kBatchNormTraining.
158  float epsilon = 24;
159
160  // An integer value representing the index of the feature dimension.
161  // Only present for kBatchNormTraining.
162  int64 feature_index = 25;
163
164  // Represents a unique identifier for each Send/Recv instruction pair or
165  // optionally for collective instructions (AllReduce, CollectivePermute,
166  // AllToAll). Non-positive channel_id is equivalent to no channel id.
167  int64 channel_id = 26;
168
169  // The string representation of the infeed configuration.
170  bytes infeed_config = 27;
171
172  // Name of a external target (eg, global symbol) to call, only present for
173  // kCustomCall.
174  string custom_call_target = 28;
175
176  // Shape of outfeed request.
177  xla.ShapeProto outfeed_shape = 29;
178
179  // Describes the dimension numbers used for a dot operation
180  xla.DotDimensionNumbers dot_dimension_numbers = 30;
181
182  // FFT type (FFT, IFFT, etc).
183  xla.FftType fft_type = 31;
184
185  // FFT length.
186  repeated int64 fft_length = 32;
187
188  // Comparison direction only used for kCompare.
189  string comparison_direction = 63;
190
191  // Gather dimension numbers.
192  xla.GatherDimensionNumbers gather_dimension_numbers = 33;
193  repeated int64 gather_slice_sizes = 34;
194
195  // Compute Host.
196  string channel_name = 41;
197  int64 cost_estimate_ns = 42;
198
199  // The id of this instruction.
200  int64 id = 35;
201
202  repeated int64 operand_ids = 36;
203  repeated int64 control_predecessor_ids = 37;
204  repeated int64 called_computation_ids = 38;
205
206  xla.OpSharding sharding = 40;
207
208  // Backend configuration for the instruction. Has backend-specific meaning.
209  bytes backend_config = 43;
210
211  // Cross replica op fields.
212  repeated ReplicaGroup replica_groups = 49;
213  // Deprecated, but keeping it for backward compatibility. Use channel_id.
214  // Non-positive all_reduce_id is equivalent to no all_reduce_id.
215  int64 all_reduce_id = 45 [deprecated = true];
216
217  // If true, interprets ids in ReplicaGroup as global device ids, which is
218  // a linearized id of `replica_id * partition_count + partition_id`.
219  bool use_global_device_ids = 71;
220
221  // Whether this Send/Recv instruction transfers data to/from the host. Only
222  // present for Send and Recv instructions and their SendDone and RecvDone
223  // partners.
224  bool is_host_transfer = 47;
225
226  // Whether this Sort instruction should be stable.
227  bool is_stable = 60;
228
229  xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
230
231  // Precision configuration for the instruction. Has backend-specific meaning.
232  xla.PrecisionConfig precision_config = 51;
233
234  // Collective permute field.
235  repeated SourceTarget source_target_pairs = 52;
236
237  // Sharding for kDomain instructions.
238  xla.OpSharding domain_entry_sharding = 54;
239  xla.OpSharding domain_exit_sharding = 55;
240
241  // For custom call this indicates that the layouts are constrained. If
242  // constrain_layout is true then the 'shape' field must contain a layout, and
243  // 'operand_shapes_with_layout' must contain a shape with layout for each
244  // operand.
245  bool constrain_layout = 56;
246  repeated xla.ShapeProto operand_shapes_with_layout = 57;
247
248  // Options for TriangularSolve
249  xla.TriangularSolveOptions triangular_solve_options = 59;
250
251  // Options for Cholesky
252  xla.CholeskyOptions cholesky_options = 62;
253
254  // Describes how parameters behave with regards to replicas.
255  xla.ParameterReplication parameter_replication = 61;
256
257  // If set, the given instruction is run in parallel on e.g. multiple CPU
258  // cores.  The outermost dimension gets split up into
259  // outer_dimension_partitions[0] pieces, the next-outermost dim gets split
260  // into outer_dimension_partitions[1] pieces, etc.
261  //
262  // It's illegal to partition a dimension into more shards than there are
263  // elements in that dimension.
264  repeated int64 outer_dimension_partitions = 64;
265
266  // Whether the kCustomCall instruction has side-effects, only present for
267  // kCustomCall.
268  bool custom_call_has_side_effect = 65;
269
270  // A list of CustomCallOutputOperandAliasing pairs that specifies aliasing
271  // buffers between output and operands for kCustomCall.
272  repeated xla.CustomCallOutputOperandAliasing
273      custom_call_output_operand_aliasing = 74;
274
275  // Specifies the desired schedule for the custom-call. The field is only
276  // present for custom-call.
277  CustomCallSchedule custom_call_schedule = 76;
278
279  // The delta value for kRngGetAndUpdateState.
280  int64 delta = 66;
281
282  // Specifies if the gather/scatter indices are guaranteed to be sorted by the
283  // caller.
284  bool indices_are_sorted = 67;
285
286  // Frontend attributes to pass to the XLA backend.
287  xla.FrontendAttributes frontend_attributes = 68;
288
289  // Specifies if all elements updated are guaranteed to be unique by
290  // the caller.
291  bool unique_indices = 69;
292
293  // RNG algorithm used by kRngBitGenerator.
294  xla.RandomAlgorithm rng_algorithm = 70;
295
296  // The comparison type used for kCompare.
297  string comparison_type = 72;
298
299  // Specifies if this is a cross-program-prefetch, used by kCopyStart.
300  bool is_cross_program_prefetch = 73;
301
302  // If a convolution is dynamic, a dynamic padding type will be specified.
303  xla.PaddingType padding_type = 75;
304
305  // The API version used by the custom call function. This field is only
306  // present for custom-call.
307  // TODO(b/189822916): Remove this field when all clients are migrated to the
308  // status-returning API.
309  CustomCallApiVersion custom_call_api_version = 77;
310}
311
312// Serialization of HloComputation.
313message HloComputationProto {
314  reserved 3;
315  reserved "root_name";
316
317  string name = 1;
318
319  // The array of instructions is always in a valid dependency order, where
320  // operands appear before their users.
321  repeated HloInstructionProto instructions = 2;
322
323  // The program shape (with layout) of this computation.
324
325  xla.ProgramShapeProto program_shape = 4;
326
327  // The id of this computation.
328  int64 id = 5;
329
330  // The id of the root of the computation.
331  int64 root_id = 6;
332}
333
334// Serialization of an HLO schedule. An HLO schedule contains a total order of
335// instructions for each non-fusion computation in the module.
336message HloScheduleProto {
337  message InstructionSequence {
338    repeated int64 instruction_ids = 1;
339  }
340
341  // Map from computation id to sequence.
342  map<int64, InstructionSequence> sequences = 1;
343}
344
345enum Kind {
346  // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3
347  // behavior and missing has_*() APIs.
348  UNDEFINED_ALIAS = 0;
349  // The buffers may or may not alias at runtime.
350  MAY_ALIAS = 1;
351  // The buffers must alias at runtime.
352  MUST_ALIAS = 2;
353}
354
355message HloInputOutputAliasProto {
356  // The following proto describes a pair of aliased an input
357  // (described by parameter number and a ShapeIndex of the parameter)
358  // and an output (described by a ShapeIndex of the root
359  // instruction). For example:
360  //
361  // entry = {
362  //  output_shape_index={1},
363  //  parameter_number=0,
364  //  parameter_shape_index={1, 2},
365  // }
366  //
367  // This entry indicates that the first paremter's {1, 2} element is
368  // aliased with the {1} element of the root instruction.
369  message AliasEntryProto {
370    // ShapeIndex of the root hlo.
371    repeated int64 output_shape_index = 1;
372    // Number of the parameter in entry computation.
373    int64 parameter_number = 2;
374    // ShapeIndex of the parameter instruction.
375    repeated int64 parameter_shape_index = 3;
376    // The kind of alias to be setup.
377    Kind kind = 4;
378  }
379
380  repeated AliasEntryProto entries = 1;
381}
382
383message DynamicParameterBindingProto {
384  // A list of bindings which indicates that the `target_dim_num` in
385  // the subshape `target_param_index` of parameter `target_param_num`
386  // is a dynamic dimension and its real dynamic size is represented
387  // by `dynamic_param_index` in parameter `dynamic_param_num`.
388  //
389  // As an example, imagine we have a program:
390  //
391  // ENTRY main {
392  //   a = f32[] parameter(0)
393  //   b = f32[10] parameter(1)
394  //   ROOT root = (f32[], f32[10]) tuple(%a, %b)
395  // }
396  //
397  // Let's say 'b' (param index 1) is a dynamic shape whose input has
398  // an upperbound of 10 and real size is determined at runtime.'a'
399  // represents the real size of b's first dimension.
400  //
401  // In this case, the fields are set in the following way:
402  // dynamic_param_num = 1
403  // dynamic_param_index = {}
404  // target_param_num = 0
405  // target_param_index = {}
406  // target_param_dim = 0
407  message Binding {
408    int64 dynamic_param_num = 1;
409    repeated int64 dynamic_param_index = 2;
410    int64 target_param_num = 3;
411    repeated int64 target_param_index = 4;
412    int64 target_param_dim_num = 5;
413  }
414
415  repeated Binding entries = 1;
416}
417
418message CrossProgramPrefetch {
419  int64 parameter = 1;
420  repeated int64 index = 2;
421}
422
423// Serialization of HloModule.
424message HloModuleProto {
425  string name = 1;
426  string entry_computation_name = 2;
427  int64 entry_computation_id = 6;
428
429  // The array of computations is always in a valid dependency order, where
430  // callees appear before their callers.
431  repeated HloComputationProto computations = 3;
432
433  // The host program shape (with layout) of the entry computation.
434  xla.ProgramShapeProto host_program_shape = 4;
435
436  // The id of this module.
437  int64 id = 5;
438
439  // The schedule for this module.
440  HloScheduleProto schedule = 7;
441
442  // Describes alias information between inputs and outputs.
443  HloInputOutputAliasProto input_output_alias = 8;
444
445  DynamicParameterBindingProto dynamic_parameter_binding = 9;
446
447  repeated CrossProgramPrefetch cross_program_prefetches = 10;
448
449  // True if the module contains dynamic computation.
450  bool is_dynamic = 11;
451}
452
453// Serialization of LogicalBuffer.
454message LogicalBufferProto {
455  // Location represents an instruction and its shape index, which uniquely
456  // identifies a point where a buffer is needed.
457  message Location {
458    // NOTE: module_name isn't necessary, since all LogicalBuffers are
459    // associated with a single HloModule.
460    string computation_name = 1;
461    string instruction_name = 2;
462    repeated int64 shape_index = 3;
463  }
464
465  int64 id = 1;
466  int64 size = 2;
467
468  // The location where the buffer is defined.
469  Location defined_at = 3;
470
471  int64 color = 4;
472}
473
474// Serialization of BufferAllocation.
475message BufferAllocationProto {
476  // Assigned represents a single LogicalBuffer that is assigned to this
477  // BufferAllocation.
478  message Assigned {
479    int64 logical_buffer_id = 1;
480    int64 offset = 2;
481    int64 size = 3;
482  }
483
484  int64 index = 1;
485  int64 size = 2;
486  bool is_thread_local = 3;
487  bool is_tuple = 11;
488  bool is_entry_computation_parameter = 5;
489  bool is_constant = 12;
490  int64 parameter_number = 6;
491  repeated int64 parameter_shape_index = 10;
492  bool maybe_live_out = 7;
493  int64 color = 8;
494  repeated Assigned assigned = 9;
495}
496
497// A trace of a HeapSimulator run.
498message HeapSimulatorTrace {
499  // The trace includes a list of events, where each event describes one action
500  // performed by the heap simulator.
501  message Event {
502    enum Kind {
503      ALLOC = 0;  // A memory region was allocated for the buffer.
504      FREE = 1;   // A memory region was freed for the buffer.
505
506      // A buffer was shared with another (canonical) buffer. This is similar to
507      // ALLOC, except that instead of allocating a new region of memory, the
508      // memory region of the canonical buffer is directly re-used. Multiple
509      // buffers may share with the same canonical buffer. The lifetime of the
510      // canonical buffer is extended to the union of all lifetimes.
511      SHARE_WITH = 2;
512    }
513    Kind kind = 1;
514
515    // The id of the LogicalBuffer that the event applies to.
516    int64 buffer_id = 2;
517
518    // The HloInstruction that the simulation was processing that caused this
519    // event to occur, identified by its computation and instruction name. E.g.
520    // buffers defined by instruction A are allocated when processing A.
521    string computation_name = 3;
522    string instruction_name = 4;
523
524    // The id of the canonical LogicalBuffer that the buffer shares with. Only
525    // set for SHARE_WITH events.
526    int64 share_with_canonical_id = 5;
527  }
528  repeated Event events = 1;
529  bool whole_module_simulation = 2;
530  int64 buffer_allocation_index = 3;
531}
532
533// An abstraction representing a set of HLO module built to run concurrently
534// across different devices.
535message HloModuleGroupProto {
536  string name = 1;
537  repeated HloModuleProto hlo_modules = 2;
538}
539
540// Serialization of BufferAssignment.
541message BufferAssignmentProto {
542  // Alias represents a source LogicalBuffer, and the buffer location that
543  // aliases it.
544  message BufferAlias {
545    int64 source_buffer_id = 1;
546    LogicalBufferProto.Location location = 2;
547  }
548
549  repeated LogicalBufferProto logical_buffers = 1;
550  repeated BufferAlias buffer_aliases = 2;
551  repeated BufferAllocationProto buffer_allocations = 3;
552  repeated HeapSimulatorTrace heap_simulator_traces = 4;
553}
554
555// Grouping message that contains all of the information above.
556message HloProto {
557  reserved 2;
558  reserved "hlo_ordering";
559
560  HloModuleProto hlo_module = 1;
561  BufferAssignmentProto buffer_assignment = 3;
562}
563
564// Encapsulates HloProto together with the arguments, result, and
565// execution_platform. This message is used for purposes such as
566// analysis/replay/file-storage.
567message HloSnapshot {
568  // The hlo graph.
569  HloProto hlo = 1;
570
571  // The arguments passed to the graph.
572  repeated LiteralProto arguments = 2;
573
574  // The result of the graph.
575  LiteralProto result = 3;
576
577  // The name of the platform used to run the graph.
578  string execution_platform = 4;
579}
580
581// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering
582// with filename module_####.metadata.textproto, where #### is
583// canonical_module_id.
584message HloModuleMetadataProto {
585  // Uniquely identifies an HloModuleMetadata. Equal to the first unique_id
586  // of the module (a module may go through multiple unique_ids). If a module
587  // is partitioned into multiple modules, those modules will each have a new
588  // HloModuleMetadata with a different canonical_module_id.
589  int64 canonical_module_id = 1;
590
591  // Name of the module group that the module is part of.
592  string module_group_name = 2;
593
594  // The canonical module id of the module that this one is partitioned from,
595  // if applicable.
596  int64 original_module_id = 3;
597
598  // The canonical module ids of the modules that this one is partitioned into,
599  // if applicable.
600  repeated int64 partitioned_module_ids = 4;
601
602  // Metadata for the HLO passes that are run on the module.
603  repeated HloPassMetadata pass_metadata = 5;
604}
605
606// Metadata for one run of an HLO pass on a module. Provides more information
607// when processing debug dumps of HloProtos about the order of HLO passes and
608// various other stats like duration. `pass_id` may also be used to identify a
609// particular run of a pass in debug info that propagates through stages of
610// compilation.
611message HloPassMetadata {
612  // For a given module, pass_id uniquely identifies a run of an HLO pass on
613  // that module. Note that a pass_id may not always refer to the same pass
614  // because the order of passes during compilation may change. For finding
615  // metadata for a particular pass, pass_name and pipeline_name would be more
616  // reliable, although note that they may not be unique.
617  int64 pass_id = 1;
618  string pass_name = 2;
619  string pipeline_name = 3;
620
621  // Filenames of the dumps of the module after this pass ran. Module may be
622  // dumped in multiple formats, and the order of formats in this field will
623  // stay consistent across passes.
624  repeated string dump_filenames = 4;
625
626  // Return value of pass.Run(). True if this pass changed the module, or, in
627  // the case where the module was run through this pass as part of a module
628  // group, true if this pass changed any module in the same module group.
629  bool module_changed = 5;
630
631  // The unique_id of the module that this pass is run on. May be different from
632  // the canonical_module_id of the HloModuleMetadata that this HloPassMetadata
633  // is inside.
634  int64 module_id = 6;
635
636  // If the module went through this pass as part of a module group, this is
637  // set as the ids of all the modules in the module group. Empty otherwise.
638  repeated int64 module_group_module_ids = 7;
639
640  // Timestamp before and after the pass is run. Note they may be equal.
641  int64 start_timestamp_usec = 8;
642  int64 end_timestamp_usec = 9;
643}
644