• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18package xla;
19
20import "tensorflow/compiler/xla/service/hlo.proto";
21import "tensorflow/compiler/xla/xla_data.proto";
22
23// Debugging options for XLA. These options may change at any time - there are
24// no guarantees about backward or forward compatibility for these fields.
25message DebugOptions {
26  // Show addresses of HLO ops in graph dump.
27  bool xla_hlo_graph_addresses = 2;
28
29  // Instrument the computation to collect per-HLO cycle counts.
30  bool xla_hlo_profile = 9;
31
32  // List of HLO passes to disable/enable. These names must exactly match the
33  // pass names as specified by the HloPassInterface::name() method.
34  //
35  // At least one of xla_disable_hlo_passes and xla_enable_hlo_passes_only must
36  // be empty.
37  repeated string xla_disable_hlo_passes = 30;
38  repeated string xla_enable_hlo_passes_only = 124;
39
40  // Disables all HLO passes.  Notes that some passes are necessary for
41  // correctness and the invariants that must be satisfied by "fully optimized"
42  // HLO are different for different devices and may change over time.  The only
43  // "guarantee", such as it is, is that if you compile XLA and dump the
44  // optimized HLO for some graph, you should be able to run it again on the
45  // same device with the same build of XLA.
46  bool xla_disable_all_hlo_passes = 104;
47
48  // Numerical optimization level for the XLA compiler backend; the specific
49  // interpretation of this value is left to the backends.
50  int32 xla_backend_optimization_level = 31;
51
52  // Embed the compiler IR as a string in the executable.
53  bool xla_embed_ir_in_executable = 33;
54
55  // Eliminate implicit broadcasts when lowering user computations to HLO
56  // instructions; use explicit broadcast instead.
57  bool xla_eliminate_hlo_implicit_broadcast = 35;
58
59  // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen
60  // mode.
61  bool xla_cpu_multi_thread_eigen = 60;
62
63  // Path to directory with cuda/ptx tools and libraries.
64  string xla_gpu_cuda_data_dir = 61;
65
66  // Enable flush-to-zero semantics in the GPU backend.
67  bool xla_gpu_ftz = 62;
68
69  // Disable multi-streaming in the GPU backend.
70  bool xla_gpu_disable_multi_streaming = 63;
71
72  // If true, in LLVM-based backends, emit !alias.scope metadata in
73  // generated IR.
74  bool xla_llvm_enable_alias_scope_metadata = 70;
75
76  // If true, in LLVM-based backends, emit !noalias metadata in the
77  // generated IR.
78  bool xla_llvm_enable_noalias_metadata = 71;
79
80  // If true, in LLVM-based backends, emit !invariant.load metadata in
81  // the generated IR.
82  bool xla_llvm_enable_invariant_load_metadata = 72;
83
84  // If true, a set of expensive LLVM optimization passes will not be run.
85  bool xla_llvm_disable_expensive_passes = 73;
86
87  reserved 80;  // Was hlo_reduce_precision_options
88
89  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
90  // computation will run n! times with all permunations of layouts for the
91  // output shape in rank n. For example, with a 3D shape, all permutations of
92  // the set {0, 1, 2} are tried.
93  bool xla_test_all_output_layouts = 90;
94
95  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
96  // computation will run for all permunations of layouts of all input
97  // arguments. For example, with 2 input arguments in 2D and 4D shapes, the
98  // computation will run 2! * 4! times.
99  bool xla_test_all_input_layouts = 91;
100
101  // Assign colors based on sharding information when generating the Graphviz
102  // HLO graph.
103  bool xla_hlo_graph_sharding_color = 92;
104
105  reserved 93;  // Was xla_hlo_tfgraph_device_scopes
106
107  // If true, the GPU backend is free to use cudnn for HLO batch normalization
108  // ops.
109  bool xla_gpu_use_cudnn_batchnorm = 94;
110
111  // Generate calls to MKL-DNN in the CPU backend.
112  bool xla_cpu_use_mkl_dnn = 97;
113
114  // Maximum kernel unroll factor for the GPU backend.
115  int32 xla_gpu_max_kernel_unroll_factor = 98;
116
117  // When true, "unsafe" mathematical optimizations are enabled. These
118  // transformations include but are not limited to:
119  //
120  //  - Reducing the precision of operations (e.g. using an approximate sin
121  //    function, or transforming x/y into x * (1/y)).
122  //  - Assuming that operations never produce or consume NaN or +/- Inf (this
123  //    behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}).
124  //  - Assuming that +0 and -0 are indistinguishable.
125  bool xla_cpu_enable_fast_math = 99;
126
127  // When xla_cpu_enable_fast_math is true then this controls whether we allow
128  // operations to produce NaNs.  Ignored when xla_cpu_enable_fast_math is
129  // false.
130  bool xla_cpu_fast_math_honor_nans = 120;
131
132  // When xla_cpu_enable_fast_math is true then this controls whether we allow
133  // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is
134  // false.
135  bool xla_cpu_fast_math_honor_infs = 121;
136
137  // When xla_cpu_enable_fast_math is true then this controls whether we forbid
138  // to use the reciprocal of an argument instead of division. Ignored when
139  // xla_cpu_enable_fast_math is false.
140  bool xla_cpu_fast_math_honor_division = 126;
141
142  // When xla_cpu_enable_fast_math is true then this controls whether we forbid
143  // to approximate calculations for functions. Ignored when
144  // xla_cpu_enable_fast_math is false.
145  bool xla_cpu_fast_math_honor_functions = 129;
146
147  // When true we lower the Minimum and Maximum hlos in the GPU backend such
148  // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN.  In other words, if flag
149  // this is true we don't propagate NaNs through Min and Max.
150  bool xla_gpu_enable_fast_min_max = 100;
151
152  // Allows xla to increase the output precision of floating point operations.
153  bool xla_allow_excess_precision = 122;
154
155  // Crashes the program when any kind of verification fails, instead of just
156  // logging the failures. One example is cross checking of convolution results
157  // among different algorithms.
158  bool xla_gpu_crash_on_verification_failures = 101;
159
160  // Disable GEMM and Convolution auto-tuning.
161  int32 xla_gpu_autotune_level = 123;
162
163  // Force the host platform to pretend that there are these many host
164  // "devices".  All these devices are backed by the same threadpool.  Defaults
165  // to 1.
166  //
167  // Setting this to anything other than 1 can increase overhead from context
168  // switching but we let the user override this behavior to help run tests on
169  // the host that run models in parallel across multiple devices.
170  int32 xla_force_host_platform_device_count = 102;
171
172  // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3).
173  bool xla_gpu_disable_gpuasm_optimizations = 103;
174
175  // Enable fast math with eigen in the HLO evaluator.
176  bool xla_hlo_evaluator_use_fast_path = 106;
177
178  // Temporary option to allow support for both the R1 and the scalar index
179  // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing.
180  bool xla_allow_scalar_index_dynamic_ops = 107;
181
182  enum StepMarkerLocation {
183    // Generate a step marker at the program entry. This handles the case where
184    // each step is done by one or multiple program execution(s). Only the first
185    // program will be tagged for generating a step marker at the program entry.
186    // This is the default.
187    STEP_MARK_AT_ENTRY = 0;
188    // Generate a step marker at each iteration of the top level while loop,
189    // which is assumed to be a training loop.
190    STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1;
191    // Generate a step marker at each iteration of the second level while loops,
192    // which is assumed to be a training or eval loop.
193    STEP_MARK_AT_SECOND_LEVEL_WHILE_LOOP = 3;
194    // No step marker generated.
195    STEP_MARK_NONE = 2;
196  }
197  // Option to emit a target-specific marker to indicate the start of a training
198  // step. The location of the marker (if any) is determined by the option
199  // value.
200  StepMarkerLocation xla_step_marker_location = 108;
201
202  //
203  // BEGIN flags controlling dumping HLO modules for debugging.
204  //
205  // When dumping is enabled, HLO modules dumped at the very beginning and end
206  // of compilation, and optionally also during the pass pipeline.
207  //
208  // In general, if you set one of these flags, we will try to infer reasonable
209  // defaults for the others.  For example:
210  //
211  //  * Setting --xla_dump_to=/tmp/foo without specifying a format
212  //    with --xla_dump_hlo_as_* will turn on --xla_dump_hlo_as_text.
213  //
214  //  * Setting --xla_dump_hlo_as_text without specifying --xla_dump_to will
215  //    dump to stdout.
216  //
217
218  // Directory to dump into.
219  string xla_dump_to = 109;
220
221  // If specified, will only dump modules which match this regexp.
222  string xla_dump_hlo_module_re = 110;
223
224  // If this flag is specified, will also HLO before and after passes that match
225  // this regular expression.  Set to .* to dump before/after all passes.
226  string xla_dump_hlo_pass_re = 111;
227
228  // Specifies the format that HLO is dumped in.  Multiple of these may be
229  // specified.
230  bool xla_dump_hlo_as_text = 112;
231  bool xla_dump_hlo_as_proto = 113;
232  bool xla_dump_hlo_as_dot = 114;
233  bool xla_dump_hlo_as_url = 115;
234
235  // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML)
236  bool xla_dump_hlo_as_html = 116;
237
238  // If true, every time an HLO module is run, we will dump an HloSnapshot
239  // (essentially, a serialized module plus its inputs) to the --xla_dump_to
240  // directory.
241  bool xla_dump_hlo_snapshots = 118;
242
243  // Include a timestamp in the dumped filenames.
244  bool xla_dump_include_timestamp = 131;
245
246  // Max number of hlo module dumps in a directory. Set to < 0 for unbounded.
247  int32 xla_dump_max_hlo_modules = 132;
248
249  //
250  // END flags controlling dumping HLO modules.
251  //
252
253  bool xla_gpu_force_conv_nchw = 125;
254
255  // Paths to files with ptx code.
256  repeated string xla_gpu_ptx_file = 127;
257
258  // Blacklist for cuDNN convolutions.
259  string xla_gpu_algorithm_blacklist_path = 128;
260
261  // Guarantee run-to-run determinism from reductions on XLA:GPU.
262  bool xla_gpu_deterministic_reductions = 130;
263  // Next id: 133
264
265  // Extra options to pass to the compilation backend (e.g. LLVM); specific
266  // interpretation of these values is left to the backend.
267  map<string, string> xla_backend_extra_options = 500;
268
269  reserved 117;  // was xla_dump_to
270  reserved 5;    // Was xla_hlo_dump_as_graphdef
271}
272
273// These settings control how XLA compiles and/or runs code.  Not all settings
274// will have an effect on every platform.
275//
276// When adding new fields, keep in mind that boolean fields default to false.
277message ExecutionOptions {
278  // This optional field's layout is used as a hint when storing the output of
279  // this computation.  Subsequent transfers of this output array to the client
280  // may be faster when using this layout.
281  //
282  // We use a Shape here to accommodate computations that return a tuple.
283  ShapeProto shape_with_output_layout = 2;
284
285  // Used to seed random-number generators used in this computation.  If this is
286  // 0, we generate a seed ourselves.
287  //
288  // TODO(b/32083678): Changing the seed unnecessarily forces a recompilation.
289  uint64 seed = 3;
290
291  DebugOptions debug_options = 4;
292
293  // This optional field specifies a particular set of devices to run the
294  // computation on. The computation will be partitioned across these devices.
295  // If not provided, the default device will be chosen.
296  repeated DeviceHandle device_handles = 5;
297
298  // Number of replicas of the computation to run. If zero, uses the default
299  // number of replicas for the XLA service.
300  int32 num_replicas = 6;
301
302  // This optional field specifies the device assignment if known at compile
303  // time.
304  DeviceAssignmentProto device_assignment = 7;
305
306  // Alias input and output buffers for parameters that are passed-through XLA
307  // modules without being changed.
308  bool alias_passthrough_params = 8;
309
310  // Number of partitions of the computation to run (model parallelism).
311  // If zero, uses the default number of partitions for the XLA service.
312  int32 num_partitions = 9;
313}
314
315message GetDeviceHandlesRequest {
316  int64 device_count = 1;
317}
318
319message GetDeviceHandlesResponse {
320  repeated DeviceHandle device_handles = 1;
321}
322
323message TransferToClientRequest {
324  GlobalDataHandle data = 1;
325
326  // This optional field directs the service to return the literal in this
327  // layout. A shape is used to hold the layout to accommodate tuples.
328  ShapeProto shape_with_layout = 2;
329}
330
331message TransferToClientResponse {
332  LiteralProto literal = 1;
333}
334
335message TransferToServerRequest {
336  LiteralProto literal = 1;
337  DeviceHandle device_handle = 2;
338}
339
340message TransferToServerResponse {
341  GlobalDataHandle data = 1;
342}
343
344message TransferToInfeedRequest {
345  LiteralProto literal = 1;
346  int64 replica_id = 2;
347  DeviceHandle device_handle = 3;
348}
349
350message TransferToInfeedResponse {}
351
352message TransferFromOutfeedRequest {
353  // This optional field directs the service to return the literal in this
354  // layout. A shape is used to hold the layout to accommodate tuples.
355  ShapeProto shape_with_layout = 1;
356
357  int64 replica_id = 2;
358  DeviceHandle device_handle = 3;
359}
360
361message TransferFromOutfeedResponse {
362  LiteralProto literal = 1;
363}
364
365message ResetDeviceRequest {
366  DeviceHandle device_handle = 1;
367}
368
369message ResetDeviceResponse {}
370
371message ComputationGraphStatsRequest {
372  HloModuleProto computation = 1;
373  DebugOptions debug_options = 2;
374}
375
376message ComputationStatsResponse {
377  ComputationStats stats = 1;
378}
379
380message CreateChannelHandleRequest {
381  ChannelHandle.ChannelType channel_type = 1;
382}
383
384message CreateChannelHandleResponse {
385  ChannelHandle channel = 1;
386}
387
388message UnregisterRequest {
389  repeated GlobalDataHandle data = 1;
390}
391
392message UnregisterResponse {}
393
394message CompileRequest {
395  // The graph to be compiled.
396  HloModuleProto computation = 1;
397
398  // Options that affect how XLA compiles code to service this request.
399  ExecutionOptions execution_options = 2;
400
401  // The layouts of the input arguments. If not set, the default layout will be
402  // used. Although the real arguments are not needed in compilation, the
403  // layouts of the arguments can affect the compilation.
404  repeated ShapeProto input_shape_with_layout = 3;
405}
406
407message CompileResponse {
408  // The handle to the executable.
409  ExecutionHandle handle = 1;
410}
411
412message ExecuteRequest {
413  ExecutionHandle handle = 1;
414
415  // The shape and layout of the arguments must be the same as the those of the
416  // executable's parameters.
417  repeated GlobalDataHandle arguments = 2;
418}
419
420// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace
421// the uses with calls to Compile and Execute.
422message ExecuteGraphRequest {
423  HloModuleProto computation = 1;
424  repeated GlobalDataHandle arguments = 2;
425
426  // Options that affect how XLA compiles and runs code to service this request.
427  ExecutionOptions execution_options = 3;
428}
429
430message ExecuteGraphParallelRequest {
431  repeated ExecuteGraphRequest requests = 1;
432}
433
434message ExecuteResponse {
435  GlobalDataHandle output = 1;
436  ExecutionProfile profile = 2;
437}
438
439message ExecuteParallelResponse {
440  repeated ExecuteResponse responses = 1;
441}
442
443message WaitForExecutionRequest {
444  ExecutionHandle execution = 1;
445}
446
447message WaitForExecutionResponse {
448  GlobalDataHandle output = 1;
449  ExecutionProfile profile = 2;
450}
451
452message ComputeConstantGraphRequest {
453  HloModuleProto computation = 1;
454  LayoutProto output_layout = 2;
455}
456
457message ComputeConstantResponse {
458  // A LiteralProto is returned directly for this request.
459  LiteralProto literal = 1;
460}
461
462message DeconstructTupleRequest {
463  GlobalDataHandle tuple_handle = 2;
464}
465
466message DeconstructTupleResponse {
467  repeated GlobalDataHandle element_handles = 1;
468}
469
470message LoadDataRequest {
471  // Describes the path of the ColumnIO tablet to load.
472  string columnio_tablet_path = 1;
473
474  // Describes the field to load within the ColumnIO tablet.
475  string columnio_field = 2;
476
477  // Individual element shape, excluding rows.
478  ShapeProto element_shape = 3;
479
480  // Warning: ColumnIO does not support random-access, so use offset with
481  // caution in performance-critical scenarios.
482  int64 offset = 4;
483
484  // Maximum number of elements (with shape element_shape) to load.
485  int64 limit = 5;
486
487  // If more than one item is requested (via limit > 1), then this request
488  // attribute zips together the produced vectors.
489  bool zip = 6;
490}
491
492message LoadDataResponse {
493  GlobalDataHandle data = 1;
494  ShapeProto data_shape = 2;
495  int64 available_rows = 3;
496  int64 rows_loaded = 4;
497  int64 nanoseconds = 5;
498}
499
500message GetShapeRequest {
501  GlobalDataHandle data = 1;
502}
503
504message GetShapeResponse {
505  ShapeProto shape = 1;
506}
507
508message UnpackRequest {
509  GlobalDataHandle data = 1;
510}
511
512message UnpackResponse {
513  repeated GlobalDataHandle tied_data = 1;
514}
515