• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
18 
19 #include <stack>
20 
21 #include "absl/types/span.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
24 #include "tensorflow/compiler/tf2xla/xla_argument.h"
25 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
26 #include "tensorflow/compiler/tf2xla/xla_expression.h"
27 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/compiler/xla/client/local_client.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/device_mgr.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/platform/env.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/notification.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/public/version.h"
43 
44 namespace tensorflow {
45 
46 class XlaContext;
47 
48 // The XlaCompiler class is responsible for compilation of a self-contained
49 // subgraph of a TensorFlow computation using the XLA linear algebra runtime.
50 // It does a symbolic execution of the graph starting from specific input
51 // shapes, using a JIT device to convert operators into XLA computations.
52 //
53 // XlaCompiler is typically invoked from an `XlaLaunch` operator once the
54 // shapes of all input parameters to the computation are known. This is
55 // because the symbolic execution requires known shapes for all operations.
56 //
57 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes,
58 // and return outputs via _Retval nodes.
59 //
60 // The XlaCompiler requires one Argument struct for each _Arg index, that
61 // describes each argument. Arguments can be compile-time constants
62 // (kind kConstant), run-time parameters (kind kParameter), or resources
63 // (kind kResource).
64 //
65 // Only kParameter and initialized kResource arguments become runtime parameters
66 // to the generated XLA computation.
67 //
68 // The run-time outputs of the XLA computation are arranged in the following
69 // order:
70 //   +------------------+-----------------------------------------+
71 //   |  _Retval values  |  Updated values of kResource arguments  |
72 //   +------------------+-----------------------------------------+
73 // _Retval values are ordered by _Retval index, whereas kResource values are
74 // ordered by the original _Arg position of the variable.
75 //
76 // If a shape representation function is provided as part of
77 // XlaCompiler::CompileOptions, kParameter arguments and return values to an
78 // entry computation will be reshaped in accordance to the shape function.
79 // Arguments and return values to a non-entry computation are not reshaped.
80 // Variable resource arguments are passed and returned in reshaped form, even
81 // for non-entry computations. This feature allows TensorFlow to keep on-device
82 // tensors with a different shape to their representation inside the XLA
83 // computation.
84 //
85 // In computation outputs, updated kResource values are placed the end. When
86 // emitting While loop bodies, we must ensure that the loop body has
87 // identical input and output signatures. By passing variable values
88 // at the end of the argument list and using the
89 // `return_updated_values_for_all_variables` option, we can ensure that the
90 // input and output values of resources appear at the same positions.
91 //
92 // Resources are passed as parameters or returned as resource updates in
93 // "packed" form.
94 // kStack resources are packed as (array, size of stack) XLA tuples.
95 // kTensorArray resources without gradients are packed as the array that
96 // backs the TensorArray. If gradients are present (`tensor_array_gradients`),
97 // the packed representation is a (array, gradient0, gradient1, ...) tuple,
98 // where gradient_k is the value of the k-th gradient in the
99 // `tensor_array_gradients` ordered set.
100 class XlaCompiler {
101  public:
102   using Argument = ::tensorflow::XlaArgument;
103 
104   // Options pertaining to an individual call to CompileGraph() or
105   // CompileFunction().
106   struct CompileOptions {
107     // If `use_tuple_arg` is true, a single tuple parameter will be used for all
108     // arguments; if false, each argument gets its own parameter.
109     bool use_tuple_arg = false;
110 
111     // If 'return_updated_values_for_all_resources' is true, then updated
112     // values of all resource arguments will be included in the
113     // 'resource_updates' of the computation, even if the resource was not
114     // modified by the computation. Used when compiling loop bodies to ensure
115     // the input and output signatures match.
116     bool return_updated_values_for_all_resources = false;
117 
118     // If 'always_return_tuple' is true, then the output of a computation will
119     // always be a tuple. Otherwise, a single-element output will not be wrapped
120     // in a tuple.
121     bool always_return_tuple = true;
122 
123     // True when compiling the entry computation, false for subcomputations
124     // (while, call, etc.)
125     bool is_entry_computation = true;
126 
127     // True when we should add XLA input & output to the graph/function.
128     bool add_token_input_output = false;
129 
130     // Resource updates are converted into input / output of xla. The two
131     // buffers are aliased with other if this option is true.
132     bool alias_resource_update = false;
133   };
134 
135   using OutputDescription = ::tensorflow::XlaOutputDescription;
136 
137   using ResourceUpdate = ::tensorflow::XlaResourceUpdate;
138 
139   using CompilationResult = ::tensorflow::XlaCompilationResult;
140 
141   typedef std::function<StatusOr<xla::Shape>(const TensorShape&, DataType,
142                                              bool)>
143       ShapeRepresentationFn;
144   struct Options {
145     // Name of the compilation device to use. It must be set by the caller.
146     // The default empty value is invalid.
147     DeviceType device_type = DeviceType("");
148 
149     // The device to use during compilation to execute instructions on, for
150     // example for auto-tuning.
151     // Valid values are defined by `xla::Backend::devices_ordinal_supported()`.
152     // -1 indicates the default device should be used.
153     int device_ordinal = -1;
154 
155     xla::Client* client = nullptr;
156 
157     // Function library in which to find function definitions. Must be non-null.
158     const FunctionLibraryDefinition* flib_def = nullptr;
159 
160     // The graph def version to be compiled.
161     int graph_def_version = TF_GRAPH_DEF_VERSION;
162 
163     // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall()
164     // for CPU.
165     bool allow_cpu_custom_calls = false;
166 
167     // If both this and 'allow_cpu_custom_calls' are true then tf.fake_quant_*
168     // ops will be emitted as custom calls to a 'fake_quant_with_min_max_vars'
169     // function accepting the input, min, max, num_bits, and narrow_range values
170     // as runtime arguments.
171     bool custom_fake_quant_op_calls = false;
172 
173     // If set, the XLA representation of variables represented to XLA as the
174     // shape given by this shape function. Variables are reshaped to this shape
175     // on write, and reshaped to their original shape on read.
176     ShapeRepresentationFn shape_representation_fn;
177 
178     // If not nullptr, populate_resource_manager is called with the
179     // compilation device's resource manager when the compilation
180     // device is created, and can be used to create metadata objects
181     // that can be accessed by XLA op kernels.
182     std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
183 
184     // If not nullptr, this memory allocator can be used by the compiler for
185     // temporary allocations it might want to make during compilation.
186     //
187     // For example, the compiler may want to try out different algorithms and
188     // choose the fastest one, and it might run those algorithms over buffers
189     // created using this allocator.
190     //
191     // The compiler can function correctly without an explicit allocator given
192     // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
193     // allocate most or all available memory on the device, leaving none for the
194     // compiler to access, unless it can use TensorFlow's allocator.
195     // This must be a shared_ptr, as this is passed all the way down to the
196     // cluster compilation. This allows asynchronous compilation to hold a
197     // reference until the compilation is finished.
198     std::shared_ptr<se::DeviceMemoryAllocator> device_allocator;
199 
200     // Alias input and output buffers for parameters that are passed-through XLA
201     // modules without being changed.
202     bool alias_passthrough_params = false;
203 
204     // Enable detailed logging of compilation metadata.
205     bool detailed_logging = true;
206   };
207 
208   explicit XlaCompiler(Options options);
209 
210   ~XlaCompiler();
211 
212   // Helper function to populate an XlaCompiler::Argument from XlaResource.
213   static void PopulateArgumentFromResource(const XlaResource& resource,
214                                            Argument* arg);
215 
216   Status CompileFunction(const CompileOptions& options,
217                          const NameAttrList& fn_name_attrs,
218                          absl::Span<const Argument> args,
219                          CompilationResult* result);
220 
221   // Compiles a tensorflow::Graph into an xla::XlaComputation.
222   // Similar to CompileFunction, but takes a Graph as input rather than a
223   // function.
224   Status CompileGraph(
225       const CompileOptions& options, string const& name,
226       std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
227       CompilationResult* result);
228 
229   // Returns the shape of the XLA parameter for an argument 'arg'.
230   // See the class comment for more details about the argument passing
231   // convention.
232   Status XLAShapeForArgument(
233       const Argument& arg, bool is_entry_computation,
234       const absl::optional<xla::HloSharding>& arg_sharding,
235       xla::Shape* xla_shape) const;
236 
237   // Retrieves the channel handle associated with `key`. Allocates
238   // a new channel handle if none exists.
239   // Channel handles can be used to communicate between different
240   // computations. Computations that communicate should be compiled with the
241   // same XlaCompiler.
242   Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
243 
244   // Retrieves the host-to-device channel handle associated with `key`.
245   // Allocates a new channel handle if none exists.
246   Status GetHostToDeviceChannelHandle(const string& key,
247                                       xla::ChannelHandle* channel);
248 
249   // Retrieves the device-to-host channel handle associated with `key`.
250   // Allocates a new channel handle if none exists.
251   Status GetDeviceToHostChannelHandle(const string& key,
252                                       xla::ChannelHandle* channel);
253 
254   // Sets the shapes and types for the device to host transfer associated with
255   // 'key'.
256   Status SetDeviceToHostMetadata(const string& key,
257                                  absl::Span<const DataType> types,
258                                  absl::Span<const TensorShape> shapes);
259 
260   // Gets the shapes the device to host transfer associated with 'key'.
261   Status GetDeviceToHostShapes(const string& key,
262                                std::vector<TensorShape>* shapes) const;
263 
264   // Sets the shapes and types for the host to device transfer associated with
265   // 'key'.
266   Status SetHostToDeviceMetadata(const string& key,
267                                  absl::Span<const DataType> types,
268                                  absl::Span<const TensorShape> shapes);
269 
270   // In order to avoid deadlocks from dependencies in host computations, it can
271   // be necessary to enforce a partial order on the execution of HostCompute
272   // Ops. In particular it may be necessary to constrain the SendToHost for one
273   // HostCompute to run before blocking on the RecvAtHost for another
274   // HostCompute. The compiler maintains a mapping from 'host_compute_name' to
275   // handle, where the handle is an 'output' of the HostCompute Op corresponding
276   // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced
277   // later can add the handle as an 'input' to enforce the constraints.
278   // 'host_compute_name' can be any string the client wishes to use to identify
279   // a given HostCompute Op as long as the names are unique within the
280   // compilation.
281   Status GetHostComputeControlDependency(const string& host_compute_name,
282                                          xla::XlaOp* handle);
283   Status SetHostComputeControlDependency(const string& host_compute_name,
284                                          const xla::XlaOp& handle);
285 
options()286   const Options& options() const { return options_; }
client()287   xla::Client* client() const { return options_.client; }
flib_runtime()288   FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
289 
290   void PushNodeTokenMapping();
291   Status PopNodeTokenMapping();
292   Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
293   StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
294 
295   // Sets the function body `fbody` to the one registered as `function`.
296   Status FindFunctionBody(const NameAttrList& function,
297                           const FunctionBody** fbody,
298                           const ConfigProto** config_proto = nullptr);
299 
300  private:
301   // Returns the optimized graph object in this function body.
302   std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
303 
304   // Builds XLA computations for each of the arguments to the computation.
305   // `args` are the arguments to the computation.
306   Status BuildArguments(const Graph& graph,
307                         const std::vector<XlaCompiler::Argument>& args,
308                         bool use_tuple_arg, xla::XlaBuilder* builder,
309                         XlaContext* context,
310                         const std::map<int, xla::OpSharding>& arg_shardings,
311                         std::vector<XlaExpression>* arg_expressions,
312                         std::vector<int>* input_to_args,
313                         std::vector<xla::Shape>* input_shapes,
314                         bool is_entry_computation);
315 
316   // Graph compiler needs to know how to get an optimized graph from a function
317   // body.
318   friend class GraphCompiler;
319   friend class XlaCompilerTest;
320 
321   Options options_;
322 
323   // Status set to non-OK in the constructor if initialization fails.
324   Status initialization_status_;
325 
326   // Returns the next step sequence number.
327   int64 NextStepId();
328 
329   // Internal sequence number for steps executed on the compilation device.
330   int64 next_step_id_;
331 
332   XlaCompilationDevice* device_;  // Owned by device_mgr_
333   StaticDeviceMgr device_mgr_;
334 
335   // To avoid copying the client's function library, use a local function
336   // library and runtime for functions created as part of the functionalize
337   // control flow transformation.
338   std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
339   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
340   std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
341 
342   FunctionLibraryRuntime* local_flib_runtime_;  // owned by local_pflr_.
343   FunctionLibraryRuntime* flib_runtime_;        // owned by pflr_.
344 
345   struct SignatureHash {
346     uint64 operator()(
347         const std::pair<string, std::vector<Argument>>& signature) const;
348   };
349 
350   std::unordered_map<std::pair<string, std::vector<Argument>>,
351                      CompilationResult, SignatureHash>
352       cache_;
353 
354   std::unordered_map<string, xla::ChannelHandle> channels_;
355 
356   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
357   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
358 
359   std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
360 
361   // This is used to store <node name, token output> mapping. Side-effecting
362   // ops call SetNodeToken() to record its token output, so later side-effecting
363   // ops can use GetNodeToken() to get it and use it as token input.
364   //
365   // It's a stack because we need a mapping like this for each level of nested
366   // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
367   // stack, and pop the mapping before returning.
368   std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
369 
370   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
371 };
372 
373 
374 }  // namespace tensorflow
375 
376 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
377