• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18import "tensorflow/compiler/xla/xla_data.proto";
19import "tensorflow/compiler/xla/service/session.proto";
20
21package xla;
22
23// Options for the HLO insert-reduce-precision-operations pass.
24message HloReducePrecisionOptions {
25  // Where and when the reduce-precision operations will be added.
26  enum Location {
27    // Add reduce-precision operations to the inputs of selected instructions.
28    // This is done before any optimization occurs.
29    OP_INPUTS = 0;
30    // Add reduce-precision operations to the outputs of selected instructions.
31    // This is done before any optimization occurs.
32    OP_OUTPUTS = 1;
33    // After operation-fusion occurs, add reduce-precision operations to the
34    // outputs of any selected instructions that have not been fused into
35    // fusion instructions.
36    UNFUSED_OP_OUTPUTS = 2;
37    // After operation-fusion occurs, add reduce-precision operations to the
38    // outputs of any fusion instructions that contain operations matching the
39    // selection criteria.
40    FUSION_INPUTS_BY_CONTENT = 3;
41    // After operation-fusion occurs, add reduce-precision operations to the
42    // outputs of any fusion instructions that contain operations matching the
43    // selection criteria.
44    FUSION_OUTPUTS_BY_CONTENT = 4;
45  }
46  Location location = 1;
47
48  // Exponent and mantissa bit counts for the reduced precision.
49  uint32 exponent_bits = 2;
50  uint32 mantissa_bits = 3;
51
52  // Operations matching these opcodes should be suffixed with reduce-precision
53  // operations.
54  repeated uint32 opcodes_to_suffix = 4;
55
56  // Operations with names containing these substrings should be suffixed with
57  // reduce-precision operations.
58  repeated string opname_substrings_to_suffix = 5;
59}
60
61// Debugging options for XLA. These options may change at any time - there are
62// no guarantees about backward or forward compatibility for these fields.
63message DebugOptions {
64  // HLO modules matching this regex will be dumped to a .dot file throughout
65  // various stages in compilation (file names are LOG(INFO)'d). Set to ".*" to
66  // dump *all* HLO modules.
67  string xla_generate_hlo_graph = 1;
68
69  // Show addresses of HLO ops in graph dump.
70  bool xla_hlo_graph_addresses = 2;
71
72  // Path to dump HLO graphs to.
73  string xla_hlo_graph_path = 4;
74
75  // Dump HLO graphs as TensorFlow GraphDefs.
76  bool xla_hlo_dump_as_graphdef = 5;
77
78  // HLO modules matching this regex will be dumped to LOG(INFO). Set to ".*" to
79  // dump *all* HLO modules.
80  string xla_log_hlo_text = 6;
81
82  // Dump all HLO modules as text into the provided directory path.
83  string xla_generate_hlo_text_to = 7;
84
85  // Dump Hlo after all hlo passes are executed as proto binary into this
86  // directory.
87  string xla_dump_optimized_hlo_proto_to = 8;
88
89  // Instrument the computation to collect per-HLO cycle counts.
90  bool xla_hlo_profile = 9;
91
92  // Dumps computations that XLA executes into the provided directory path.
93  string xla_dump_computations_to = 10;
94
95  // Dumps parameters and results of computations that XLA executes into the
96  // provided directory path.
97  string xla_dump_executions_to = 11;
98
99  // List of HLO passes to disable. These names must exactly match the pass
100  // names as specified by the HloPassInterface::name() method.
101  repeated string xla_disable_hlo_passes = 30;
102
103  // Numerical optimization level for the XLA compiler backend; the specific
104  // interpretation of this value is left to the backends.
105  int32 xla_backend_optimization_level = 31;
106
107  // When true, "unsafe" mathematical optimizations are enabled. These
108  // transformations include but are not limited to:
109  //
110  //  - Reducing the precision of operations (e.g. using an approximate sin
111  //    function, or transforming x/y into x * (1/y)).
112  //  - Assuming that operations never produce or consume NaN or +/- Inf.
113  //  - Assuming that +0 and -0 are indistinguishable.
114  bool xla_enable_fast_math = 32;
115
116  // Embed the compiler IR as a string in the executable.
117  bool xla_embed_ir_in_executable = 33;
118
119  // Dump the compiler IR into this directory as individual files.
120  string xla_dump_ir_to = 34;
121
122  // Eliminate implicit broadcasts when lowering user computations to HLO
123  // instructions; use explicit broadcast instead.
124  bool xla_eliminate_hlo_implicit_broadcast = 35;
125
126  // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen
127  // mode.
128  bool xla_cpu_multi_thread_eigen = 60;
129
130  // Path to directory with cuda/ptx tools and libraries.
131  string xla_gpu_cuda_data_dir = 61;
132
133  // Enable flush-to-zero semantics in the GPU backend.
134  bool xla_gpu_ftz = 62;
135
136  // Disable multi-streaming in the GPU backend.
137  bool xla_gpu_disable_multi_streaming = 63;
138
139  // If true, in LLVM-based backends, emit !alias.scope metadata in
140  // generated IR.
141  bool xla_llvm_enable_alias_scope_metadata = 70;
142
143  // If true, in LLVM-based backends, emit !noalias metadata in the
144  // generated IR.
145  bool xla_llvm_enable_noalias_metadata = 71;
146
147  // If true, in LLVM-based backends, emit !invariant.load metadata in
148  // the generated IR.
149  bool xla_llvm_enable_invariant_load_metadata = 72;
150
151  // If true, a set of expensive LLVM optimization passes will not be run.
152  bool xla_llvm_disable_expensive_passes = 73;
153
154  // Options for inserting reduce-precision operations for numerical
155  // experimentation.  This is a repeated field, as we may want to have
156  // multiple passes with different parameters.
157  repeated HloReducePrecisionOptions hlo_reduce_precision_options = 80;
158
159  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
160  // computation will run n! times with all permunations of layouts for the
161  // output shape in rank n. For example, with a 3D shape, all permutations of
162  // the set {0, 1, 2} are tried.
163  bool xla_test_all_output_layouts = 90;
164
165  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
166  // computation will run for all permunations of layouts of all input
167  // arguments. For example, with 2 input arguments in 2D and 4D shapes, the
168  // computation will run 2! * 4! times.
169  bool xla_test_all_input_layouts = 91;
170
171  // Assign colors based on sharding information when generating the Graphviz
172  // HLO graph.
173  bool xla_hlo_graph_sharding_color = 92;
174
175  // Prefix the name scopes of the TF graph exports with "devX" device
176  // assignments, if available.
177  bool xla_hlo_tfgraph_device_scopes = 93;
178
179  // If true, the GPU backend is free to use cudnn for HLO batch normalization
180  // ops.
181  bool xla_gpu_use_cudnn_batchnorm = 94;
182
183  // Dump HLO before any hlo passes are executed as proto binary into this
184  // directory.
185  string xla_dump_unoptimized_hlo_proto_to = 95;
186
187  // Dump HLO after each pass as an HloProto in binary file format into this
188  // directory.
189  string xla_dump_per_pass_hlo_proto_to = 96;
190
191  // Extra options to pass to the compilation backend; specific interpretation
192  // of these values is left to the backend.
193  map<string, string> xla_backend_extra_options = 500;
194}
195
196// These settings control how XLA compiles and/or runs code.  Not all settings
197// will have an effect on every platform.
198//
199// When adding new fields, keep in mind that boolean fields default to false.
200message ExecutionOptions {
201  // This optional field's layout is used as a hint when storing the output of
202  // this computation.  Subsequent transfers of this output array to the client
203  // may be faster when using this layout.
204  //
205  // We use a Shape here to accommodate computations that return a tuple.
206  Shape shape_with_output_layout = 2;
207
208  // Used to seed random-number generators used in this computation.  If this is
209  // 0, we generate a seed ourselves.
210  //
211  // TODO(b/32083678): Changing the seed unnecessarily forces a recompilation.
212  uint64 seed = 3;
213
214  DebugOptions debug_options = 4;
215
216  // This optional field specifies a particular set of devices to run the
217  // computation on. The computation will be partitioned across these devices.
218  // If not provided, the default device will be chosen.
219  repeated DeviceHandle device_handles = 5;
220}
221
222message SnapshotComputationRequest {
223  ComputationHandle computation = 1;
224}
225
226message SnapshotComputationResponse {
227  SessionModule module = 1;
228}
229
230message LoadComputationSnapshotRequest {
231  SessionModule module = 1;
232}
233
234message LoadComputationSnapshotResponse {
235  ComputationHandle computation = 1;
236}
237
238message GetDeviceHandlesRequest {
239  int64 device_count = 1;
240}
241
242message GetDeviceHandlesResponse {
243  repeated DeviceHandle device_handles = 1;
244}
245
246message TransferToClientRequest {
247  GlobalDataHandle data = 1;
248
249  // This optional field directs the service to return the literal in this
250  // layout. A shape is used to hold the layout to accommodate tuples.
251  Shape shape_with_layout = 2;
252}
253
254message TransferToClientResponse {
255  LiteralProto literal = 1;
256}
257
258message TransferToServerRequest {
259  LiteralProto literal = 1;
260  DeviceHandle device_handle = 2;
261}
262
263message TransferToServerResponse {
264  GlobalDataHandle data = 1;
265}
266
267message TransferToInfeedRequest {
268  LiteralProto literal = 1;
269  int64 replica_id = 2;
270  DeviceHandle device_handle = 3;
271}
272
273message TransferToInfeedResponse {
274}
275
276message TransferFromOutfeedRequest {
277  // This optional field directs the service to return the literal in this
278  // layout. A shape is used to hold the layout to accommodate tuples.
279  Shape shape_with_layout = 1;
280
281  int64 replica_id = 2;
282  DeviceHandle device_handle = 3;
283}
284
285message TransferFromOutfeedResponse {
286  LiteralProto literal = 1;
287}
288
289message ResetDeviceRequest {
290  DeviceHandle device_handle = 1;
291}
292
293message ResetDeviceResponse {
294}
295
296message ComputationStatsRequest {
297  ComputationHandle computation = 1;
298  DebugOptions debug_options = 2;
299}
300
301message ComputationStatsResponse {
302  ComputationStats stats = 1;
303}
304
305message ComputationRequest {
306  string name = 1;
307}
308
309message ComputationResponse {
310  ComputationHandle computation = 1;
311}
312
313message CreateChannelHandleRequest {
314}
315
316message CreateChannelHandleResponse {
317  ChannelHandle channel = 1;
318}
319
320message UnregisterRequest {
321  GlobalDataHandle data = 1;
322}
323
324message UnregisterResponse {
325}
326
327message SetReturnValueRequest {
328  ComputationHandle computation = 1;
329  ComputationDataHandle operand = 2;
330}
331
332message SetReturnValueResponse {
333}
334
335message ExecuteRequest {
336  reserved 3, 4;
337
338  ComputationHandle computation = 1;
339  repeated GlobalDataHandle arguments = 2;
340
341  // Options that affect how XLA compiles and runs code to service this request.
342  ExecutionOptions execution_options = 5;
343}
344
345message ExecuteParallelRequest {
346  repeated ExecuteRequest requests = 1;
347}
348
349message ExecuteResponse {
350  GlobalDataHandle output = 1;
351  ExecutionProfile profile = 2;
352}
353
354message ExecuteParallelResponse {
355  repeated ExecuteResponse responses = 1;
356}
357
358message ExecuteAsyncRequest {
359  reserved 3, 4;
360
361  ComputationHandle computation = 1;
362  repeated GlobalDataHandle arguments = 2;
363
364  // Options that affect how XLA compiles and runs code to service this request.
365  ExecutionOptions execution_options = 6;
366}
367
368message ExecuteAsyncResponse {
369  // A handle to the execution launched asynchronously.
370  ExecutionHandle execution = 1;
371}
372
373message WaitForExecutionRequest {
374  ExecutionHandle execution = 1;
375}
376
377message WaitForExecutionResponse {
378  GlobalDataHandle output = 1;
379  ExecutionProfile profile = 2;
380}
381
382message IsConstantRequest {
383  ComputationHandle computation = 1;
384  ComputationDataHandle operand = 2;
385  int64 num_parameters = 3;
386}
387
388message IsConstantResponse {
389  bool is_constant = 1;
390}
391
392message ComputeConstantRequest {
393  ComputationHandle computation = 1;
394  ComputationDataHandle operand = 2;
395  Layout output_layout = 3;
396  repeated LiteralProto parameters = 4;
397}
398
399message ComputeConstantResponse {
400  // A LiteralProto is returned directly for this request, instead of a
401  // ComputationDataHandle.
402  LiteralProto literal = 1;
403}
404
405message DeconstructTupleRequest {
406  GlobalDataHandle tuple_handle = 2;
407}
408
409message DeconstructTupleResponse {
410  repeated GlobalDataHandle element_handles = 1;
411}
412
413message LoadDataRequest {
414  // Describes the path of the ColumnIO tablet to load.
415  string columnio_tablet_path = 1;
416
417  // Describes the field to load within the ColumnIO tablet.
418  string columnio_field = 2;
419
420  // Individual element shape, excluding rows.
421  Shape element_shape = 3;
422
423  // Warning: ColumnIO does not support random-access, so use offset with
424  // caution in performance-critical scenarios.
425  int64 offset = 4;
426
427  // Maximum number of elements (with shape element_shape) to load.
428  int64 limit = 5;
429
430  // If more than one item is requested (via limit > 1), then this request
431  // attribute zips together the produced vectors.
432  bool zip = 6;
433}
434
435message LoadDataResponse {
436  GlobalDataHandle data = 1;
437  Shape data_shape = 2;
438  int64 available_rows = 3;
439  int64 rows_loaded = 4;
440  int64 nanoseconds = 5;
441}
442
443message SpecializeRequest {
444  ComputationHandle computation = 1;
445  repeated GlobalDataHandle arguments = 2;
446}
447
448message SpecializeResponse {
449}
450
451message GetShapeRequest {
452  GlobalDataHandle data = 1;
453}
454
455message GetShapeResponse {
456  Shape shape = 1;
457}
458
459message GetComputationShapeRequest {
460  ComputationHandle computation = 1;
461}
462
463message GetComputationShapeResponse {
464  ProgramShape program_shape = 1;
465}
466
467message UnpackRequest {
468  GlobalDataHandle data = 1;
469}
470
471message UnpackResponse {
472  repeated GlobalDataHandle tied_data = 1;
473}
474