• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18package xla;
19
20import "tensorflow/compiler/xla/service/hlo.proto";
21import "tensorflow/compiler/xla/xla_data.proto";
22
23// Debugging options for XLA. These options may change at any time - there are
24// no guarantees about backward or forward compatibility for these fields.
25message DebugOptions {
26  // Show addresses of HLO ops in graph dump.
27  bool xla_hlo_graph_addresses = 2;
28
29  // Instrument the computation to collect per-HLO cycle counts.
30  bool xla_hlo_profile = 9;
31
32  // List of HLO passes to disable/enable. These names must exactly match the
33  // pass names as specified by the HloPassInterface::name() method.
34  //
35  // At least one of xla_disable_hlo_passes and xla_enable_hlo_passes_only must
36  // be empty.
37  repeated string xla_disable_hlo_passes = 30;
38  repeated string xla_enable_hlo_passes_only = 124;
39
40  // Disables all HLO passes.  Notes that some passes are necessary for
41  // correctness and the invariants that must be satisfied by "fully optimized"
42  // HLO are different for different devices and may change over time.  The only
43  // "guarantee", such as it is, is that if you compile XLA and dump the
44  // optimized HLO for some graph, you should be able to run it again on the
45  // same device with the same build of XLA.
46  bool xla_disable_all_hlo_passes = 104;
47
48  // Numerical optimization level for the XLA compiler backend; the specific
49  // interpretation of this value is left to the backends.
50  int32 xla_backend_optimization_level = 31;
51
52  // Embed the compiler IR as a string in the executable.
53  bool xla_embed_ir_in_executable = 33;
54
55  // Eliminate implicit broadcasts when lowering user computations to HLO
56  // instructions; use explicit broadcast instead.
57  bool xla_eliminate_hlo_implicit_broadcast = 35;
58
59  // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen
60  // mode.
61  bool xla_cpu_multi_thread_eigen = 60;
62
63  // Path to directory with cuda/ptx tools and libraries.
64  string xla_gpu_cuda_data_dir = 61;
65
66  // Enable flush-to-zero semantics in the GPU backend.
67  bool xla_gpu_ftz = 62;
68
69  // Disable multi-streaming in the GPU backend.
70  bool xla_gpu_disable_multi_streaming = 63;
71
72  // Debugging feature: if enabled, the GPU backend will assign HLO operators to
73  // randomly chosen streams. This is intended to trigger concurrency bugs.
74  bool xla_gpu_use_random_streams = 134;
75
76  // If true, in LLVM-based backends, emit !alias.scope metadata in
77  // generated IR.
78  bool xla_llvm_enable_alias_scope_metadata = 70;
79
80  // If true, in LLVM-based backends, emit !noalias metadata in the
81  // generated IR.
82  bool xla_llvm_enable_noalias_metadata = 71;
83
84  // If true, in LLVM-based backends, emit !invariant.load metadata in
85  // the generated IR.
86  bool xla_llvm_enable_invariant_load_metadata = 72;
87
88  // If true, a set of expensive LLVM optimization passes will not be run.
89  bool xla_llvm_disable_expensive_passes = 73;
90
91  reserved 80;  // Was hlo_reduce_precision_options
92
93  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
94  // computation will run n! times with all permunations of layouts for the
95  // output shape in rank n. For example, with a 3D shape, all permutations of
96  // the set {0, 1, 2} are tried.
97  bool xla_test_all_output_layouts = 90;
98
99  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
100  // computation will run for all permunations of layouts of all input
101  // arguments. For example, with 2 input arguments in 2D and 4D shapes, the
102  // computation will run 2! * 4! times.
103  bool xla_test_all_input_layouts = 91;
104
105  // Assign colors based on sharding information when generating the Graphviz
106  // HLO graph.
107  bool xla_hlo_graph_sharding_color = 92;
108
109  reserved 93;  // Was xla_hlo_tfgraph_device_scopes
110
111  // If true, the GPU backend is free to use cudnn for HLO batch normalization
112  // ops.
113  bool xla_gpu_use_cudnn_batchnorm = 94;
114
115  // Generate calls to MKL-DNN in the CPU backend.
116  bool xla_cpu_use_mkl_dnn = 97;
117
118  // Maximum kernel unroll factor for the GPU backend.
119  int32 xla_gpu_max_kernel_unroll_factor = 98;
120
121  // When true, "unsafe" mathematical optimizations are enabled. These
122  // transformations include but are not limited to:
123  //
124  //  - Reducing the precision of operations (e.g. using an approximate sin
125  //    function, or transforming x/y into x * (1/y)).
126  //  - Assuming that operations never produce or consume NaN or +/- Inf (this
127  //    behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}).
128  //  - Assuming that +0 and -0 are indistinguishable.
129  bool xla_cpu_enable_fast_math = 99;
130
131  // When xla_cpu_enable_fast_math is true then this controls whether we allow
132  // operations to produce NaNs.  Ignored when xla_cpu_enable_fast_math is
133  // false.
134  bool xla_cpu_fast_math_honor_nans = 120;
135
136  // When xla_cpu_enable_fast_math is true then this controls whether we allow
137  // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is
138  // false.
139  bool xla_cpu_fast_math_honor_infs = 121;
140
141  // When xla_cpu_enable_fast_math is true then this controls whether we forbid
142  // to use the reciprocal of an argument instead of division. Ignored when
143  // xla_cpu_enable_fast_math is false.
144  bool xla_cpu_fast_math_honor_division = 126;
145
146  // When xla_cpu_enable_fast_math is true then this controls whether we forbid
147  // to approximate calculations for functions. Ignored when
148  // xla_cpu_enable_fast_math is false.
149  bool xla_cpu_fast_math_honor_functions = 129;
150
151  // When false we lower the Minimum and Maximum hlos in the CPU backend such
152  // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN.  In other words, if flag
153  // this is false we always propagate NaNs through Min and Max.
154  //
155  // Note, this does not correspond to the exact same behavior as the gpu flag
156  // below!
157  bool xla_cpu_enable_fast_min_max = 140;
158
159  // When true we lower the Minimum and Maximum hlos in the GPU backend such
160  // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN.  In other words, if flag
161  // this is true we don't propagate NaNs through Min and Max.
162  //
163  // Note, this does not correspond to the exact same behavior as the cpu flag
164  // above!
165  bool xla_gpu_enable_fast_min_max = 100;
166
167  // Allows xla to increase the output precision of floating point operations.
168  bool xla_allow_excess_precision = 122;
169
170  // Crashes the program when any kind of verification fails, instead of just
171  // logging the failures. One example is cross checking of convolution results
172  // among different algorithms.
173  bool xla_gpu_crash_on_verification_failures = 101;
174
175  // 0:   Disable gemm and convolution autotuning.
176  // 1:   Enable autotuning, but disable correctness checking.
177  // 2:   Also set output buffers to random numbers during autotuning.
178  // 3:   Also reset output buffers to random numbers after autotuning each
179  //      algorithm.
180  // 4+:  Also check for correct outputs and for out-of-bounds reads/writes.
181  //
182  // Default: 4.
183  int32 xla_gpu_autotune_level = 123;
184
185  // Force the host platform to pretend that there are these many host
186  // "devices".  All these devices are backed by the same threadpool.  Defaults
187  // to 1.
188  //
189  // Setting this to anything other than 1 can increase overhead from context
190  // switching but we let the user override this behavior to help run tests on
191  // the host that run models in parallel across multiple devices.
192  int32 xla_force_host_platform_device_count = 102;
193
194  // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3).
195  bool xla_gpu_disable_gpuasm_optimizations = 103;
196
197  // Enable fast math with eigen in the HLO evaluator.
198  bool xla_hlo_evaluator_use_fast_path = 106;
199
200  // Temporary option to allow support for both the R1 and the scalar index
201  // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing.
202  bool xla_allow_scalar_index_dynamic_ops = 107;
203
204  enum StepMarkerLocation {
205    // Generate a step marker at the program entry. This handles the case where
206    // each step is done by one or multiple program execution(s). Only the first
207    // program will be tagged for generating a step marker at the program entry.
208    // This is the default.
209    STEP_MARK_AT_ENTRY = 0;
210    // Generate a step marker at each iteration of the top level while loop,
211    // which is assumed to be a training loop.
212    STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1;
213    // Generate a step marker at each iteration of the second level while loops,
214    // which is assumed to be a training or eval loop.
215    STEP_MARK_AT_SECOND_LEVEL_WHILE_LOOP = 3;
216    // No step marker generated.
217    STEP_MARK_NONE = 2;
218  }
219  // Option to emit a target-specific marker to indicate the start of a training
220  // step. The location of the marker (if any) is determined by the option
221  // value.
222  StepMarkerLocation xla_step_marker_location = 108;
223
224  //
225  // BEGIN flags controlling dumping HLO modules for debugging.
226  //
227  // When dumping is enabled, HLO modules dumped at the very beginning and end
228  // of compilation, and optionally also during the pass pipeline.
229  //
230  // In general, if you set one of these flags, we will try to infer reasonable
231  // defaults for the others.  For example:
232  //
233  //  * Setting --xla_dump_to=/tmp/foo without specifying a format
234  //    with --xla_dump_hlo_as_* will turn on --xla_dump_hlo_as_text.
235  //
236  //  * Setting --xla_dump_hlo_as_text without specifying --xla_dump_to will
237  //    dump to stdout.
238  //
239
240  // Directory to dump into.
241  string xla_dump_to = 109;
242
243  // If specified, will only dump modules which match this regexp.
244  string xla_dump_hlo_module_re = 110;
245
246  // If this flag is specified, will also dump HLO before and after passes that
247  // match this regular expression.  Set to .* to dump before/after all passes.
248  string xla_dump_hlo_pass_re = 111;
249
250  // Specifies the format that HLO is dumped in.  Multiple of these may be
251  // specified.
252  bool xla_dump_hlo_as_text = 112;
253  bool xla_dump_hlo_as_proto = 113;
254  bool xla_dump_hlo_as_dot = 114;
255  bool xla_dump_hlo_as_url = 115;
256
257  // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML)
258  bool xla_dump_hlo_as_html = 116;
259
260  // Dump the visualization of the fusion progress.
261  bool xla_dump_fusion_visualization = 149;
262
263  // If true, every time an HLO module is run, we will dump an HloSnapshot
264  // (essentially, a serialized module plus its inputs) to the --xla_dump_to
265  // directory.
266  bool xla_dump_hlo_snapshots = 118;
267
268  // Include a timestamp in the dumped filenames.
269  bool xla_dump_include_timestamp = 131;
270
271  // Max number of hlo module dumps in a directory. Set to < 0 for unbounded.
272  int32 xla_dump_max_hlo_modules = 132;
273
274  // Dump HloModuleMetadata as a text proto for each HLO module.
275  bool xla_dump_module_metadata = 144;
276
277  // GZip-compress protos dumped via --xla_dump_hlo_as_proto.
278  bool xla_dump_compress_protos = 151;
279
280  //
281  // END flags controlling dumping HLO modules.
282  //
283
284  // Overrides for XLA GPU's convolution layout heuristic.
285  bool xla_gpu_force_conv_nchw = 125;
286  bool xla_gpu_force_conv_nhwc = 146;
287
288  // Paths to files with ptx code.
289  repeated string xla_gpu_ptx_file = 127;
290
291  // Whether to dump llvm ir when compiling to ptx.
292  bool xla_gpu_dump_llvmir = 155;
293
294  // Denylist for cuDNN convolutions.
295  string xla_gpu_algorithm_denylist_path = 128;
296
297  // Guarantee run-to-run determinism from reductions on XLA:GPU.
298  bool xla_gpu_deterministic_reductions = 130;
299
300  // Debug options that trigger execution errors when NaN or Inf are detected.
301  bool xla_tpu_detect_nan = 135;
302  bool xla_tpu_detect_inf = 136;
303
304  // True if TraceMe annotations are enabled for XLA:CPU.
305  bool xla_cpu_enable_xprof_traceme = 137;
306
307  // It is usually preferable to not fallback to the driver; it can consume more
308  // memory, or have bugs.
309  bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138;
310
311  // Extra parameters to pass the GPU assembler.
312  string xla_gpu_asm_extra_flags = 141;
313
314  // Per-heap size constraint. New heaps will be created if per-heap max size is
315  // reached.
316  int32 xla_multiheap_size_constraint_per_heap = 142;
317
318  // Enable detailed logging into vlog and xla dumping. If this is disabled, no
319  // compilation summary will be printed in the end of computation and no hlo
320  // modules will be dumped.
321  bool xla_detailed_logging_and_dumping = 143;
322
323  // Overrides normal multi-threaded compilation settting to use this many
324  // threads. Setting to 0 (the default value) means no enforcement.
325  int32 xla_gpu_force_compilation_parallelism = 147;
326
327  // Guarantees run-to-run determinism. At present, the HLO ops Scatter and
328  // SelectAndScatter do not have deterministic XLA:GPU implementations.
329  // Compilation errors out if these ops are encountered.
330  bool xla_gpu_deterministic_ops = 148;
331
332  // Paths to files with LLVM code.
333  repeated string xla_gpu_llvm_ir_file = 150;
334
335  // Convert synchronous all-reduces ops into asynchronous.
336  bool xla_gpu_enable_async_all_reduce = 152;
337
338  // Disable dumping metadata in HLO dumps.
339  bool xla_dump_disable_metadata = 153;
340
341  // If this flag is specified, will only dump HLO before and after passes in
342  // the pass pipeline that matches this regular expression. Default empty value
343  // enables dumping in all pipelines.
344  string xla_dump_hlo_pipeline_re = 154;
345
346  // Next id: 156
347
348  // Extra options to pass to the compilation backend (e.g. LLVM); specific
349  // interpretation of these values is left to the backend.
350  map<string, string> xla_backend_extra_options = 500;
351
352  reserved 5, 117, 133,
353      139;  // were xla_hlo_dump_as_graphdef, xla_dump_to,
354            // xla_gpu_use_horizontal_fusion, and
355            // xla_gpu_unsafe_fallback_to_driver_on_ptxas_error
356}
357
358// These settings control how XLA compiles and/or runs code.  Not all settings
359// will have an effect on every platform.
360//
361// When adding new fields, keep in mind that boolean fields default to false.
362message ExecutionOptions {
363  // This optional field's layout is used as a hint when storing the output of
364  // this computation.  Subsequent transfers of this output array to the client
365  // may be faster when using this layout.
366  //
367  // We use a Shape here to accommodate computations that return a tuple.
368  ShapeProto shape_with_output_layout = 2;
369
370  // Used to seed random-number generators used in this computation.  If this is
371  // 0, we generate a seed ourselves.
372  //
373  // TODO(b/32083678): Changing the seed unnecessarily forces a recompilation.
374  uint64 seed = 3;
375
376  DebugOptions debug_options = 4;
377
378  // This optional field specifies a particular set of devices to run the
379  // computation on. The computation will be partitioned across these devices.
380  // If not provided, the default device will be chosen.
381  repeated DeviceHandle device_handles = 5;
382
383  // Number of replicas of the computation to run. If zero, uses the default
384  // number of replicas for the XLA service.
385  int32 num_replicas = 6;
386
387  // This optional field specifies the device assignment if known at compile
388  // time.
389  DeviceAssignmentProto device_assignment = 7;
390
391  // Alias input and output buffers for parameters that are passed-through XLA
392  // modules without being changed.
393  bool alias_passthrough_params = 8;
394
395  // Number of partitions of the computation to run (model parallelism).
396  // If zero, uses the default number of partitions for the XLA service.
397  int32 num_partitions = 9;
398
399  // Used to identify a set of programs that should be launch together.
400  int32 launch_id = 10;
401
402  // Indicates whether to use SPMD (true) or MPMD (false) partitioning when
403  // num_partitions > 1 and XLA is requested to partition the input program.
404  bool use_spmd_partitioning = 11;
405
406  // If set, deduplicate hlo into function calls to reduce binary size. Only
407  // works on TPU.
408  bool deduplicate_hlo = 12;
409
410  reserved 13;  // Was broadcast_replicated_parameters_via_collectives
411}
412
413message GetDeviceHandlesRequest {
414  int64 device_count = 1;
415}
416
417message GetDeviceHandlesResponse {
418  repeated DeviceHandle device_handles = 1;
419}
420
421message TransferToClientRequest {
422  GlobalDataHandle data = 1;
423
424  // This optional field directs the service to return the literal in this
425  // layout. A shape is used to hold the layout to accommodate tuples.
426  ShapeProto shape_with_layout = 2;
427}
428
429message TransferToClientResponse {
430  LiteralProto literal = 1;
431}
432
433message TransferToServerRequest {
434  LiteralProto literal = 1;
435  DeviceHandle device_handle = 2;
436}
437
438message TransferToServerResponse {
439  GlobalDataHandle data = 1;
440}
441
442message TransferToInfeedRequest {
443  LiteralProto literal = 1;
444  int64 replica_id = 2;
445  DeviceHandle device_handle = 3;
446}
447
448message TransferToInfeedResponse {}
449
450message TransferFromOutfeedRequest {
451  // This optional field directs the service to return the literal in this
452  // layout. A shape is used to hold the layout to accommodate tuples.
453  ShapeProto shape_with_layout = 1;
454
455  int64 replica_id = 2;
456  DeviceHandle device_handle = 3;
457}
458
459message TransferFromOutfeedResponse {
460  LiteralProto literal = 1;
461}
462
463message ResetDeviceRequest {
464  DeviceHandle device_handle = 1;
465}
466
467message ResetDeviceResponse {}
468
469message ComputationGraphStatsRequest {
470  HloModuleProto computation = 1;
471  DebugOptions debug_options = 2;
472}
473
474message ComputationStatsResponse {
475  ComputationStats stats = 1;
476}
477
478message CreateChannelHandleRequest {
479  ChannelHandle.ChannelType channel_type = 1;
480}
481
482message CreateChannelHandleResponse {
483  ChannelHandle channel = 1;
484}
485
486message UnregisterRequest {
487  repeated GlobalDataHandle data = 1;
488}
489
490message UnregisterResponse {}
491
492message CompileRequest {
493  // The graph to be compiled.
494  HloModuleProto computation = 1;
495
496  // Options that affect how XLA compiles code to service this request.
497  ExecutionOptions execution_options = 2;
498
499  // The layouts of the input arguments. If not set, the default layout will be
500  // used. Although the real arguments are not needed in compilation, the
501  // layouts of the arguments can affect the compilation.
502  repeated ShapeProto input_shape_with_layout = 3;
503}
504
505message CompileResponse {
506  // The handle to the executable.
507  ExecutionHandle handle = 1;
508}
509
510message ExecuteRequest {
511  ExecutionHandle handle = 1;
512
513  // The shape and layout of the arguments must be the same as the those of the
514  // executable's parameters.
515  repeated GlobalDataHandle arguments = 2;
516}
517
518// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace
519// the uses with calls to Compile and Execute.
520message ExecuteGraphRequest {
521  HloModuleProto computation = 1;
522  repeated GlobalDataHandle arguments = 2;
523
524  // Options that affect how XLA compiles and runs code to service this request.
525  ExecutionOptions execution_options = 3;
526}
527
528message ExecuteGraphParallelRequest {
529  repeated ExecuteGraphRequest requests = 1;
530}
531
532message ExecuteResponse {
533  GlobalDataHandle output = 1;
534  ExecutionProfile profile = 2;
535}
536
537message ExecuteParallelResponse {
538  repeated ExecuteResponse responses = 1;
539}
540
541message WaitForExecutionRequest {
542  ExecutionHandle execution = 1;
543}
544
545message WaitForExecutionResponse {
546  GlobalDataHandle output = 1;
547  ExecutionProfile profile = 2;
548}
549
550message ComputeConstantGraphRequest {
551  HloModuleProto computation = 1;
552  LayoutProto output_layout = 2;
553}
554
555message ComputeConstantResponse {
556  // A LiteralProto is returned directly for this request.
557  LiteralProto literal = 1;
558}
559
560message DeconstructTupleRequest {
561  GlobalDataHandle tuple_handle = 2;
562}
563
564message DeconstructTupleResponse {
565  repeated GlobalDataHandle element_handles = 1;
566}
567
568message LoadDataRequest {
569  // Describes the path of the ColumnIO tablet to load.
570  string columnio_tablet_path = 1;
571
572  // Describes the field to load within the ColumnIO tablet.
573  string columnio_field = 2;
574
575  // Individual element shape, excluding rows.
576  ShapeProto element_shape = 3;
577
578  // Warning: ColumnIO does not support random-access, so use offset with
579  // caution in performance-critical scenarios.
580  int64 offset = 4;
581
582  // Maximum number of elements (with shape element_shape) to load.
583  int64 limit = 5;
584
585  // If more than one item is requested (via limit > 1), then this request
586  // attribute zips together the produced vectors.
587  bool zip = 6;
588}
589
590message LoadDataResponse {
591  GlobalDataHandle data = 1;
592  ShapeProto data_shape = 2;
593  int64 available_rows = 3;
594  int64 rows_loaded = 4;
595  int64 nanoseconds = 5;
596}
597
598message GetShapeRequest {
599  GlobalDataHandle data = 1;
600}
601
602message GetShapeResponse {
603  ShapeProto shape = 1;
604}
605
606message UnpackRequest {
607  GlobalDataHandle data = 1;
608}
609
610message UnpackResponse {
611  repeated GlobalDataHandle tied_data = 1;
612}
613