• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18package xla;
19
20import "tensorflow/compiler/xla/service/hlo.proto";
21import "tensorflow/compiler/xla/xla_data.proto";
22
23// Options for the HLO insert-reduce-precision-operations pass.
24message HloReducePrecisionOptions {
25  // Where and when the reduce-precision operations will be added.
26  enum Location {
27    // Add reduce-precision operations to the inputs of selected instructions.
28    // This is done before any optimization occurs.
29    OP_INPUTS = 0;
30    // Add reduce-precision operations to the outputs of selected instructions.
31    // This is done before any optimization occurs.
32    OP_OUTPUTS = 1;
33    // After operation-fusion occurs, add reduce-precision operations to the
34    // outputs of any selected instructions that have not been fused into
35    // fusion instructions.
36    UNFUSED_OP_OUTPUTS = 2;
37    // After operation-fusion occurs, add reduce-precision operations to the
38    // outputs of any fusion instructions that contain operations matching the
39    // selection criteria.
40    FUSION_INPUTS_BY_CONTENT = 3;
41    // After operation-fusion occurs, add reduce-precision operations to the
42    // outputs of any fusion instructions that contain operations matching the
43    // selection criteria.
44    FUSION_OUTPUTS_BY_CONTENT = 4;
45  }
46  Location location = 1;
47
48  // Exponent and mantissa bit counts for the reduced precision.
49  uint32 exponent_bits = 2;
50  uint32 mantissa_bits = 3;
51
52  // Operations matching these opcodes should be suffixed with reduce-precision
53  // operations.
54  repeated uint32 opcodes_to_suffix = 4;
55
56  // Operations with names containing these substrings should be suffixed with
57  // reduce-precision operations.
58  repeated string opname_substrings_to_suffix = 5;
59}
60
61// Debugging options for XLA. These options may change at any time - there are
62// no guarantees about backward or forward compatibility for these fields.
63message DebugOptions {
64  // Show addresses of HLO ops in graph dump.
65  bool xla_hlo_graph_addresses = 2;
66
67  // Instrument the computation to collect per-HLO cycle counts.
68  bool xla_hlo_profile = 9;
69
70  // List of HLO passes to disable. These names must exactly match the pass
71  // names as specified by the HloPassInterface::name() method.
72  repeated string xla_disable_hlo_passes = 30;
73
74  // Disables all HLO passes.  Notes that some passes are necessary for
75  // correctness and the invariants that must be satisfied by "fully optimized"
76  // HLO are different for different devices and may change over time.  The only
77  // "guarantee", such as it is, is that if you compile XLA and dump the
78  // optimized HLO for some graph, you should be able to run it again on the
79  // same device with the same build of XLA.
80  bool xla_disable_all_hlo_passes = 104;
81
82  // Numerical optimization level for the XLA compiler backend; the specific
83  // interpretation of this value is left to the backends.
84  int32 xla_backend_optimization_level = 31;
85
86  // Embed the compiler IR as a string in the executable.
87  bool xla_embed_ir_in_executable = 33;
88
89  // Eliminate implicit broadcasts when lowering user computations to HLO
90  // instructions; use explicit broadcast instead.
91  bool xla_eliminate_hlo_implicit_broadcast = 35;
92
93  // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen
94  // mode.
95  bool xla_cpu_multi_thread_eigen = 60;
96
97  // Path to directory with cuda/ptx tools and libraries.
98  string xla_gpu_cuda_data_dir = 61;
99
100  // Enable flush-to-zero semantics in the GPU backend.
101  bool xla_gpu_ftz = 62;
102
103  // Disable multi-streaming in the GPU backend.
104  bool xla_gpu_disable_multi_streaming = 63;
105
106  // If true, in LLVM-based backends, emit !alias.scope metadata in
107  // generated IR.
108  bool xla_llvm_enable_alias_scope_metadata = 70;
109
110  // If true, in LLVM-based backends, emit !noalias metadata in the
111  // generated IR.
112  bool xla_llvm_enable_noalias_metadata = 71;
113
114  // If true, in LLVM-based backends, emit !invariant.load metadata in
115  // the generated IR.
116  bool xla_llvm_enable_invariant_load_metadata = 72;
117
118  // If true, a set of expensive LLVM optimization passes will not be run.
119  bool xla_llvm_disable_expensive_passes = 73;
120
121  // Options for inserting reduce-precision operations for numerical
122  // experimentation.  This is a repeated field, as we may want to have
123  // multiple passes with different parameters.
124  repeated HloReducePrecisionOptions hlo_reduce_precision_options = 80;
125
126  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
127  // computation will run n! times with all permunations of layouts for the
128  // output shape in rank n. For example, with a 3D shape, all permutations of
129  // the set {0, 1, 2} are tried.
130  bool xla_test_all_output_layouts = 90;
131
132  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
133  // computation will run for all permunations of layouts of all input
134  // arguments. For example, with 2 input arguments in 2D and 4D shapes, the
135  // computation will run 2! * 4! times.
136  bool xla_test_all_input_layouts = 91;
137
138  // Assign colors based on sharding information when generating the Graphviz
139  // HLO graph.
140  bool xla_hlo_graph_sharding_color = 92;
141
142  reserved 93;  // Was xla_hlo_tfgraph_device_scopes
143
144  // If true, the GPU backend is free to use cudnn for HLO batch normalization
145  // ops.
146  bool xla_gpu_use_cudnn_batchnorm = 94;
147
148  // Generate calls to MKL-DNN in the CPU backend.
149  bool xla_cpu_use_mkl_dnn = 97;
150
151  // Maximum kernel unroll factor for the GPU backend.
152  int32 xla_gpu_max_kernel_unroll_factor = 98;
153
154  // When true, "unsafe" mathematical optimizations are enabled. These
155  // transformations include but are not limited to:
156  //
157  //  - Reducing the precision of operations (e.g. using an approximate sin
158  //    function, or transforming x/y into x * (1/y)).
159  //  - Assuming that operations never produce or consume NaN or +/- Inf (this
160  //    behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}).
161  //  - Assuming that +0 and -0 are indistinguishable.
162  bool xla_cpu_enable_fast_math = 99;
163
164  // When xla_cpu_enable_fast_math is true then this controls whether we allow
165  // operations to produce NaNs.  Ignored when xla_cpu_enable_fast_math is
166  // false.
167  bool xla_cpu_fast_math_honor_nans = 120;
168
169  // When xla_cpu_enable_fast_math is true then this controls whether we allow
170  // operations to produce infinites.  Ignored when xla_cpu_enable_fast_math is
171  // false.
172  bool xla_cpu_fast_math_honor_infs = 121;
173
174  // When true we lower the Minimum and Maximum hlos in the GPU backend such
175  // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN.  In other words, if flag
176  // this is true we don't propagate NaNs through Min and Max.
177  bool xla_gpu_enable_fast_min_max = 100;
178
179  // Crashes the program when any kind of verification fails, instead of just
180  // logging the failures. One example is cross checking of convolution results
181  // among different algorithms.
182  bool xla_gpu_crash_on_verification_failures = 101;
183
184  // Force the host platform to pretend that there are these many host
185  // "devices".  All these devices are backed by the same threadpool.  Defaults
186  // to 1.
187  //
188  // Setting this to anything other than 1 can increase overhead from context
189  // switching but we let the user override this behavior to help run tests on
190  // the host that run models in parallel across multiple devices.
191  int32 xla_force_host_platform_device_count = 102;
192
193  // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3).
194  bool xla_gpu_disable_ptxas_optimizations = 103;
195
196  // Enable fast math with eigen in the HLO evaluator.
197  bool xla_hlo_evaluator_use_fast_path = 106;
198
199  // Temporary option to allow support for both the R1 and the scalar index
200  // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing.
201  bool xla_allow_scalar_index_dynamic_ops = 107;
202
203  enum StepMarkerLocation {
204    // Generate step mark at each iteration of top level while loop, which
205    // is assumed to be a training loop. This is the default.
206    STEP_MARK_AT_ENTRY = 0;
207    // Generate step mark at program entry. This handles the case where each
208    // step are done by one or multiple programs execution. Only the first
209    // program will be tagged for generating step mark at program entry.
210    STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1;
211    // No step mark.
212    STEP_MARK_NONE = 2;
213  }
214  // Option to emit a target-specific marker to indicate the start of a training
215  // step. The location of the marker (if any) is determined by the option
216  // value.
217  StepMarkerLocation xla_step_marker_location = 108;
218
219  //
220  // BEGIN flags controlling dumping HLO modules for debugging.
221  //
222  // When dumping is enabled, HLO modules dumped at the very beginning and end
223  // of compilation, and optionally also during the pass pipeline.
224  //
225  // In general, if you set one of these flags, we will try to infer reasonable
226  // defaults for the others.  For example:
227  //
228  //  * Setting --xla_dump_to=/tmp/foo without specifying a format
229  //    with --xla_dump_hlo_as_* will turn on --xla_dump_hlo_as_text.
230  //
231  //  * Setting --xla_dump_hlo_as_text without specifying --xla_dump_to will
232  //    dump to stdout.
233  //
234
235  // Directory to dump into.
236  string xla_dump_to = 109;
237
238  // If specified, will only dump modules which match this regexp.
239  string xla_dump_hlo_module_re = 110;
240
241  // If this flag is specified, will also HLO before and after passes that match
242  // this regular expression.  Set to .* to dump before/after all passes.
243  string xla_dump_hlo_pass_re = 111;
244
245  // Specifies the format that HLO is dumped in.  Multiple of these may be
246  // specified.
247  bool xla_dump_hlo_as_text = 112;
248  bool xla_dump_hlo_as_proto = 113;
249  bool xla_dump_hlo_as_dot = 114;
250  bool xla_dump_hlo_as_url = 115;
251
252  // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML)
253  bool xla_dump_hlo_as_html = 116;
254
255  // If true, every time an HLO module is run, we will dump an HloSnapshot
256  // (essentially, a serialized module plus its inputs) to the --xla_dump_to
257  // directory.
258  bool xla_dump_hlo_snapshots = 118;
259
260  //
261  // END flags controlling dumping HLO modules.
262  //
263
264  // Next id: 121
265
266  // Extra options to pass to the compilation backend (e.g. LLVM); specific
267  // interpretation of these values is left to the backend.
268  map<string, string> xla_backend_extra_options = 500;
269
270  reserved 117;  // was xla_dump_to
271  reserved 5;    // Was xla_hlo_dump_as_graphdef
272}
273
274// These settings control how XLA compiles and/or runs code.  Not all settings
275// will have an effect on every platform.
276//
277// When adding new fields, keep in mind that boolean fields default to false.
278message ExecutionOptions {
279  // This optional field's layout is used as a hint when storing the output of
280  // this computation.  Subsequent transfers of this output array to the client
281  // may be faster when using this layout.
282  //
283  // We use a Shape here to accommodate computations that return a tuple.
284  ShapeProto shape_with_output_layout = 2;
285
286  // Used to seed random-number generators used in this computation.  If this is
287  // 0, we generate a seed ourselves.
288  //
289  // TODO(b/32083678): Changing the seed unnecessarily forces a recompilation.
290  uint64 seed = 3;
291
292  DebugOptions debug_options = 4;
293
294  // This optional field specifies a particular set of devices to run the
295  // computation on. The computation will be partitioned across these devices.
296  // If not provided, the default device will be chosen.
297  repeated DeviceHandle device_handles = 5;
298
299  // Number of replicas of the computation to run. If zero, uses the default
300  // number of replicas for the XLA service.
301  int32 num_replicas = 6;
302
303  // This optional field specifies the device assignment if known at compile
304  // time.
305  DeviceAssignmentProto device_assignment = 7;
306}
307
308message GetDeviceHandlesRequest {
309  int64 device_count = 1;
310}
311
312message GetDeviceHandlesResponse {
313  repeated DeviceHandle device_handles = 1;
314}
315
316message TransferToClientRequest {
317  GlobalDataHandle data = 1;
318
319  // This optional field directs the service to return the literal in this
320  // layout. A shape is used to hold the layout to accommodate tuples.
321  ShapeProto shape_with_layout = 2;
322}
323
324message TransferToClientResponse {
325  LiteralProto literal = 1;
326}
327
328message TransferToServerRequest {
329  LiteralProto literal = 1;
330  DeviceHandle device_handle = 2;
331}
332
333message TransferToServerResponse {
334  GlobalDataHandle data = 1;
335}
336
337message TransferToInfeedRequest {
338  LiteralProto literal = 1;
339  int64 replica_id = 2;
340  DeviceHandle device_handle = 3;
341}
342
343message TransferToInfeedResponse {}
344
345message TransferFromOutfeedRequest {
346  // This optional field directs the service to return the literal in this
347  // layout. A shape is used to hold the layout to accommodate tuples.
348  ShapeProto shape_with_layout = 1;
349
350  int64 replica_id = 2;
351  DeviceHandle device_handle = 3;
352}
353
354message TransferFromOutfeedResponse {
355  LiteralProto literal = 1;
356}
357
358message ResetDeviceRequest {
359  DeviceHandle device_handle = 1;
360}
361
362message ResetDeviceResponse {}
363
364message ComputationGraphStatsRequest {
365  HloModuleProto computation = 1;
366  DebugOptions debug_options = 2;
367}
368
369message ComputationStatsResponse {
370  ComputationStats stats = 1;
371}
372
373message CreateChannelHandleRequest {
374  ChannelHandle.ChannelType channel_type = 1;
375}
376
377message CreateChannelHandleResponse {
378  ChannelHandle channel = 1;
379}
380
381message UnregisterRequest {
382  repeated GlobalDataHandle data = 1;
383}
384
385message UnregisterResponse {}
386
387message CompileRequest {
388  // The graph to be compiled.
389  HloModuleProto computation = 1;
390
391  // Options that affect how XLA compiles code to service this request.
392  ExecutionOptions execution_options = 2;
393
394  // The layouts of the input arguments. If not set, the default layout will be
395  // used. Although the real arguments are not needed in compilation, the
396  // layouts of the arguments can affect the compilation.
397  repeated ShapeProto input_shape_with_layout = 3;
398}
399
400message CompileResponse {
401  // The handle to the executable.
402  ExecutionHandle handle = 1;
403}
404
405message ExecuteRequest {
406  ExecutionHandle handle = 1;
407
408  // The shape and layout of the arguments must be the same as the those of the
409  // executable's parameters.
410  repeated GlobalDataHandle arguments = 2;
411}
412
413// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace
414// the uses with calls to Compile and Execute.
415message ExecuteGraphRequest {
416  HloModuleProto computation = 1;
417  repeated GlobalDataHandle arguments = 2;
418
419  // Options that affect how XLA compiles and runs code to service this request.
420  ExecutionOptions execution_options = 3;
421}
422
423message ExecuteGraphParallelRequest {
424  repeated ExecuteGraphRequest requests = 1;
425}
426
427message ExecuteResponse {
428  GlobalDataHandle output = 1;
429  ExecutionProfile profile = 2;
430}
431
432message ExecuteParallelResponse {
433  repeated ExecuteResponse responses = 1;
434}
435
436message WaitForExecutionRequest {
437  ExecutionHandle execution = 1;
438}
439
440message WaitForExecutionResponse {
441  GlobalDataHandle output = 1;
442  ExecutionProfile profile = 2;
443}
444
445message ComputeConstantGraphRequest {
446  HloModuleProto computation = 1;
447  LayoutProto output_layout = 2;
448}
449
450message ComputeConstantResponse {
451  // A LiteralProto is returned directly for this request.
452  LiteralProto literal = 1;
453}
454
455message DeconstructTupleRequest {
456  GlobalDataHandle tuple_handle = 2;
457}
458
459message DeconstructTupleResponse {
460  repeated GlobalDataHandle element_handles = 1;
461}
462
463message LoadDataRequest {
464  // Describes the path of the ColumnIO tablet to load.
465  string columnio_tablet_path = 1;
466
467  // Describes the field to load within the ColumnIO tablet.
468  string columnio_field = 2;
469
470  // Individual element shape, excluding rows.
471  ShapeProto element_shape = 3;
472
473  // Warning: ColumnIO does not support random-access, so use offset with
474  // caution in performance-critical scenarios.
475  int64 offset = 4;
476
477  // Maximum number of elements (with shape element_shape) to load.
478  int64 limit = 5;
479
480  // If more than one item is requested (via limit > 1), then this request
481  // attribute zips together the produced vectors.
482  bool zip = 6;
483}
484
485message LoadDataResponse {
486  GlobalDataHandle data = 1;
487  ShapeProto data_shape = 2;
488  int64 available_rows = 3;
489  int64 rows_loaded = 4;
490  int64 nanoseconds = 5;
491}
492
493message GetShapeRequest {
494  GlobalDataHandle data = 1;
495}
496
497message GetShapeResponse {
498  ShapeProto shape = 1;
499}
500
501message UnpackRequest {
502  GlobalDataHandle data = 1;
503}
504
505message UnpackResponse {
506  repeated GlobalDataHandle tied_data = 1;
507}
508